Skip to content

Commit

Permalink
refactor: replaces "where" logic in DML statements to start using the…
Browse files Browse the repository at this point in the history
… primary key of the table instead of the hardcoded 'id' column
  • Loading branch information
jorgerojas26 committed Mar 10, 2024
1 parent 509f881 commit ae0c0dc
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 89 deletions.
128 changes: 103 additions & 25 deletions components/ResultsTable.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ import (
type ResultsTableState struct {
listOfDbChanges *[]models.DbDmlChange
listOfDbInserts *[]models.DbInsert
dbReference string
currentSort string
error string
currentSort string
dbReference string
records [][]string
columns [][]string
constraints [][]string
Expand Down Expand Up @@ -201,7 +201,7 @@ func (table *ResultsTable) AddInsertedRows() {
for j, cell := range row {
tableCell := tview.NewTableCell(cell)
tableCell.SetExpansion(1)
tableCell.SetReference(inserts[i].RowId)
tableCell.SetReference(inserts[i].PrimaryKeyValue)

tableCell.SetTextColor(tview.Styles.PrimaryTextColor)
tableCell.SetBackgroundColor(InsertColor)
Expand Down Expand Up @@ -397,7 +397,7 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event
for i, insertedRow := range *table.state.listOfDbInserts {
cellReference := table.GetCell(selectedRowIndex, 0).GetReference()

if cellReference != nil && insertedRow.RowId.String() == cellReference.(uuid.UUID).String() {
if cellReference != nil && insertedRow.PrimaryKeyValue.String() == cellReference.(uuid.UUID).String() {
isAnInsertedRow = true
indexOfInsertedRow = i
}
Expand Down Expand Up @@ -444,11 +444,11 @@ func (table *ResultsTable) tableInputCapture(event *tcell.EventKey) *tcell.Event
}

newInsert := models.DbInsert{
Table: table.GetDBReference(),
Columns: table.GetRecords()[0],
Values: newRow,
RowId: newRowUuid,
Option: 1,
Table: table.GetDBReference(),
Columns: table.GetRecords()[0],
Values: newRow,
PrimaryKeyValue: newRowUuid,
Option: 1,
}

*table.state.listOfDbInserts = append(*table.state.listOfDbInserts, newInsert)
Expand Down Expand Up @@ -876,6 +876,7 @@ func (table *ResultsTable) FetchRecords() [][]string {
if table.GetIsFiltering() {
table.SetIsFiltering(false)
}

columns, _ := table.DBDriver.GetTableColumns(table.Tree.GetSelectedDatabase(), tableName)
constraints, _ := table.DBDriver.GetConstraints(tableName)
foreignKeys, _ := table.DBDriver.GetForeignKeys(tableName)
Expand Down Expand Up @@ -965,7 +966,7 @@ func (table *ResultsTable) StartEditingCell(row int, col int, callback func(newV

func (table *ResultsTable) CheckIfRowIsInserted(rowId uuid.UUID) bool {
for _, insertedRow := range *table.state.listOfDbInserts {
if insertedRow.RowId == rowId {
if insertedRow.PrimaryKeyValue == rowId {
return true
}
}
Expand All @@ -975,7 +976,7 @@ func (table *ResultsTable) CheckIfRowIsInserted(rowId uuid.UUID) bool {

func (table *ResultsTable) MutateInsertedRowCell(rowId uuid.UUID, colIndex int, newValue string) {
for i, insertedRow := range *table.state.listOfDbInserts {
if insertedRow.RowId == rowId {
if insertedRow.PrimaryKeyValue == rowId {
(*table.state.listOfDbInserts)[i].Values[colIndex] = newValue
}
}
Expand All @@ -997,13 +998,13 @@ func (table *ResultsTable) AppendNewChange(changeType string, tableName string,
}

if !isInsertedRow {
selectedRowId := table.GetRecords()[rowIndex][0]
primaryKeyValue, primaryKeyColumnName := table.GetPrimaryKeyValue(rowIndex)

alreadyExists := false
indexOfChange := -1

for i, change := range *table.state.listOfDbChanges {
if change.RowId == selectedRowId && change.Column == table.GetColumnNameByIndex(colIndex) {
if change.PrimaryKeyValue == primaryKeyValue && change.Column == table.GetColumnNameByIndex(colIndex) {
alreadyExists = true
indexOfChange = i
}
Expand Down Expand Up @@ -1039,12 +1040,13 @@ func (table *ResultsTable) AppendNewChange(changeType string, tableName string,
}
} else {
newChange := models.DbDmlChange{
Type: changeType,
Table: tableName,
Column: columnName,
Value: value,
RowId: selectedRowId,
Option: 1,
Type: changeType,
Table: tableName,
Column: columnName,
Value: value,
PrimaryKeyColumnName: primaryKeyColumnName,
PrimaryKeyValue: primaryKeyValue,
Option: 1,
}

*table.state.listOfDbChanges = append(*table.state.listOfDbChanges, newChange)
Expand Down Expand Up @@ -1079,12 +1081,13 @@ func (table *ResultsTable) AppendNewChange(changeType string, tableName string,
}

newChange := models.DbDmlChange{
Type: changeType,
Table: tableName,
Column: "",
Value: "",
RowId: selectedRowId,
Option: 1,
Type: changeType,
Table: tableName,
Column: "",
Value: "",
PrimaryKeyColumnName: primaryKeyColumnName,
PrimaryKeyValue: primaryKeyValue,
Option: 1,
}

*table.state.listOfDbChanges = append(*table.state.listOfDbChanges, newChange)
Expand All @@ -1096,3 +1099,78 @@ func (table *ResultsTable) AppendNewChange(changeType string, tableName string,
}
}
}

func (table *ResultsTable) GetPrimaryKeyValue(rowIndex int) (string, string) {
provider := table.DBDriver.GetProvider()
columns := table.GetColumns()

primaryKeyColumnName := ""
primaryKeyValue := ""

switch provider {
case "mysql":
keyColumnIndex := -1
primaryKeyColumnIndex := -1

for i, col := range columns[0] {
if col == "Key" {
keyColumnIndex = i
}
}

for i, col := range columns {
if col[keyColumnIndex] == "PRI" {
primaryKeyColumnIndex = i - 1
primaryKeyColumnName = col[0]
}
}

if primaryKeyColumnIndex != -1 {
primaryKeyValue = table.GetRecords()[rowIndex][primaryKeyColumnIndex]
}

case "postgres":
keyColumnIndex := -1
primaryKeyColumnIndex := -1

for i, col := range columns[0] {
if col == "column_default" {
keyColumnIndex = i
}
}

for i, col := range columns {
if strings.Contains(col[keyColumnIndex], "nextval") {
primaryKeyColumnIndex = i - 1
primaryKeyColumnName = col[0]
}
}

if primaryKeyColumnIndex != -1 {
primaryKeyValue = table.GetRecords()[rowIndex][primaryKeyColumnIndex]
}

case "sqlite3":
keyColumnIndex := -1
primaryKeyColumnIndex := -1

for i, col := range columns[0] {
if col == "pk" {
keyColumnIndex = i
}
}

for i, col := range columns {
if col[keyColumnIndex] == "1" {
primaryKeyColumnIndex = i - 1
primaryKeyColumnName = col[0]
}
}

if primaryKeyColumnIndex != -1 {
primaryKeyValue = table.GetRecords()[rowIndex][primaryKeyColumnIndex]
}
}

return primaryKeyValue, primaryKeyColumnName
}
4 changes: 2 additions & 2 deletions drivers/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ type Driver interface {
GetForeignKeys(table string) ([][]string, error)
GetIndexes(table string) ([][]string, error)
GetRecords(table, where, sort string, offset, limit int) ([][]string, int, error)
UpdateRecord(table, column, value, id string) error
DeleteRecord(table string, id string) error
UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) error
DeleteRecord(table string, primaryKeyColumnName, primaryKeyValue string) error
ExecuteDMLStatement(query string) (string, error)
ExecuteQuery(query string) ([][]string, error)
ExecutePendingChanges(changes []models.DbDmlChange, inserts []models.DbInsert) error
Expand Down
26 changes: 14 additions & 12 deletions drivers/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -312,16 +312,16 @@ func (db *MySQL) ExecuteQuery(query string) (results [][]string, err error) {
}

// TODO: Rewrites this logic to use the primary key instead of the id
func (db *MySQL) UpdateRecord(table, column, value, id string) error {
query := fmt.Sprintf("UPDATE %s SET %s = \"%s\" WHERE id = \"%s\"", table, column, value, id)
func (db *MySQL) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) error {
query := fmt.Sprintf("UPDATE %s SET %s = \"%s\" WHERE %s = \"%s\"", table, column, value, primaryKeyColumnName, primaryKeyValue)
_, err := db.Connection.Exec(query)

return err
}

// TODO: Rewrites this logic to use the primary key instead of the id
func (db *MySQL) DeleteRecord(table, id string) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = \"%s\"", table, id)
func (db *MySQL) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) error {
query := fmt.Sprintf("DELETE FROM %s WHERE %s = \"%s\"", table, primaryKeyColumnName, primaryKeyValue)
_, err := db.Connection.Exec(query)

return err
Expand All @@ -348,10 +348,11 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m

// Group changes by RowId and Table
for _, change := range changes {
if change.Type == "UPDATE" {
key := fmt.Sprintf("%s|%s", change.Table, change.RowId)
switch change.Type {
case "UPDATE":
key := fmt.Sprintf("%s|%s|%s", change.Table, change.PrimaryKeyColumnName, change.PrimaryKeyValue)
groupedUpdated[key] = append(groupedUpdated[key], change)
} else if change.Type == "DELETE" {
case "DELETE":
groupedDeletes = append(groupedDeletes, change)
}
}
Expand All @@ -363,7 +364,8 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m
// Split key into table and rowId
splitted := strings.Split(key, "|")
table := splitted[0]
rowId := splitted[1]
primaryKeyColumnName := splitted[1]
primaryKeyValue := splitted[2]

for _, change := range changes {
columns = append(columns, fmt.Sprintf("%s='%s'", change.Column, change.Value))
Expand All @@ -372,8 +374,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m
// Merge all column updates
updateClause := strings.Join(columns, ", ")

// TODO: Rewrites this logic to use the primary key instead of the id
query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s';", table, updateClause, rowId)
query := fmt.Sprintf("UPDATE %s SET %s WHERE %s = '%s';", table, updateClause, primaryKeyColumnName, primaryKeyValue)

queries = append(queries, query)
}
Expand All @@ -383,8 +384,8 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m
query := ""

statementType = "DELETE FROM"
// TODO: Rewrites this logic to use the primary key instead of the id
query = fmt.Sprintf("%s %s WHERE id = \"%s\"", statementType, delete.Table, delete.RowId)

query = fmt.Sprintf("%s %s WHERE %s = \"%s\"", statementType, delete.Table, delete.PrimaryKeyColumnName, delete.PrimaryKeyValue)

if query != "" {
queries = append(queries, query)
Expand Down Expand Up @@ -415,6 +416,7 @@ func (db *MySQL) ExecutePendingChanges(changes []models.DbDmlChange, inserts []m
}

for _, query := range queries {

_, err = tx.Exec(query)

if err != nil {
Expand Down
23 changes: 11 additions & 12 deletions drivers/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ func (db *Postgres) GetDatabases() (databases []string, err error) {
return databases, nil
}

// TODO: Implement GetSchemas function
func (db *Postgres) GetTables(database string) (tables map[string][]string, err error) {
tables = make(map[string][]string)

Expand All @@ -81,7 +80,7 @@ func (db *Postgres) GetTables(database string) (tables map[string][]string, err

func (db *Postgres) GetTableColumns(database, table string) (results [][]string, error error) {
tableName := strings.Split(table, ".")[1]
rows, err := db.Connection.Query(fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = '%s' AND table_name = '%s'", database, tableName))
rows, err := db.Connection.Query(fmt.Sprintf("SELECT column_name, data_type, is_nullable, column_default FROM information_schema.columns WHERE table_catalog = '%s' AND table_name = '%s' ORDER by ordinal_position", database, tableName))
if err != nil {
return results, err
}
Expand Down Expand Up @@ -332,17 +331,15 @@ func (db *Postgres) GetRecords(table, where, sort string, offset, limit int) (re
return
}

// TODO: Rewrites this logic to use the primary key instead of the id
func (db *Postgres) UpdateRecord(table, column, value, id string) (err error) {
query := fmt.Sprintf("UPDATE %s SET %s = '%s' WHERE id = '%s'", table, column, value, id)
func (db *Postgres) UpdateRecord(table, column, value, primaryKeyColumnName, primaryKeyValue string) (err error) {
query := fmt.Sprintf("UPDATE %s SET %s = '%s' WHERE '%s' = '%s'", table, column, value, primaryKeyColumnName, primaryKeyValue)
_, err = db.Connection.Exec(query)

return err
}

// TODO: Rewrites this logic to use the primary key instead of the id
func (db *Postgres) DeleteRecord(table, id string) (err error) {
query := fmt.Sprintf("DELETE FROM %s WHERE id = '%s'", table, id)
func (db *Postgres) DeleteRecord(table, primaryKeyColumnName, primaryKeyValue string) (err error) {
query := fmt.Sprintf("DELETE FROM %s WHERE '%s' = '%s'", table, primaryKeyColumnName, primaryKeyValue)
_, err = db.Connection.Exec(query)

return err
Expand Down Expand Up @@ -403,7 +400,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
for _, change := range changes {
if change.Type == "UPDATE" {
tableName := strings.Split(change.Table, ".")[1]
key := fmt.Sprintf("%s|%s", tableName, change.RowId)
key := fmt.Sprintf("%s|%s|%s", tableName, change.PrimaryKeyColumnName, change.PrimaryKeyValue)
groupedUpdated[key] = append(groupedUpdated[key], change)
} else if change.Type == "DELETE" {
groupedDeletes = append(groupedDeletes, change)
Expand All @@ -417,7 +414,8 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
// Split key into table and rowId
splitted := strings.Split(key, "|")
table := splitted[0]
rowId := splitted[1]
PrimaryKeyColumnName := splitted[1]
primaryKeyValue := splitted[2]

for _, change := range changes {
columns = append(columns, fmt.Sprintf("%s='%s'", change.Column, change.Value))
Expand All @@ -426,7 +424,7 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts
// Merge all column updates
updateClause := strings.Join(columns, ", ")

query := fmt.Sprintf("UPDATE %s SET %s WHERE id = '%s';", table, updateClause, rowId)
query := fmt.Sprintf("UPDATE %s SET %s WHERE '%s' = '%s';", table, updateClause, PrimaryKeyColumnName, primaryKeyValue)

queries = append(queries, query)
}
Expand All @@ -437,7 +435,8 @@ func (db *Postgres) ExecutePendingChanges(changes []models.DbDmlChange, inserts

statementType = "DELETE FROM"
tableName := strings.Split(delete.Table, ".")[1]
query = fmt.Sprintf("%s %s WHERE id = \"%s\"", statementType, tableName, delete.RowId)

query = fmt.Sprintf("%s %s WHERE \"%s\" = '%s'", statementType, tableName, delete.PrimaryKeyColumnName, delete.PrimaryKeyValue)

if query != "" {
queries = append(queries, query)
Expand Down
Loading

0 comments on commit ae0c0dc

Please sign in to comment.