Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restored and refactored missing functionality required by Dolt #2393

Merged
merged 2 commits into from Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 1 addition & 2 deletions enginetest/memory_harness.go
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/mysql_db"
)

const testNumPartitions = 5
Expand Down Expand Up @@ -103,7 +102,7 @@ func (m *MemoryHarness) SessionBuilder() server.SessionBuilder {
return func(ctx context.Context, c *mysql.Conn, addr string) (sql.Session, error) {
host := ""
user := ""
mysqlConnectionUser, ok := c.UserData.(mysql_db.MysqlConnectionUser)
mysqlConnectionUser, ok := c.UserData.(sql.MysqlConnectionUser)
if ok {
host = mysqlConnectionUser.Host
user = mysqlConnectionUser.User
Expand Down
3 changes: 1 addition & 2 deletions memory/session.go
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/dolthub/vitess/go/mysql"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/mysql_db"
)

type GlobalsMap = map[string]interface{}
Expand Down Expand Up @@ -61,7 +60,7 @@ func NewSessionBuilder(pro *DbProvider) func(ctx context.Context, conn *mysql.Co
return func(ctx context.Context, conn *mysql.Conn, addr string) (sql.Session, error) {
host := ""
user := ""
mysqlConnectionUser, ok := conn.UserData.(mysql_db.MysqlConnectionUser)
mysqlConnectionUser, ok := conn.UserData.(sql.MysqlConnectionUser)
if ok {
host = mysqlConnectionUser.Host
user = mysqlConnectionUser.User
Expand Down
39 changes: 39 additions & 0 deletions server/server.go
Expand Up @@ -77,6 +77,45 @@ func NewServer(cfg Config, e *sqle.Engine, sb SessionBuilder, listener ServerEve
//handler = NewHandler_(e, sm, cfg.ConnReadTimeout, cfg.DisableClientMultiStatements, cfg.MaxLoggedQueryLen, cfg.EncodeLoggedQuery, listener)
return newServerFromHandler(cfg, e, sm, handler)
}

// HandlerWrapper provides a way for clients to wrap the mysql.Handler used by the server with a custom implementation
// that wraps it.
type HandlerWrapper func(h mysql.Handler) (mysql.Handler, error)

// NewServerWithHandler creates a Server with a handler wrapped by the provided wrapper function.
func NewServerWithHandler(
cfg Config,
e *sqle.Engine,
sb SessionBuilder,
listener ServerEventListener,
wrapper HandlerWrapper,
) (*Server, error) {
var tracer trace.Tracer
if cfg.Tracer != nil {
tracer = cfg.Tracer
} else {
tracer = sql.NoopTracer
}

sm := NewSessionManager(sb, tracer, e.Analyzer.Catalog.Database, e.MemoryManager, e.ProcessList, cfg.Address)
h := &Handler{
e: e,
sm: sm,
readTimeout: cfg.ConnReadTimeout,
disableMultiStmts: cfg.DisableClientMultiStatements,
maxLoggedQueryLen: cfg.MaxLoggedQueryLen,
encodeLoggedQuery: cfg.EncodeLoggedQuery,
sel: listener,
}

handler, err := wrapper(h)
if err != nil {
return nil, err
}

return newServerFromHandler(cfg, e, sm, handler)
}

func portInUse(hostPort string) bool {
timeout := time.Second
conn, _ := net.DialTimeout("tcp", hostPort, timeout)
Expand Down
3 changes: 1 addition & 2 deletions server/server_test.go
Expand Up @@ -16,7 +16,6 @@ import (
"github.com/dolthub/go-mysql-server/memory"
"github.com/dolthub/go-mysql-server/server"
gsql "github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/mysql_db"
)

// TestSeverCustomListener verifies a caller can provide their own net.Conn implementation for the server to use
Expand All @@ -37,7 +36,7 @@ func TestSeverCustomListener(t *testing.T) {
sessionBuilder := func(ctx context.Context, c *vsql.Conn, addr string) (gsql.Session, error) {
host := ""
user := ""
mysqlConnectionUser, ok := c.UserData.(mysql_db.MysqlConnectionUser)
mysqlConnectionUser, ok := c.UserData.(gsql.MysqlConnectionUser)
if ok {
host = mysqlConnectionUser.Host
user = mysqlConnectionUser.User
Expand Down
15 changes: 15 additions & 0 deletions sql/base_session.go
Expand Up @@ -15,11 +15,13 @@
package sql

import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"

"github.com/dolthub/vitess/go/mysql"
"github.com/sirupsen/logrus"
)

Expand Down Expand Up @@ -493,6 +495,19 @@ func (s *BaseSession) SetPrivilegeSet(newPs PrivilegeSet, counter uint64) {
s.privilegeSet = newPs
}

// BaseSessionFromConnection is a SessionBuilder that returns a base session for the given connection and remote address
func BaseSessionFromConnection(ctx context.Context, c *mysql.Conn, addr string) (*BaseSession, error) {
host := ""
user := ""
mysqlConnectionUser, ok := c.UserData.(MysqlConnectionUser)
if ok {
host = mysqlConnectionUser.Host
user = mysqlConnectionUser.User
}
client := Client{Address: host, User: user, Capabilities: c.Capabilities}
return NewBaseSessionWithClientServer(addr, client, c.ConnectionID), nil
}

// NewBaseSessionWithClientServer creates a new session with data.
func NewBaseSessionWithClientServer(server string, client Client, id uint32) *BaseSession {
// TODO: if system variable "activate_all_roles_on_login" if set, activate all roles
Expand Down
4 changes: 2 additions & 2 deletions sql/mysql_db/mysql_conn_user.go → sql/mysql_conn_user.go
@@ -1,4 +1,4 @@
// Copyright 2022 Dolthub, Inc.
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

package mysql_db
package sql

import (
"github.com/dolthub/vitess/go/mysql"
Expand Down
6 changes: 3 additions & 3 deletions sql/mysql_db/mysql_db.go
Expand Up @@ -818,7 +818,7 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
defer rd.Close()

if !db.Enabled() {
return MysqlConnectionUser{User: user, Host: host}, nil
return sql.MysqlConnectionUser{User: user, Host: host}, nil
}

userEntry := db.GetUser(rd, user, host, false)
Expand All @@ -834,7 +834,7 @@ func (db *MySQLDb) ValidateHash(salt []byte, user string, authResponse []byte, a
return nil, mysql.NewSQLError(mysql.ERAccessDeniedError, mysql.SSAccessDeniedError, "Access denied for user '%v'", user)
}

return MysqlConnectionUser{User: userEntry.User, Host: userEntry.Host}, nil
return sql.MysqlConnectionUser{User: userEntry.User, Host: userEntry.Host}, nil
}

// Negotiate implements the interface mysql.AuthServer. This is called when the method used is not "mysql_native_password".
Expand All @@ -857,7 +857,7 @@ func (db *MySQLDb) Negotiate(c *mysql.Conn, user string, addr net.Addr) (mysql.G
rd := db.Reader()
defer rd.Close()

connUser := MysqlConnectionUser{User: user, Host: host}
connUser := sql.MysqlConnectionUser{User: user, Host: host}
if !db.Enabled() {
return connUser, nil
}
Expand Down