Skip to content

Commit cd1d224

Browse files
Experimental JSON Array functions
1 parent 7d71771 commit cd1d224

File tree

4 files changed

+260
-0
lines changed

4 files changed

+260
-0
lines changed

expression/builtin.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,10 @@ var funcs = map[string]functionClass{
910910
ast.NextVal: &nextValFunctionClass{baseFunctionClass{ast.NextVal, 1, 1}},
911911
ast.LastVal: &lastValFunctionClass{baseFunctionClass{ast.LastVal, 1, 1}},
912912
ast.SetVal: &setValFunctionClass{baseFunctionClass{ast.SetVal, 2, 2}},
913+
914+
// Custom JSON Array function.
915+
ast.XCosineSim: &xcosinesimFunctionClass{baseFunctionClass{ast.XCosineSim, 2, 2}},
916+
ast.XDotProduct: &xdotproductFunctionClass{baseFunctionClass{ast.XDotProduct, 2, 2}},
913917
}
914918

915919
// IsFunctionSupported check if given function name is a builtin sql function.

expression/builtin_x_json_array.go

Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
package expression
2+
3+
import (
4+
"math"
5+
6+
"github.com/pingcap/errors"
7+
"github.com/pingcap/tidb/sessionctx"
8+
"github.com/pingcap/tidb/types"
9+
"github.com/pingcap/tidb/util/chunk"
10+
)
11+
12+
var (
13+
_ functionClass = &xcosinesimFunctionClass{}
14+
_ functionClass = &xdotproductFunctionClass{}
15+
)
16+
17+
var (
18+
_ builtinFunc = &builtinXcosinesimSig{}
19+
)
20+
21+
type xcosinesimFunctionClass struct {
22+
baseFunctionClass
23+
}
24+
25+
func (c *xcosinesimFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
26+
if err := c.verifyArgs(args); err != nil {
27+
return nil, err
28+
}
29+
30+
argTps := make([]types.EvalType, 0, 2)
31+
argTps = append(argTps, types.ETJson)
32+
argTps = append(argTps, types.ETJson)
33+
34+
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...)
35+
if err != nil {
36+
return nil, err
37+
}
38+
39+
types.SetBinChsClnFlag(bf.tp)
40+
sig := &builtinXcosinesimSig{bf}
41+
return sig, nil
42+
}
43+
44+
type builtinXcosinesimSig struct {
45+
baseBuiltinFunc
46+
}
47+
48+
func (b *builtinXcosinesimSig) Clone() builtinFunc {
49+
newSig := &builtinXcosinesimSig{}
50+
newSig.cloneFrom(&b.baseBuiltinFunc)
51+
return newSig
52+
}
53+
54+
func (b *builtinXcosinesimSig) evalReal(row chunk.Row) (float64, bool, error) {
55+
const zero float64 = 0.0
56+
57+
arr1, isNull, err := ExtractFloat64Array(b.ctx, b.args[0], row)
58+
if isNull || err != nil {
59+
return zero, isNull, err
60+
}
61+
62+
arr2, isNull, err := ExtractFloat64Array(b.ctx, b.args[1], row)
63+
if isNull || err != nil {
64+
return zero, isNull, err
65+
}
66+
67+
cosineSimilarity, cosErr := Cosine(arr1, arr2)
68+
69+
if cosErr != nil {
70+
cosErr = errors.Wrap(cosErr, "Invalid JSON Array: an array of non-zero numbers was expected")
71+
return zero, false, cosErr
72+
}
73+
return cosineSimilarity, false, cosErr
74+
}
75+
76+
type xdotproductFunctionClass struct {
77+
baseFunctionClass
78+
}
79+
80+
func (c *xdotproductFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
81+
if err := c.verifyArgs(args); err != nil {
82+
return nil, err
83+
}
84+
85+
argTps := make([]types.EvalType, 0, 2)
86+
argTps = append(argTps, types.ETJson)
87+
argTps = append(argTps, types.ETJson)
88+
89+
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETReal, argTps...)
90+
if err != nil {
91+
return nil, err
92+
}
93+
94+
types.SetBinChsClnFlag(bf.tp)
95+
sig := &builtinXdotproductSig{bf}
96+
return sig, nil
97+
}
98+
99+
type builtinXdotproductSig struct {
100+
baseBuiltinFunc
101+
}
102+
103+
func (b *builtinXdotproductSig) Clone() builtinFunc {
104+
newSig := &builtinXdotproductSig{}
105+
newSig.cloneFrom(&b.baseBuiltinFunc)
106+
return newSig
107+
}
108+
109+
func (b *builtinXdotproductSig) evalReal(row chunk.Row) (float64, bool, error) {
110+
const zero float64 = 0.0
111+
112+
arr1, isNull, err := ExtractFloat64Array(b.ctx, b.args[0], row)
113+
if isNull || err != nil {
114+
return zero, isNull, err
115+
}
116+
117+
arr2, isNull, err := ExtractFloat64Array(b.ctx, b.args[1], row)
118+
if isNull || err != nil {
119+
return zero, isNull, err
120+
}
121+
122+
dotProduct, cosErr := DotProduct(arr1, arr2)
123+
124+
if cosErr != nil {
125+
cosErr = errors.Wrap(cosErr, "Invalid JSON Array: an array of non-zero numbers were expected")
126+
return zero, false, cosErr
127+
}
128+
return dotProduct, false, cosErr
129+
}
130+
131+
func ExtractFloat64Array(ctx sessionctx.Context, expr Expression, row chunk.Row) (values []float64, isNull bool, err error) {
132+
json1, isNull, err := expr.EvalJSON(ctx, row)
133+
if isNull || err != nil {
134+
return nil, isNull, err
135+
}
136+
values, err = AsFloat64Array(json1)
137+
if err != nil {
138+
return nil, false, err
139+
}
140+
return values, false, nil
141+
}
142+
143+
func AsFloat64Array(binJson types.BinaryJSON) (values []float64, err error) {
144+
if binJson.TypeCode != types.JSONTypeCodeArray {
145+
err = errors.New("Invalid JSON Array: an array of numbers were expected")
146+
return nil, err
147+
}
148+
149+
var arrCount int = binJson.GetElemCount()
150+
values = make([]float64, arrCount)
151+
for i := 0; i < arrCount && err == nil; i++ {
152+
var elem = binJson.ArrayGetElem(i)
153+
values[i], err = types.ConvertJSONToFloat(fakeSctx, elem)
154+
}
155+
return values, err
156+
}
157+
158+
func DotProduct(a []float64, b []float64) (cosine float64, err error) {
159+
if len(a) != len(b) {
160+
return 0.0, errors.New("Invalid vectors: two arrays of the same length were expected")
161+
}
162+
if len(a) == 0 {
163+
return 0.0, errors.New("Invalid vectors: two non-zero length arrays were expected")
164+
}
165+
166+
sum := 0.0
167+
168+
for i := range a {
169+
sum += a[i] * b[i]
170+
}
171+
return sum, nil
172+
}
173+
174+
func Cosine(a []float64, b []float64) (cosine float64, err error) {
175+
if len(a) != len(b) {
176+
return 0.0, errors.New("Invalid vectors: two arrays of the same length were expected")
177+
}
178+
if len(a) == 0 {
179+
return 0.0, errors.New("Invalid vectors: two non-zero length arrays were expected")
180+
}
181+
182+
sum := 0.0
183+
s1 := 0.0
184+
s2 := 0.0
185+
186+
for i := range a {
187+
sum += a[i] * b[i]
188+
s1 += a[i] * a[i]
189+
s2 += b[i] * b[i]
190+
}
191+
return sum / (math.Sqrt(s1) * math.Sqrt(s2)), nil
192+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package expression
2+
3+
import (
4+
"math"
5+
"testing"
6+
7+
"github.com/pingcap/tidb/parser/ast"
8+
"github.com/pingcap/tidb/util/chunk"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestXCosineSim(t *testing.T) {
13+
const float64EqualityThreshold = 1e-9
14+
ctx := createContext(t)
15+
fc := funcs[ast.XCosineSim]
16+
tbl := []struct {
17+
Input []interface{}
18+
Expected float64
19+
}{
20+
{[]interface{}{`[1.1,1.2,1.3,1.4,1.5]`, `[1.1,1.2,1.3,1.4,1.5]`}, 1.0},
21+
{[]interface{}{`[2.1,2.2,2.3,2.4,2.5]`, `[1.1,1.2,1.3,1.4,1.5]`}, 0.9988980834329954},
22+
{[]interface{}{`[-1.0,-2.0,-3.0,-4.0,-5.0,-6.0]`, `[1.0,1.1,1.2,1.3,1.4,1.5]`}, -0.949807619836754},
23+
}
24+
dtbl := tblToDtbl(tbl)
25+
for _, tt := range dtbl {
26+
f, err := fc.getFunction(ctx, datumsToConstants(tt["Input"]))
27+
require.NoError(t, err)
28+
d, err := evalBuiltinFunc(f, chunk.Row{})
29+
require.NoError(t, err)
30+
var result = d.GetFloat64()
31+
var expected = tt["Expected"][0].GetFloat64()
32+
var approximatelyEqual = math.Abs(result-expected) <= float64EqualityThreshold
33+
require.Equal(t, true, approximatelyEqual)
34+
}
35+
}
36+
37+
func TestDotProduct(t *testing.T) {
38+
const float64EqualityThreshold = 1e-9
39+
ctx := createContext(t)
40+
fc := funcs[ast.XDotProduct]
41+
tbl := []struct {
42+
Input []interface{}
43+
Expected float64
44+
}{
45+
{[]interface{}{`[1.1,1.2,1.3,1.4,1.5]`, `[1.1,1.2,1.3,1.4,1.5]`}, 8.55},
46+
{[]interface{}{`[2.1,2.2,2.3,2.4,2.5]`, `[1.1,1.2,1.3,1.4,1.5]`}, 15.05},
47+
{[]interface{}{`[-1.0,-2.0,-3.0,-4.0,-5.0,-6.0]`, `[1.0,1.1,1.2,1.3,1.4,1.5]`}, -28.0},
48+
}
49+
dtbl := tblToDtbl(tbl)
50+
for _, tt := range dtbl {
51+
f, err := fc.getFunction(ctx, datumsToConstants(tt["Input"]))
52+
require.NoError(t, err)
53+
d, err := evalBuiltinFunc(f, chunk.Row{})
54+
require.NoError(t, err)
55+
var result = d.GetFloat64()
56+
var expected = tt["Expected"][0].GetFloat64()
57+
var approximatelyEqual = math.Abs(result-expected) <= float64EqualityThreshold
58+
require.Equal(t, true, approximatelyEqual)
59+
}
60+
}

parser/ast/functions.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,10 @@ const (
366366
NextVal = "nextval"
367367
LastVal = "lastval"
368368
SetVal = "setval"
369+
370+
// Custom JSON Array Functions
371+
XCosineSim = "x_cosine_sim"
372+
XDotProduct = "x_dot_product"
369373
)
370374

371375
type FuncCallExprType int8

0 commit comments

Comments
 (0)