-
Notifications
You must be signed in to change notification settings - Fork 1
/
repository.go
162 lines (139 loc) · 4.76 KB
/
repository.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
package gorm_common_repository
import (
"gorm.io/gorm"
)
// CommonRepository is a generic repository for common database operations using GORM.
// It is parameterized with a generic type Model, representing the model structure used in the repository.
type CommonRepository[Model any] struct {
genericModelStruct Model // The generic model structure for this repository.
tableName string // The name of the database table associated with the model.
dbClient *gorm.DB // The GORM database client used for database operations.
}
// NewCommonRepository creates a new instance of CommonRepository for a specific model type.
// It takes the tableName as the name of the associated database table and dbClient as the GORM database client.
// The returned repository instance can be used to perform common database operations for the specified model type.
func NewCommonRepository[Model any](tableName string, dbClient *gorm.DB) CommonRepositoryInterface[Model] {
return CommonRepository[Model]{
tableName: tableName,
dbClient: dbClient,
}
}
func (repo CommonRepository[Model]) CreateRecord(dataObject Model) (Model, error) {
err := repo.dbClient.
Table(repo.tableName).
Create(&dataObject).Error
if err != nil {
// Check if the error is a duplicate entry error, detected by SQL unique constraint.
if ok, _ := ParseDuplicateEntry(err); ok {
return repo.genericModelStruct, gorm.ErrDuplicatedKey
}
return repo.genericModelStruct, err
}
return dataObject, nil
}
func (repo CommonRepository[Model]) CreateBulkRecords(dataObjects []Model) ([]Model, error) {
err := repo.dbClient.
Table(repo.tableName).
Create(&dataObjects).Error
if err != nil {
// Check if the error is a duplicate entry error, detected by SQL unique constraint.
if ok, _ := ParseDuplicateEntry(err); ok {
return nil, gorm.ErrDuplicatedKey
}
return nil, err
}
return dataObjects, nil
}
func (repo CommonRepository[Model]) GetRecordByID(queryID interface{}) (Model, error) {
return repo.GetRecordByAttributes(map[string]interface{}{
"id": queryID,
})
}
func (repo CommonRepository[Model]) GetRecordByAttributes(queryParams map[string]interface{}) (Model, error) {
var dataObject Model
err := repo.dbClient.
Table(repo.tableName).
Where(queryParams).
First(&dataObject).Error
if err != nil {
return repo.genericModelStruct, err
}
return dataObject, nil
}
func (repo CommonRepository[Model]) GetRecordsForMultipleIDs(queryIDs []interface{}) ([]Model, error) {
return repo.GetRecordsByMultipleAttributeValues(map[string][]interface{}{
"id": queryIDs,
})
}
func (repo CommonRepository[Model]) GetRecordsByMultipleAttributeValues(queryValues map[string][]interface{}) ([]Model, error) {
var dataObjects []Model
err := repo.dbClient.
Table(repo.tableName).
Where(queryValues).
Find(&dataObjects).Error
if err != nil {
return nil, err
}
return dataObjects, nil
}
func (repo CommonRepository[Model]) GetRecordsByQueryParams(queryParams *QueryParams) ([]Model, error) {
var dataObjects []Model
// Preload all related data
dbQuery := repo.dbClient.Table(repo.tableName)
if queryParams != nil {
// Apply Filtering
dbQuery.Scopes(queryParams.FilterByParams(repo.tableName))
// Apply pagination
dbQuery.Scopes(queryParams.Paginate())
// Apply sorting
dbQuery.Scopes(queryParams.SortByDirection())
}
err := dbQuery.Find(&dataObjects).Error
if err != nil {
return nil, err
}
return dataObjects, nil
}
func (repo CommonRepository[Model]) GetRecordCount(queryParams *QueryParams) (int64, error) {
var totalCount int64
// Preload all related data
dbQuery := repo.dbClient.Table(repo.tableName)
if queryParams != nil {
// Apply Filtering
dbQuery.Scopes(queryParams.FilterByParams(repo.tableName))
}
err := dbQuery.Count(&totalCount).Error
if err != nil {
return 0, err
}
return totalCount, nil
}
func (repo CommonRepository[Model]) UpdateRecordByID(queryID interface{}, data map[string]interface{}) error {
return repo.UpdateRecordsByAttributes(map[string]interface{}{"id": queryID}, data)
}
func (repo CommonRepository[Model]) UpdateRecordsByAttributes(queryParams map[string]interface{}, data map[string]interface{}) error {
res := repo.dbClient.
Table(repo.tableName).
Where(queryParams).
UpdateColumns(data)
if res.Error != nil {
return res.Error
}
return nil
}
func (repo CommonRepository[Model]) DeleteRecordByID(queryID interface{}) error {
return repo.DeleteRecordsByAttributes(map[string]interface{}{"id": queryID})
}
func (repo CommonRepository[Model]) DeleteRecordsByAttributes(queryParams map[string]interface{}) error {
res := repo.dbClient.
Table(repo.tableName).
Where(queryParams).
Delete(&repo.genericModelStruct)
if res.Error != nil {
return res.Error
}
if res.RowsAffected == 0 {
return gorm.ErrRecordNotFound
}
return nil
}