Skip to content

Commit

Permalink
fix bug: get incorrect result for "in" expr in some case (#14602)
Browse files Browse the repository at this point in the history
fix bug: get incorrect result for "in" expr  in some case.

bug case:
```sql
drop table if exists t1;
create table t1(a int primary key, b int);
insert into t1 values (1,1),(2,2),(3,3);
select mo_ctl('dn', 'flush', 'select.t1');
select * from t1 where a in (3,3,3,2,1);
-- here, we should get 3 rows. but 5 rows now
```

Approved by: @badboynt1, @XuPeng-SH, @nnsgmsone, @heni02, @aunjgr
  • Loading branch information
ouyuanning committed Feb 22, 2024
1 parent d18e348 commit 862a18d
Show file tree
Hide file tree
Showing 6 changed files with 355 additions and 3 deletions.
333 changes: 333 additions & 0 deletions pkg/container/vector/vector.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package vector
import (
"bytes"
"fmt"
"slices"
"sort"
"unsafe"

Expand Down Expand Up @@ -3590,6 +3591,338 @@ func (v *Vector) GetMinMaxValue() (ok bool, minv, maxv []byte) {
return
}

// InplaceSortAndCompact @todo optimization in the future
func (v *Vector) InplaceSortAndCompact() {
switch v.GetType().Oid {
case types.T_bool:
col := MustFixedCol[bool](v)
sort.Slice(col, func(i, j int) bool {
return !col[i] && col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_bit:
col := MustFixedCol[uint64](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_int8:
col := MustFixedCol[int8](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_int16:
col := MustFixedCol[int16](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_int32:
col := MustFixedCol[int32](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_int64:
col := MustFixedCol[int64](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_uint8:
col := MustFixedCol[uint8](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_uint16:
col := MustFixedCol[uint16](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_uint32:
col := MustFixedCol[uint32](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_uint64:
col := MustFixedCol[uint64](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_float32:
col := MustFixedCol[float32](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_float64:
col := MustFixedCol[float64](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_date:
col := MustFixedCol[types.Date](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_datetime:
col := MustFixedCol[types.Datetime](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_time:
col := MustFixedCol[types.Time](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_timestamp:
col := MustFixedCol[types.Timestamp](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_enum:
col := MustFixedCol[types.Enum](v)
sort.Slice(col, func(i, j int) bool {
return col[i] < col[j]
})
newCol := slices.Compact(col)
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_decimal64:
col := MustFixedCol[types.Decimal64](v)
sort.Slice(col, func(i, j int) bool {
return col[i].Less(col[j])
})
newCol := slices.CompactFunc(col, func(a, b types.Decimal64) bool {
return a.Compare(b) == 0
})
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_decimal128:
col := MustFixedCol[types.Decimal128](v)
sort.Slice(col, func(i, j int) bool {
return col[i].Less(col[j])
})
newCol := slices.CompactFunc(col, func(a, b types.Decimal128) bool {
return a.Compare(b) == 0
})
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_TS:
col := MustFixedCol[types.TS](v)
sort.Slice(col, func(i, j int) bool {
return col[i].Less(col[j])
})
newCol := slices.CompactFunc(col, func(a, b types.TS) bool {
return a.Equal(b)
})
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_uuid:
col := MustFixedCol[types.Uuid](v)
sort.Slice(col, func(i, j int) bool {
return col[i].Lt(col[j])
})
newCol := slices.CompactFunc(col, func(a, b types.Uuid) bool {
return a.Compare(b) == 0
})
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}
case types.T_Rowid:
col := MustFixedCol[types.Rowid](v)
sort.Slice(col, func(i, j int) bool {
return col[i].Less(col[j])
})
newCol := slices.CompactFunc(col, func(a, b types.Rowid) bool {
return a.Equal(b)
})
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_char, types.T_varchar, types.T_json, types.T_binary, types.T_varbinary, types.T_blob, types.T_text:
col, area := MustVarlenaRawData(v)
sort.Slice(col, func(i, j int) bool {
return bytes.Compare(col[i].GetByteSlice(area), col[j].GetByteSlice(area)) < 0
})
newCol := slices.CompactFunc(col, func(a, b types.Varlena) bool {
return bytes.Equal(a.GetByteSlice(area), b.GetByteSlice(area))
})
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_array_float32:
col, area := MustVarlenaRawData(v)
sort.Slice(col, func(i, j int) bool {
return moarray.Compare[float32](
types.GetArray[float32](&col[i], area),
types.GetArray[float32](&col[j], area),
) < 0
})
newCol := slices.CompactFunc(col, func(a, b types.Varlena) bool {
return moarray.Compare[float32](
types.GetArray[float32](&a, area),
types.GetArray[float32](&b, area),
) == 0
})
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}

case types.T_array_float64:
col, area := MustVarlenaRawData(v)
sort.Slice(col, func(i, j int) bool {
return moarray.Compare[float64](
types.GetArray[float64](&col[i], area),
types.GetArray[float64](&col[j], area),
) < 0
})
newCol := slices.CompactFunc(col, func(a, b types.Varlena) bool {
return moarray.Compare[float64](
types.GetArray[float64](&a, area),
types.GetArray[float64](&b, area),
) == 0
})
if len(newCol) != len(col) {
v.CleanOnlyData()
v.SetSorted(true)
appendList(v, newCol, nil, nil)
}
}
}

func (v *Vector) InplaceSort() {
switch v.GetType().Oid {
case types.T_bool:
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/plan/partition_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,7 +675,7 @@ func PartitionFuncConstantFold(bat *batch.Batch, e *plan.Expr, proc *process.Pro
}
defer vec.Free(proc.Mp())

vec.InplaceSort()
vec.InplaceSortAndCompact()
data, err := vec.MarshalBinary()
if err != nil {
return nil, err
Expand Down
3 changes: 2 additions & 1 deletion pkg/sql/plan/rule/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ func (r *ConstantFold) constantFold(expr *plan.Expr, proc *process.Process) *pla
}
defer vec.Free(proc.Mp())

vec.InplaceSort()
vec.InplaceSortAndCompact()

data, err := vec.MarshalBinary()
if err != nil {
return expr
Expand Down
2 changes: 1 addition & 1 deletion pkg/sql/plan/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -1074,7 +1074,7 @@ func ConstantFold(bat *batch.Batch, expr *plan.Expr, proc *process.Process, varA
}
defer vec.Free(proc.Mp())

vec.InplaceSort()
vec.InplaceSortAndCompact()
data, err := vec.MarshalBinary()
if err != nil {
return nil, err
Expand Down
11 changes: 11 additions & 0 deletions test/distributed/cases/dml/select/select.result
Original file line number Diff line number Diff line change
Expand Up @@ -477,3 +477,14 @@ count(*)
select count(*) from t11 where a in (1,20000);
count(*)
2
drop table if exists t1;
create table t1(a int primary key, b int);
insert into t1 values (1,1),(2,2),(3,3);
select mo_ctl('dn', 'flush', 'select.t1');
mo_ctl(dn, flush, select.t1)
{\n "method": "Flush",\n "result": [\n {\n "returnStr": "OK"\n }\n ]\n}\n
select * from t1 where a in (3,3,3,2,1);
a b
1 1
2 2
3 3
Loading

0 comments on commit 862a18d

Please sign in to comment.