forked from vitessio/vitess
/
driver_go18.go
145 lines (126 loc) · 3.93 KB
/
driver_go18.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Copyright 2016, Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
// TODO(sougou): Merge this with driver.go once go1.7 is deprecated.
// Also write tests for these new functions once go1.8 becomes mainstream.
package vitessdriver
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"github.com/youtube/vitess/go/vt/vtgate/vtgateconn"
)
var (
errNoIntermixing = errors.New("named and positional arguments intermixing disallowed")
errIsolationUnsupported = errors.New("isolation levels are not supported")
)
// Type-check interfaces.
var (
_ driver.QueryerContext = &conn{}
_ driver.ExecerContext = &conn{}
_ driver.StmtQueryContext = &stmt{}
_ driver.StmtExecContext = &stmt{}
)
// These are synonyms of the constants defined in vtgateconn.
const (
// AtomicityMulti is the default level. It allows distributed transactions
// with best effort commits. Partial commits are possible.
AtomicityMulti = vtgateconn.AtomicityMulti
// AtomicitySingle prevents a transaction from crossing the boundary of
// a single database.
AtomicitySingle = vtgateconn.AtomicitySingle
// Atomicity2PC allows distributed transactions, and performs 2PC commits.
Atomicity2PC = vtgateconn.Atomicity2PC
)
// WithAtomicity returns a context with the atomicity level set.
func WithAtomicity(ctx context.Context, level vtgateconn.Atomicity) context.Context {
return vtgateconn.WithAtomicity(ctx, level)
}
// AtomicityFromContext returns the atomicity of the context.
func AtomicityFromContext(ctx context.Context) vtgateconn.Atomicity {
return vtgateconn.AtomicityFromContext(ctx)
}
func (c *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if c.Streaming {
return nil, errors.New("transaction not allowed for streaming connection")
}
if opts.Isolation != driver.IsolationLevel(0) || opts.ReadOnly {
return nil, errIsolationUnsupported
}
tx, err := c.vtgateConn.Begin(ctx)
if err != nil {
return nil, err
}
c.tx = tx
return c, nil
}
func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
if c.Streaming {
return nil, errors.New("Exec not allowed for streaming connections")
}
bv, err := bindVarsFromNamedValues(args)
if err != nil {
return nil, err
}
qr, err := c.exec(ctx, query, bv)
if err != nil {
return nil, err
}
return result{int64(qr.InsertID), int64(qr.RowsAffected)}, nil
}
func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
bv, err := bindVarsFromNamedValues(args)
if err != nil {
return nil, err
}
if c.Streaming {
stream, err := c.vtgateConn.StreamExecute(ctx, query, bv, c.tabletTypeProto, nil)
if err != nil {
return nil, err
}
return newStreamingRows(stream, nil), nil
}
qr, err := c.exec(ctx, query, bv)
if err != nil {
return nil, err
}
return newRows(qr), nil
}
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
return s.c.ExecContext(ctx, s.query, args)
}
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
return s.c.QueryContext(ctx, s.query, args)
}
func bindVarsFromNamedValues(args []driver.NamedValue) (map[string]interface{}, error) {
bv := make(map[string]interface{}, len(args))
nameUsed := false
for i, v := range args {
if i == 0 {
// Determine if args are based on names or ordinals.
if v.Name != "" {
nameUsed = true
}
} else {
// Verify that there's no intermixing.
if nameUsed && v.Name == "" {
return nil, errNoIntermixing
}
if !nameUsed && v.Name != "" {
return nil, errNoIntermixing
}
}
if v.Name == "" {
bv[fmt.Sprintf("v%d", i+1)] = v.Value
} else {
if v.Name[0] == ':' || v.Name[0] == '@' {
bv[v.Name[1:]] = v.Value
} else {
bv[v.Name] = v.Value
}
}
}
return bv, nil
}