From f79a9b4106454b3d1ad39dfd443b9d439013a537 Mon Sep 17 00:00:00 2001 From: Sasha Klizhentas Date: Sun, 8 Jul 2018 10:52:15 -0700 Subject: [PATCH] Introduce disconnect client logic. This commit implements #1935, fixes #2038 Auth server now supports global defaults for timeout behavior: ``` auth_service: client_idle_timeout: 15m disconnect_expired_cert: no ``` New role options were introduced: ``` kind: role version: v3 metadata: name: intern spec: options: # these two settings override the global ones: client_idle_timeout: 1m disconnect_expired_cert: yes ``` --- constants.go | 3 + integration/helpers.go | 4 +- integration/integration_test.go | 175 ++++++++++++-- lib/auth/permissions.go | 2 +- lib/auth/tls_test.go | 4 +- lib/config/configuration.go | 8 +- lib/config/configuration_test.go | 59 +++++ lib/config/fileconf.go | 224 +++++++++--------- lib/config/testdata_test.go | 2 + lib/events/api.go | 8 + lib/services/clusterconfig.go | 39 +++ lib/services/role.go | 378 +++++++++++++++++------------- lib/services/role_test.go | 92 +++++++- lib/services/suite/suite.go | 6 +- lib/srv/authhandlers.go | 4 + lib/srv/ctx.go | 168 ++++++++++++- lib/srv/forward/sshserver.go | 18 ++ lib/srv/regular/proxy.go | 12 +- lib/srv/regular/sshserver.go | 7 +- lib/srv/regular/sshserver_test.go | 4 +- lib/srv/sess.go | 1 + lib/utils/buf.go | 72 ++++++ lib/utils/utils.go | 15 +- lib/web/apiserver_test.go | 4 +- 24 files changed, 979 insertions(+), 330 deletions(-) create mode 100644 lib/utils/buf.go diff --git a/constants.go b/constants.go index dc084d1fc04..87443519c12 100644 --- a/constants.go +++ b/constants.go @@ -302,6 +302,9 @@ const ( // CertificateFormatUnspecified is used to check if the format was specified // or not. CertificateFormatUnspecified = "" + + // DurationNever is human friendly shortcut that is interpreted as a Duration of 0 + DurationNever = "never" ) const ( diff --git a/integration/helpers.go b/integration/helpers.go index 4938325cade..b4998d4dd91 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -354,7 +354,7 @@ func SetupUser(process *service.TeleportProcess, username string, roles []servic // allow tests to forward agent, still needs to be passed in client roleOptions := role.GetOptions() - roleOptions.Set(services.ForwardAgent, true) + roleOptions.ForwardAgent = services.NewBool(true) role.SetOptions(roleOptions) err = auth.UpsertRole(role, backend.Forever) @@ -510,7 +510,7 @@ func (i *TeleInstance) CreateEx(trustedSecrets []*InstanceSecrets, tconf *servic // allow tests to forward agent, still needs to be passed in client roleOptions := role.GetOptions() - roleOptions.Set(services.ForwardAgent, true) + roleOptions.ForwardAgent = services.NewBool(true) role.SetOptions(roleOptions) err = auth.UpsertRole(role, backend.Forever) diff --git a/integration/integration_test.go b/integration/integration_test.go index c85018346bb..89b3d246c75 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -491,26 +491,26 @@ func (s *IntSuite) TestInteroperability(c *check.C) { // 0 - echo "1\n2\n" | ssh localhost "cat -" // this command can be used to copy files by piping stdout to stdin over ssh. { - "cat -", - "1\n2\n", - "1\n2\n", - false, + inCommand: "cat -", + inStdin: "1\n2\n", + outContains: "1\n2\n", + outFile: false, }, // 1 - ssh -tt locahost '/bin/sh -c "mkdir -p /tmp && echo a > /tmp/file.txt"' // programs like ansible execute commands like this { - fmt.Sprintf(`/bin/sh -c "mkdir -p /tmp && echo a > %v"`, tempfile), - "", - "a", - true, + inCommand: fmt.Sprintf(`/bin/sh -c "mkdir -p /tmp && echo a > %v"`, tempfile), + inStdin: "", + outContains: "a", + outFile: true, }, // 2 - ssh localhost tty // should print "not a tty" { - "tty", - "", - "not a tty", - false, + inCommand: "tty", + inStdin: "", + outContains: "not a tty", + outFile: false, }, } @@ -521,7 +521,7 @@ func (s *IntSuite) TestInteroperability(c *check.C) { // hook up stdin and stdout to a buffer for reading and writing inbuf := bytes.NewReader([]byte(tt.inStdin)) - outbuf := &bytes.Buffer{} + outbuf := utils.NewSyncBuffer() cl.Stdin = inbuf cl.Stdout = outbuf cl.Stderr = outbuf @@ -688,6 +688,153 @@ func (s *IntSuite) TestShutdown(c *check.C) { } } +type disconnectTestCase struct { + recordingMode string + options services.RoleOptions + disconnectTimeout time.Duration +} + +// TestDisconnectScenarios tests multiple scenarios with client disconnects +func (s *IntSuite) TestDisconnectScenarios(c *check.C) { + + testCases := []disconnectTestCase{ + { + recordingMode: services.RecordAtNode, + options: services.RoleOptions{ + ClientIdleTimeout: services.NewDuration(500 * time.Millisecond), + }, + disconnectTimeout: time.Second, + }, + { + recordingMode: services.RecordAtProxy, + options: services.RoleOptions{ + ClientIdleTimeout: services.NewDuration(500 * time.Millisecond), + }, + disconnectTimeout: time.Second, + }, + { + recordingMode: services.RecordAtNode, + options: services.RoleOptions{ + DisconnectExpiredCert: services.NewBool(true), + MaxSessionTTL: services.NewDuration(2 * time.Second), + }, + disconnectTimeout: 4 * time.Second, + }, + { + recordingMode: services.RecordAtProxy, + options: services.RoleOptions{ + DisconnectExpiredCert: services.NewBool(true), + MaxSessionTTL: services.NewDuration(2 * time.Second), + }, + disconnectTimeout: 4 * time.Second, + }, + } + for _, tc := range testCases { + s.runDisconnectTest(c, tc) + } +} + +func (s *IntSuite) runDisconnectTest(c *check.C, tc disconnectTestCase) { + t := NewInstance(InstanceConfig{ + ClusterName: Site, + HostID: HostID, + NodeName: Host, + Ports: s.getPorts(5), + Priv: s.priv, + Pub: s.pub, + }) + + // devs role gets disconnected after 1 second idle time + username := s.me.Username + role, err := services.NewRole("devs", services.RoleSpecV3{ + Options: tc.options, + Allow: services.RoleConditions{ + Logins: []string{username}, + }, + }) + c.Assert(err, check.IsNil) + t.AddUserWithRole(username, role) + + clusterConfig, err := services.NewClusterConfig(services.ClusterConfigSpecV3{ + SessionRecording: services.RecordAtNode, + }) + c.Assert(err, check.IsNil) + + cfg := service.MakeDefaultConfig() + cfg.Auth.Enabled = true + cfg.Auth.ClusterConfig = clusterConfig + cfg.Proxy.DisableWebService = true + cfg.Proxy.DisableWebInterface = true + cfg.Proxy.Enabled = true + cfg.SSH.Enabled = true + + c.Assert(t.CreateEx(nil, cfg), check.IsNil) + c.Assert(t.Start(), check.IsNil) + defer t.Stop(true) + + // get a reference to site obj: + site := t.GetSiteAPI(Site) + c.Assert(site, check.NotNil) + + person := NewTerminal(250) + + // commandsC receive commands + commandsC := make(chan string, 0) + + // PersonA: SSH into the server, wait one second, then type some commands on stdin: + sessionCtx, sessionCancel := context.WithCancel(context.TODO()) + openSession := func() { + defer sessionCancel() + cl, err := t.NewClient(ClientConfig{Login: username, Cluster: Site, Host: Host, Port: t.GetPortSSHInt()}) + c.Assert(err, check.IsNil) + cl.Stdout = &person + cl.Stdin = &person + + go func() { + for command := range commandsC { + person.Type(command) + } + }() + + err = cl.SSH(context.TODO(), []string{}, false) + if err != nil && err != io.EOF { + c.Fatalf("expected EOF or nil, got %v instead", err) + } + } + + go openSession() + + retry := func(command, pattern string) { + person.Type(command) + abortTime := time.Now().Add(10 * time.Second) + var matched bool + var output string + for { + output = string(replaceNewlines(person.Output(1000))) + matched, _ = regexp.MatchString(pattern, output) + if matched { + break + } + time.Sleep(time.Millisecond * 200) + if time.Now().After(abortTime) { + c.Fatalf("failed to capture output: %v", pattern) + } + } + if !matched { + c.Fatalf("output %q does not match pattern %q", output, pattern) + } + } + + retry("echo start \r\n", ".*start.*") + time.Sleep(tc.disconnectTimeout) + select { + case <-time.After(tc.disconnectTimeout): + c.Fatalf("timeout waiting for session to exit") + case <-sessionCtx.Done(): + // session closed + } +} + // TestInvalidLogins validates that you can't login with invalid login or // with invalid 'site' parameter func (s *IntSuite) TestEnvironmentVariables(c *check.C) { @@ -1509,7 +1656,7 @@ func (s *IntSuite) TestDiscovery(c *check.C) { // attempt to allow the discovery request to be received and the connection // added to the agent pool. lb.AddBackend(mainProxyAddr) - output, err = runCommand(main, []string{"echo", "hello world"}, cfg, 10) + output, err = runCommand(main, []string{"echo", "hello world"}, cfg, 20) c.Assert(err, check.IsNil) c.Assert(output, check.Equals, "hello world\n") diff --git a/lib/auth/permissions.go b/lib/auth/permissions.go index 47510939fae..a59b3e8669b 100644 --- a/lib/auth/permissions.go +++ b/lib/auth/permissions.go @@ -387,7 +387,7 @@ func GetCheckerForBuiltinRole(clusterName string, clusterConfig services.Cluster role.String(), services.RoleSpecV3{ Options: services.RoleOptions{ - services.MaxSessionTTL: services.MaxDuration(), + MaxSessionTTL: services.MaxDuration(), }, Allow: services.RoleConditions{ Namespaces: []string{services.Wildcard}, diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index c1f2ed5dc58..37af40329ae 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -1118,7 +1118,7 @@ func (s *TLSSuite) TestGenerateCerts(c *check.C) { // now update role to permit agent forwarding roleOptions := userRole.GetOptions() - roleOptions.Set(services.ForwardAgent, true) + roleOptions.ForwardAgent = services.NewBool(true) userRole.SetOptions(roleOptions) err = s.server.Auth().UpsertRole(userRole, backend.Forever) c.Assert(err, check.IsNil) @@ -1182,7 +1182,7 @@ func (s *TLSSuite) TestCertificateFormat(c *check.C) { for _, tt := range tests { roleOptions := userRole.GetOptions() - roleOptions.Set(services.CertificateFormat, tt.inRoleCertificateFormat) + roleOptions.CertificateFormat = tt.inRoleCertificateFormat userRole.SetOptions(roleOptions) err := s.server.Auth().UpsertRole(userRole, backend.Forever) c.Assert(err, check.IsNil) diff --git a/lib/config/configuration.go b/lib/config/configuration.go index a9ee16c3861..f829f66d16b 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -461,9 +461,11 @@ func ApplyFileConfig(fc *FileConfig, cfg *service.Config) error { // build cluster config from session recording and host key checking preferences cfg.Auth.ClusterConfig, err = services.NewClusterConfig(services.ClusterConfigSpecV3{ - SessionRecording: fc.Auth.SessionRecording, - ProxyChecksHostKeys: fc.Auth.ProxyChecksHostKeys, - Audit: *auditConfig, + SessionRecording: fc.Auth.SessionRecording, + ProxyChecksHostKeys: fc.Auth.ProxyChecksHostKeys, + Audit: *auditConfig, + ClientIdleTimeout: fc.Auth.ClientIdleTimeout, + DisconnectExpiredCert: fc.Auth.DisconnectExpiredCert, }) if err != nil { return trace.Wrap(err) diff --git a/lib/config/configuration_test.go b/lib/config/configuration_test.go index 9efbfc7696c..0f816c1ae59 100644 --- a/lib/config/configuration_test.go +++ b/lib/config/configuration_test.go @@ -119,6 +119,61 @@ func (s *ConfigTestSuite) TestSampleConfig(c *check.C) { c.Assert(lib.IsInsecureDevMode(), check.Equals, false) } +// TestBooleanParsing tests that boolean options +// are parsed properly +func (s *ConfigTestSuite) TestBooleanParsing(c *check.C) { + testCases := []struct { + s string + b bool + }{ + {s: "true", b: true}, + {s: "'true'", b: true}, + {s: "yes", b: true}, + {s: "'yes'", b: true}, + {s: "'1'", b: true}, + {s: "1", b: true}, + {s: "no", b: false}, + {s: "0", b: false}, + } + for i, tc := range testCases { + comment := check.Commentf("test case %v", i) + conf, err := ReadFromString(base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf(` +teleport: + advertise_ip: 10.10.10.1 +auth_service: + enabled: yes + disconnect_expired_cert: %v +`, tc.s)))) + c.Assert(err, check.IsNil) + c.Assert(conf.Auth.DisconnectExpiredCert.Value(), check.Equals, tc.b, comment) + } +} + +// TestDurationParsing tests that duration options +// are parsed properly +func (s *ConfigTestSuite) TestDuration(c *check.C) { + testCases := []struct { + s string + d time.Duration + }{ + {s: "1s", d: time.Second}, + {s: "never", d: 0}, + {s: "'1m'", d: time.Minute}, + } + for i, tc := range testCases { + comment := check.Commentf("test case %v", i) + conf, err := ReadFromString(base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf(` +teleport: + advertise_ip: 10.10.10.1 +auth_service: + enabled: yes + client_idle_timeout: %v +`, tc.s)))) + c.Assert(err, check.IsNil) + c.Assert(conf.Auth.ClientIdleTimeout.Value(), check.Equals, tc.d, comment) + } +} + func (s *ConfigTestSuite) TestConfigReading(c *check.C) { // non-existing file: conf, err := ReadFromFile("/heaven/trees/apple.ymL") @@ -149,6 +204,8 @@ func (s *ConfigTestSuite) TestConfigReading(c *check.C) { c.Assert(conf.Auth.Enabled(), check.Equals, true) c.Assert(conf.Auth.ListenAddress, check.Equals, "tcp://auth") c.Assert(conf.Auth.LicenseFile, check.Equals, "lic.pem") + c.Assert(conf.Auth.DisconnectExpiredCert.Value(), check.Equals, true) + c.Assert(conf.Auth.ClientIdleTimeout.Value(), check.Equals, 17*time.Second) c.Assert(conf.SSH.Configured(), check.Equals, true) c.Assert(conf.SSH.Enabled(), check.Equals, true) c.Assert(conf.SSH.ListenAddress, check.Equals, "tcp://ssh") @@ -575,6 +632,8 @@ func makeConfigFixture() string { conf.Auth.EnabledFlag = "Yeah" conf.Auth.ListenAddress = "tcp://auth" conf.Auth.LicenseFile = "lic.pem" + conf.Auth.ClientIdleTimeout = services.NewDuration(17 * time.Second) + conf.Auth.DisconnectExpiredCert = services.NewBool(true) // ssh service: conf.SSH.EnabledFlag = "true" diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index 53d38348ae4..109574fb8cf 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -47,100 +47,102 @@ var ( // true = has sub-keys // false = does not have sub-keys (a leaf) validKeys = map[string]bool{ - "proxy_protocol": false, - "namespace": true, - "cluster_name": true, - "trusted_clusters": true, - "pid_file": true, - "cert_file": true, - "private_key_file": true, - "cert": true, - "private_key": true, - "checking_keys": true, - "checking_key_files": true, - "signing_keys": true, - "signing_key_files": true, - "allowed_logins": true, - "teleport": true, - "enabled": true, - "ssh_service": true, - "proxy_service": true, - "auth_service": true, - "auth_token": true, - "auth_servers": true, - "domain_name": true, - "storage": false, - "nodename": true, - "log": true, - "period": true, - "connection_limits": true, - "max_connections": true, - "max_users": true, - "rates": true, - "commands": true, - "labels": false, - "output": true, - "severity": true, - "role": true, - "name": true, - "type": true, - "data_dir": true, - "web_listen_addr": true, - "tunnel_listen_addr": true, - "ssh_listen_addr": true, - "kube_listen_addr": true, - "kube_api_addr": true, - "kube_ca_cert_file": true, - "listen_addr": true, - "https_key_file": true, - "https_cert_file": true, - "advertise_ip": true, - "authorities": true, - "keys": true, - "reverse_tunnels": true, - "addresses": true, - "oidc_connectors": true, - "id": true, - "issuer_url": true, - "client_id": true, - "client_secret": true, - "redirect_url": true, - "acr_values": true, - "provider": true, - "tokens": true, - "region": true, - "table_name": true, - "access_key": true, - "secret_key": true, - "u2f": true, - "app_id": true, - "facets": true, - "authentication": true, - "second_factor": false, - "oidc": true, - "display": false, - "scope": false, - "claims_to_roles": true, - "dynamic_config": false, - "seed_config": false, - "public_addr": false, - "cache": true, - "ttl": false, - "issuer": false, - "permit_user_env": false, - "ciphers": false, - "kex_algos": false, - "mac_algos": false, - "connector_name": false, - "session_recording": false, - "read_capacity_units": false, - "write_capacity_units": false, - "license_file": false, - "proxy_checks_host_keys": false, - "audit_table_name": false, - "audit_sessions_uri": false, - "pam": true, - "service_name": false, + "proxy_protocol": false, + "namespace": true, + "cluster_name": true, + "trusted_clusters": true, + "pid_file": true, + "cert_file": true, + "private_key_file": true, + "cert": true, + "private_key": true, + "checking_keys": true, + "checking_key_files": true, + "signing_keys": true, + "signing_key_files": true, + "allowed_logins": true, + "teleport": true, + "enabled": true, + "ssh_service": true, + "proxy_service": true, + "auth_service": true, + "auth_token": true, + "auth_servers": true, + "domain_name": true, + "storage": false, + "nodename": true, + "log": true, + "period": true, + "connection_limits": true, + "max_connections": true, + "max_users": true, + "rates": true, + "commands": true, + "labels": false, + "output": true, + "severity": true, + "role": true, + "name": true, + "type": true, + "data_dir": true, + "web_listen_addr": true, + "tunnel_listen_addr": true, + "ssh_listen_addr": true, + "kube_listen_addr": true, + "kube_api_addr": true, + "kube_ca_cert_file": true, + "listen_addr": true, + "https_key_file": true, + "https_cert_file": true, + "advertise_ip": true, + "authorities": true, + "keys": true, + "reverse_tunnels": true, + "addresses": true, + "oidc_connectors": true, + "id": true, + "issuer_url": true, + "client_id": true, + "client_secret": true, + "redirect_url": true, + "acr_values": true, + "provider": true, + "tokens": true, + "region": true, + "table_name": true, + "access_key": true, + "secret_key": true, + "u2f": true, + "app_id": true, + "facets": true, + "authentication": true, + "second_factor": false, + "oidc": true, + "display": false, + "scope": false, + "claims_to_roles": true, + "dynamic_config": false, + "seed_config": false, + "public_addr": false, + "cache": true, + "ttl": false, + "issuer": false, + "permit_user_env": false, + "ciphers": false, + "kex_algos": false, + "mac_algos": false, + "connector_name": false, + "session_recording": false, + "read_capacity_units": false, + "write_capacity_units": false, + "license_file": false, + "proxy_checks_host_keys": false, + "audit_table_name": false, + "audit_sessions_uri": false, + "pam": true, + "service_name": false, + "client_idle_timeout": false, + "disconnect_expired_cert": false, } ) @@ -386,14 +388,6 @@ type CachePolicy struct { TTL string `yaml:"ttl,omitempty"` } -func isTrue(v string) bool { - switch v { - case "yes", "yeah", "y", "true", "1": - return true - } - return false -} - func isNever(v string) bool { switch v { case "never", "no", "0": @@ -404,7 +398,11 @@ func isNever(v string) bool { // Enabled determines if a given "_service" section has been set to 'true' func (c *CachePolicy) Enabled() bool { - return c.EnabledFlag == "" || isTrue(c.EnabledFlag) + if c.EnabledFlag == "" { + return true + } + enabled, _ := utils.ParseBool(c.EnabledFlag) + return enabled } // NeverExpires returns if cache never expires by itself @@ -447,11 +445,14 @@ func (s *Service) Configured() bool { // Enabled determines if a given "_service" section has been set to 'true' func (s *Service) Enabled() bool { - switch strings.ToLower(s.EnabledFlag) { - case "", "yes", "yeah", "y", "true", "1": + if s.EnabledFlag == "" { return true } - return false + v, err := utils.ParseBool(s.EnabledFlag) + if err != nil { + return false + } + return v } // Disabled returns 'true' if the service has been deliberately turned off @@ -527,6 +528,13 @@ type Auth struct { // KubeCACertFile is a path to kubernetes certificate authority certificate file KubeCACertFile string `yaml:"kube_ca_cert_file,omitempty"` + + // ClientIdleTimeout sets global cluster default setting for client idle timeouts + ClientIdleTimeout services.Duration `yaml:"client_idle_timeout"` + + // DisconnectExpiredCert provides disconnect expired certificate setting - + // if true, connections with expired client certificates will get disconnected + DisconnectExpiredCert services.Bool `yaml:"disconnect_expired_cert"` } // TrustedCluster struct holds configuration values under "trusted_clusters" key @@ -686,9 +694,9 @@ func (p *PAM) Parse() *pam.Config { if serviceName == "" { serviceName = defaults.ServiceName } - + enabled, _ := utils.ParseBool(p.Enabled) return &pam.Config{ - Enabled: isTrue(p.Enabled), + Enabled: enabled, ServiceName: serviceName, } } diff --git a/lib/config/testdata_test.go b/lib/config/testdata_test.go index c5f07680b8c..a735712a7fd 100644 --- a/lib/config/testdata_test.go +++ b/lib/config/testdata_test.go @@ -78,6 +78,8 @@ auth_service: - domain_name: tunnel.example.org addresses: ["org-1"] public_addr: ["auth.default.svc.cluster.local:3080"] + disconnect_expired_cert: yes + client_idle_timeout: 17s ssh_service: enabled: no diff --git a/lib/events/api.go b/lib/events/api.go index 504b040164f..ba03f2635f6 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -88,6 +88,14 @@ const ( // SessionLeaveEvent indicates that someone left a session SessionLeaveEvent = "session.leave" + // ClientDisconnectEvent is emitted when client is disconnected + // by the server due to inactivity or any other reason + ClientDisconnectEvent = "client.disconnect" + + // Reason is a field that specifies reason for event, e.g. in disconnect + // event it explains why server disconnected the client + Reason = "reason" + // UserLoginEvent indicates that a user logged into web UI or via tsh UserLoginEvent = "user.login" // LoginMethod is the event field indicating how the login was performed diff --git a/lib/services/clusterconfig.go b/lib/services/clusterconfig.go index 92c31166857..154634ce5b9 100644 --- a/lib/services/clusterconfig.go +++ b/lib/services/clusterconfig.go @@ -62,6 +62,18 @@ type ClusterConfig interface { // SetAuditConfig sets audit config SetAuditConfig(AuditConfig) + // GetClientIdleTimeout returns client idle timeout setting + GetClientIdleTimeout() time.Duration + + // SetClientIdleTimeout sets client idle timeout setting + SetClientIdleTimeout(t time.Duration) + + // GetDisconnectExpiredCert returns disconnect expired certificate setting + GetDisconnectExpiredCert() bool + + // SetDisconnectExpiredCert sets disconnect client with expired certificate setting + SetDisconnectExpiredCert(bool) + // Copy creates a copy of the resource and returns it. Copy() ClusterConfig } @@ -183,6 +195,13 @@ type ClusterConfigSpecV3 struct { // Audit is a section with audit config Audit AuditConfig `json:"audit"` + + // ClientIdleTimeout sets global cluster default setting for client idle timeouts + ClientIdleTimeout Duration `json:"client_idle_timeout"` + + // DisconnectExpiredCert provides disconnect expired certificate setting - + // if true, connections with expired client certificates will get disconnected + DisconnectExpiredCert Bool `json:"disconnect_expired_cert"` } // GetName returns the name of the cluster. @@ -255,6 +274,26 @@ func (c *ClusterConfigV3) SetAuditConfig(cfg AuditConfig) { c.Spec.Audit = cfg } +// GetClientIdleTimeout returns client idle timeout setting +func (c *ClusterConfigV3) GetClientIdleTimeout() time.Duration { + return c.Spec.ClientIdleTimeout.Duration +} + +// SetClientIdleTimeout sets client idle timeout setting +func (c *ClusterConfigV3) SetClientIdleTimeout(d time.Duration) { + c.Spec.ClientIdleTimeout.Duration = d +} + +// GetDisconnectExpiredCert returns disconnect expired certificate setting +func (c *ClusterConfigV3) GetDisconnectExpiredCert() bool { + return c.Spec.DisconnectExpiredCert.bool +} + +// SetDisconnectExpiredCert sets disconnect client with expired certificate setting +func (c *ClusterConfigV3) SetDisconnectExpiredCert(b bool) { + c.Spec.DisconnectExpiredCert.bool = b +} + // CheckAndSetDefaults checks validity of all parameters and sets defaults. func (c *ClusterConfigV3) CheckAndSetDefaults() error { // make sure we have defaults for all metadata fields diff --git a/lib/services/role.go b/lib/services/role.go index 846e6a69f97..4a8eceb581e 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -93,8 +93,8 @@ func NewAdminRole() Role { Options: RoleOptions{ CertificateFormat: teleport.CertificateFormatStandard, MaxSessionTTL: NewDuration(defaults.MaxCertDuration), - PortForwarding: true, - ForwardAgent: true, + PortForwarding: NewBoolOption(true), + ForwardAgent: NewBool(true), }, Allow: RoleConditions{ Namespaces: []string{defaults.Namespace}, @@ -142,8 +142,8 @@ func RoleForUser(u User) Role { Options: RoleOptions{ CertificateFormat: teleport.CertificateFormatStandard, MaxSessionTTL: NewDuration(defaults.MaxCertDuration), - PortForwarding: true, - ForwardAgent: true, + PortForwarding: NewBoolOption(true), + ForwardAgent: NewBool(true), }, Allow: RoleConditions{ Namespaces: []string{defaults.Namespace}, @@ -206,25 +206,6 @@ type Access interface { DeleteRole(name string) error } -// TODO: [ev] can we please define a RoleOption type (instead of using strings) -// and use RoleOption prefix for naming these? It's impossible right now to find -// all possible role options. -const ( - // ForwardAgent is SSH agent forwarding. - ForwardAgent = "forward_agent" - - // MaxSessionTTL defines how long a SSH session can last for. - MaxSessionTTL = "max_session_ttl" - - // PortForwarding defines if the certificate will have "permit-port-forwarding" - // in the certificate. - PortForwarding = "port_forwarding" - - // CertificateFormat defines the format of the user certificate to allow - // compatibility with older versions of OpenSSH. - CertificateFormat = "cert_format" -) - const ( // Allow is the set of conditions that allow access. Allow RoleConditionType = true @@ -484,7 +465,7 @@ func (r *RoleV3) GetOptions() RoleOptions { // SetOptions sets role options. func (r *RoleV3) SetOptions(options RoleOptions) { - r.Spec.Options = utils.CopyStringMapInterface(options) + r.Spec.Options = options } // GetLogins gets system logins for allow or deny condition. @@ -590,12 +571,14 @@ func (r *RoleV3) CheckAndSetDefaults() error { } // make sure we have defaults for all fields - if r.Spec.Options == nil { - r.Spec.Options = map[string]interface{}{ - CertificateFormat: teleport.CertificateFormatStandard, - MaxSessionTTL: NewDuration(defaults.MaxCertDuration), - PortForwarding: true, - } + if r.Spec.Options.CertificateFormat == "" { + r.Spec.Options.CertificateFormat = teleport.CertificateFormatStandard + } + if r.Spec.Options.MaxSessionTTL.Value() == 0 { + r.Spec.Options.MaxSessionTTL = NewDuration(defaults.MaxCertDuration) + } + if r.Spec.Options.PortForwarding == nil { + r.Spec.Options.PortForwarding = NewBoolOption(true) } if r.Spec.Allow.Namespaces == nil { r.Spec.Allow.Namespaces = []string{defaults.Namespace} @@ -620,15 +603,8 @@ func (r *RoleV3) CheckAndSetDefaults() error { } // check and correct the session ttl - maxSessionTTL, err := r.Spec.Options.GetDuration(MaxSessionTTL) - if err != nil { - return trace.BadParameter("invalid duration: %v", err) - } - if maxSessionTTL.Duration == 0 { - r.Spec.Options.Set(MaxSessionTTL, NewDuration(defaults.MaxCertDuration)) - } - if maxSessionTTL.Duration < defaults.MinCertDuration { - return trace.BadParameter("maximum session TTL can not be less than, minimal certificate duration") + if r.Spec.Options.MaxSessionTTL.Value() <= 0 { + r.Spec.Options.MaxSessionTTL = NewDuration(defaults.MaxCertDuration) } // restrict wildcards @@ -673,100 +649,40 @@ type RoleSpecV3 struct { Deny RoleConditions `json:"deny,omitempty"` } -// RoleOptions are key/value pairs that always exist for a role. -type RoleOptions map[string]interface{} - -// UnmarshalJSON is used when parsing RoleV3 to convert MaxSessionTTL into the -// correct type. -func (o *RoleOptions) UnmarshalJSON(data []byte) error { - var raw map[string]interface{} - err := json.Unmarshal(data, &raw) - if err != nil { - return err - } - - rmap := make(map[string]interface{}) - for k, v := range raw { - switch k { - case MaxSessionTTL: - d, err := time.ParseDuration(v.(string)) - if err != nil { - return err - } - rmap[MaxSessionTTL] = NewDuration(d) - default: - rmap[k] = v - } - } - - *o = rmap - return nil -} - -// Set an option key/value pair. -func (o RoleOptions) Set(key string, value interface{}) { - o[key] = value -} - -// Get returns the option as an interface{}, it is the responsibility of the -// caller to convert to the correct type. -func (o RoleOptions) Get(key string) (interface{}, error) { - valueI, ok := o[key] - if !ok { - return nil, trace.NotFound("key %q not found in options", key) - } - - return valueI, nil -} - -// GetString returns the option as a string or returns an error. -func (o RoleOptions) GetString(key string) (string, error) { - valueI, ok := o[key] - if !ok { - return "", trace.NotFound("key %q not found in options", key) - } - - value, ok := valueI.(string) - if !ok { - return "", trace.BadParameter("type %T for key %q is not a string", valueI, key) - } - - return value, nil -} - -// GetBoolean returns the option as a bool or returns an error. -func (o RoleOptions) GetBoolean(key string) (bool, error) { - valueI, ok := o[key] - if !ok { - return false, trace.NotFound("key %q not found in options", key) - } +// RoleOptions is a set of role options +type RoleOptions struct { + // ForwardAgent is SSH agent forwarding. + ForwardAgent Bool `json:"forward_agent"` - value, ok := valueI.(bool) - if !ok { - return false, trace.BadParameter("type %T for key %q is not a bool", valueI, key) - } + // MaxSessionTTL defines how long a SSH session can last for. + MaxSessionTTL Duration `json:"max_session_ttl"` - return value, nil -} + // PortForwarding defines if the certificate will have "permit-port-forwarding" + // in the certificate. PortForwarding is "yes" if not set, + // that's why this is a pointer + PortForwarding *Bool `json:"port_forwarding,omitempty"` -// GetDuration returns the option as a services.Duration or returns an error. -func (o RoleOptions) GetDuration(key string) (Duration, error) { - valueI, ok := o[key] - if !ok { - return NewDuration(defaults.MinCertDuration), trace.NotFound("key %q not found in options", key) - } + // CertificateFormat defines the format of the user certificate to allow + // compatibility with older versions of OpenSSH. + CertificateFormat string `json:"cert_format"` - value, ok := valueI.(Duration) - if !ok { - return NewDuration(defaults.MinCertDuration), trace.BadParameter("type %T for key %q is not a Duration", valueI, key) - } + // ClientIdleTimeout sets disconnect clients on idle timeout behavior, + // if set to 0 means do not disconnect, otherwise is set to the idle + // duration. + ClientIdleTimeout Duration `json:"client_idle_timeout"` - return value, nil + // DisconnectExpiredCert sets disconnect clients on expired certificates. + DisconnectExpiredCert Bool `json:"disconnect_expired_cert"` } // Equals checks if all the key/values in the RoleOptions map match. func (o RoleOptions) Equals(other RoleOptions) bool { - return utils.InterfaceMapsEqual(o, other) + return (o.ForwardAgent.Value() == other.ForwardAgent.Value() && + o.MaxSessionTTL.Value() == other.MaxSessionTTL.Value() && + BoolOption(o.PortForwarding).Value() == BoolOption(other.PortForwarding).Value() && + o.CertificateFormat == other.CertificateFormat && + o.ClientIdleTimeout.Value() == other.ClientIdleTimeout.Value() && + o.DisconnectExpiredCert.Value() == other.DisconnectExpiredCert.Value()) } // RoleConditions is a set of conditions that must all match to be allowed or @@ -1212,7 +1128,7 @@ func (r *RoleV2) CheckAndSetDefaults() error { r.Spec.MaxSessionTTL.Duration = defaults.MaxCertDuration } if r.Spec.MaxSessionTTL.Duration < defaults.MinCertDuration { - return trace.BadParameter("maximum session TTL can not be less than") + return trace.BadParameter("maximum session TTL can not be less than %v", defaults.MinCertDuration) } if r.Spec.Namespaces == nil { r.Spec.Namespaces = []string{defaults.Namespace} @@ -1255,7 +1171,7 @@ func (r *RoleV2) V3() *RoleV3 { Options: RoleOptions{ CertificateFormat: teleport.CertificateFormatStandard, MaxSessionTTL: r.GetMaxSessionTTL(), - PortForwarding: true, + PortForwarding: NewBoolOption(true), }, Allow: RoleConditions{ Logins: r.GetLogins(), @@ -1268,7 +1184,7 @@ func (r *RoleV2) V3() *RoleV3 { // translate old v2 agent forwarding to a v3 option if r.CanForwardAgent() { - role.Spec.Options[ForwardAgent] = true + role.Spec.Options.ForwardAgent = NewBool(true) } // translate old v2 resources to v3 rules @@ -1349,6 +1265,15 @@ type AccessChecker interface { // for this role set, otherwise it returns ttl unchanged AdjustSessionTTL(ttl time.Duration) time.Duration + // AdjustClientIdleTimeout adjusts requested idle timeout + // to the lowest max allowed timeout, the most restricive + // option will be picked + AdjustClientIdleTimeout(ttl time.Duration) time.Duration + + // AdjustDisconnectExpiredCert adjusts the value based on the role set + // the most restrictive option will be picked + AdjustDisconnectExpiredCert(disconnect bool) bool + // CheckAgentForward checks if the role can request agent forward for this // user. CheckAgentForward(login string) error @@ -1503,18 +1428,52 @@ func (set RoleSet) HasRole(role string) bool { } // AdjustSessionTTL will reduce the requested ttl to lowest max allowed TTL -// for this role set, otherwise it returns ttl unchanges +// for this role set, otherwise it returns ttl unchanged func (set RoleSet) AdjustSessionTTL(ttl time.Duration) time.Duration { for _, role := range set { - maxSessionTTL, err := role.GetOptions().GetDuration(MaxSessionTTL) - if err != nil { + maxSessionTTL := role.GetOptions().MaxSessionTTL.Value() + if maxSessionTTL != 0 && ttl > maxSessionTTL { + ttl = maxSessionTTL + } + } + return ttl +} + +// AdjustClientIdleTimeout adjusts requested idle timeout +// to the lowest max allowed timeout, the most restrictive +// option will be picked, negative values will be assumed as 0 +func (set RoleSet) AdjustClientIdleTimeout(timeout time.Duration) time.Duration { + if timeout < 0 { + timeout = 0 + } + for _, role := range set { + roleTimeout := role.GetOptions().ClientIdleTimeout + // 0 means not set, so it can't be most restrictive, disregard it too + if roleTimeout.Duration <= 0 { continue } - if ttl > maxSessionTTL.Duration { - ttl = maxSessionTTL.Duration + switch { + // in case if timeout is 0, means that incoming value + // does not restrict the idle timeout, pick any other value + // set by the role + case timeout == 0: + timeout = roleTimeout.Duration + case roleTimeout.Duration < timeout: + timeout = roleTimeout.Duration } } - return ttl + return timeout +} + +// AdjustDisconnectExpiredCert adjusts the value based on the role set +// the most restrictive option will be picked +func (set RoleSet) AdjustDisconnectExpiredCert(disconnect bool) bool { + for _, role := range set { + if role.GetOptions().DisconnectExpiredCert.Value() { + disconnect = true + } + } + return disconnect } // CheckKubeGroups check if role can login into kubernetes @@ -1523,13 +1482,9 @@ func (set RoleSet) CheckKubeGroups(ttl time.Duration) ([]string, error) { groups := make(map[string]bool) var matchedTTL bool for _, role := range set { - maxSessionTTL, err := role.GetOptions().GetDuration(MaxSessionTTL) - if err != nil { - return nil, trace.Wrap(err) - } - if ttl <= maxSessionTTL.Duration && maxSessionTTL.Duration != 0 { + maxSessionTTL := role.GetOptions().MaxSessionTTL.Value() + if ttl <= maxSessionTTL && maxSessionTTL != 0 { matchedTTL = true - for _, group := range role.GetKubeGroups(Allow) { groups[group] = true } @@ -1554,11 +1509,8 @@ func (set RoleSet) CheckLoginDuration(ttl time.Duration) ([]string, error) { logins := make(map[string]bool) var matchedTTL bool for _, role := range set { - maxSessionTTL, err := role.GetOptions().GetDuration(MaxSessionTTL) - if err != nil { - return nil, trace.Wrap(err) - } - if ttl <= maxSessionTTL.Duration && maxSessionTTL.Duration != 0 { + maxSessionTTL := role.GetOptions().MaxSessionTTL.Value() + if ttl <= maxSessionTTL && maxSessionTTL != 0 { matchedTTL = true for _, login := range role.GetLogins(Allow) { @@ -1622,11 +1574,7 @@ func (set RoleSet) CheckAccessToServer(login string, s Server) error { // CanForwardAgents returns true if role set allows forwarding agents. func (set RoleSet) CanForwardAgents() bool { for _, role := range set { - forwardAgent, err := role.GetOptions().GetBoolean(ForwardAgent) - if err != nil { - return false - } - if forwardAgent == true { + if role.GetOptions().ForwardAgent.Value() { return true } } @@ -1636,11 +1584,7 @@ func (set RoleSet) CanForwardAgents() bool { // CanPortForward returns true if a role in the RoleSet allows port forwarding. func (set RoleSet) CanPortForward() bool { for _, role := range set { - portForwarding, err := role.GetOptions().GetBoolean(PortForwarding) - if err != nil { - return false - } - if portForwarding == true { + if BoolOption(role.GetOptions().PortForwarding).Value() { return true } } @@ -1655,8 +1599,8 @@ func (set RoleSet) CertificateFormat() string { for _, role := range set { // get the certificate format for each individual role. if a role does not // have a certificate format (like implicit roles) skip over it - certificateFormat, err := role.GetOptions().GetString(CertificateFormat) - if err != nil { + certificateFormat := role.GetOptions().CertificateFormat + if certificateFormat == "" { continue } @@ -1697,11 +1641,7 @@ func (set RoleSet) CheckAgentForward(login string) error { // in the first place. for _, role := range set { for _, l := range role.GetLogins(Allow) { - forwardAgent, err := role.GetOptions().GetBoolean(ForwardAgent) - if err != nil { - return trace.AccessDenied("unable to parse ForwardAgent: %v", err) - } - if forwardAgent && l == login { + if role.GetOptions().ForwardAgent.Value() && l == login { return nil } } @@ -1780,6 +1720,87 @@ func NewDuration(d time.Duration) Duration { return Duration{Duration: d} } +// NewBool returns Bool struct based on bool value +func NewBool(b bool) Bool { + return Bool{bool: b} +} + +// NewBoolOption returns Bool struct based on bool value +func NewBoolOption(b bool) *Bool { + return &Bool{bool: b} +} + +// BoolOption converts bool pointer to Bool value +// returns equivalent of false if not set +func BoolOption(v *Bool) Bool { + if v == nil { + return Bool{} + } + return *v +} + +// Bool is a wrapper around boolean values +type Bool struct { + bool +} + +// Value returns boolean value of the wrapper +func (b Bool) Value() bool { + return b.bool +} + +// MarshalJSON marshals Duration to string +func (b Bool) MarshalJSON() ([]byte, error) { + return json.Marshal(fmt.Sprintf("%t", b.bool)) +} + +// UnmarshalJSON unmarshals JSON from string or bool, +// in case if value is missing or not recognized, defaults to false +func (b *Bool) UnmarshalJSON(data []byte) error { + if len(data) == 0 { + return nil + } + // check if it's a bool variable + if err := json.Unmarshal(data, &b.bool); err == nil { + return nil + } + // also support string variables + var stringVar string + if err := json.Unmarshal(data, &stringVar); err != nil { + return trace.Wrap(err) + } + v, err := utils.ParseBool(stringVar) + if err != nil { + b.bool = false + return nil + } + b.bool = v + return nil +} + +// MarshalYAML marshals bool into yaml value +func (b Bool) MarshalYAML() (interface{}, error) { + return b.bool, nil +} + +func (b *Bool) UnmarshalYAML(unmarshal func(interface{}) error) error { + var boolVar bool + if err := unmarshal(&boolVar); err == nil { + b.bool = boolVar + } + var stringVar string + if err := unmarshal(&stringVar); err != nil { + return trace.Wrap(err) + } + v, err := utils.ParseBool(stringVar) + if err != nil { + b.bool = v + return nil + } + b.bool = v + return nil +} + // Duration is a wrapper around duration to set up custom marshal/unmarshal type Duration struct { time.Duration @@ -1790,6 +1811,11 @@ func (d Duration) MarshalJSON() ([]byte, error) { return json.Marshal(fmt.Sprintf("%v", d.Duration)) } +// Value returns time.Duration value of this wrapper +func (d Duration) Value() time.Duration { + return d.Duration +} + // UnmarshalJSON marshals Duration to string func (d *Duration) UnmarshalJSON(data []byte) error { if len(data) == 0 { @@ -1799,24 +1825,38 @@ func (d *Duration) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &stringVar); err != nil { return trace.Wrap(err) } - out, err := time.ParseDuration(stringVar) - if err != nil { - return trace.BadParameter(err.Error()) + if stringVar == teleport.DurationNever { + d.Duration = 0 + } else { + out, err := time.ParseDuration(stringVar) + if err != nil { + return trace.BadParameter(err.Error()) + } + d.Duration = out } - d.Duration = out return nil } +// MarshalYAML marshals duration into YAML value, +// encodes it as a string in format "1m" +func (d Duration) MarshalYAML() (interface{}, error) { + return fmt.Sprintf("%v", d.Duration), nil +} + func (d *Duration) UnmarshalYAML(unmarshal func(interface{}) error) error { var stringVar string if err := unmarshal(&stringVar); err != nil { return trace.Wrap(err) } - out, err := time.ParseDuration(stringVar) - if err != nil { - return trace.BadParameter(err.Error()) + if stringVar == teleport.DurationNever { + d.Duration = 0 + } else { + out, err := time.ParseDuration(stringVar) + if err != nil { + return trace.BadParameter(err.Error()) + } + d.Duration = out } - d.Duration = out return nil } diff --git a/lib/services/role_test.go b/lib/services/role_test.go index 1659153cec3..55c98974e20 100644 --- a/lib/services/role_test.go +++ b/lib/services/role_test.go @@ -183,19 +183,19 @@ func (s *RoleSuite) TestRoleParse(c *C) { }, { name: "role with no spec still gets defaults", - in: `{"kind": "role", "version": "v3", "metadata": {"name": "name1"}, "spec": {}}`, + in: `{"kind": "role", "version": "v3", "metadata": {"name": "defrole"}, "spec": {}}`, role: RoleV3{ Kind: KindRole, Version: V3, Metadata: Metadata{ - Name: "name1", + Name: "defrole", Namespace: defaults.Namespace, }, Spec: RoleSpecV3{ Options: RoleOptions{ CertificateFormat: teleport.CertificateFormatStandard, MaxSessionTTL: NewDuration(defaults.MaxCertDuration), - PortForwarding: true, + PortForwarding: NewBoolOption(true), }, Allow: RoleConditions{ NodeLabels: map[string]string{Wildcard: Wildcard}, @@ -218,7 +218,9 @@ func (s *RoleSuite) TestRoleParse(c *C) { "options": { "cert_format": "standard", "max_session_ttl": "20h", - "port_forwarding": true + "port_forwarding": true, + "client_idle_timeout": "17m", + "disconnect_expired_cert": "yes" }, "allow": { "node_labels": {"a": "b"}, @@ -248,9 +250,83 @@ func (s *RoleSuite) TestRoleParse(c *C) { }, Spec: RoleSpecV3{ Options: RoleOptions{ - CertificateFormat: teleport.CertificateFormatStandard, - MaxSessionTTL: NewDuration(20 * time.Hour), - PortForwarding: true, + CertificateFormat: teleport.CertificateFormatStandard, + MaxSessionTTL: NewDuration(20 * time.Hour), + PortForwarding: NewBoolOption(true), + ClientIdleTimeout: NewDuration(17 * time.Minute), + DisconnectExpiredCert: NewBool(true), + }, + Allow: RoleConditions{ + NodeLabels: map[string]string{"a": "b"}, + Namespaces: []string{"default"}, + Rules: []Rule{ + Rule{ + Resources: []string{KindRole}, + Verbs: []string{VerbRead, VerbList}, + Where: "contains(user.spec.traits[\"groups\"], \"prod\")", + Actions: []string{ + "log(\"info\", \"log entry\")", + }, + }, + }, + }, + Deny: RoleConditions{ + Namespaces: []string{defaults.Namespace}, + Logins: []string{"c"}, + }, + }, + }, + error: nil, + }, + { + name: "alternative options forma", + in: `{ + "kind": "role", + "version": "v3", + "metadata": {"name": "name1"}, + "spec": { + "options": { + "cert_format": "standard", + "max_session_ttl": "20h", + "port_forwarding": "yes", + "forward_agent": "yes", + "client_idle_timeout": "never", + "disconnect_expired_cert": "no" + }, + "allow": { + "node_labels": {"a": "b"}, + "namespaces": ["default"], + "rules": [ + { + "resources": ["role"], + "verbs": ["read", "list"], + "where": "contains(user.spec.traits[\"groups\"], \"prod\")", + "actions": [ + "log(\"info\", \"log entry\")" + ] + } + ] + }, + "deny": { + "logins": ["c"] + } + } + }`, + role: RoleV3{ + Kind: KindRole, + Version: V3, + Metadata: Metadata{ + Name: "name1", + Namespace: defaults.Namespace, + }, + Spec: RoleSpecV3{ + Options: RoleOptions{ + CertificateFormat: teleport.CertificateFormatStandard, + ForwardAgent: NewBool(true), + MaxSessionTTL: NewDuration(20 * time.Hour), + PortForwarding: NewBoolOption(true), + ClientIdleTimeout: NewDuration(0), + DisconnectExpiredCert: NewBool(false), }, Allow: RoleConditions{ NodeLabels: map[string]string{"a": "b"}, @@ -293,7 +369,7 @@ func (s *RoleSuite) TestRoleParse(c *C) { role2, err := UnmarshalRole(out) c.Assert(err, IsNil, comment) - c.Assert(*role2, DeepEquals, tc.role, comment) + fixtures.DeepCompare(c, *role2, tc.role) } } } diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index 095bde88e23..e11187475f6 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -429,7 +429,9 @@ func (s *ServicesTestSuite) RolesCRUD(c *C) { }, Spec: services.RoleSpecV3{ Options: services.RoleOptions{ - services.MaxSessionTTL: services.Duration{Duration: time.Hour}, + MaxSessionTTL: services.Duration{Duration: time.Hour}, + PortForwarding: services.NewBoolOption(true), + CertificateFormat: teleport.CertificateFormatStandard, }, Allow: services.RoleConditions{ Logins: []string{"root", "bob"}, @@ -448,7 +450,7 @@ func (s *ServicesTestSuite) RolesCRUD(c *C) { c.Assert(err, IsNil) rout, err := s.Access.GetRole(role.Metadata.Name) c.Assert(err, IsNil) - c.Assert(rout, DeepEquals, &role) + fixtures.DeepCompare(c, rout, &role) role.Spec.Allow.Logins = []string{"bob"} err = s.Access.UpsertRole(&role, backend.Forever) diff --git a/lib/srv/authhandlers.go b/lib/srv/authhandlers.go index 17a62899058..11058652d37 100644 --- a/lib/srv/authhandlers.go +++ b/lib/srv/authhandlers.go @@ -21,6 +21,7 @@ import ( "net" "os" "os/user" + "time" "golang.org/x/crypto/ssh" @@ -71,6 +72,9 @@ func (h *AuthHandlers) CreateIdentityContext(sconn *ssh.ServerConn) (IdentityCon if err != nil { return IdentityContext{}, trace.Wrap(err) } + if certificate.ValidBefore != 0 { + identity.CertValidBefore = time.Unix(int64(certificate.ValidBefore), 0) + } certAuthority, err := h.authorityForCert(services.UserCA, certificate.SignatureKey) if err != nil { diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index b82a2867f5a..649b3c52118 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -17,10 +17,12 @@ limitations under the License. package srv import ( + "context" "fmt" "io" "sync" "sync/atomic" + "time" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" @@ -33,8 +35,9 @@ import ( rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" - "github.com/gravitational/trace" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" log "github.com/sirupsen/logrus" ) @@ -75,6 +78,9 @@ type Server interface { // GetPAM returns PAM configuration for this server. GetPAM() (*pam.Config, error) + + // GetClock returns a clock setup for the server + GetClock() clockwork.Clock } // IdentityContext holds all identity information associated with the user @@ -96,6 +102,10 @@ type IdentityContext struct { // RoleSet is the roles this Teleport user is associated with. RoleSet is // used to check RBAC permissions. RoleSet services.RoleSet + + // CertValidBefore is set to the expiry time of a certificate, or + // empty, if cert does not expire + CertValidBefore time.Time } // GetCertificate parses the SSH certificate bytes and returns a *ssh.Certificate. @@ -183,6 +193,23 @@ type ServerContext struct { // RemoteSession holds a SSH session to a remote server. Only used by the // recording proxy. RemoteSession *ssh.Session + + // clientLastActive records the last time there was activity from the client + clientLastActive time.Time + + // disconnectExpiredCert is set to time when/if the certificate should + // be disconnected, set to empty if no disconect is necessary + disconnectExpiredCert time.Time + + // clientIdleTimeout is set to the timeout on + // on client inactivity, set to 0 if not setup + clientIdleTimeout time.Duration + + // cancelContext signals closure to all outstanding operations + cancelContext context.Context + + // cancel is called whenever server context is closed + cancel context.CancelFunc } // NewServerContext creates a new *ServerContext which is used to pass and @@ -193,6 +220,8 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity return nil, trace.Wrap(err) } + cancelContext, cancel := context.WithCancel(context.TODO()) + ctx := &ServerContext{ id: int(atomic.AddInt32(&ctxID, int32(1))), env: make(map[string]string), @@ -203,19 +232,38 @@ func NewServerContext(srv Server, conn *ssh.ServerConn, identityContext Identity ClusterName: conn.Permissions.Extensions[utils.CertTeleportClusterName], ClusterConfig: clusterConfig, Identity: identityContext, + clientIdleTimeout: identityContext.RoleSet.AdjustClientIdleTimeout(clusterConfig.GetClientIdleTimeout()), + cancelContext: cancelContext, + cancel: cancel, } + disconnectExpiredCert := identityContext.RoleSet.AdjustDisconnectExpiredCert(clusterConfig.GetDisconnectExpiredCert()) + if !identityContext.CertValidBefore.IsZero() && disconnectExpiredCert { + ctx.disconnectExpiredCert = identityContext.CertValidBefore + } + + fields := log.Fields{ + "local": conn.LocalAddr(), + "remote": conn.RemoteAddr(), + "login": ctx.Identity.Login, + "teleportUser": ctx.Identity.TeleportUser, + "id": ctx.id, + } + if !ctx.disconnectExpiredCert.IsZero() { + fields["cert"] = ctx.disconnectExpiredCert + } + if ctx.clientIdleTimeout != 0 { + fields["idle"] = ctx.clientIdleTimeout + } ctx.Entry = log.WithFields(log.Fields{ - trace.Component: srv.Component(), - trace.ComponentFields: log.Fields{ - "local": conn.LocalAddr(), - "remote": conn.RemoteAddr(), - "login": ctx.Identity.Login, - "teleportUser": ctx.Identity.TeleportUser, - "id": ctx.id, - }, + trace.Component: srv.Component(), + trace.ComponentFields: fields, }) + if !ctx.disconnectExpiredCert.IsZero() || ctx.clientIdleTimeout != 0 { + go ctx.periodicCheckDisconnect() + } + return ctx, nil } @@ -265,6 +313,86 @@ func (c *ServerContext) CreateOrJoinSession(reg *SessionRegistry) error { return nil } +func (c *ServerContext) periodicCheckDisconnect() { + var certTime <-chan time.Time + if !c.disconnectExpiredCert.IsZero() { + t := time.NewTimer(c.disconnectExpiredCert.Sub(c.srv.GetClock().Now().UTC())) + defer t.Stop() + certTime = t.C + } + + var idleTimer *time.Timer + var idleTime <-chan time.Time + if c.clientIdleTimeout != 0 { + idleTimer = time.NewTimer(c.clientIdleTimeout) + idleTime = idleTimer.C + } + + for { + select { + // certificate has expired, disconnect + case <-certTime: + event := events.EventFields{ + events.EventType: events.ClientDisconnectEvent, + events.EventLogin: c.Identity.Login, + events.EventUser: c.Identity.TeleportUser, + events.LocalAddr: c.Conn.LocalAddr().String(), + events.RemoteAddr: c.Conn.RemoteAddr().String(), + events.SessionServerID: c.srv.ID(), + events.Reason: fmt.Sprintf("client certificate expired at %v", c.clientLastActive), + } + c.srv.EmitAuditEvent(events.ClientDisconnectEvent, event) + c.Debugf("Disconnecting client: %v", event[events.Reason]) + c.Conn.Close() + return + case <-idleTime: + now := c.srv.GetClock().Now() + clientLastActive := c.GetClientLastActive() + c.Debugf("client last active %v, client idle timeout %v", clientLastActive, c.clientIdleTimeout) + if now.Sub(clientLastActive) >= c.clientIdleTimeout { + event := events.EventFields{ + events.EventLogin: c.Identity.Login, + events.EventUser: c.Identity.TeleportUser, + events.LocalAddr: c.Conn.LocalAddr().String(), + events.RemoteAddr: c.Conn.RemoteAddr().String(), + events.SessionServerID: c.srv.ID(), + } + if clientLastActive.IsZero() { + event[events.Reason] = "client reported no activity" + } else { + event[events.Reason] = fmt.Sprintf("client is idle for %v, exceeded idle timeout of %v", + now.Sub(clientLastActive), c.clientIdleTimeout) + } + c.Debugf("Disconnecting client: %v", event[events.Reason]) + c.srv.EmitAuditEvent(events.ClientDisconnectEvent, event) + c.Conn.Close() + return + } + c.Debugf("Next check in %v", c.clientIdleTimeout-now.Sub(clientLastActive)) + idleTimer = time.NewTimer(c.clientIdleTimeout - now.Sub(clientLastActive)) + idleTime = idleTimer.C + case <-c.cancelContext.Done(): + c.Debugf("Releasing associated resources - context has been closed.") + return + } + } +} + +// GetClientLastActive returns time when client was last active +func (c *ServerContext) GetClientLastActive() time.Time { + c.RLock() + defer c.RUnlock() + return c.clientLastActive +} + +// UpdateClientActivity sets last recorded client activity associated with this context +// either channel or session +func (c *ServerContext) UpdateClientActivity() { + c.Lock() + defer c.Unlock() + c.clientLastActive = c.srv.GetClock().Now().UTC() +} + // AddCloser adds any closer in ctx that will be called // whenever server closes session channel func (c *ServerContext) AddCloser(closer io.Closer) { @@ -348,6 +476,7 @@ func (c *ServerContext) takeClosers() []io.Closer { } func (c *ServerContext) Close() error { + c.cancel() return closeAll(c.takeClosers()...) } @@ -425,3 +554,24 @@ type closerFunc func() error func (f closerFunc) Close() error { return f() } + +// NewTrackingReader returns a new instance of +// activity tracking reader. +func NewTrackingReader(ctx *ServerContext, r io.Reader) *TrackingReader { + return &TrackingReader{ctx: ctx, r: r} +} + +// TrackingReader wraps the writer +// and every time write occurs, updates +// the activity in the server context +type TrackingReader struct { + ctx *ServerContext + r io.Reader +} + +// Read passes the read through to internal +// reader, and updates activity of the server context +func (a *TrackingReader) Read(b []byte) (int, error) { + a.ctx.UpdateClientActivity() + return a.r.Read(b) +} diff --git a/lib/srv/forward/sshserver.go b/lib/srv/forward/sshserver.go index 5bf8704e894..d53dc9e3c22 100644 --- a/lib/srv/forward/sshserver.go +++ b/lib/srv/forward/sshserver.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport/lib/utils/proxy" "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" "github.com/pborman/uuid" "github.com/sirupsen/logrus" ) @@ -125,6 +126,8 @@ type Server struct { // server is closing and all blocking goroutines should unblock. closeContext context.Context closeCancel context.CancelFunc + + clock clockwork.Clock } // ServerConfig is the configuration needed to create an instance of a Server. @@ -150,6 +153,9 @@ type ServerConfig struct { // DataDir is a local data directory used for local server storage DataDir string + + // Clock is an optoinal clock to override default real time clock + Clock clockwork.Clock } // CheckDefaults makes sure all required parameters are passed in. @@ -175,6 +181,9 @@ func (s *ServerConfig) CheckDefaults() error { if s.HostCertificate == nil { return trace.BadParameter("host certificate required to act on behalf of remote host") } + if s.Clock == nil { + s.Clock = clockwork.NewRealClock() + } return nil } @@ -214,6 +223,7 @@ func New(c ServerConfig) (*Server, error) { authService: c.AuthClient, sessionServer: c.AuthClient, dataDir: c.DataDir, + clock: c.Clock, } // Set the ciphers, KEX, and MACs that the in-memory server will send to the @@ -321,6 +331,11 @@ func (s *Server) Dial() (net.Conn, error) { return s.clientConn, nil } +// GetClock returns server clock implementation +func (s *Server) GetClock() clockwork.Clock { + return s.clock +} + func (s *Server) Serve() { config := &ssh.ServerConfig{} @@ -635,6 +650,9 @@ func (s *Server) handleDirectTCPIPRequest(ch ssh.Channel, req *sshutils.DirectTC func (s *Server) handleSessionRequests(ch ssh.Channel, in <-chan *ssh.Request) { // Create context for this channel. This context will be closed when the // session request is complete. + // There is no need for the forwarding server to initiate disconnects, + // based on teleport business logic, because this logic is already + // done on the server's terminating side. ctx, err := srv.NewServerContext(s, s.sconn, s.identityContext) if err != nil { ctx.Errorf("Unable to create connection context: %v.", err) diff --git a/lib/srv/regular/proxy.go b/lib/srv/regular/proxy.go index 79507474319..1b97d053c3f 100644 --- a/lib/srv/regular/proxy.go +++ b/lib/srv/regular/proxy.go @@ -187,16 +187,16 @@ func (t *proxySubsys) Start(sconn *ssh.ServerConn, ch ssh.Channel, req *ssh.Requ site = sites[0] t.log.Debugf("Cluster not specified. connecting to default='%s'", site.GetName()) } - return t.proxyToHost(site, clientAddr, ch) + return t.proxyToHost(ctx, site, clientAddr, ch) } // connect to a site's auth server: - return t.proxyToSite(site, clientAddr, ch) + return t.proxyToSite(ctx, site, clientAddr, ch) } // proxyToSite establishes a proxy connection from the connected SSH client to the // auth server of the requested remote site func (t *proxySubsys) proxyToSite( - site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { + ctx *srv.ServerContext, site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { conn, err := site.DialAuthServer() if err != nil { @@ -218,7 +218,7 @@ func (t *proxySubsys) proxyToSite( t.close(err) }() defer conn.Close() - _, err = io.Copy(conn, ch) + _, err = io.Copy(conn, srv.NewTrackingReader(ctx, ch)) }() @@ -228,7 +228,7 @@ func (t *proxySubsys) proxyToSite( // proxyToHost establishes a proxy connection from the connected SSH client to the // requested remote node (t.host:t.port) via the given site func (t *proxySubsys) proxyToHost( - site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { + ctx *srv.ServerContext, site reversetunnel.RemoteSite, remoteAddr net.Addr, ch ssh.Channel) error { // // first, lets fetch a list of servers at the given site. this allows us to // match the given "host name" against node configuration (their 'nodename' setting) @@ -324,7 +324,7 @@ func (t *proxySubsys) proxyToHost( t.close(err) }() defer conn.Close() - _, err = io.Copy(conn, ch) + _, err = io.Copy(conn, srv.NewTrackingReader(ctx, ch)) }() return nil diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index e0a5c66030c..84e1708d62f 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -128,6 +128,11 @@ type Server struct { dataDir string } +// GetClock returns server clock implementation +func (s *Server) GetClock() clockwork.Clock { + return s.clock +} + // GetDataDir returns server data dir func (s *Server) GetDataDir() string { return s.dataDir @@ -822,7 +827,7 @@ func (s *Server) handleDirectTCPIPRequest(sconn *ssh.ServerConn, identityContext wg.Add(1) go func() { defer wg.Done() - io.Copy(conn, ch) + io.Copy(conn, srv.NewTrackingReader(ctx, ch)) conn.Close() }() wg.Wait() diff --git a/lib/srv/regular/sshserver_test.go b/lib/srv/regular/sshserver_test.go index e1e54c2941b..4935627e9e7 100644 --- a/lib/srv/regular/sshserver_test.go +++ b/lib/srv/regular/sshserver_test.go @@ -202,7 +202,7 @@ func (s *SrvSuite) TestAgentForwardPermission(c *C) { role, err := s.server.Auth().GetRole(roleName) c.Assert(err, IsNil) roleOptions := role.GetOptions() - roleOptions.Set(services.ForwardAgent, false) + roleOptions.ForwardAgent = services.NewBool(false) role.SetOptions(roleOptions) err = s.server.Auth().UpsertRole(role, backend.Forever) c.Assert(err, IsNil) @@ -228,7 +228,7 @@ func (s *SrvSuite) TestAgentForward(c *C) { role, err := s.server.Auth().GetRole(roleName) c.Assert(err, IsNil) roleOptions := role.GetOptions() - roleOptions.Set(services.ForwardAgent, true) + roleOptions.ForwardAgent = services.NewBool(true) role.SetOptions(roleOptions) err = s.server.Auth().UpsertRole(role, backend.Forever) c.Assert(err, IsNil) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 562c697f46a..d39c222bfaa 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -988,6 +988,7 @@ func (p *party) getLastActive() time.Time { func (p *party) Read(bytes []byte) (int, error) { p.updateActivity() + p.ctx.UpdateClientActivity() return p.ch.Read(bytes) } diff --git a/lib/utils/buf.go b/lib/utils/buf.go new file mode 100644 index 00000000000..265129c8c31 --- /dev/null +++ b/lib/utils/buf.go @@ -0,0 +1,72 @@ +/* +Copyright 2018 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package utils + +import ( + "bytes" + "io" +) + +// NewSyncBuffer returns new in memory buffer +func NewSyncBuffer() *SyncBuffer { + reader, writer := io.Pipe() + buf := &bytes.Buffer{} + go func() { + io.Copy(buf, reader) + }() + return &SyncBuffer{ + reader: reader, + writer: writer, + buf: buf, + } +} + +// SyncBuffer is in memory bytes buffer that is +// safe for concurrent writes +type SyncBuffer struct { + reader *io.PipeReader + writer *io.PipeWriter + buf *bytes.Buffer +} + +func (b *SyncBuffer) Write(data []byte) (n int, err error) { + return b.writer.Write(data) +} + +// String returns contents of the buffer +// after this call, all writes will fail +func (b *SyncBuffer) String() string { + b.Close() + return b.buf.String() +} + +// Bytes returns contents of the buffer +// after this call, all writes will fail +func (b *SyncBuffer) Bytes() []byte { + b.Close() + return b.buf.Bytes() +} + +// Close closes reads and writes on the buffer +func (b *SyncBuffer) Close() error { + err := b.reader.Close() + err2 := b.writer.Close() + if err != nil { + return err + } + return err2 +} diff --git a/lib/utils/utils.go b/lib/utils/utils.go index 5c31bcc2dc2..67188bd0741 100644 --- a/lib/utils/utils.go +++ b/lib/utils/utils.go @@ -38,10 +38,23 @@ func AsBool(v string) bool { if v == "" { return false } - out, _ := strconv.ParseBool(v) + out, _ := ParseBool(v) return out } +// ParseBool parses string as boolean value, +// returns error in case if value is not recognized +func ParseBool(value string) (bool, error) { + switch strings.ToLower(value) { + case "yes", "yeah", "y", "true", "1", "on": + return true, nil + case "no", "nope", "n", "false", "0", "off": + return false, nil + default: + return false, trace.BadParameter("unsupported value: %q", value) + } +} + // ParseAdvertiseAddress validates advertise address, // makes sure it's not an unreachable or multicast address // returns address split into host and port, port could be empty diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 88cbb084972..11859bd241c 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -331,7 +331,7 @@ func (s *WebSuite) createUser(c *C, user string, login string, pass string, otpS role := services.RoleForUser(teleUser) role.SetLogins(services.Allow, []string{login}) options := role.GetOptions() - options[services.ForwardAgent] = true + options.ForwardAgent = services.NewBool(true) role.SetOptions(options) err = s.server.Auth().UpsertRole(role, backend.Forever) c.Assert(err, IsNil) @@ -438,7 +438,7 @@ func (s *WebSuite) TestSAMLSuccess(c *C) { role, err := services.NewRole(connector.GetAttributesToRoles()[0].Roles[0], services.RoleSpecV3{ Options: services.RoleOptions{ - services.MaxSessionTTL: services.NewDuration(defaults.MaxCertDuration), + MaxSessionTTL: services.NewDuration(defaults.MaxCertDuration), }, Allow: services.RoleConditions{ NodeLabels: map[string]string{services.Wildcard: services.Wildcard},