Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supported multiple result sets in *sql.Rows result from db.QueryContext #735

Merged
merged 4 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 64 additions & 49 deletions data/sqlutil/dynamic_frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,63 +25,67 @@ func findDataTypes(rows Rows, rowLimit int64, types []*sql.ColumnType) ([]Field,

var returnData [][]interface{}

for rows.Next() {
if i == rowLimit {
break
}
for {
for rows.Next() {
if i == rowLimit {
break
}
row := make([]interface{}, len(types))
for i := range row {
row[i] = new(interface{})
}
err := rows.Scan(row)
if err != nil {
return nil, nil, err
}

row := make([]interface{}, len(types))
for i := range row {
row[i] = new(interface{})
}
err := rows.Scan(row)
if err != nil {
return nil, nil, err
}
returnData = append(returnData, row)

returnData = append(returnData, row)
if len(fields) == len(types) {
// found all data types. keep looping to load all the return data
continue
}

if len(fields) == len(types) {
// found all data types. keep looping to load all the return data
continue
}
for colIdx, col := range row {
val := *col.(*interface{})
var field Field
colType := types[colIdx]
switch val.(type) {
case time.Time, *time.Time:
field.converter = &TimeToNullableTime
field.kind = "time"
field.name = colType.Name()
case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.converter = &IntOrFloatToNullableFloat64
field.kind = "float64"
field.name = colType.Name()
case string:
field.converter = &converters.AnyToNullableString
field.kind = STRING
field.name = colType.Name()
case []uint8:
field.converter = &converters.Uint8ArrayToNullableString
field.kind = STRING
field.name = colType.Name()
case nil:
continue
default:
field.converter = &converters.AnyToNullableString
field.kind = STRING
field.name = colType.Name()
}

for colIdx, col := range row {
val := *col.(*interface{})
var field Field
colType := types[colIdx]
switch val.(type) {
case time.Time, *time.Time:
field.converter = &TimeToNullableTime
field.kind = "time"
field.name = colType.Name()
case float64, float32, int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
field.converter = &IntOrFloatToNullableFloat64
field.kind = "float64"
field.name = colType.Name()
case string:
field.converter = &converters.AnyToNullableString
field.kind = STRING
field.name = colType.Name()
case []uint8:
field.converter = &converters.Uint8ArrayToNullableString
field.kind = STRING
field.name = colType.Name()
case nil:
continue
default:
field.converter = &converters.AnyToNullableString
field.kind = STRING
field.name = colType.Name()
fields[colIdx] = field
}

fields[colIdx] = field
i++
}
if i == rowLimit || !rows.NextResultSet() {
break
}

i++
}

fieldList := []Field{}
fieldList := make([]Field, len(types))
for colIdx, col := range types {
field, ok := fields[colIdx]
field.name = col.Name()
Expand All @@ -92,7 +96,7 @@ func findDataTypes(rows Rows, rowLimit int64, types []*sql.ColumnType) ([]Field,
name: col.Name(),
}
}
fieldList = append(fieldList, field)
fieldList[colIdx] = field
}

return fieldList, returnData, nil
Expand Down Expand Up @@ -172,6 +176,10 @@ type Field struct {
kind string
}

type ResultSetIterator interface {
NextResultSet() bool
}

type RowIterator interface {
Next() bool
Scan(dest ...interface{}) error
Expand All @@ -181,6 +189,13 @@ type Rows struct {
itr RowIterator
}

func (rs Rows) NextResultSet() bool {
if itr, has := rs.itr.(ResultSetIterator); has {
return itr.NextResultSet()
}
return false
}

func (rs Rows) Next() bool {
return rs.itr.Next()
}
Expand Down
38 changes: 22 additions & 16 deletions data/sqlutil/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,31 @@ func FrameFromRows(rows *sql.Rows, rowLimit int64, converters ...Converter) (*da
frame := NewFrame(names, scanRow.Converters...)

var i int64
for rows.Next() {
if i == rowLimit {
frame.AppendNotices(data.Notice{
Severity: data.NoticeSeverityWarning,
Text: fmt.Sprintf("Results have been limited to %v because the SQL row limit was reached", rowLimit),
})
break
}
for {
// first iterate over rows may be nop if not switched result set to next
for rows.Next() {
if i == rowLimit {
frame.AppendNotices(data.Notice{
Severity: data.NoticeSeverityWarning,
Text: fmt.Sprintf("Results have been limited to %v because the SQL row limit was reached", rowLimit),
})
break
}

r := scanRow.NewScannableRow()
if err := rows.Scan(r...); err != nil {
return nil, err
}
r := scanRow.NewScannableRow()
if err := rows.Scan(r...); err != nil {
return nil, err
}

if err := Append(frame, r, scanRow.Converters...); err != nil {
return nil, err
}
if err := Append(frame, r, scanRow.Converters...); err != nil {
return nil, err
}

i++
i++
}
if i == rowLimit || !rows.NextResultSet() {
break
}
}

if err := rows.Err(); err != nil {
Expand Down
Loading