Permalink
Cannot retrieve contributors at this time
// Copyright 2013 The ql Authors. All rights reserved. | |
// Use of this source code is governed by a BSD-style | |
// license that can be found in the LICENSES/QL-LICENSE file. | |
// Copyright 2015 PingCAP, Inc. | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
package tidb | |
import ( | |
"bytes" | |
"encoding/json" | |
"fmt" | |
"strings" | |
"sync" | |
"sync/atomic" | |
"time" | |
"github.com/juju/errors" | |
"github.com/ngaut/log" | |
"github.com/pingcap/tidb/context" | |
"github.com/pingcap/tidb/field" | |
"github.com/pingcap/tidb/kv" | |
"github.com/pingcap/tidb/mysql" | |
"github.com/pingcap/tidb/privilege" | |
"github.com/pingcap/tidb/privilege/privileges" | |
"github.com/pingcap/tidb/rset" | |
"github.com/pingcap/tidb/sessionctx" | |
"github.com/pingcap/tidb/sessionctx/autocommit" | |
"github.com/pingcap/tidb/sessionctx/db" | |
"github.com/pingcap/tidb/sessionctx/forupdate" | |
"github.com/pingcap/tidb/sessionctx/variable" | |
"github.com/pingcap/tidb/stmt" | |
"github.com/pingcap/tidb/stmt/stmts" | |
"github.com/pingcap/tidb/terror" | |
"github.com/pingcap/tidb/util" | |
"github.com/pingcap/tidb/util/sqlexec" | |
"github.com/pingcap/tidb/util/types" | |
) | |
// Session context | |
type Session interface { | |
Status() uint16 // Flag of current status, such as autocommit | |
LastInsertID() uint64 // Last inserted auto_increment id | |
AffectedRows() uint64 // Affected rows by lastest executed stmt | |
Execute(sql string) ([]rset.Recordset, error) // Execute a sql statement | |
String() string // For debug | |
FinishTxn(rollback bool) error | |
// For execute prepare statement in binary protocol | |
PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) | |
// Execute a prepared statement | |
ExecutePreparedStmt(stmtID uint32, param ...interface{}) (rset.Recordset, error) | |
DropPreparedStmt(stmtID uint32) error | |
SetClientCapability(uint32) // Set client capability flags | |
Close() error | |
Retry() error | |
Auth(user string, auth []byte, salt []byte) bool | |
} | |
var ( | |
_ Session = (*session)(nil) | |
sessionID int64 | |
sessionMu sync.Mutex | |
) | |
type stmtRecord struct { | |
stmtID uint32 | |
st stmt.Statement | |
params []interface{} | |
} | |
type stmtHistory struct { | |
history []*stmtRecord | |
} | |
func (h *stmtHistory) add(stmtID uint32, st stmt.Statement, params ...interface{}) { | |
s := &stmtRecord{ | |
stmtID: stmtID, | |
st: st, | |
params: append(([]interface{})(nil), params...), | |
} | |
h.history = append(h.history, s) | |
} | |
func (h *stmtHistory) reset() { | |
if len(h.history) > 0 { | |
h.history = h.history[:0] | |
} | |
} | |
func (h *stmtHistory) clone() *stmtHistory { | |
nh := *h | |
nh.history = make([]*stmtRecord, len(h.history)) | |
copy(nh.history, h.history) | |
return &nh | |
} | |
type session struct { | |
txn kv.Transaction // Current transaction | |
args []interface{} // Statment execution args, this should be cleaned up after exec | |
values map[fmt.Stringer]interface{} | |
store kv.Storage | |
sid int64 | |
history stmtHistory | |
initing bool // Running bootstrap using this session. | |
} | |
func (s *session) Status() uint16 { | |
return variable.GetSessionVars(s).Status | |
} | |
func (s *session) LastInsertID() uint64 { | |
return variable.GetSessionVars(s).LastInsertID | |
} | |
func (s *session) AffectedRows() uint64 { | |
return variable.GetSessionVars(s).AffectedRows | |
} | |
func (s *session) resetHistory() { | |
s.ClearValue(forupdate.ForUpdateKey) | |
s.history.reset() | |
} | |
func (s *session) SetClientCapability(capability uint32) { | |
variable.GetSessionVars(s).ClientCapability = capability | |
} | |
func (s *session) FinishTxn(rollback bool) error { | |
// transaction has already been committed or rolled back | |
if s.txn == nil { | |
return nil | |
} | |
defer func() { | |
s.txn = nil | |
variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusInTrans, false) | |
}() | |
if rollback { | |
return s.txn.Rollback() | |
} | |
err := s.txn.Commit() | |
if err != nil { | |
if kv.IsRetryableError(err) { | |
err = s.Retry() | |
} | |
if err != nil { | |
log.Warnf("txn:%s, %v", s.txn, err) | |
return errors.Trace(err) | |
} | |
} | |
s.resetHistory() | |
return nil | |
} | |
func (s *session) String() string { | |
// TODO: how to print binded context in values appropriately? | |
data := map[string]interface{}{ | |
"currDBName": db.GetCurrentSchema(s), | |
"sid": s.sid, | |
} | |
if s.txn != nil { | |
// if txn is committed or rolled back, txn is nil. | |
data["txn"] = s.txn.String() | |
} | |
b, _ := json.MarshalIndent(data, "", " ") | |
return string(b) | |
} | |
func needRetry(st stmt.Statement) bool { | |
switch st.(type) { | |
case *stmts.PreparedStmt, *stmts.ShowStmt, *stmts.DoStmt: | |
return false | |
default: | |
return true | |
} | |
} | |
func isPreparedStmt(st stmt.Statement) bool { | |
switch st.(type) { | |
case *stmts.PreparedStmt: | |
return true | |
default: | |
return false | |
} | |
} | |
func (s *session) Retry() error { | |
nh := s.history.clone() | |
defer func() { | |
s.history.history = nh.history | |
}() | |
if forUpdate := s.Value(forupdate.ForUpdateKey); forUpdate != nil { | |
return errors.Errorf("can not retry select for update statement") | |
} | |
var err error | |
for { | |
s.resetHistory() | |
s.FinishTxn(true) | |
success := true | |
for _, sr := range nh.history { | |
st := sr.st | |
// Skip prepare statement | |
if !needRetry(st) { | |
continue | |
} | |
log.Warnf("Retry %s", st.OriginText()) | |
_, err = runStmt(s, st) | |
if err != nil { | |
if terror.ErrorEqual(err, kv.ErrConditionNotMatch) { | |
success = false | |
break | |
} | |
log.Warnf("session:%v, err:%v", s, err) | |
return errors.Trace(err) | |
} | |
} | |
if success { | |
err = s.FinishTxn(false) | |
if !kv.IsRetryableError(err) { | |
break | |
} | |
} | |
} | |
return err | |
} | |
// ExecRestrictedSQL implements SQLHelper interface. | |
// This is used for executing some restricted sql statements. | |
func (s *session) ExecRestrictedSQL(ctx context.Context, sql string) (rset.Recordset, error) { | |
if ctx.Value(&sqlexec.RestrictedSQLExecutorKeyType{}) != nil { | |
// We do not support run this function concurrently. | |
// TODO: Maybe we should remove this restriction latter. | |
return nil, errors.New("Should not call ExecRestrictedSQL concurrently.") | |
} | |
statements, err := Compile(ctx, sql) | |
if err != nil { | |
log.Errorf("Compile %s with error: %v", sql, err) | |
return nil, errors.Trace(err) | |
} | |
if len(statements) != 1 { | |
log.Errorf("ExecRestrictedSQL only executes one statement. Too many/few statement in %s", sql) | |
return nil, errors.New("Wrong number of statement.") | |
} | |
st := statements[0] | |
// Check statement for some restriction | |
// For example only support DML on system meta table. | |
// TODO: Add more restrictions. | |
log.Debugf("Executing %s [%s]", st.OriginText(), sql) | |
ctx.SetValue(&sqlexec.RestrictedSQLExecutorKeyType{}, true) | |
defer ctx.ClearValue(&sqlexec.RestrictedSQLExecutorKeyType{}) | |
rs, err := st.Exec(ctx) | |
return rs, errors.Trace(err) | |
} | |
// getExecRet executes restricted sql and the result is one column. | |
// It returns a string value. | |
func (s *session) getExecRet(ctx context.Context, sql string) (string, error) { | |
rs, err := s.ExecRestrictedSQL(ctx, sql) | |
if err != nil { | |
return "", errors.Trace(err) | |
} | |
defer rs.Close() | |
row, err := rs.Next() | |
if err != nil { | |
return "", errors.Trace(err) | |
} | |
if row == nil { | |
return "", terror.ExecResultIsEmpty | |
} | |
value, err := types.ToString(row.Data[0]) | |
if err != nil { | |
return "", errors.Trace(err) | |
} | |
return value, nil | |
} | |
// GetGlobalStatusVar implements GlobalVarAccessor.GetGlobalStatusVar interface. | |
func (s *session) GetGlobalStatusVar(ctx context.Context, name string) (string, error) { | |
sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s";`, | |
mysql.SystemDB, mysql.GlobalStatusTable, name) | |
statusVar, err := s.getExecRet(ctx, sql) | |
if err != nil { | |
if terror.ExecResultIsEmpty.Equal(err) { | |
return "", terror.ExecResultIsEmpty.Gen("unknown status variable:%s", name) | |
} | |
return "", errors.Trace(err) | |
} | |
return statusVar, nil | |
} | |
// SetGlobalStatusVar implements GlobalVarAccessor.SetGlobalStatusVar interface. | |
func (s *session) SetGlobalStatusVar(ctx context.Context, name string, value string) error { | |
sql := fmt.Sprintf(`UPDATE %s.%s SET VARIABLE_VALUE="%s" WHERE VARIABLE_NAME="%s";`, | |
mysql.SystemDB, mysql.GlobalStatusTable, value, strings.ToLower(name)) | |
_, err := s.ExecRestrictedSQL(ctx, sql) | |
return errors.Trace(err) | |
} | |
// GetGlobalSysVar implements GlobalVarAccessor.GetGlobalSysVar interface. | |
func (s *session) GetGlobalSysVar(ctx context.Context, name string) (string, error) { | |
sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s";`, | |
mysql.SystemDB, mysql.GlobalVariablesTable, name) | |
sysVar, err := s.getExecRet(ctx, sql) | |
if err != nil { | |
if terror.ExecResultIsEmpty.Equal(err) { | |
return "", terror.ExecResultIsEmpty.Gen("unknown sys variable:%s", name) | |
} | |
return "", errors.Trace(err) | |
} | |
return sysVar, nil | |
} | |
// SetGlobalSysVar implements GlobalVarAccessor.SetGlobalSysVar interface. | |
func (s *session) SetGlobalSysVar(ctx context.Context, name string, value string) error { | |
sql := fmt.Sprintf(`UPDATE %s.%s SET VARIABLE_VALUE="%s" WHERE VARIABLE_NAME="%s";`, | |
mysql.SystemDB, mysql.GlobalVariablesTable, value, strings.ToLower(name)) | |
_, err := s.ExecRestrictedSQL(ctx, sql) | |
return errors.Trace(err) | |
} | |
// IsAutocommit checks if it is in the auto-commit mode. | |
func (s *session) isAutocommit(ctx context.Context) bool { | |
if ctx.Value(&sqlexec.RestrictedSQLExecutorKeyType{}) != nil { | |
return false | |
} | |
autocommit, ok := variable.GetSessionVars(ctx).Systems["autocommit"] | |
if !ok { | |
if s.initing { | |
return false | |
} | |
var err error | |
autocommit, err = s.GetGlobalSysVar(ctx, "autocommit") | |
if err != nil { | |
log.Errorf("Get global sys var error: %v", err) | |
return false | |
} | |
variable.GetSessionVars(ctx).Systems["autocommit"] = autocommit | |
ok = true | |
} | |
if ok && (autocommit == "ON" || autocommit == "on" || autocommit == "1") { | |
variable.GetSessionVars(ctx).SetStatusFlag(mysql.ServerStatusAutocommit, true) | |
return true | |
} | |
variable.GetSessionVars(ctx).SetStatusFlag(mysql.ServerStatusAutocommit, false) | |
return false | |
} | |
func (s *session) ShouldAutocommit(ctx context.Context) bool { | |
if ctx.Value(&sqlexec.RestrictedSQLExecutorKeyType{}) != nil { | |
return false | |
} | |
// With START TRANSACTION, autocommit remains disabled until you end | |
// the transaction with COMMIT or ROLLBACK. | |
if variable.GetSessionVars(ctx).Status&mysql.ServerStatusInTrans == 0 && s.isAutocommit(ctx) { | |
return true | |
} | |
return false | |
} | |
func (s *session) Execute(sql string) ([]rset.Recordset, error) { | |
statements, err := Compile(s, sql) | |
if err != nil { | |
log.Errorf("Syntax error: %s", sql) | |
log.Errorf("Error occurs at %s.", err) | |
return nil, errors.Trace(err) | |
} | |
var rs []rset.Recordset | |
for _, st := range statements { | |
r, err := runStmt(s, st) | |
if err != nil { | |
log.Warnf("session:%v, err:%v", s, err) | |
return nil, errors.Trace(err) | |
} | |
// Record executed query | |
if isPreparedStmt(st) { | |
ps := st.(*stmts.PreparedStmt) | |
s.history.add(ps.ID, st) | |
} else { | |
s.history.add(0, st) | |
} | |
if r != nil { | |
rs = append(rs, r) | |
} | |
} | |
return rs, nil | |
} | |
// For execute prepare statement in binary protocol | |
func (s *session) PrepareStmt(sql string) (stmtID uint32, paramCount int, fields []*field.ResultField, err error) { | |
return prepareStmt(s, sql) | |
} | |
// checkArgs makes sure all the arguments' types are known and can be handled. | |
// integer types are converted to int64 and uint64, time.Time is converted to mysql.Time. | |
// time.Duration is converted to mysql.Duration, other known types are leaved as it is. | |
func checkArgs(args ...interface{}) error { | |
for i, v := range args { | |
switch x := v.(type) { | |
case bool: | |
if x { | |
args[i] = int64(1) | |
} else { | |
args[i] = int64(0) | |
} | |
case int8: | |
args[i] = int64(x) | |
case int16: | |
args[i] = int64(x) | |
case int32: | |
args[i] = int64(x) | |
case int: | |
args[i] = int64(x) | |
case uint8: | |
args[i] = uint64(x) | |
case uint16: | |
args[i] = uint64(x) | |
case uint32: | |
args[i] = uint64(x) | |
case uint: | |
args[i] = uint64(x) | |
case int64: | |
case uint64: | |
case float32: | |
case float64: | |
case string: | |
case []byte: | |
case time.Duration: | |
args[i] = mysql.Duration{Duration: x} | |
case time.Time: | |
args[i] = mysql.Time{Time: x, Type: mysql.TypeDatetime} | |
case nil: | |
default: | |
return errors.Errorf("cannot use arg[%d] (type %T):unsupported type", i, v) | |
} | |
} | |
return nil | |
} | |
// Execute a prepared statement | |
func (s *session) ExecutePreparedStmt(stmtID uint32, args ...interface{}) (rset.Recordset, error) { | |
err := checkArgs(args...) | |
if err != nil { | |
return nil, err | |
} | |
st := &stmts.ExecuteStmt{ID: stmtID} | |
s.history.add(stmtID, st, args...) | |
return runStmt(s, st, args...) | |
} | |
func (s *session) DropPreparedStmt(stmtID uint32) error { | |
return dropPreparedStmt(s, stmtID) | |
} | |
// If forceNew is true, GetTxn() must return a new transaction. | |
// In this situation, if current transaction is still in progress, | |
// there will be an implicit commit and create a new transaction. | |
func (s *session) GetTxn(forceNew bool) (kv.Transaction, error) { | |
var err error | |
if s.txn == nil { | |
s.resetHistory() | |
s.txn, err = s.store.Begin() | |
if err != nil { | |
return nil, err | |
} | |
if !s.isAutocommit(s) { | |
variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusInTrans, true) | |
} | |
log.Infof("New txn:%s in session:%d", s.txn, s.sid) | |
return s.txn, nil | |
} | |
if forceNew { | |
err = s.txn.Commit() | |
variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusInTrans, false) | |
if err != nil { | |
if kv.IsRetryableError(err) { | |
err = s.Retry() | |
} | |
if err != nil { | |
return nil, errors.Trace(err) | |
} | |
} | |
s.resetHistory() | |
s.txn, err = s.store.Begin() | |
if err != nil { | |
return nil, err | |
} | |
if !s.isAutocommit(s) { | |
variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusInTrans, true) | |
} | |
log.Warnf("Force new txn:%s in session:%d", s.txn, s.sid) | |
} | |
return s.txn, nil | |
} | |
func (s *session) SetValue(key fmt.Stringer, value interface{}) { | |
s.values[key] = value | |
} | |
func (s *session) Value(key fmt.Stringer) interface{} { | |
value := s.values[key] | |
return value | |
} | |
func (s *session) ClearValue(key fmt.Stringer) { | |
delete(s.values, key) | |
} | |
// Close function does some clean work when session end. | |
func (s *session) Close() error { | |
return s.FinishTxn(true) | |
} | |
func (s *session) Auth(user string, auth []byte, salt []byte) bool { | |
strs := strings.Split(user, "@") | |
if len(strs) != 2 { | |
log.Warnf("Invalid format for user: %s", user) | |
return false | |
} | |
// Get user password. | |
name := strs[0] | |
host := strs[1] | |
authSQL := fmt.Sprintf("SELECT Password FROM %s.%s WHERE User=\"%s\" and Host=\"%s\";", mysql.SystemDB, mysql.UserTable, name, host) | |
rs, err := s.Execute(authSQL) | |
if err != nil { | |
log.Warnf("Encounter error when auth user %s. Error: %v", user, err) | |
return false | |
} | |
if len(rs) == 0 { | |
return false | |
} | |
row, err := rs[0].Next() | |
if err != nil { | |
log.Warnf("Encounter error when auth user %s. Error: %v", user, err) | |
return false | |
} | |
if row == nil || len(row.Data) == 0 { | |
return false | |
} | |
pwd, ok := row.Data[0].(string) | |
if !ok { | |
return false | |
} | |
hpwd, err := util.DecodePassword(pwd) | |
if err != nil { | |
log.Errorf("Decode password string error %v", err) | |
return false | |
} | |
checkAuth := util.CalcPassword(salt, hpwd) | |
if !bytes.Equal(auth, checkAuth) { | |
return false | |
} | |
variable.GetSessionVars(s).SetCurrentUser(user) | |
return true | |
} | |
// CreateSession creates a new session environment. | |
func CreateSession(store kv.Storage) (Session, error) { | |
s := &session{ | |
values: make(map[fmt.Stringer]interface{}), | |
store: store, | |
sid: atomic.AddInt64(&sessionID, 1), | |
} | |
domain, err := domap.Get(store) | |
if err != nil { | |
return nil, err | |
} | |
sessionctx.BindDomain(s, domain) | |
variable.BindSessionVars(s) | |
variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusAutocommit, true) | |
// session implements variable.GlobalVarAccessor. Bind it to ctx. | |
variable.BindGlobalVarAccessor(s, s) | |
// session implements autocommit.Checker. Bind it to ctx | |
autocommit.BindAutocommitChecker(s, s) | |
sessionMu.Lock() | |
defer sessionMu.Unlock() | |
_, ok := storeBootstrapped[store.UUID()] | |
if !ok { | |
s.initing = true | |
bootstrap(s) | |
s.initing = false | |
storeBootstrapped[store.UUID()] = true | |
} | |
// TODO: Add auth here | |
privChecker := &privileges.UserPrivileges{} | |
privilege.BindPrivilegeChecker(s, privChecker) | |
return s, nil | |
} |