Skip to content

Commit

Permalink
Merge pull request #2393 from dolthub/zachmu/default-server
Browse files Browse the repository at this point in the history
Restored and refactored missing functionality required by Dolt
  • Loading branch information
zachmu committed Mar 13, 2024
2 parents 3a7d6e5 + a0c6d5f commit 4a2f230
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 11 deletions.
3 changes: 1 addition & 2 deletions enginetest/memory_harness.go
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/mysql_db"
)

const testNumPartitions = 5
Expand Down Expand Up @@ -103,7 +102,7 @@ func (m *MemoryHarness) SessionBuilder() server.SessionBuilder {
return func(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) {
host := ""
user := ""
mysqlConnectionUser, ok := c.UserData.(mysql_db.MysqlConnectionUser)
mysqlConnectionUser, ok := c.UserData.(sql.MysqlConnectionUser)
if ok {
host = mysqlConnectionUser.Host
user = mysqlConnectionUser.User
Expand Down
3 changes: 1 addition & 2 deletions memory/session.go
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/dolthub/vitess/go/mysql"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/mysql_db"
)

type GlobalsMap = map[string]interface{}
Expand Down Expand Up @@ -61,7 +60,7 @@ func NewSessionBuilder(pro *DbProvider) func(ctx context.Context, conn *mysql.Co
return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
host := ""
user := ""
mysqlConnectionUser, ok := conn.UserData.(mysql_db.MysqlConnectionUser)
mysqlConnectionUser, ok := conn.UserData.(sql.MysqlConnectionUser)
if ok {
host = mysqlConnectionUser.Host
user = mysqlConnectionUser.User
Expand Down
39 changes: 39 additions & 0 deletions server/server.go
Expand Up @@ -77,6 +77,45 @@ func NewServer(cfg Config, e *sqle.Engine, sb SessionBuilder, listener ServerEve
//handler = NewHandler_(e, sm, cfg.ConnReadTimeout, cfg.DisableClientMultiStatements, cfg.MaxLoggedQueryLen, cfg.EncodeLoggedQuery, listener)
return newServerFromHandler(cfg, e, sm, handler)
}

// HandlerWrapper provides a way for clients to wrap the mysql.Handler used by the server with a custom implementation
// that wraps it.
type HandlerWrapper func(h mysql.Handler) (mysql.Handler, error)

// NewServerWithHandler creates a Server with a handler wrapped by the provided wrapper function.
func NewServerWithHandler(
cfg Config,
e *sqle.Engine,
sb SessionBuilder,
listener ServerEventListener,
wrapper HandlerWrapper,
) (*Server, error) {
var tracer trace.Tracer
if cfg.Tracer != nil {
tracer = cfg.Tracer
} else {
tracer = sql.NoopTracer
}

sm := NewSessionManager(sb, tracer, e.Analyzer.Catalog.Database, e.MemoryManager, e.ProcessList, cfg.Address)
h := &Handler{
e: e,
sm: sm,
readTimeout: cfg.ConnReadTimeout,
disableMultiStmts: cfg.DisableClientMultiStatements,
maxLoggedQueryLen: cfg.MaxLoggedQueryLen,
encodeLoggedQuery: cfg.EncodeLoggedQuery,
sel: listener,
}

handler, err := wrapper(h)
if err != nil {
return nil, err
}

return newServerFromHandler(cfg, e, sm, handler)
}

func portInUse(hostPort string) bool {
timeout := time.Second
conn, _ := net.DialTimeout("tcp", hostPort, timeout)
Expand Down
3 changes: 1 addition & 2 deletions server/server_test.go
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/server"
gsql "github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/mysql_db"
)

// TestSeverCustomListener verifies a caller can provide their own net.Conn implementation for the server to use
Expand All @@ -37,7 +36,7 @@ func TestSeverCustomListener(t *testing.T) {
sessionBuilder := func(ctx context.Context, c *vsql.Conn, addr string) (gsql.Session, error) {
host := ""
user := ""
mysqlConnectionUser, ok := c.UserData.(mysql_db.MysqlConnectionUser)
mysqlConnectionUser, ok := c.UserData.(gsql.MysqlConnectionUser)
if ok {
host = mysqlConnectionUser.Host
user = mysqlConnectionUser.User
Expand Down
15 changes: 15 additions & 0 deletions sql/base_session.go
Expand Up @@ -15,11 +15,13 @@
package sql

import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"

"github.com/dolthub/vitess/go/mysql"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -493,6 +495,19 @@ func (s *BaseSession) SetPrivilegeSet(newPs PrivilegeSet, counter uint64) {
s.privilegeSet = newPs
}

// BaseSessionFromConnection is a SessionBuilder that returns a base session for the given connection and remote address
func BaseSessionFromConnection(ctx context.Context, c *mysql.Conn, addr string) (*BaseSession, error) {
host := ""
user := ""
mysqlConnectionUser, ok := c.UserData.(MysqlConnectionUser)
if ok {
host = mysqlConnectionUser.Host
user = mysqlConnectionUser.User
}
client := Client{Address: host, User: user, Capabilities: c.Capabilities}
return NewBaseSessionWithClientServer(addr, client, c.ConnectionID), nil
}

// NewBaseSessionWithClientServer creates a new session with data.
func NewBaseSessionWithClientServer(server string, client Client, id uint32) *BaseSession {
// TODO: if system variable "activate_all_roles_on_login" if set, activate all roles
Expand Down
4 changes: 2 additions & 2 deletions sql/mysql_db/mysql_conn_user.go → sql/mysql_conn_user.go
@@ -1,4 +1,4 @@
// Copyright 2022 Dolthub, Inc.
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package mysql_db
package sql

import (
"github.com/dolthub/vitess/go/mysql"
Expand Down
6 changes: 3 additions & 3 deletions sql/mysql_db/mysql_db.go
Expand Up @@ -818,7 +818,7 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
defer rd.Close()

if !db.Enabled() {
return MysqlConnectionUser{User: user, Host: host}, nil
return sql.MysqlConnectionUser{User: user, Host: host}, nil
}

userEntry := db.GetUser(rd, user, host, false)
Expand All @@ -834,7 +834,7 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}

return MysqlConnectionUser{User: userEntry.User, Host: userEntry.Host}, nil
return sql.MysqlConnectionUser{User: userEntry.User, Host: userEntry.Host}, nil
}

// Negotiate implements the interface mysql.AuthServer. This is called when the method used is not "mysql_native_password".
Expand All @@ -857,7 +857,7 @@ func (db *MySQLDb) Negotiate(c *mysql.Conn, user string, addr net.Addr) (mysql.G
rd := db.Reader()
defer rd.Close()

connUser := MysqlConnectionUser{User: user, Host: host}
connUser := sql.MysqlConnectionUser{User: user, Host: host}
if !db.Enabled() {
return connUser, nil
}
Expand Down

0 comments on commit 4a2f230

Please sign in to comment.