diff --git a/main.go b/main.go index 9ce5eed..0686dfe 100644 --- a/main.go +++ b/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "flag" "fmt" "net/http" @@ -37,6 +38,7 @@ type mysqlConnKey struct { type timedConn struct { *mysql.Conn lastUsed time.Time + mu sync.Mutex } var ( @@ -52,20 +54,27 @@ var ( flagMySQLDbname = commandLine.String("mysql-dbname", "mysql", "MySQL database to connect to") ) +var errSessionInUse = errors.New("session already in use") + // getConn gets or dials a connection from the connection pool // connections are maintained unique for credential combos and session id // since this isn't meant to truly represent reality, it's possible you // can do things with connections locally by munging session ids or auth // that aren't allowed on PlanetScale. This is meant to just mimic the public API. -func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, error) { +func getConn(ctx context.Context, uname, pass, session string) (*timedConn, error) { key := mysqlConnKey{uname, pass, session} // check first if there's already a connection connMu.RLock() if conn, ok := connPool[key]; ok { - connMu.RUnlock() - conn.lastUsed = time.Now() - return conn.Conn, nil + defer connMu.RUnlock() + + if conn.mu.TryLock() { + conn.lastUsed = time.Now() + return conn, nil + } else { + return nil, errSessionInUse + } } connMu.RUnlock() @@ -78,15 +87,12 @@ func getConn(ctx context.Context, uname, pass, session string) (*mysql.Conn, err // lock to write to map connMu.Lock() - connPool[key] = &timedConn{rawConn, time.Now()} + connPool[key] = &timedConn{Conn: rawConn, lastUsed: time.Now()} connMu.Unlock() // since it was parallel, the last one would have won and been written // so re-read back so we use the conn that was actually stored in the pool - connMu.RLock() - conn := connPool[key] - connMu.RUnlock() - return conn.Conn, nil + return getConn(ctx, uname, pass, session) } // dial connects to the underlying MySQL server, and switches to the underlying @@ -161,13 +167,22 @@ func (s *server) CreateSession( session := gonanoid.Must() - if _, err := getConn(context.Background(), creds.Username(), string(creds.SecretBytes()), session); err != nil { + if conn, err := getConn(context.Background(), creds.Username(), string(creds.SecretBytes()), session); err != nil { if strings.Contains(err.Error(), "Access denied for user") { ll.Error("unauthenticated", log.Error(err)) return nil, connect.NewError(connect.CodeUnauthenticated, err) + } else if err == errSessionInUse { + ll.Warn(err.Error()) + return nil, connect.NewError( + connect.CodePermissionDenied, + fmt.Errorf("%s: %s", err.Error(), session), + ) } ll.Error("failed to connect", log.Error(err)) return nil, err + } else { + // need to release the lock immediately since it's not being used. + conn.mu.Unlock() } ll.Info("ok") @@ -222,15 +237,23 @@ func (s *server) Execute( if strings.Contains(err.Error(), "Access denied for user") { ll.Error("unauthenticated", log.Error(err)) return nil, connect.NewError(connect.CodeUnauthenticated, err) + } else if err == errSessionInUse { + ll.Warn(err.Error()) + return nil, connect.NewError( + connect.CodePermissionDenied, + fmt.Errorf("%s: %s", err.Error(), session), + ) } ll.Error("failed to connect", log.Error(err)) return nil, err } + defer conn.mu.Unlock() ll.Info("ok") // This is a gross simplificiation, but is likely sufficient qr, err := conn.ExecuteFetch(query, int(*flagMySQLMaxRows), true) + return connect.NewResponse(&psdbv1alpha1.ExecuteResponse{ Session: session, Result: sqltypes.ResultToProto3(qr), @@ -279,10 +302,17 @@ func (s *server) StreamExecute( if strings.Contains(err.Error(), "Access denied for user") { ll.Error("unauthenticated", log.Error(err)) return connect.NewError(connect.CodeUnauthenticated, err) + } else if err == errSessionInUse { + ll.Warn(err.Error()) + return connect.NewError( + connect.CodePermissionDenied, + fmt.Errorf("%s: %s", err.Error(), session), + ) } ll.Error("failed to connect", log.Error(err)) return err } + defer conn.mu.Unlock() // fake a streaming response by just returning 2 messages of the same payload // far from reality, but a simple way to exercise the protocol. @@ -346,6 +376,11 @@ func initConnPool() { conn.Close() delete(connPool, key) connMu.Unlock() + + logger.Debug("closing idle connection", + log.String("username", key.username), + log.String("session_id", key.session), + ) } } }