Skip to content

Commit

Permalink
Prevent concurrent session usage
Browse files Browse the repository at this point in the history
Fixes #1
  • Loading branch information
mattrobenolt committed Nov 11, 2023
1 parent 620d3a9 commit 7dc275b
Showing 1 changed file with 45 additions and 10 deletions.
55 changes: 45 additions & 10 deletions main.go
Expand Up @@ -2,6 +2,7 @@ package main

import (
"context"
"errors"
"flag"
"fmt"
"net/http"
Expand Down Expand Up @@ -37,6 +38,7 @@ type mysqlConnKey struct {
type timedConn struct {
*mysql.Conn
lastUsed time.Time
mu sync.Mutex
}

var (
Expand All @@ -52,20 +54,27 @@ var (
flagMySQLDbname = commandLine.String("mysql-dbname", "mysql", "MySQL database to connect to")
)

var errSessionInUse = errors.New("session already in use")

// getConn gets or dials a connection from the connection pool
// connections are maintained unique for credential combos and session id
// since this isn't meant to truly represent reality, it's possible you
// can do things with connections locally by munging session ids or auth
// that aren't allowed on PlanetScale. This is meant to just mimic the public API.
func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, error) {
func getConn(ctx context.Context, uname, pass, session string) (*timedConn, error) {
key := mysqlConnKey{uname, pass, session}

// check first if there's already a connection
connMu.RLock()
if conn, ok := connPool[key]; ok {
connMu.RUnlock()
conn.lastUsed = time.Now()
return conn.Conn, nil
defer connMu.RUnlock()

if conn.mu.TryLock() {
conn.lastUsed = time.Now()
return conn, nil
} else {
return nil, errSessionInUse
}
}
connMu.RUnlock()

Expand All @@ -78,15 +87,12 @@ func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, err

// lock to write to map
connMu.Lock()
connPool[key] = &timedConn{rawConn, time.Now()}
connPool[key] = &timedConn{Conn: rawConn, lastUsed: time.Now()}
connMu.Unlock()

// since it was parallel, the last one would have won and been written
// so re-read back so we use the conn that was actually stored in the pool
connMu.RLock()
conn := connPool[key]
connMu.RUnlock()
return conn.Conn, nil
return getConn(ctx, uname, pass, session)
}

// dial connects to the underlying MySQL server, and switches to the underlying
Expand Down Expand Up @@ -161,13 +167,22 @@ func (s *server) CreateSession(

session := gonanoid.Must()

if _, err := getConn(context.Background(), creds.Username(), string(creds.SecretBytes()), session); err != nil {
if conn, err := getConn(context.Background(), creds.Username(), string(creds.SecretBytes()), session); err != nil {
if strings.Contains(err.Error(), "Access denied for user") {
ll.Error("unauthenticated", log.Error(err))
return nil, connect.NewError(connect.CodeUnauthenticated, err)
} else if err == errSessionInUse {
ll.Warn(err.Error())
return nil, connect.NewError(
connect.CodePermissionDenied,
fmt.Errorf("%s: %s", err.Error(), session),
)
}
ll.Error("failed to connect", log.Error(err))
return nil, err
} else {
// need to release the lock immediately since it's not being used.
conn.mu.Unlock()
}

ll.Info("ok")
Expand Down Expand Up @@ -222,15 +237,23 @@ func (s *server) Execute(
if strings.Contains(err.Error(), "Access denied for user") {
ll.Error("unauthenticated", log.Error(err))
return nil, connect.NewError(connect.CodeUnauthenticated, err)
} else if err == errSessionInUse {
ll.Warn(err.Error())
return nil, connect.NewError(
connect.CodePermissionDenied,
fmt.Errorf("%s: %s", err.Error(), session),
)
}
ll.Error("failed to connect", log.Error(err))
return nil, err
}
defer conn.mu.Unlock()

ll.Info("ok")

// This is a gross simplificiation, but is likely sufficient
qr, err := conn.ExecuteFetch(query, int(*flagMySQLMaxRows), true)

return connect.NewResponse(&psdbv1alpha1.ExecuteResponse{
Session: session,
Result: sqltypes.ResultToProto3(qr),
Expand Down Expand Up @@ -279,10 +302,17 @@ func (s *server) StreamExecute(
if strings.Contains(err.Error(), "Access denied for user") {
ll.Error("unauthenticated", log.Error(err))
return connect.NewError(connect.CodeUnauthenticated, err)
} else if err == errSessionInUse {
ll.Warn(err.Error())
return connect.NewError(
connect.CodePermissionDenied,
fmt.Errorf("%s: %s", err.Error(), session),
)
}
ll.Error("failed to connect", log.Error(err))
return err
}
defer conn.mu.Unlock()

// fake a streaming response by just returning 2 messages of the same payload
// far from reality, but a simple way to exercise the protocol.
Expand Down Expand Up @@ -346,6 +376,11 @@ func initConnPool() {
conn.Close()
delete(connPool, key)
connMu.Unlock()

logger.Debug("closing idle connection",
log.String("username", key.username),
log.String("session_id", key.session),
)
}
}
}
Expand Down

0 comments on commit 7dc275b

Please sign in to comment.