Skip to content

Commit

Permalink
Merge be2db39 into a1d5e64
Browse files Browse the repository at this point in the history
  • Loading branch information
abraithwaite committed Mar 31, 2021
2 parents a1d5e64 + be2db39 commit a15c058
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 9 deletions.
19 changes: 13 additions & 6 deletions named.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,28 @@ func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper)
return bound, arglist, nil
}

var valueBracketReg = regexp.MustCompile(`\([^(]*.[^(]\)\s*$`)
var valueBracketReg = regexp.MustCompile(`VALUES\s+(\([^(]*.[^(]\))`)

func fixBound(bound string, loop int) string {
loc := valueBracketReg.FindStringIndex(bound)
if len(loc) != 2 {
loc := valueBracketReg.FindAllStringSubmatchIndex(bound, -1)
// Either no VALUES () found or more than one found??
if len(loc) != 1 {
return bound
}
// defensive guard. loc should be len 4 representing the starting and
// ending index for the whole regex match and the starting + ending
// index for the single inside group
if len(loc[0]) != 4 {
return bound
}
var buffer bytes.Buffer

buffer.WriteString(bound[0:loc[1]])
buffer.WriteString(bound[0:loc[0][1]])
for i := 0; i < loop-1; i++ {
buffer.WriteString(",")
buffer.WriteString(bound[loc[0]:loc[1]])
buffer.WriteString(bound[loc[0][2]:loc[0][3]])
}
buffer.WriteString(bound[loc[1]:])
buffer.WriteString(bound[loc[0][1]:])
return buffer.String()
}

Expand Down
63 changes: 60 additions & 3 deletions named_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,10 @@ func TestNamedQueries(t *testing.T) {
{FirstName: "Ngani", LastName: "Laumape", Email: "nlaumape@ab.co.nz"},
}

insert := fmt.Sprintf("INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)\n", now)
insert := fmt.Sprintf(
"INSERT INTO person (first_name, last_name, email, added_at) VALUES (:first_name, :last_name, :email, %v)\n",
now,
)
_, err = db.NamedExec(insert, sls)
test.Error(err)

Expand All @@ -214,7 +217,7 @@ func TestNamedQueries(t *testing.T) {
}

_, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
VALUES (:first_name, :last_name, :email) `, slsMap)
VALUES (:first_name, :last_name, :email) ;--`, slsMap)
test.Error(err)

type A map[string]interface{}
Expand All @@ -226,7 +229,7 @@ func TestNamedQueries(t *testing.T) {
}

_, err = db.NamedExec(`INSERT INTO person (first_name, last_name, email)
VALUES (:first_name, :last_name, :email) `, typedMap)
VALUES (:first_name, :last_name, :email) ;--`, typedMap)
test.Error(err)

for _, p := range sls {
Expand Down Expand Up @@ -296,3 +299,57 @@ func TestNamedQueries(t *testing.T) {

})
}

func TestFixBounds(t *testing.T) {
table := []struct {
name, query, expect string
loop int
}{
{
name: `named syntax`,
query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last)`,
loop: 2,
},
{
name: `mysql syntax`,
query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?)`,
loop: 2,
},
{
name: `named syntax w/ trailer`,
query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) ;--`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last),(:name, :age, :first, :last) ;--`,
loop: 2,
},
{
name: `mysql syntax w/ trailer`,
query: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?) ;--`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?),(?, ?, ?, ?) ;--`,
loop: 2,
},
{
name: `not found test`,
query: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`,
expect: `INSERT INTO foo (a,b,c,d) (:name, :age, :first, :last)`,
loop: 2,
},
{
name: `found twice test`,
query: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`,
expect: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last) VALUES (:name, :age, :first, :last)`,
loop: 2,
},
}

for _, tc := range table {
t.Run(tc.name, func(t *testing.T) {
res := fixBound(tc.query, tc.loop)
if res != tc.expect {
t.Errorf("mismatched results")
}
})
}

}

0 comments on commit a15c058

Please sign in to comment.