/
statement.go
112 lines (97 loc) · 3.33 KB
/
statement.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
package rds
import (
"context"
"database/sql/driver"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/rdsdata"
)
var _ driver.Stmt = (*Statement)(nil) // explicit compile time type check
var _ driver.StmtExecContext = (*Statement)(nil) // explicit compile time type check
var _ driver.StmtQueryContext = (*Statement)(nil) // explicit compile time type check
//var _ driver.NamedValueChecker = (*Statement)(nil) // explicit compile time type check
// NewStatement for the provided connection
func NewStatement(_ context.Context, connection *Connection, sql []string) *Statement {
return &Statement{
conn: connection,
queries: sql,
}
}
// Statement encapsulates a single RDS queries statement
type Statement struct {
conn *Connection
queries []string
}
// Close closes the statement.
func (s *Statement) Close() error {
if s.conn == nil {
return ErrClosed
}
s.conn = nil
return nil
}
// NumInput returns the number of placeholder parameters.
func (s *Statement) NumInput() int {
return -1
}
// Exec executes a queries that doesn't return rows, such as an INSERT or UPDATE.
func (s *Statement) Exec(values []driver.Value) (driver.Result, error) {
args := s.ConvertOrdinal(values)
return s.ExecContext(context.Background(), args)
}
// ExecContext executes a queries that doesn't return rows, such as an INSERT or UPDATE.
func (s *Statement) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
var output []*rdsdata.ExecuteStatementOutput
for _, query := range s.queries {
out, err := s.executeStatement(ctx, query, args)
if err != nil {
return nil, err
}
output = append(output, out)
}
return NewResult(output), nil
}
// Query executes a queries that may return rows, such as a SELECT.
func (s *Statement) Query(values []driver.Value) (driver.Rows, error) {
// We're trying to execute this as an ordinal queries, so convert it.
args := s.ConvertOrdinal(values)
return s.QueryContext(context.Background(), args)
}
// QueryContext executes a queries that may return rows, such as a SELECT.
func (s *Statement) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
var output []*rdsdata.ExecuteStatementOutput
for _, query := range s.queries {
out, err := s.executeStatement(ctx, query, args)
if err != nil {
return nil, err
}
output = append(output, out)
}
return NewRows(s.conn.dialect, output), nil
}
// ConvertOrdinal converts a list of Values to Ordinal NamedValues
func (s *Statement) ConvertOrdinal(values []driver.Value) []driver.NamedValue {
// Start with the MySQL separator as a default
namedValues := make([]driver.NamedValue, len(values))
for i, v := range values {
namedValues[i] = driver.NamedValue{
Name: "",
Ordinal: i + 1,
Value: v,
}
}
return namedValues
}
func (s *Statement) executeStatement(ctx context.Context, query string, values []driver.NamedValue) (*rdsdata.ExecuteStatementOutput, error) {
input, err := s.conn.dialect.MigrateQuery(query, values)
if err != nil {
return nil, err
}
if s.conn.tx != nil {
input.TransactionId = s.conn.tx.TransactionID
}
input.IncludeResultMetadata = true
input.ResourceArn = aws.String(s.conn.resourceARN)
input.SecretArn = aws.String(s.conn.secretARN)
input.Database = aws.String(s.conn.database)
return s.conn.rds.ExecuteStatement(ctx, input)
}