/
db.go
143 lines (116 loc) · 3.59 KB
/
db.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
package testutils
import (
"database/sql"
"fmt"
"github.com/stretchr/testify/assert"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"os"
"testing"
"github.com/DATA-DOG/go-txdb"
gamedb "github.com/curio-research/keystone/db"
"github.com/curio-research/keystone/state"
"github.com/joho/godotenv"
"github.com/stretchr/testify/require"
"gorm.io/driver/mysql"
)
var testSQLDSN string
var testSQLiteDBPath = "test.db"
func init() {
if err := godotenv.Load("../.env"); err != nil {
fmt.Println("Failed to load .env file")
}
testSQLDSN = os.Getenv("SQL_DSN")
txdb.Register("txdb", "mysql", testSQLDSN)
}
func SetupMySQLTestDB(t *testing.T, testGameID string, deleteTables bool, accessors map[interface{}]*state.TableBaseAccessor[any]) (*gamedb.MySQLSaveStateHandler, *gamedb.MySQLSaveTransactionHandler, *sql.DB) {
var db *sql.DB
db, err := sql.Open("txdb", testSQLDSN)
if err != nil {
require.Nil(t, err)
}
require.Nil(t, db.Ping())
if deleteTables {
deleteAllTablesMySQL(t, db)
}
sqlDialector := mysql.New(mysql.Config{Conn: db})
mySQLSaveStateHandler, mySQLSaveTxHandler, err := gamedb.SQLHandlersFromDialector(sqlDialector, testGameID, accessors)
require.Nil(t, err)
return mySQLSaveStateHandler, mySQLSaveTxHandler, db
}
func deleteAllTablesMySQL(t *testing.T, db *sql.DB) {
rows, err := db.Query("SHOW TABLES")
require.Nil(t, err)
defer rows.Close()
var tables []string
for rows.Next() {
var table string
require.Nil(t, rows.Scan(&table))
tables = append(tables, table)
}
// Drop each table
for _, table := range tables {
_, err = db.Exec(fmt.Sprintf("DROP TABLE %s", table))
if err != nil {
fmt.Println("Failed to drop table", table, "err", err)
}
}
}
// setup local sqlite test db
func SetupSQLiteTestDB(t *testing.T, testGameID string, deleteTables bool, accessors map[interface{}]*state.TableBaseAccessor[any]) (*gamedb.MySQLSaveStateHandler, *gamedb.MySQLSaveTransactionHandler, *sql.DB) {
db, err := sql.Open("sqlite3", testSQLiteDBPath)
if err != nil {
require.Nil(t, err)
}
require.Nil(t, db.Ping())
if deleteTables {
deleteAllTablesSQLite(t)
}
gormDB, err := gorm.Open(sqlite.Open(testSQLiteDBPath))
mySQLSaveStateHandler, mySQLSaveTxHandler, err := gamedb.SQLHandlersFromDialector(gormDB.Dialector, testGameID, accessors)
require.Nil(t, err)
return mySQLSaveStateHandler, mySQLSaveTxHandler, db
}
func ResetSQLiteTestDB() error {
dbFileName := testSQLiteDBPath
// Check if the file exists
if _, err := os.Stat(dbFileName); err == nil {
// File exists, so delete it
err := os.Remove(dbFileName)
if err != nil {
return err
}
}
// Create an empty file
file, err := os.Create(dbFileName)
if err != nil {
return err
}
defer file.Close()
return nil
}
// delete all tables in a sqlite db
func deleteAllTablesSQLite(t *testing.T) {
db, err := gorm.Open(sqlite.Open(testSQLiteDBPath), &gorm.Config{})
if err != nil {
panic("failed to connect database")
}
// get list of table names
tableNames := getSQLiteTableNames(db)
// Iterate through the table names and drop each table
for _, tableName := range tableNames {
if err := db.Exec("DROP TABLE " + tableName + ";").Error; err != nil {
panic("Failed to drop table " + tableName + ": " + err.Error())
}
}
// verify that table names array is empty
updatedTableNames := getSQLiteTableNames(db)
assert.Equal(t, 0, len(updatedTableNames))
}
func getSQLiteTableNames(db *gorm.DB) []string {
var tableNames []string
if err := db.Raw("SELECT name FROM sqlite_master WHERE type='table';").Scan(&tableNames).Error; err != nil {
panic("Failed to fetch table names: " + err.Error())
}
return tableNames
}