forked from literatesnow/go-datapipe
-
Notifications
You must be signed in to change notification settings - Fork 0
/
copyin.go
129 lines (96 loc) · 2.67 KB
/
copyin.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package bulk
import (
"context"
"database/sql"
"strconv"
"github.com/juju/errors"
"github.com/lib/pq"
)
type CopyIn struct {
conn *sql.Conn //Database handle
tx *sql.Tx
stmt *sql.Stmt
valueTypes []string
valuePtrs []interface{} //Pointer to current row buffer
values []interface{} //Buffer for the current row
totalRowCount int //Total number of rows
}
// Appends row values to internal buffer
func (r *CopyIn) Append(ctx context.Context, rows *sql.Rows) (err error) {
rows.Scan(r.valuePtrs...)
for i := 0; i < len(r.valueTypes); i++ {
if r.values[i] == nil {
continue
}
if s, ok := r.values[i].([]byte); ok {
switch r.valueTypes[i] {
case "numeric":
r.values[i], _ = strconv.ParseFloat(string(s), 64)
default:
r.values[i] = string(s)
}
}
}
if _, err = r.stmt.Exec(r.values...); err != nil {
return errors.Trace(err)
}
r.totalRowCount++
return nil
}
// Closes any prepared statements
func (r *CopyIn) Close() (err error) {
if err = r.stmt.Close(); err != nil {
return errors.Trace(err)
}
if err = r.tx.Commit(); err != nil {
return errors.Trace(err)
}
return nil
}
func (r *CopyIn) Flush(ctx context.Context) (totalRowCount int, err error) {
if _, err = r.stmt.Exec(); err != nil {
return 0, errors.Trace(err)
}
return r.totalRowCount, nil
}
func (r *CopyIn) findColumnTypes(ctx context.Context, schema string, tableName string, columns []string) (err error) {
sql := "SELECT column_name AS name, data_type AS type FROM information_schema.columns WHERE table_schema = $1 AND table_name = $2"
rows, err := r.conn.QueryContext(ctx, sql, schema, tableName)
if err != nil {
return errors.Trace(err)
}
defer rows.Close()
for rows.Next() {
var colName, colType string
if err := rows.Scan(&colName, &colType); err != nil {
return errors.Trace(err)
}
for i := 0; i < len(columns); i++ {
if colName == columns[i] {
r.valueTypes[i] = colType
}
}
}
return errors.Trace(rows.Err())
}
func NewCopyIn(ctx context.Context, conn *sql.Conn, columns []string, schema string, tableName string) (r *CopyIn, err error) {
r = &CopyIn{
conn: conn}
colCount := len(columns)
r.values = make([]interface{}, colCount)
r.valuePtrs = make([]interface{}, colCount)
r.valueTypes = make([]string, colCount)
for i := 0; i < colCount; i++ {
r.valuePtrs[i] = &r.values[i]
}
if r.tx, err = r.conn.BeginTx(ctx, nil); err != nil {
return nil, errors.Trace(err)
}
if err = r.findColumnTypes(ctx, schema, tableName, columns); err != nil {
return nil, errors.Trace(err)
}
if r.stmt, err = r.tx.Prepare(pq.CopyInSchema(schema, tableName, columns...)); err != nil {
return nil, errors.Trace(err)
}
return r, nil
}