/
orm.go
311 lines (265 loc) · 9.07 KB
/
orm.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package orm
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/beego/beego/v2/client/orm"
"github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/log"
tracelib "github.com/goharbor/harbor/src/lib/trace"
)
// NewCondition alias function of orm.NewCondition
var NewCondition = orm.NewCondition
// Condition alias to orm.Condition
type Condition = orm.Condition
// Params alias to orm.Params
type Params = orm.Params
// ParamsList alias to orm.ParamsList
type ParamsList = orm.ParamsList
// QuerySeter alias to orm.QuerySeter
type QuerySeter = orm.QuerySeter
// RegisterModel ...
func RegisterModel(models ...interface{}) {
orm.RegisterModel(models...)
}
type ormKey struct{}
// valueOnlyContext aims to only copy value from parent context, but no other
// linkage of parent like cancelation.
type valueOnlyContext struct{ context.Context }
func (valueOnlyContext) Deadline() (time.Time, bool) { return time.Time{}, false }
func (valueOnlyContext) Done() <-chan struct{} { return nil }
func (valueOnlyContext) Err() error { return nil }
const (
tracerName = "goharbor/harbor/src/lib/orm"
defaultTranscationOpName = "start-transaction"
)
func init() {
if os.Getenv("ORM_DEBUG") == "true" {
orm.Debug = true
}
}
// FromContext returns orm from context
func FromContext(ctx context.Context) (orm.QueryExecutor, error) {
o, ok := ctx.Value(ormKey{}).(orm.QueryExecutor)
if !ok {
return nil, errors.New("cannot get the ORM from context")
}
return o, nil
}
// NewContext returns new context with orm
func NewContext(ctx context.Context, o orm.QueryExecutor) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, ormKey{}, o)
}
// Context returns a context with an orm
func Context() context.Context {
return NewContext(context.Background(), orm.NewOrm())
}
// Clone returns new context with orm for ctx
func Clone(ctx context.Context) context.Context {
return NewContext(ctx, orm.NewOrm())
}
// Copy returns new context with orm and value from parent context but no
// linkage of parent.
func Copy(ctx context.Context) context.Context {
return NewContext(valueOnlyContext{ctx}, orm.NewOrm())
}
type operationNameKey struct{}
// SetTransactionOpNameToContext sets the transaction operation name
func SetTransactionOpNameToContext(ctx context.Context, name string) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, operationNameKey{}, name)
}
// GetTransactionOpNameFromContext returns the transaction operation name from context
func GetTransactionOpNameFromContext(ctx context.Context) string {
opName, ok := ctx.Value(operationNameKey{}).(string)
if !ok {
return defaultTranscationOpName
}
if opName == "" {
return defaultTranscationOpName
}
return opName
}
// WithTransaction a decorator which make f run in transaction
func WithTransaction(f func(ctx context.Context) error) func(ctx context.Context) error {
return func(ctx context.Context) error {
cx, span := tracelib.StartTrace(ctx, tracerName, GetTransactionOpNameFromContext(ctx))
defer span.End()
o, err := FromContext(ctx)
if err != nil {
tracelib.RecordError(span, err, "get orm from ctx failed")
return err
}
var tx ormerTx
if _, ok := o.(orm.Ormer); ok {
tx = ormerTx{Ormer: o.(orm.Ormer)}
} else if _, ok := o.(orm.TxOrmer); ok {
tx = ormerTx{TxOrmer: o.(orm.TxOrmer)}
} else {
return errors.New("no orm found in the context")
}
if err := tx.Begin(); err != nil {
tracelib.RecordError(span, err, "begin transaction failed")
log.Errorf("begin transaction failed: %v", err)
return err
}
// When set multiple times, context.WithValue returns only the last ormer.
// To ensure that the rollback works, set TxOrmer as the ormer in the transaction.
cx = NewContext(cx, tx.TxOrmer)
if err := f(cx); err != nil {
span.AddEvent("rollback transaction")
if e := tx.Rollback(); e != nil {
tracelib.RecordError(span, e, "rollback transaction failed")
log.Errorf("rollback transaction failed: %v", e)
return e
}
return err
}
span.AddEvent("commit transaction")
if err := tx.Commit(); err != nil {
tracelib.RecordError(span, err, "commit transaction failed")
log.Errorf("commit transaction failed: %v", err)
return err
}
return nil
}
}
// ReadOrCreate read or create instance to database, retry to read when met a duplicate key error after the creating
func ReadOrCreate(ctx context.Context, md interface{}, col1 string, cols ...string) (created bool, id int64, err error) {
getter, ok := md.(interface {
GetID() int64
})
if !ok {
err = fmt.Errorf("missing GetID method for the model %T", md)
return
}
defer func() {
if !created && err == nil { // found in the database
id = getter.GetID()
}
}()
o, err := FromContext(ctx)
if err != nil {
return
}
cols = append([]string{col1}, cols...)
err = o.Read(md, cols...)
if err == nil { // found in the database
return
}
if !errors.Is(err, orm.ErrNoRows) { // met a error when read database
return
}
// not found in the database, try to create one
err = WithTransaction(func(ctx context.Context) error {
o, err := FromContext(ctx)
if err != nil {
return err
}
id, err = o.Insert(md)
return err
})(ctx)
if err == nil { // create success
created = true
return
}
// got a duplicate key error, try to read again
if IsDuplicateKeyError(err) {
err = o.Read(md, cols...)
}
return
}
// CreateInClause creates an IN clause with the provided sql and args to avoid the sql injection
// The sql should return the ID list with the specific condition(e.g. select id from table1 where column1=?)
// The sql runs as a prepare statement with the "?" be populated rather than concat string directly
// The returning in clause is a string like "IN (id1, id2, id3, ...)"
func CreateInClause(ctx context.Context, sql string, args ...interface{}) (string, error) {
ormer, err := FromContext(ctx)
if err != nil {
return "", err
}
ids := []int64{}
if _, err = ormer.Raw(sql, args...).QueryRows(&ids); err != nil {
return "", err
}
// no matching, append -1 as the id
if len(ids) == 0 {
ids = append(ids, -1)
}
var idStrs []string
for _, id := range ids {
idStrs = append(idStrs, strconv.FormatInt(id, 10))
}
// there is no too many arguments issue like https://github.com/goharbor/harbor/issues/12269
// when concat the in clause directly
return fmt.Sprintf(`IN (%s)`, strings.Join(idStrs, ",")), nil
}
// Escape ..
func Escape(str string) string {
str = strings.Replace(str, `\`, `\\`, -1)
str = strings.Replace(str, `%`, `\%`, -1)
str = strings.Replace(str, `_`, `\_`, -1)
return str
}
// ParamPlaceholderForIn returns a string that contains placeholders for sql keyword "in"
// e.g. n=3, returns "?,?,?"
func ParamPlaceholderForIn(n int) string {
placeholders := []string{}
for i := 0; i < n; i++ {
placeholders = append(placeholders, "?")
}
return strings.Join(placeholders, ",")
}
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
// to DDL and other statements that do not accept parameters) to be used as part
// of an SQL statement. For example:
//
// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
//
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
// that PostgreSQL provides ('E') will be prepended to the string.
func QuoteLiteral(literal string) string {
// This follows the PostgreSQL internal algorithm for handling quoted literals
// from libpq, which can be found in the "PQEscapeStringInternal" function,
// which is found in the libpq/fe-exec.c source file:
// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
//
// substitute any single-quotes (') with two single-quotes ('')
literal = strings.Replace(literal, `'`, `''`, -1)
// determine if the string has any backslashes (\) in it.
// if it does, replace any backslashes (\) with two backslashes (\\)
// then, we need to wrap the entire string with a PostgreSQL
// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
// also add a space before the "E"
if strings.Contains(literal, `\`) {
literal = strings.Replace(literal, `\`, `\\`, -1)
literal = ` E'` + literal + `'`
} else {
// otherwise, we can just wrap the literal with a pair of single quotes
literal = `'` + literal + `'`
}
return literal
}