Skip to content

Commit

Permalink
supported multiple result sets in *sql.Rows result from db.QueryConte…
Browse files Browse the repository at this point in the history
…xt (#735)

* supported multiple result sets in *sql.Rows result from db.QueryContext

* Added ResultSetIterator + tests for iterate over sql.Rows with multiple result sets

* switch package sqlutil to sqlutil_test for data/sqlutil/sql_test.go

* fix check limit + add test case for reaching limit
  • Loading branch information
asmyasnikov committed Sep 20, 2023
1 parent abf7655 commit 6485d5a
Show file tree
Hide file tree
Showing 3 changed files with 439 additions and 65 deletions.
113 changes: 64 additions & 49 deletions data/sqlutil/dynamic_frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,63 +25,67 @@ func findDataTypes(rows Rows, rowLimit int64, types []*sql.ColumnType) ([]Field,

var returnData [][]interface{}

for rows.Next() {
if i == rowLimit {
break
}
for {
for rows.Next() {
if i == rowLimit {
break
}
row := make([]interface{}, len(types))
for i := range row {
row[i] = new(interface{})
}
err := rows.Scan(row)
if err != nil {
return nil, nil, err
}

row := make([]interface{}, len(types))
for i := range row {
row[i] = new(interface{})
}
err := rows.Scan(row)
if err != nil {
return nil, nil, err
}
returnData = append(returnData, row)

returnData = append(returnData, row)
if len(fields) == len(types) {
// found all data types. keep looping to load all the return data
continue
}

if len(fields) == len(types) {
// found all data types. keep looping to load all the return data
continue
}
for colIdx, col := range row {
val := *col.(*interface{})
var field Field
colType := types[colIdx]
switch val.(type) {
case time.Time, *time.Time:
field.converter = &TimeToNullableTime
field.kind = "time"
field.name = colType.Name()
case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.converter = &IntOrFloatToNullableFloat64
field.kind = "float64"
field.name = colType.Name()
case string:
field.converter = &converters.AnyToNullableString
field.kind = STRING
field.name = colType.Name()
case []uint8:
field.converter = &converters.Uint8ArrayToNullableString
field.kind = STRING
field.name = colType.Name()
case nil:
continue
default:
field.converter = &converters.AnyToNullableString
field.kind = STRING
field.name = colType.Name()
}

for colIdx, col := range row {
val := *col.(*interface{})
var field Field
colType := types[colIdx]
switch val.(type) {
case time.Time, *time.Time:
field.converter = &TimeToNullableTime
field.kind = "time"
field.name = colType.Name()
case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.converter = &IntOrFloatToNullableFloat64
field.kind = "float64"
field.name = colType.Name()
case string:
field.converter = &converters.AnyToNullableString
field.kind = STRING
field.name = colType.Name()
case []uint8:
field.converter = &converters.Uint8ArrayToNullableString
field.kind = STRING
field.name = colType.Name()
case nil:
continue
default:
field.converter = &converters.AnyToNullableString
field.kind = STRING
field.name = colType.Name()
fields[colIdx] = field
}

fields[colIdx] = field
i++
}
if i == rowLimit || !rows.NextResultSet() {
break
}

i++
}

fieldList := []Field{}
fieldList := make([]Field, len(types))
for colIdx, col := range types {
field, ok := fields[colIdx]
field.name = col.Name()
Expand All @@ -92,7 +96,7 @@ func findDataTypes(rows Rows, rowLimit int64, types []*sql.ColumnType) ([]Field,
name: col.Name(),
}
}
fieldList = append(fieldList, field)
fieldList[colIdx] = field
}

return fieldList, returnData, nil
Expand Down Expand Up @@ -172,6 +176,10 @@ type Field struct {
kind string
}

type ResultSetIterator interface {
NextResultSet() bool
}

type RowIterator interface {
Next() bool
Scan(dest ...interface{}) error
Expand All @@ -181,6 +189,13 @@ type Rows struct {
itr RowIterator
}

func (rs Rows) NextResultSet() bool {
if itr, has := rs.itr.(ResultSetIterator); has {
return itr.NextResultSet()
}
return false
}

func (rs Rows) Next() bool {
return rs.itr.Next()
}
Expand Down
38 changes: 22 additions & 16 deletions data/sqlutil/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,31 @@ func FrameFromRows(rows *sql.Rows, rowLimit int64, converters ...Converter) (*da
frame := NewFrame(names, scanRow.Converters...)

var i int64
for rows.Next() {
if i == rowLimit {
frame.AppendNotices(data.Notice{
Severity: data.NoticeSeverityWarning,
Text: fmt.Sprintf("Results have been limited to %v because the SQL row limit was reached", rowLimit),
})
break
}
for {
// first iterate over rows may be nop if not switched result set to next
for rows.Next() {
if i == rowLimit {
frame.AppendNotices(data.Notice{
Severity: data.NoticeSeverityWarning,
Text: fmt.Sprintf("Results have been limited to %v because the SQL row limit was reached", rowLimit),
})
break
}

r := scanRow.NewScannableRow()
if err := rows.Scan(r...); err != nil {
return nil, err
}
r := scanRow.NewScannableRow()
if err := rows.Scan(r...); err != nil {
return nil, err
}

if err := Append(frame, r, scanRow.Converters...); err != nil {
return nil, err
}
if err := Append(frame, r, scanRow.Converters...); err != nil {
return nil, err
}

i++
i++
}
if i == rowLimit || !rows.NextResultSet() {
break
}
}

if err := rows.Err(); err != nil {
Expand Down
Loading

0 comments on commit 6485d5a

Please sign in to comment.