forked from slicebit/qb
/
session.go
382 lines (316 loc) · 9.32 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
package qb
import (
"database/sql"
"errors"
"fmt"
"sync"
)
// New generates a new Session given engine and returns session pointer
func New(driver string, dsn string) (*Session, error) {
engine, err := NewEngine(driver, dsn)
if err != nil {
return nil, err
}
dialect := NewDialect(driver)
return &Session{
statements: []*Stmt{},
engine: engine,
dialect: dialect,
mapper: Mapper(dialect),
metadata: MetaData(dialect),
mutex: &sync.Mutex{},
}, nil
}
// Session is the composition of engine connection & orm mappings
type Session struct {
builder Builder
filters []Conditional
statements []*Stmt
engine *Engine
mapper MapperElem
metadata *MetaDataElem
dialect Dialect
tx *sql.Tx
mutex *sync.Mutex
}
func (s *Session) add(statement *Stmt) {
s.mutex.Lock()
defer s.mutex.Unlock()
var err error
if s.tx == nil {
s.tx, err = s.engine.DB().Begin()
s.statements = []*Stmt{}
if err != nil {
panic(err)
}
}
s.statements = append(s.statements, statement)
}
// Metadata Wrappers
// T returns the table given name as string
// It is for Query() function parameter generation
func (s *Session) T(name string) TableElem {
return s.metadata.Table(name)
}
// AddTable adds a model to metadata that is mapped into table object
func (s *Session) AddTable(model interface{}) {
s.metadata.Add(model)
}
// CreateAll creates all tables that are registered to metadata
func (s *Session) CreateAll() error {
return s.metadata.CreateAll(s.engine)
}
// DropAll drops all tables that are registered to metadata
func (s *Session) DropAll() error {
return s.metadata.DropAll(s.engine)
}
// Engine wrappers
// Engine returns the current sqlx wrapped engine
func (s *Session) Engine() *Engine {
return s.engine
}
// Close closes engine db (sqlx) connection
func (s *Session) Close() {
s.engine.DB().Close()
}
// AddStatement adds a statement given the query pointer retrieved from Build() function
func (s *Session) AddStatement(statement *Stmt) {
s.add(statement)
}
// Dialect returns the current dialect of session
func (s *Session) Dialect() Dialect {
return s.dialect
}
// Metadata returns the metadata of session
func (s *Session) Metadata() *MetaDataElem {
return s.metadata
}
// Session Api
// Delete adds a single delete statement to the session
func (s *Session) Delete(model interface{}) {
kv := s.mapper.ToMap(model, false)
tableName := s.mapper.ModelName(model)
d := Delete(s.metadata.Table(tableName))
conditions := []Conditional{}
for k, v := range kv {
conditions = append(conditions, Eq(s.metadata.Table(tableName).C(k), v))
}
stmt := d.Where(And(conditions...)).Build(s.dialect)
s.add(stmt)
}
// Add adds a single model to the session. The query must be insert or update
func (s *Session) Add(model interface{}) {
m := s.mapper.ToMap(model, false)
tableName := s.mapper.ModelName(model)
ups := Upsert(s.metadata.Table(tableName)).Values(m)
statement := ups.Build(s.dialect)
s.dialect.Reset()
s.add(statement)
}
// AddAll adds multiple models an adds an insert statement to current queries
func (s *Session) AddAll(models ...interface{}) {
for _, m := range models {
s.Add(m)
}
}
// Commit commits the current transaction with queries
func (s *Session) Commit() error {
for _, statement := range s.statements {
_, err := s.tx.Exec(statement.SQL(), statement.Bindings()...)
if err != nil {
s.tx = nil
s.statements = []*Stmt{}
return err
}
}
err := s.tx.Commit()
s.tx = nil
s.statements = []*Stmt{}
return err
}
// Rollback rollbacks the current transaction
func (s *Session) Rollback() error {
if s.tx != nil {
return s.tx.Rollback()
}
return errors.New("Current transaction is nil")
}
// Find returns a row given model properties
func (s *Session) Find(model interface{}) *Session {
table := s.mapper.ModelName(model)
modelMap := s.mapper.ToMap(model, true)
cols := []Clause{}
for k := range modelMap {
cols = append(cols, s.T(table).C(k))
}
ands := []Conditional{}
for k := range modelMap {
if modelMap[k] == nil {
continue
}
ands = append(ands, Eq(s.metadata.Table(table).C(k), modelMap[k]))
}
s.builder = Select(cols...).From(s.T(table)).Where(And(ands...))
return s
}
// Statement builds the active query and returns it as a Stmt
func (s *Session) Statement() *Stmt {
if s.isSelect() {
if len(s.filters) > 0 {
s.builder = (s.builder.(SelectStmt)).Where(And(s.filters...))
}
}
statement := s.builder.Build(s.dialect)
s.dialect.Reset()
s.filters = []Conditional{}
s.builder = nil
return statement
}
// Query starts a select statement given columns
func (s *Session) Query(clauses ...Clause) *Session {
if len(clauses) == 0 {
panic(fmt.Errorf("You must enter one or more column or aggregate paramater(s)"))
} else {
var table string
for _, v := range clauses {
if s.isCol(v) {
table = (v.(ColumnElem)).Table
}
}
s.builder = Select(clauses...)
if table != "" {
s.builder = (s.builder.(SelectStmt)).From(s.T(table))
}
}
return s
}
// isCol returns if the clause is ColumnElem type
func (s *Session) isCol(clause Clause) bool {
switch clause.(type) {
case ColumnElem:
return true
default:
return false
}
}
// isSelect returns if the current builder is *Session
func (s *Session) isSelect() bool {
switch s.builder.(type) {
case SelectStmt:
return true
default:
return false
}
}
// Filter appends a filter to the current select statement
// NOTE: It currently only builds AndClause within the filters
// TODO: Add OR able filters
func (s *Session) Filter(conditional Conditional) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling Filter()"))
}
s.filters = append(s.filters, conditional)
return s
}
// From wraps select's From
// NOTE: You only need to set if Query() parameters are not columns
// No columns are in aggregate clauses
func (s *Session) From(table TableElem) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling From()"))
}
s.builder = (s.builder.(SelectStmt)).From(table)
return s
}
// InnerJoin wraps select's InnerJoin
func (s *Session) InnerJoin(table TableElem, fromCol ColumnElem, col ColumnElem) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling InnerJoin()"))
}
s.builder = (s.builder.(SelectStmt)).InnerJoin(table, fromCol, col)
return s
}
// CrossJoin wraps select's CrossJoin
func (s *Session) CrossJoin(table TableElem) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling CrossJoin()"))
}
s.builder = (s.builder.(SelectStmt)).CrossJoin(table)
return s
}
// LeftJoin wraps select's LeftJoin
func (s *Session) LeftJoin(table TableElem, fromCol ColumnElem, col ColumnElem) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling LeftJoin()"))
}
s.builder = (s.builder.(SelectStmt)).LeftJoin(table, fromCol, col)
return s
}
// RightJoin wraps select's RightJoin
func (s *Session) RightJoin(table TableElem, fromCol ColumnElem, col ColumnElem) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling RightJoin()"))
}
s.builder = (s.builder.(SelectStmt)).RightJoin(table, fromCol, col)
return s
}
// GroupBy wraps the select's GroupBy
func (s *Session) GroupBy(cols ...ColumnElem) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling GroupBy()"))
}
s.builder = (s.builder.(SelectStmt)).GroupBy(cols...)
return s
}
// Having wraps the select's Having
func (s *Session) Having(aggregate AggregateClause, op string, value interface{}) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling Having()"))
}
s.builder = (s.builder.(SelectStmt)).Having(aggregate, op, value)
return s
}
// OrderBy wraps the select's OrderBy
func (s *Session) OrderBy(cols ...ColumnElem) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling OrderBy()"))
}
s.builder = (s.builder.(SelectStmt)).OrderBy(cols...).Asc()
return s
}
// Asc wraps the select's Asc
func (s *Session) Asc() *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) & OrderBy() before calling Asc()"))
}
s.builder = (s.builder.(SelectStmt)).Asc()
return s
}
// Desc wraps the select's Desc
func (s *Session) Desc() *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) & OrderBy() before calling Desc()"))
}
s.builder = (s.builder.(SelectStmt)).Desc()
return s
}
// Limit wraps the select's Limit
func (s *Session) Limit(offset int, count int) *Session {
if !s.isSelect() {
panic(fmt.Errorf("Please use Query(cols ...ColumnElem) before calling Limit()"))
}
s.builder = (s.builder.(SelectStmt)).Limit(offset, count)
return s
}
// Active query select & (insert/delete/update) ... returning ... finishers
// One returns the first record mapped as a model
// The interface should be struct pointer instead of struct
func (s *Session) One(model interface{}) error {
return s.engine.Get(s.Statement(), model)
}
// All returns all the records mapped as a model slice
// The interface should be struct pointer instead of struct
func (s *Session) All(models interface{}) error {
statement := s.Statement()
return s.engine.Select(statement, models)
}