Skip to content

Commit

Permalink
* Save ssh_service.public_addr values to Server.PublicAddrs instead o…
Browse files Browse the repository at this point in the history
…f discarding them (#25223)

* Use Server.PublicAddrs when checking if session MFA is required

* Deprecate server PublicAddr in favor of PublicAddrs
  • Loading branch information
Joerger committed Apr 26, 2023
1 parent b10c8da commit 5ad772a
Show file tree
Hide file tree
Showing 20 changed files with 1,473 additions and 1,269 deletions.
7 changes: 5 additions & 2 deletions api/proto/teleport/legacy/types/types.proto
Expand Up @@ -700,9 +700,10 @@ message ServerV2List {
message ServerSpecV2 {
reserved 8;

// Addr is server host:port address
// Addr is a host:port address where this server can be reached.
string Addr = 1 [(gogoproto.jsontag) = "addr"];
// PublicAddr is the public address this cluster can be reached at.
// PublicAddr is the public address where this server can be reached.
// DELETE IN 15.0. (joerger) Deprecated in favor of public_addrs.
string PublicAddr = 2 [(gogoproto.jsontag) = "public_addr,omitempty"];
// Hostname is server hostname
string Hostname = 3 [(gogoproto.jsontag) = "hostname"];
Expand Down Expand Up @@ -740,6 +741,8 @@ message ServerSpecV2 {
string PeerAddr = 11 [(gogoproto.jsontag) = "peer_addr,omitempty"];
// ProxyIDs is a list of proxy IDs this server is expected to be connected to.
repeated string ProxyIDs = 12 [(gogoproto.jsontag) = "proxy_ids,omitempty"];
// PublicAddrs is a list of public addresses where this server can be reached.
repeated string public_addrs = 13;
}

// AppServerV3 represents a single proxied web app.
Expand Down
36 changes: 28 additions & 8 deletions api/types/server.go
Expand Up @@ -48,8 +48,10 @@ type Server interface {
GetCmdLabels() map[string]CommandLabel
// SetCmdLabels sets command labels.
SetCmdLabels(cmdLabels map[string]CommandLabel)
// GetPublicAddr is an optional field that returns the public address this cluster can be reached at.
// GetPublicAddr returns a public address where this server can be reached.
GetPublicAddr() string
// GetPublicAddrs returns a list of public addresses where this server can be reached.
GetPublicAddrs() []string
// GetRotation gets the state of certificate authority rotation.
GetRotation() Rotation
// SetRotation sets the state of certificate authority rotation.
Expand All @@ -62,8 +64,8 @@ type Server interface {
String() string
// SetAddr sets server address
SetAddr(addr string)
// SetPublicAddr sets the public address this cluster can be reached at.
SetPublicAddr(string)
// SetPublicAddrs sets the public addresses where this server can be reached.
SetPublicAddrs([]string)
// SetNamespace sets server namespace
SetNamespace(namespace string)
// GetApps gets the list of applications this server is proxying.
Expand Down Expand Up @@ -179,9 +181,13 @@ func (s *ServerV2) Expiry() time.Time {
return s.Metadata.Expiry()
}

// SetPublicAddr sets the public address this cluster can be reached at.
func (s *ServerV2) SetPublicAddr(addr string) {
s.Spec.PublicAddr = addr
// SetPublicAddrs sets the public proxy addresses where this server can be reached.
func (s *ServerV2) SetPublicAddrs(addrs []string) {
s.Spec.PublicAddrs = addrs
// DELETE IN 15.0. (Joerger) PublicAddr deprecated in favor of PublicAddrs
if len(addrs) != 0 {
s.Spec.PublicAddr = addrs[0]
}
}

// GetName returns server name
Expand All @@ -199,9 +205,22 @@ func (s *ServerV2) GetAddr() string {
return s.Spec.Addr
}

// GetPublicAddr is an optional field that returns the public address this cluster can be reached at.
// GetPublicAddr returns a public address where this server can be reached.
func (s *ServerV2) GetPublicAddr() string {
return s.Spec.PublicAddr
addrs := s.GetPublicAddrs()
if len(addrs) != 0 {
return addrs[0]
}
return ""
}

// GetPublicAddrs returns a list of public addresses where this server can be reached.
func (s *ServerV2) GetPublicAddrs() []string {
// DELETE IN 15.0. (Joerger) PublicAddr deprecated in favor of PublicAddrs
if len(s.Spec.PublicAddrs) == 0 && s.Spec.PublicAddr != "" {
return []string{s.Spec.PublicAddr}
}
return s.Spec.PublicAddrs
}

// GetRotation gets the state of certificate authority rotation.
Expand Down Expand Up @@ -434,6 +453,7 @@ func (s *ServerV2) MatchSearch(values []string) bool {

if s.GetKind() == KindNode {
fieldVals = append(utils.MapToStrings(s.GetAllLabels()), s.GetName(), s.GetHostname(), s.GetAddr())
fieldVals = append(fieldVals, s.GetPublicAddrs()...)

if s.GetUseTunnel() {
custom = func(val string) bool {
Expand Down
2,487 changes: 1,269 additions & 1,218 deletions api/types/types.pb.go

Large diffs are not rendered by default.

10 changes: 3 additions & 7 deletions lib/auth/auth.go
Expand Up @@ -34,7 +34,6 @@ import (
"math"
"math/big"
insecurerand "math/rand"
"net"
"os"
"sort"
"strings"
Expand Down Expand Up @@ -4358,13 +4357,10 @@ func (a *Server) isMFARequired(ctx context.Context, checker services.AccessCheck
if !ok {
continue
}
// Get the server address without port number.
addr, _, err := net.SplitHostPort(srv.GetAddr())
if err != nil {
addr = srv.GetAddr()
}

// Filter out any matches on labels before checking access
if n.GetName() != t.Node.Node && srv.GetHostname() != t.Node.Node && addr != t.Node.Node {
fieldVals := append(srv.GetPublicAddrs(), srv.GetName(), srv.GetHostname(), srv.GetAddr())
if !types.MatchSearch(fieldVals, []string{t.Node.Node}, nil) {
continue
}

Expand Down
3 changes: 1 addition & 2 deletions lib/auth/auth_with_roles_test.go
Expand Up @@ -3446,8 +3446,7 @@ func TestListResources_WithRoles(t *testing.T) {
Labels: labels,
},
Spec: types.ServerSpecV2{
Addr: addr,
PublicAddr: addr,
Addr: addr,
},
}

Expand Down
99 changes: 98 additions & 1 deletion lib/auth/grpcserver_test.go
Expand Up @@ -57,6 +57,7 @@ import (
"github.com/gravitational/teleport/api/observability/tracing"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/api/types/installers"
"github.com/gravitational/teleport/api/utils"
apiutils "github.com/gravitational/teleport/api/utils"
"github.com/gravitational/teleport/api/utils/sshutils"
"github.com/gravitational/teleport/lib/auth/mocku2f"
Expand Down Expand Up @@ -1701,7 +1702,7 @@ func TestIsMFARequired(t *testing.T) {
Kind: types.KindKubeService,
Version: types.V2,
Metadata: types.Metadata{
Name: "node-a",
Name: uuid.NewString(),
},
Spec: types.ServerSpecV2{
Hostname: "node-a",
Expand Down Expand Up @@ -1961,6 +1962,102 @@ func TestRoleVersions(t *testing.T) {
}
}

func TestIsMFARequired_NodeMatch(t *testing.T) {
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})

ctx := context.Background()
srv := newTestTLSServer(t)

// Register an SSH node.
node, err := types.NewServerWithLabels(uuid.NewString(), types.KindNode, types.ServerSpecV2{
Hostname: "node-a",
Addr: "127.0.0.1:3022",
PublicAddrs: []string{"node.example.com:3022", "localhost:3022"},
}, map[string]string{"foo": "bar"})
require.NoError(t, err)
_, err = srv.Auth().UpsertNode(ctx, node)
require.NoError(t, err)

// Create a fake user with per session mfa required for all nodes.
role, err := CreateRole(ctx, srv.Auth(), "mfa-user", types.RoleSpecV6{
Options: types.RoleOptions{
RequireMFAType: types.RequireMFAType_SESSION,
},
Allow: types.RoleConditions{
Logins: []string{"mfa-user"},
NodeLabels: types.Labels{types.Wildcard: utils.Strings{types.Wildcard}},
},
})
require.NoError(t, err)

user, err := CreateUser(srv.Auth(), "mfa-user", role)
require.NoError(t, err)

cl, err := srv.NewClient(TestUser(user.GetName()))
require.NoError(t, err)

for _, tc := range []struct {
desc string
// IsMFARequired only expects a host name or ip without the port.
node string
expectMatch require.BoolAssertionFunc
}{
{
desc: "OK uuid match",
node: node.GetName(),
expectMatch: require.True,
},
{
desc: "OK host name match",
node: node.GetHostname(),
expectMatch: require.True,
},
{
desc: "OK addr match",
node: node.GetAddr(),
expectMatch: require.True,
},
{
desc: "OK public addr 1 match",
node: "node.example.com",
expectMatch: require.True,
},
{
desc: "OK public addr 2 match",
node: "localhost",
expectMatch: require.True,
},
{
desc: "NOK label match",
node: "foo",
expectMatch: require.False,
},
{
desc: "NOK unknown ip",
node: "1.2.3.4",
expectMatch: require.False,
},
{
desc: "NOK unknown addr",
node: "unknown.example.com",
expectMatch: require.False,
},
} {
tc := tc
t.Run(tc.desc, func(t *testing.T) {
t.Parallel()
resp, err := cl.IsMFARequired(ctx, &proto.IsMFARequiredRequest{
Target: &proto.IsMFARequiredRequest_Node{Node: &proto.NodeLogin{
Login: user.GetName(),
Node: tc.node,
}},
})
require.NoError(t, err)
tc.expectMatch(t, resp.Required)
})
}
}

// testOriginDynamicStored tests setting a ResourceWithOrigin via the server
// API always results in the resource being stored with OriginDynamic.
func testOriginDynamicStored(t *testing.T, setWithOrigin func(*Client, string) error, getStored func(*Server) (types.ResourceWithOrigin, error)) {
Expand Down
4 changes: 2 additions & 2 deletions lib/auth/usertoken_test.go
Expand Up @@ -172,8 +172,8 @@ func TestFormatAccountName(t *testing.T) {
proxies: []types.Server{
&types.ServerV2{
Spec: types.ServerSpecV2{
PublicAddr: "foo",
Version: "bar",
PublicAddrs: []string{"foo"},
Version: "bar",
},
},
},
Expand Down
2 changes: 2 additions & 0 deletions lib/service/service.go
Expand Up @@ -2391,6 +2391,7 @@ func (process *TeleportProcess) initSSH() error {
regular.SetTracerProvider(process.TracingProvider),
regular.SetSessionController(sessionController),
regular.SetCAGetter(caGetter),
regular.SetPublicAddrs(cfg.SSH.PublicAddrs),
)
if err != nil {
return trace.Wrap(err)
Expand Down Expand Up @@ -3832,6 +3833,7 @@ func (process *TeleportProcess) initProxyEndpoint(conn *Connector) error {
regular.SetSessionController(sessionController),
regular.SetIngressReporter(ingress.SSH, ingressReporter),
regular.SetPROXYSigner(proxySigner),
regular.SetPublicAddrs(cfg.Proxy.PublicAddrs),
)
if err != nil {
return trace.Wrap(err)
Expand Down
3 changes: 1 addition & 2 deletions lib/services/local/perf_test.go
Expand Up @@ -108,8 +108,7 @@ func insertNodes(ctx context.Context, b *testing.B, svc services.Presence, nodeC
Labels: labels,
},
Spec: types.ServerSpecV2{
Addr: addr,
PublicAddr: addr,
Addr: addr,
},
}
_, err := svc.UpsertNode(ctx, node)
Expand Down
29 changes: 27 additions & 2 deletions lib/services/matchers_test.go
Expand Up @@ -139,8 +139,9 @@ func TestMatchResourceByFilters_Helper(t *testing.T) {
t.Parallel()

server, err := types.NewServerWithLabels("banana", types.KindNode, types.ServerSpecV2{
Hostname: "foo",
Addr: "bar",
Hostname: "foo",
Addr: "bar",
PublicAddrs: []string{"foo.example.com:3080"},
}, map[string]string{"env": "prod", "os": "mac"})
require.NoError(t, err)

Expand Down Expand Up @@ -177,6 +178,30 @@ func TestMatchResourceByFilters_Helper(t *testing.T) {
assertErr: require.NoError,
assertMatch: require.False,
},
{
name: "search keywords hostname match",
filters: MatchResourceFilter{
SearchKeywords: []string{"foo"},
},
assertErr: require.NoError,
assertMatch: require.True,
},
{
name: "search keywords addr match",
filters: MatchResourceFilter{
SearchKeywords: []string{"bar"},
},
assertErr: require.NoError,
assertMatch: require.True,
},
{
name: "search keywords public addr match",
filters: MatchResourceFilter{
SearchKeywords: []string{"foo.example.com"},
},
assertErr: require.NoError,
assertMatch: require.True,
},
{
name: "expression match",
filters: MatchResourceFilter{
Expand Down
7 changes: 6 additions & 1 deletion lib/services/server.go
Expand Up @@ -91,9 +91,14 @@ func compareServers(a, b types.Server) int {
if a.GetNamespace() != b.GetNamespace() {
return Different
}
if a.GetPublicAddr() != b.GetPublicAddr() {
if len(a.GetPublicAddrs()) != len(b.GetPublicAddrs()) {
return Different
}
for i := range a.GetPublicAddrs() {
if a.GetPublicAddrs()[i] != b.GetPublicAddrs()[i] {
return Different
}
}
r := a.GetRotation()
if !r.Matches(b.GetRotation()) {
return Different
Expand Down
10 changes: 5 additions & 5 deletions lib/services/servers_test.go
Expand Up @@ -73,9 +73,9 @@ func TestServersCompare(t *testing.T) {
node2.Spec.Addr = "localhost:3033"
require.Equal(t, CompareServers(node, &node2), Different)

// Public addr has changed
// Proxy addr has changed
node2 = *node
node2.Spec.PublicAddr = "localhost:3033"
node2.Spec.PublicAddrs = []string{"localhost:3033"}
require.Equal(t, CompareServers(node, &node2), Different)

// Hostname has changed
Expand Down Expand Up @@ -160,13 +160,13 @@ func TestGuessProxyHostAndVersion(t *testing.T) {
require.Equal(t, version, proxyA.Spec.Version)
require.NoError(t, err)

// At least one proxy has public address set.
// At least one proxy has proxy address set.
proxyB := types.ServerV2{}
proxyB.Spec.PublicAddr = "test-B"
proxyB.Spec.PublicAddrs = []string{"test-B"}
proxyB.Spec.Version = "test-B"

host, version, err = GuessProxyHostAndVersion([]types.Server{&proxyA, &proxyB})
require.Equal(t, host, proxyB.Spec.PublicAddr)
require.Equal(t, host, proxyB.Spec.PublicAddrs[0])
require.Equal(t, version, proxyB.Spec.Version)
require.NoError(t, err)
}
Expand Down
3 changes: 1 addition & 2 deletions lib/services/suite/suite.go
Expand Up @@ -362,8 +362,7 @@ func NewServer(kind, name, addr, namespace string) *types.ServerV2 {
Namespace: namespace,
},
Spec: types.ServerSpecV2{
Addr: addr,
PublicAddr: addr,
Addr: addr,
},
}
}
Expand Down
2 changes: 0 additions & 2 deletions lib/services/watcher.go
Expand Up @@ -1444,8 +1444,6 @@ type Node interface {
GetNamespace() string
// GetCmdLabels gets command labels
GetCmdLabels() map[string]types.CommandLabel
// GetPublicAddr is an optional field that returns the public address this cluster can be reached at.
GetPublicAddr() string
// GetRotation gets the state of certificate authority rotation.
GetRotation() types.Rotation
// GetUseTunnel gets if a reverse tunnel should be used to connect to this node.
Expand Down

0 comments on commit 5ad772a

Please sign in to comment.