Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multitenancy, approach #1 #121

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ type Client struct {
uid string
user string
info []byte
env string
publicationsOnce sync.Once
closed bool
authenticated bool
Expand Down Expand Up @@ -1151,6 +1152,7 @@ func (c *Client) connectCmd(cmd *protocol.ConnectRequest, rw *replyWriter) *Disc
c.user = credentials.UserID
c.info = credentials.Info
c.exp = credentials.ExpireAt
c.env = credentials.Env
c.mu.Unlock()
case cmd.Token != "":
var (
Expand All @@ -1170,6 +1172,7 @@ func (c *Client) connectCmd(cmd *protocol.ConnectRequest, rw *replyWriter) *Disc
c.mu.Lock()
c.user = token.UserID
c.exp = token.ExpireAt
c.env = token.Env
c.mu.Unlock()

if len(token.Info) > 0 {
Expand Down Expand Up @@ -1495,6 +1498,11 @@ func (c *Client) validateSubscribeRequest(cmd *protocol.SubscribeRequest, server
return ChannelOptions{}, ErrorPermissionDenied, nil
}

if !serverSide && !c.node.hasValidEnv(channel, c.env) {
c.node.logger.log(newLogEntry(LogLevelInfo, "channel belongs to another environment", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid, "env": c.env}))
return ChannelOptions{}, ErrorPermissionDenied, nil
}

if !chOpts.Anonymous && c.user == "" && !insecure {
c.node.logger.log(newLogEntry(LogLevelInfo, "anonymous user is not allowed to subscribe on channel", map[string]interface{}{"channel": channel, "user": c.user, "client": c.uid}))
return ChannelOptions{}, ErrorPermissionDenied, nil
Expand Down Expand Up @@ -2051,6 +2059,11 @@ func (c *Client) publishCmd(cmd *protocol.PublishRequest) (*clientproto.PublishR

resp := &clientproto.PublishResponse{}

if !c.node.hasValidEnv(ch, c.env) {
resp.Error = ErrorPermissionDenied.toProto()
return resp, nil
}

chOpts, ok := c.node.ChannelOpts(ch)
if !ok {
c.node.logger.log(newLogEntry(LogLevelInfo, "attempt to publish to non-existing namespace", map[string]interface{}{"channel": ch, "user": c.user, "client": c.uid}))
Expand Down
9 changes: 9 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ type Config struct {
// Only users with user ID defined will subscribe to personal channels, anonymous
// users are ignored.
UserSubscribeToPersonal bool
// ChannelEnvDelimiter string. Must contain two ascii symbols. If set to "[]" then a
// channel with env set should look like "[env]$public:news"
ChannelEnvDelimiters string
}

// Validate validates config and returns error if problems found
Expand Down Expand Up @@ -135,6 +138,12 @@ func (c *Config) Validate() error {
return fmt.Errorf("namespace for user personal channel not found: %s", personalChannelNamespace)
}

asciiRegexp := regexp.MustCompile("\\w+")

if c.ChannelEnvDelimiters != "" && len(c.ChannelEnvDelimiters) != 2 && !asciiRegexp.Match([]byte(c.ChannelEnvDelimiters)) {
return errors.New("invalid channel env delimiters")
}

return nil
}

Expand Down
18 changes: 18 additions & 0 deletions config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ func TestConfigValidateInvalidNamespaceName(t *testing.T) {
require.Error(t, err)
}

func TestConfigValidateInvalidEnvDelimiters(t *testing.T) {
c := DefaultConfig
err := c.Validate()
require.NoError(t, err)
c.ChannelEnvDelimiters = "}"
err = c.Validate()
require.Error(t, err)
c.ChannelEnvDelimiters = "ЬЬ"
err = c.Validate()
require.Error(t, err)
c.ChannelEnvDelimiters = "|||"
err = c.Validate()
require.Error(t, err)
c.ChannelEnvDelimiters = "{}"
err = c.Validate()
require.NoError(t, err)
}

func TestConfigValidateDuplicateNamespaceName(t *testing.T) {
c := DefaultConfig
c.Namespaces = []ChannelNamespace{
Expand Down
3 changes: 3 additions & 0 deletions credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ import "context"

// Credentials allows to authenticate connection when set into context.
type Credentials struct {
// Env used to set connection environment. This automatically enables using
// environment prefix for channels.
Env string
// UserID tells library an ID of connecting user.
UserID string
// ExpireAt allows to set time in future when connection must be validated.
Expand Down
30 changes: 30 additions & 0 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,7 @@ func (n *Node) Disconnect(user string, opts ...DisconnectOption) error {

// namespaceName returns namespace name from channel if exists.
func (n *Node) namespaceName(ch string) string {
ch, _ = n.stripEnv(ch)
cTrim := strings.TrimPrefix(ch, n.config.ChannelPrivatePrefix)
if n.config.ChannelNamespaceBoundary != "" && strings.Contains(cTrim, n.config.ChannelNamespaceBoundary) {
parts := strings.SplitN(cTrim, n.config.ChannelNamespaceBoundary, 2)
Expand Down Expand Up @@ -905,6 +906,35 @@ func (n *Node) privateChannel(ch string) bool {
return strings.HasPrefix(ch, n.config.ChannelPrivatePrefix)
}

func (n *Node) stripEnv(ch string) (string, bool) {
if n.config.ChannelEnvDelimiters == "" {
return ch, false
}
envStartSymbol := string(n.config.ChannelEnvDelimiters[0])
if strings.HasPrefix(ch, envStartSymbol) {
envEndSymbol := string(n.config.ChannelEnvDelimiters[1])
index := strings.Index(ch[1:], envEndSymbol)
if index > 0 {
return ch[index+2:], true
}
}
return ch, false
}

func (n *Node) hasValidEnv(ch string, env string) bool {
if env == "" {
return true
}
n.mu.RLock()
defer n.mu.RUnlock()
envStartSymbol := string(n.config.ChannelEnvDelimiters[0])
if !strings.HasPrefix(ch, envStartSymbol) || len(ch) < len(env)+2 {
return false
}
envEndSymbol := string(n.config.ChannelEnvDelimiters[1])
return strings.HasPrefix(ch, envStartSymbol+env+envEndSymbol)
}

// userAllowed checks if user can subscribe on channel - as channel
// can contain special part in the end to indicate which users allowed
// to subscribe on it.
Expand Down
96 changes: 94 additions & 2 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,59 @@ func TestUserAllowed(t *testing.T) {
require.False(t, node.userAllowed("channel#1,2", "3"))
}

func TestHasValidEnv(t *testing.T) {
node := nodeWithTestEngine()
defer func() { _ = node.Shutdown(context.Background()) }()

conf := node.Config()
conf.ChannelEnvDelimiters = "[]"
err := node.Reload(conf)
require.NoError(t, err)

require.True(t, node.hasValidEnv("[test]channel", ""))
require.True(t, node.hasValidEnv("[test]channel", "test"))
require.False(t, node.hasValidEnv("[test]channel", "test2"))
require.False(t, node.hasValidEnv("ss[test2]channel", "test2"))
}

func TestStripEnv(t *testing.T) {
node := nodeWithTestEngine()
defer func() { _ = node.Shutdown(context.Background()) }()

conf := node.Config()
conf.ChannelEnvDelimiters = "[]"
err := node.Reload(conf)
require.NoError(t, err)

ch, found := node.stripEnv("test")
require.Equal(t, "test", ch)
require.False(t, found)

ch, found = node.stripEnv("[xxx]test")
require.Equal(t, "test", ch)
require.True(t, found)

ch, found = node.stripEnv("[xxx-test")
require.Equal(t, "[xxx-test", ch)
require.False(t, found)

ch, found = node.stripEnv("xxx]test")
require.Equal(t, "xxx]test", ch)
require.False(t, found)

ch, found = node.stripEnv("[xxx]yyy[test")
require.Equal(t, "yyy[test", ch)
require.True(t, found)

ch, found = node.stripEnv("[]")
require.Equal(t, "[]", ch)
require.False(t, found)

ch, found = node.stripEnv("[")
require.Equal(t, "[", ch)
require.False(t, found)
}

func TestSetConfig(t *testing.T) {
node := nodeWithTestEngine()
defer func() { _ = node.Shutdown(context.Background()) }()
Expand Down Expand Up @@ -282,8 +335,6 @@ func BenchmarkHistory(b *testing.B) {
require.NoError(b, err)
}

b.ResetTimer()

b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := e.node.History(channel)
Expand All @@ -295,3 +346,44 @@ func BenchmarkHistory(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
}

var validEnv bool

func BenchmarkHasValidEnv(b *testing.B) {
e := testMemoryEngine()
conf := e.node.Config()
conf.ChannelEnvDelimiters = "[]"
err := e.node.Reload(conf)
require.NoError(b, err)

b.ResetTimer()
for i := 0; i < b.N; i++ {
validEnv = e.node.hasValidEnv("[test]test", "test")
if !validEnv {
b.Fatal(err)
}

}
b.StopTimer()
b.ReportAllocs()
}

var strippedChannel string

func BenchmarkStripEnv(b *testing.B) {
e := testMemoryEngine()
conf := e.node.Config()
conf.ChannelEnvDelimiters = "[]"
err := e.node.Reload(conf)
require.NoError(b, err)
channel := "[test]test"
b.ResetTimer()
for i := 0; i < b.N; i++ {
strippedChannel, _ = e.node.stripEnv(channel)
if strippedChannel != "test" {
b.Fatal("env not properly stripped")
}
}
b.StopTimer()
b.ReportAllocs()
}
3 changes: 3 additions & 0 deletions token_verifier.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ type tokenVerifier interface {
}

type connectToken struct {
// Env used to set connection environment. This automatically enables using
// environment prefix for channels.
Env string
// UserID tells library an ID of connecting user.
UserID string
// ExpireAt allows to set time in future when connection must be validated.
Expand Down
2 changes: 2 additions & 0 deletions token_verifier_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ var (
)

type connectTokenClaims struct {
Env string `json:"env,omitempty"`
Info json.RawMessage `json:"info,omitempty"`
Base64Info string `json:"b64info,omitempty"`
Channels []string `json:"channels,omitempty"`
Expand Down Expand Up @@ -165,6 +166,7 @@ func (verifier *tokenVerifierJWT) VerifyConnectToken(t string) (connectToken, er
}

ct := connectToken{
Env: claims.Env,
UserID: claims.StandardClaims.Subject,
Info: claims.Info,
Channels: claims.Channels,
Expand Down