diff --git a/dataparser.go b/dataparser.go new file mode 100644 index 0000000..505b7c8 --- /dev/null +++ b/dataparser.go @@ -0,0 +1,384 @@ +package clickhouse + +import ( + "database/sql/driver" + "fmt" + "io" + "reflect" + "strconv" + "strings" + "time" +) + +// DataParser implements parsing of a driver value and reporting its type. +type DataParser interface { + Parse(io.RuneScanner) (driver.Value, error) + Type() reflect.Type +} + +type stringParser struct { + unquote bool + length int +} + +type dateTimeParser struct { + unquote bool + format string + location *time.Location +} + +func readNumber(s io.RuneScanner) (string, error) { + var builder strings.Builder + +loop: + for { + r := read(s) + + switch r { + case eof: + break loop + case ',', ']', ')': + s.UnreadRune() + break loop + } + + builder.WriteRune(r) + } + + return builder.String(), nil +} + +func readUnquoted(s io.RuneScanner, length int) (string, error) { + var builder strings.Builder + + runesRead := 0 +loop: + for length == 0 || runesRead < length { + r := read(s) + + switch r { + case eof: + break loop + case '\\': + escaped, err := readEscaped(s) + if err != nil { + return "", fmt.Errorf("incorrect escaping in string: %v", err) + } + r = escaped + case '\'': + s.UnreadRune() + break loop + } + + builder.WriteRune(r) + runesRead++ + } + + if length != 0 && runesRead != length { + return "", fmt.Errorf("unexpected string length %d, expected %d", runesRead, length) + } + + return builder.String(), nil +} + +func readString(s io.RuneScanner, length int, unquote bool) (string, error) { + if unquote { + if r := read(s); r != '\'' { + return "", fmt.Errorf("unexpected character instead of a quote") + } + } + + str, err := readUnquoted(s, length) + if err != nil { + return "", fmt.Errorf("failed to read string") + } + + if unquote { + if r := read(s); r != '\'' { + return "", fmt.Errorf("unexpected character instead of a quote") + } + } + + return str, nil +} + +func (p *stringParser) Parse(s io.RuneScanner) (driver.Value, error) { + return readString(s, p.length, p.unquote) +} + +func (p *stringParser) Type() reflect.Type { + return reflect.ValueOf("").Type() +} + +func (p *dateTimeParser) Parse(s io.RuneScanner) (driver.Value, error) { + str, err := readString(s, len(p.format), p.unquote) + if err != nil { + return nil, fmt.Errorf("failed to read the string representation of date or datetime: %v", err) + } + + if str == "0000-00-00" || str == "0000-00-00 00:00:00" { + return time.Time{}, nil + } + + return time.ParseInLocation(p.format, str, p.location) +} + +func (p *dateTimeParser) Type() reflect.Type { + return reflect.ValueOf(time.Time{}).Type() +} + +type arrayParser struct{ + arg DataParser +} + +func (p *arrayParser) Type() reflect.Type { + return reflect.SliceOf(p.arg.Type()) +} + +type tupleParser struct{ + args []DataParser +} + +func (p *tupleParser) Type() reflect.Type { + fields := make([]reflect.StructField, len(p.args), len(p.args)) + for i, arg := range p.args { + fields[i].Name = "Field" + strconv.Itoa(i) + fields[i].Type = arg.Type() + } + return reflect.StructOf(fields) +} + +func (p *tupleParser) Parse(s io.RuneScanner) (driver.Value, error) { + r := read(s) + if r != '(' { + return nil, fmt.Errorf("unexpected character '%c', expected '(' at the beginning of tuple", r) + } + + struc := reflect.New(p.Type()).Elem() + for i, arg := range p.args { + if i > 0 { + r := read(s) + if r != ',' { + return nil, fmt.Errorf("unexpected character '%c', expected ',' between tuple elements", r) + } + } + + v, err := arg.Parse(s) + if err != nil { + return nil, fmt.Errorf("failed to parse tuple element: %v", err) + } + + struc.Field(i).Set(reflect.ValueOf(v)) + } + + r = read(s) + if r != ')' { + return nil, fmt.Errorf("unexpected character '%c', expected ')' at the end of tuple", r) + } + + return struc.Interface(), nil +} + + +func (p *arrayParser) Parse(s io.RuneScanner) (driver.Value, error) { + r := read(s) + if r != '[' { + return nil, fmt.Errorf("unexpected character '%c', expected '[' at the beginning of array", r) + } + + slice := reflect.MakeSlice(p.Type(), 0, 0) + for i := 0;; i++ { + r := read(s) + s.UnreadRune() + if r == ']' { + break + } + + v, err := p.arg.Parse(s) + if err != nil { + return nil, fmt.Errorf("failed to parse array element: %v", err) + } + + slice = reflect.Append(slice, reflect.ValueOf(v)) + + r = read(s) + if r != ',' { + s.UnreadRune() + } + } + + r = read(s) + if r != ']' { + return nil, fmt.Errorf("unexpected character '%c', expected ']' at the end of array", r) + } + + return slice.Interface(), nil +} + + +func newDateTimeParser(format, locname string, unquote bool) (DataParser, error) { + loc, err := time.LoadLocation(locname) + if err != nil { + return nil, err + } + return &dateTimeParser{ + unquote: unquote, + format: format, + location: loc, + }, nil +} + +type intParser struct { + signed bool + bitSize int +} + +type floatParser struct { + bitSize int +} + +func (p *intParser) Parse(s io.RuneScanner) (driver.Value, error) { + repr, err := readNumber(s) + if err != nil { + return nil, err + } + + if p.signed { + v, err := strconv.ParseInt(repr, 10, p.bitSize) + switch p.bitSize { + case 8: return int8(v), err + case 16: return int16(v), err + case 32: return int32(v), err + case 64: return int64(v), err + default: panic("unsupported bit size") + } + } else { + v, err := strconv.ParseUint(repr, 10, p.bitSize) + switch p.bitSize { + case 8: return uint8(v), err + case 16: return uint16(v), err + case 32: return uint32(v), err + case 64: return uint64(v), err + default: panic("unsupported bit size") + } + } +} + +func (p *intParser) Type() reflect.Type { + if p.signed { + switch p.bitSize { + case 8: return reflect.ValueOf(int8(0)).Type() + case 16: return reflect.ValueOf(int16(0)).Type() + case 32: return reflect.ValueOf(int32(0)).Type() + case 64: return reflect.ValueOf(int64(0)).Type() + default: panic("unsupported bit size") + } + } else { + switch p.bitSize { + case 8: return reflect.ValueOf(uint8(0)).Type() + case 16: return reflect.ValueOf(uint16(0)).Type() + case 32: return reflect.ValueOf(uint32(0)).Type() + case 64: return reflect.ValueOf(uint64(0)).Type() + default: panic("unsupported bit size") + } + } +} + +func (p *floatParser) Parse(s io.RuneScanner) (driver.Value, error) { + repr, err := readNumber(s) + if err != nil { + return nil, err + } + + v, err := strconv.ParseFloat(repr, p.bitSize) + switch p.bitSize { + case 32: return float32(v), err + case 64: return float64(v), err + default: panic("unsupported bit size") + } +} + +func (p *floatParser) Type() reflect.Type { + switch p.bitSize { + case 32: return reflect.ValueOf(float32(0)).Type() + case 64: return reflect.ValueOf(float64(0)).Type() + default: panic("unsupported bit size") + } +} + +type nothingParser struct{} + +func (p *nothingParser) Parse(s io.RuneScanner) (driver.Value, error) { + return nil, nil +} + +func (p *nothingParser) Type() reflect.Type { + return reflect.ValueOf(struct{}{}).Type() +} + +// NewDataParser creates a new DataParser based on the +// given TypeDesc. +func NewDataParser(t *TypeDesc) (DataParser, error) { + return newDataParser(t, false) +} + +func newDataParser(t *TypeDesc, unquote bool) (DataParser, error) { + switch t.Name { + case "Nothing": return ¬hingParser{}, nil + case "Nullable": return nil, fmt.Errorf("Nullable types are not supported") + case "Date": + // FIXME: support custom default/override location + return newDateTimeParser("2006-01-02", "UTC", unquote) + case "DateTime": + // FIXME: support custom default/override location + locname := "UTC" + if len(t.Args) > 0 { + locname = t.Args[0].Name + } + return newDateTimeParser("2006-01-02 15:04:05", locname, unquote) + case "UInt8": return &intParser{false, 8}, nil + case "UInt16": return &intParser{false, 16}, nil + case "UInt32": return &intParser{false, 32}, nil + case "UInt64": return &intParser{false, 64}, nil + case "Int8": return &intParser{true, 8}, nil + case "Int16": return &intParser{true, 16}, nil + case "Int32": return &intParser{true, 32}, nil + case "Int64": return &intParser{true, 64}, nil + case "Float32": return &floatParser{32}, nil + case "Float64": return &floatParser{64}, nil + case "String", "Enum8", "Enum16": return &stringParser{unquote: unquote}, nil + case "FixedString": + if len(t.Args) != 1 { + return nil, fmt.Errorf("length not specified for FixedString") + } + length, err := strconv.Atoi(t.Args[0].Name) + if err != nil{ + return nil, fmt.Errorf("malformed length specified for FixedString: %v", err) + } + return &stringParser{unquote: unquote, length: length}, nil + case "Array": + if len(t.Args) != 1 { + return nil, fmt.Errorf("element type not specified for Array") + } + subParser, err := newDataParser(t.Args[0], true) + if err != nil { + return nil, fmt.Errorf("failed to create parser for array elements: %v", err) + } + return &arrayParser{subParser}, nil + case "Tuple": + if len(t.Args) < 1 { + return nil, fmt.Errorf("element types not specified for Tuple") + } + subParsers := make([]DataParser, len(t.Args), len(t.Args)) + for i, arg := range t.Args { + subParser, err := newDataParser(arg, true) + if err != nil { + return nil, fmt.Errorf("failed to create parser for tuple element: %v", err) + } + subParsers[i] = subParser + } + return &tupleParser{subParsers}, nil + default: + return nil, fmt.Errorf("type %s is not supported", t.Name) + } +} diff --git a/dataparser_test.go b/dataparser_test.go new file mode 100644 index 0000000..6fc2887 --- /dev/null +++ b/dataparser_test.go @@ -0,0 +1,284 @@ +package clickhouse + +import ( + "fmt" + "math" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestParseData(t *testing.T) { + type testCase struct { + name string + inputtype string + inputdata string + output interface{} + failParseDesc bool + failNewParser bool + failParseData bool + } + + losAngeles, err := time.LoadLocation("America/Los_Angeles") + if err != nil { + t.Fatalf("failed to load time zone: %v", err) + } + + testCases := []*testCase{ + { + name: "nullable not supported", + inputtype: "Nullable(String)", + inputdata: "NULL", + failNewParser: true, + }, + { + name: "string", + inputtype: "String", + inputdata: "hello world", + output: "hello world", + }, + { + name: "fixed string", + inputtype: "FixedString(10)", + inputdata: `hello\0\0\0\0\0`, + output: "hello\x00\x00\x00\x00\x00", + }, + { + name: "string with escaping", + inputtype: "String", + inputdata: `hello \'world`, + output: "hello 'world", + }, + { + name: "string with incorrect escaping", + inputtype: "String", + inputdata: `hello world\`, + failParseData: true, + }, + { + name: "int", + inputtype: "UInt64", + inputdata: "123", + output: uint64(123), + }, + { + name: "float", + inputtype: "Float32", + inputdata: "-inf", + output: float32(math.Inf(-1)), + }, + { + name: "date", + inputtype: "Date", + inputdata: "2018-01-02", + output: time.Date(2018, 1, 2, 0, 0, 0, 0, time.UTC), + }, + { + name: "zero date", + inputtype: "Date", + inputdata: "0000-00-00", + output: time.Time{}, + }, + { + name: "enum", + inputtype: "Enum8('hello' = 1, 'world' = 2)", + inputdata: "hello", + output: "hello", + }, + { + name: "datetime", + inputtype: "DateTime", + inputdata: "2018-01-02 12:34:56", + output: time.Date(2018, 1, 2, 12, 34, 56, 0, time.UTC), + }, + { + name: "datetime in Los Angeles", + inputtype: "DateTime('America/Los_Angeles')", + inputdata: "2018-01-02 12:34:56", + output: time.Date(2018, 1, 2, 12, 34, 56, 0, losAngeles), + }, + { + name: "datetime in nowhere", + inputtype: "DateTime('Nowhere')", + inputdata: "2018-01-02 12:34:56", + failNewParser: true, + }, + { + name: "zero datetime", + inputtype: "DateTime", + inputdata: "0000-00-00 00:00:00", + output: time.Time{}, + }, + { + name: "short datetime", + inputtype: "DateTime", + inputdata: "000-00-00 00:00:00", + output: time.Time{}, + failParseData: true, + }, + { + name: "malformed datetime", + inputtype: "DateTime", + inputdata: "a000-00-00 00:00:00", + output: time.Time{}, + failParseData: true, + }, + { + name: "tuple", + inputtype: "Tuple(String, Float64, Int16, UInt16, Int64)", + inputdata: "('hello world',32.1,-1,2,3)", + output: struct{ + Field0 string + Field1 float64 + Field2 int16 + Field3 uint16 + Field4 int64 + }{"hello world", 32.1, -1, 2, 3}, + }, + { + name: "array of strings", + inputtype: "Array(String)", + inputdata: `['hello world\',','goodbye galaxy']`, + output: []string{"hello world',", "goodbye galaxy"}, + }, + { + name: "array of unquoted strings", + inputtype: "Array(String)", + inputdata: "[hello,world]", + failParseData: true, + }, + { + name: "array with unfinished quoted string", + inputtype: "Array(String)", + inputdata: "['hello','world]", + failParseData: true, + }, + { + name: "array of ints", + inputtype: "Array(UInt64)", + inputdata: "[1,2,3]", + output: []uint64{1, 2, 3}, + }, + { + name: "array of dates", + inputtype: "Array(Date)", + inputdata: "['2018-01-02','0000-00-00']", + output: []time.Time{ + time.Date(2018, 1, 2, 0, 0, 0, 0, time.UTC), + time.Time{}, + }, + }, + { + name: "empty array of ints", + inputtype: "Array(Int8)", + inputdata: "[]", + output: []int8{}, + }, + { + name: "empty array of nothing", + inputtype: "Array(Nothing)", + inputdata: "[]", + output: []struct{}{}, + }, + { + name: "array of tuples", + inputtype: "Array(Tuple(String, Float32))", + inputdata: "[('hello world',32.1),('goodbye galaxy',42.0)]", + output: []struct{ + Field0 string + Field1 float32 + }{ + { + "hello world", + float32(32.1), + }, + { + "goodbye galaxy", + float32(42.0), + }, + }, + }, + { + name: "malformed array element", + inputtype: "Array(UInt8)", + inputdata: "[1,2,'3']", + failParseData: true, + }, + { + name: "array without left bracket", + inputtype: "Array(Int8)", + inputdata: "1,2,3]", + failParseData: true, + }, + { + name: "array without right bracket", + inputtype: "Array(UInt64)", + inputdata: "[1,2,3", + failParseData: true, + }, + { + name: "wrong character between tuple elements", + inputtype: "Tuple(String, String)", + inputdata: "('hello'.'world')", + failParseData: true, + }, + { + name: "malformed tuple element", + inputtype: "Tuple(UInt32, Int32)", + inputdata: "(1,'2')", + failParseData: true, + }, + { + name: "tuple without left paren", + inputtype: "Tuple(Int8, Int8)", + inputdata: "1,2)", + failParseData: true, + }, + { + name: "tuple without right paren", + inputtype: "Tuple(UInt8, Int8)", + inputdata: "(1,2", + failParseData: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + desc, err := ParseTypeDesc(tc.inputtype) + if tc.failParseDesc { + assert.Error(tt, err) + return + } else { + if !assert.NoError(tt, err) { + return + } + } + + parser, err := newDataParser(desc, false) + if tc.failNewParser { + assert.Error(tt, err) + return + } else { + if !assert.NoError(tt, err) { + return + } + } + + output, err := parser.Parse(strings.NewReader(tc.inputdata)) + if tc.failParseData { + assert.Error(tt, err) + return + } else { + if !assert.NoError(tt, err) { + return + } + } + + fmt.Printf("%T: %#v\n", output, output) + + assert.Equal(tt, tc.output, output) + }) + } +} diff --git a/encoder.go b/encoder.go index ff957a8..eab103f 100644 --- a/encoder.go +++ b/encoder.go @@ -5,16 +5,9 @@ import ( "fmt" "reflect" "strconv" - "strings" "time" ) -const ( - dateFormat = "2006-01-02" - timeFormat = "2006-01-02 15:04:05" - timeZoneBorder = "\\'" -) - var ( textEncode encoder = new(textEncoder) ) @@ -23,18 +16,9 @@ type encoder interface { Encode(value driver.Value) ([]byte, error) } -type decoder interface { - Decode(t string, value []byte) (driver.Value, error) -} - type textEncoder struct { } -type textDecoder struct { - location *time.Location - useDBLocation bool -} - // Encode encodes driver value into string // Note: there is 2 convention: // type string will be quoted @@ -123,117 +107,3 @@ func (e *textEncoder) encodeArray(value reflect.Value) ([]byte, error) { } return append(res, ']'), nil } - -func (d *textDecoder) Decode(t string, value []byte) (driver.Value, error) { - v := string(value) - switch t { - case "Date": - uv := unquote(v) - if uv == "0000-00-00" { - return time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC), nil - } - return time.ParseInLocation(dateFormat, uv, d.location) - case "DateTime": - uv := unquote(v) - if uv == "0000-00-00 00:00:00" { - return time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC), nil - } - return time.ParseInLocation(timeFormat, uv, d.location) - case "UInt8": - vv, err := strconv.ParseUint(v, 10, 8) - return uint8(vv), err - case "UInt16": - vv, err := strconv.ParseUint(v, 10, 16) - return uint16(vv), err - case "UInt32": - vv, err := strconv.ParseUint(v, 10, 32) - return uint32(vv), err - case "UInt64": - return strconv.ParseUint(v, 10, 64) - case "Int8": - vv, err := strconv.ParseInt(v, 10, 8) - return int8(vv), err - case "Int16": - vv, err := strconv.ParseInt(v, 10, 16) - return int16(vv), err - case "Int32": - vv, err := strconv.ParseInt(v, 10, 32) - return int32(vv), err - case "Int64": - return strconv.ParseInt(v, 10, 64) - case "Float32": - vv, err := strconv.ParseFloat(v, 64) - return float32(vv), err - case "Float64": - return strconv.ParseFloat(v, 64) - case "String": - return unescape(unquote(v)), nil - } - - // got zoned datetime - if strings.HasPrefix(t, "DateTime") { - var ( - loc *time.Location - err error - ) - - if d.useDBLocation { - left := strings.Index(t, timeZoneBorder) - if left == -1 { - return nil, fmt.Errorf("time zone not found") - } - right := strings.LastIndex(t, timeZoneBorder) - timeZoneName := t[left+len(timeZoneBorder) : right] - - loc, err = time.LoadLocation(timeZoneName) - if err != nil { - return nil, err - } - } else { - loc = d.location - } - - var t time.Time - if t, err = time.ParseInLocation(timeFormat, unquote(v), loc); err != nil { - return t, err - } - return t.In(d.location), nil - } - - if strings.HasPrefix(t, "FixedString") { - return unescape(unquote(v)), nil - } - if strings.HasPrefix(t, "Array") { - if len(v) > 0 && v[0] == '[' && v[len(v)-1] == ']' { - var items []string - subType := t[6 : len(t)-1] - // check that array is not empty ([]) - if len(v) > 2 { - // check if array of strings and not empty (['example']) - if subType == "String" || strings.HasPrefix(subType, "FixedString") { - items = strings.Split(v[2:len(v)-2], "','") - for i, v := range items { - items[i] = unescape(v) - } - } else { - items = strings.Split(v[1:len(v)-1], ",") - } - } - - r := reflect.MakeSlice(reflect.SliceOf(columnType(subType)), len(items), len(items)) - for i, item := range items { - vv, err := d.Decode(subType, []byte(item)) - if err != nil { - return nil, err - } - r.Index(i).Set(reflect.ValueOf(vv)) - } - return r.Interface(), nil - } - return nil, ErrMalformed - } - if strings.HasPrefix(t, "Enum") { - return unquote(v), nil - } - return value, nil -} diff --git a/encoder_test.go b/encoder_test.go index 2c1a03f..6891d27 100644 --- a/encoder_test.go +++ b/encoder_test.go @@ -50,62 +50,3 @@ func TestTextEncoder(t *testing.T) { } } } - -func TestTextDecoder(t *testing.T) { - dt := time.Date(2011, 3, 6, 6, 20, 0, 0, time.UTC) - d := time.Date(2012, 5, 31, 0, 0, 0, 0, time.UTC) - zerodt := time.Date(0, 0, 0, 0, 0, 0, 0, time.UTC) - testCases := []struct { - tt string - value string - expected interface{} - }{ - {"Int8", "1", int8(1)}, - {"Int16", "1", int16(1)}, - {"Int32", "1", int32(1)}, - {"Int64", "1", int64(1)}, - {"UInt8", "1", uint8(1)}, - {"UInt16", "1", uint16(1)}, - {"UInt32", "1", uint32(1)}, - {"UInt64", "1", uint64(1)}, - {"Float32", "1", float32(1)}, - {"Float64", "1", float64(1)}, - {"Date", "'2012-05-31'", d}, - {"Date", "'0000-00-00'", zerodt}, - {"DateTime", "'2011-03-06 06:20:00'", dt}, - {"DateTime", "'0000-00-00 00:00:00'", zerodt}, - {"DateTime(\\'Europe/Moscow\\')", "'2011-03-06 06:20:00'", dt}, - {"String", "'hello'", "hello"}, - {"String", `'\\\\\'hello'`, `\\'hello`}, - {"FixedString(5)", "'hello'", "hello"}, - {"FixedString(7)", `'\\\\\'hello'`, `\\'hello`}, - {"Enum8('one'=1)", "'one'", "one"}, - {"Enum16('one'=1)", "'one'", "one"}, - {"Array(UInt32)", "[1,2]", []uint32{1, 2}}, - {"Array(UInt32)", "[]", []uint32{}}, - {"Array(String)", "['one, two','one\\'']", []string{"one, two", "one'"}}, - {"Array(String)", "['']", []string{""}}, - {"Array(String)", "[]", []string{}}, - {"Array(FixedString(3)", "['1,2','2,3']", []string{"1,2", "2,3"}}, - } - - dec := &textDecoder{location: time.UTC, useDBLocation: false} - for i, tc := range testCases { - v, err := dec.Decode(tc.tt, []byte(tc.value)) - if assert.NoError(t, err, "%d", i) { - assert.Equal(t, tc.expected, v) - } - } -} - -func TestDecodeTimeWithLocation(t *testing.T) { - dt := time.Date(2011, 3, 6, 3, 20, 0, 0, time.UTC) - dataType := "DateTime(\\'Europe/Moscow\\')" - dtStr := "'2011-03-06 06:20:00'" - dec := &textDecoder{location: time.UTC, useDBLocation: true} - - v, err := dec.Decode(dataType, []byte(dtStr)) - if assert.NoError(t, err) { - assert.Equal(t, dt, v) - } -} diff --git a/helpers.go b/helpers.go index d967fcc..96e8572 100644 --- a/helpers.go +++ b/helpers.go @@ -3,35 +3,24 @@ package clickhouse import ( "bytes" "net/http" - "reflect" "strings" "time" ) var ( escaper = strings.NewReplacer(`\`, `\\`, `'`, `\'`) - unescaper = strings.NewReplacer(`\\`, `\`, `\'`, `'`) + dateFormat = "2006-01-02" + timeFormat = "2006-01-02 15:04:05" ) func escape(s string) string { return escaper.Replace(s) } -func unescape(s string) string { - return unescaper.Replace(s) -} - func quote(s string) string { return "'" + s + "'" } -func unquote(s string) string { - if len(s) > 0 && s[0] == '\'' && s[len(s)-1] == '\'' { - return s[1 : len(s)-1] - } - return s -} - func formatTime(value time.Time) string { return quote(value.Format(timeFormat)) } @@ -81,46 +70,3 @@ func splitTSV(data []byte, out []string) int { } return -1 } - -func columnType(name string) reflect.Type { - switch name { - case "Date", "DateTime": - return reflect.ValueOf(time.Time{}).Type() - case "UInt8": - return reflect.ValueOf(uint8(0)).Type() - case "UInt16": - return reflect.ValueOf(uint16(0)).Type() - case "UInt32": - return reflect.ValueOf(uint32(0)).Type() - case "UInt64": - return reflect.ValueOf(uint64(0)).Type() - case "Int8": - return reflect.ValueOf(int8(0)).Type() - case "Int16": - return reflect.ValueOf(int16(0)).Type() - case "Int32": - return reflect.ValueOf(int32(0)).Type() - case "Int64": - return reflect.ValueOf(int64(0)).Type() - case "Float32": - return reflect.ValueOf(float32(0)).Type() - case "Float64": - return reflect.ValueOf(float64(0)).Type() - case "String": - return reflect.ValueOf("").Type() - } - if strings.HasPrefix(name, "FixedString") { - return reflect.ValueOf("").Type() - } - if strings.HasPrefix(name, "Array") { - subType := columnType(name[6 : len(name)-1]) - if subType != nil { - return reflect.SliceOf(subType) - } - return nil - } - if strings.HasPrefix(name, "Enum") { - return reflect.ValueOf("").Type() - } - return nil -} diff --git a/helpers_test.go b/helpers_test.go deleted file mode 100644 index a592ed0..0000000 --- a/helpers_test.go +++ /dev/null @@ -1,38 +0,0 @@ -package clickhouse - -import ( - "reflect" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -func TestColumnType(t *testing.T) { - testCases := []struct { - tt string - expected reflect.Type - }{ - {"Int8", reflect.TypeOf(int8(0))}, - {"Int16", reflect.TypeOf(int16(0))}, - {"Int32", reflect.TypeOf(int32(0))}, - {"Int64", reflect.TypeOf(int64(0))}, - {"UInt8", reflect.TypeOf(uint8(0))}, - {"UInt16", reflect.TypeOf(uint16(0))}, - {"UInt32", reflect.TypeOf(uint32(0))}, - {"UInt64", reflect.TypeOf(uint64(0))}, - {"Float32", reflect.TypeOf(float32(0))}, - {"Float64", reflect.TypeOf(float64(0))}, - {"Date", reflect.TypeOf(time.Time{})}, - {"DateTime", reflect.TypeOf(time.Time{})}, - {"String", reflect.TypeOf("")}, - {"FixedString(5)", reflect.TypeOf("")}, - {"Enum8('one'=1)", reflect.TypeOf("")}, - {"Enum16('one'=1)", reflect.TypeOf("")}, - {"Array(UInt32)", reflect.TypeOf([]uint32{})}, - } - - for _, tc := range testCases { - assert.Equal(t, tc.expected, columnType(tc.tt)) - } -} diff --git a/rows.go b/rows.go index ead340b..b4ec7f5 100644 --- a/rows.go +++ b/rows.go @@ -3,8 +3,10 @@ package clickhouse import ( "database/sql/driver" "encoding/csv" + "fmt" "io" "reflect" + "strings" "time" ) @@ -22,13 +24,26 @@ func newTextRows(c *conn, body io.ReadCloser, location *time.Location, useDBLoca return nil, err } + parsers := make([]DataParser, len(types), len(types)) + for i, typ := range types { + desc, err := ParseTypeDesc(typ) + if err != nil { + return nil, err + } + + parsers[i], err = NewDataParser(desc) + if err != nil { + return nil, err + } + } + return &textRows{ c: c, respBody: body, tsv: tsvReader, columns: columns, types: types, - decode: &textDecoder{location: location, useDBLocation: useDBLocation}, + parsers: parsers, }, nil } @@ -38,7 +53,7 @@ type textRows struct { tsv *csv.Reader columns []string types []string - decode decoder + parsers []DataParser } func (r *textRows) Columns() []string { @@ -57,10 +72,14 @@ func (r *textRows) Next(dest []driver.Value) error { } for i, s := range row { - v, err := r.decode.Decode(r.types[i], []byte(s)) + reader := strings.NewReader(s) + v, err := r.parsers[i].Parse(reader) if err != nil { return err } + if _, _, err := reader.ReadRune(); err != io.EOF { + return fmt.Errorf("trailing data after parsing the value") + } dest[i] = v } @@ -69,7 +88,7 @@ func (r *textRows) Next(dest []driver.Value) error { // ColumnTypeScanType implements the driver.RowsColumnTypeScanType func (r *textRows) ColumnTypeScanType(index int) reflect.Type { - return columnType(r.types[index]) + return r.parsers[index].Type() } // ColumnTypeDatabaseTypeName implements the driver.RowsColumnTypeDatabaseTypeName diff --git a/rows_test.go b/rows_test.go index dccd924..8235d71 100644 --- a/rows_test.go +++ b/rows_test.go @@ -21,7 +21,7 @@ func (r *bufReadCloser) Close() error { } func TestTextRows(t *testing.T) { - buf := bytes.NewReader([]byte("Number\tText\nInt32\tString\n1\t'hello'\n2\t'world'\n")) + buf := bytes.NewReader([]byte("Number\tText\nInt32\tString\n1\thello\n2\tworld\n")) rows, err := newTextRows(&conn{}, &bufReadCloser{buf}, time.Local, false) if !assert.NoError(t, err) { return @@ -33,7 +33,6 @@ func TestTextRows(t *testing.T) { assert.Equal(t, "Int32", rows.ColumnTypeDatabaseTypeName(0)) assert.Equal(t, "String", rows.ColumnTypeDatabaseTypeName(1)) - assert.Equal(t, time.Local, rows.decode.(*textDecoder).location) dest := make([]driver.Value, 2) if !assert.NoError(t, rows.Next(dest)) { return diff --git a/tokenizer.go b/tokenizer.go new file mode 100644 index 0000000..e29218b --- /dev/null +++ b/tokenizer.go @@ -0,0 +1,150 @@ +package clickhouse + +import ( + "fmt" + "io" + "strings" +) + +const ( + eof = rune(0) +) + +type token struct { + kind rune + data string +} + +func skipWhiteSpace(s io.RuneScanner) { + for { + r := read(s) + switch r { + case ' ', '\t', '\n': + continue + case eof: + return + } + s.UnreadRune() + return + } +} + +func read(s io.RuneScanner) rune { + r, _, err := s.ReadRune() + if err != nil { + return eof + } + return r +} + +func readEscaped(s io.RuneScanner) (rune, error) { + r := read(s) + switch r { + case eof: + return 0, fmt.Errorf("unexpected eof in escaped char") + case 'b': + return '\b', nil + case 'f': + return '\f', nil + case 'r': + return '\r', nil + case 'n': + return '\n', nil + case 't': + return '\t', nil + case '0': + return '\x00', nil + default: + return r, nil + } +} + +func readQuoted(s io.RuneScanner) (*token, error) { + var data strings.Builder + +loop: + for { + r := read(s) + + switch r { + case eof: + return nil, fmt.Errorf("unexpected eof inside quoted string") + case '\\': + escaped, err := readEscaped(s) + if err != nil { + return nil, fmt.Errorf("incorrect escaping in quoted string: %v", err) + } + r = escaped + case '\'': + break loop + } + + data.WriteRune(r) + } + + return &token{'q', data.String()}, nil +} + +func readNumberOrID(s io.RuneScanner) *token { + var data strings.Builder + +loop: + for { + r := read(s) + + switch r { + case eof, ' ', '\t', '\n': + break loop + case '(', ')', ',': + s.UnreadRune() + break loop + default: + data.WriteRune(r) + } + } + + return &token{'s', data.String()} +} + +func tokenize(s io.RuneScanner) ([]*token, error) { + var tokens []*token + +loop: + for { + var t *token + var err error + + switch read(s) { + case eof: + break loop + case ' ', '\t', '\n': + skipWhiteSpace(s) + continue + case '(': + t = &token{kind: '('} + case ')': + t = &token{kind: ')'} + case ',': + t = &token{kind: ','} + case '\'': + t, err = readQuoted(s) + if err != nil { + return nil, err + } + default: + s.UnreadRune() + t = readNumberOrID(s) + } + + tokens = append(tokens, t) + } + + tokens = append(tokens, &token{kind: eof}) + return tokens, nil +} + +// tokenizeString splits a string into tokens according to ClickHouse +// formatting rules as per https://clickhouse.yandex/docs/en/interfaces/formats/#data-formatting +func tokenizeString(s string) ([]*token, error) { + return tokenize(strings.NewReader(s)) +} diff --git a/tokenizer_test.go b/tokenizer_test.go new file mode 100644 index 0000000..b7c67c5 --- /dev/null +++ b/tokenizer_test.go @@ -0,0 +1,80 @@ +package clickhouse + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTokenize(t *testing.T) { + type testCase struct { + name string + input string + output []*token + fail bool + } + testCases := []*testCase{ + { + name: "empty", + input: "", + output: []*token{{eof, ""}}, + }, + { + name: "only whitespace", + input: "", + output: []*token{{eof, ""}}, + }, + { + name: "whitespace all over the place", + input: " \t\nhello \t \n world \n", + output: []*token{ + {'s', "hello"}, + {'s', "world"}, + {eof, ""}, + }, + }, + { + name: "complex with quotes and escaping", + input: `Array(Tuple(FixedString(5), Float32, 'hello, \') world'))`, + output: []*token{ + {'s', "Array"}, + {'(', ""}, + {'s', "Tuple"}, + {'(', ""}, + {'s', "FixedString"}, + {'(', ""}, + {'s', "5"}, + {')', ""}, + {',', ""}, + {'s', "Float32"}, + {',', ""}, + {'q', `hello, ') world`}, + {')', ""}, + {')', ""}, + {eof, ""}, + }, + }, + { + name: "unclosed quote", + input: "Array(')", + fail: true, + }, + { + name: "unfinished escape", + input: `Array('\`, + fail: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + output, err := tokenizeString(tc.input) + if tc.fail { + assert.Error(tt, err) + } else { + assert.NoError(tt, err) + assert.Equal(tt, tc.output, output) + } + }) + } +} diff --git a/typeparser.go b/typeparser.go new file mode 100644 index 0000000..4be264c --- /dev/null +++ b/typeparser.go @@ -0,0 +1,93 @@ +package clickhouse + +import ( + "fmt" +) + +type TypeDesc struct { + Name string + Args []*TypeDesc +} + +func parseTypeDesc(tokens []*token) (*TypeDesc, []*token, error) { + var name string + if tokens[0].kind == 's' || tokens[0].kind == 'q' { + name = tokens[0].data + tokens = tokens[1:] + } else { + return nil, nil, fmt.Errorf("failed to parse type name: wrong token type '%c'", tokens[0].kind) + } + + desc := TypeDesc{Name: name} + if tokens[0].kind != '(' { + return &desc, tokens, nil + } + + tokens = tokens[1:] + + if tokens[0].kind == ')' { + return &desc, tokens[1:], nil + } + + if name == "Enum8" || name == "Enum16" { + // TODO: an Enum's arguments get completely ignored + for i := range tokens { + if tokens[i].kind == ')' { + return &desc, tokens[i+1:], nil + } + } + return nil, nil, fmt.Errorf("unfinished enum type description") + } + + for { + var arg *TypeDesc + var err error + + arg, tokens, err = parseTypeDesc(tokens) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse subtype: %v", err) + } + desc.Args = append(desc.Args, arg) + + switch tokens[0].kind { + case ',': + tokens = tokens[1:] + continue + case ')': + return &desc, tokens[1:], nil + } + } +} + +// ParseTypeDesc parses the type description that ClickHouse provides. +// +// The grammar is quite simple: +// desc +// name +// name() +// name(args) +// args +// desc +// desc, args +// +// Examples: +// String +// Nullable(Nothing) +// Array(Tuple(Tuple(String, String), Tuple(String, UInt64))) +func ParseTypeDesc(s string) (*TypeDesc, error) { + tokens, err := tokenizeString(s) + if err != nil { + return nil, fmt.Errorf("failed to tokenize type description: %v", err) + } + + desc, tail, err := parseTypeDesc(tokens) + if err != nil { + return nil, fmt.Errorf("failed to parse type description: %v", err) + } + + if len(tail) != 1 || tail[0].kind != eof { + return nil, fmt.Errorf("unexpected tail after type description") + } + + return desc, nil +} diff --git a/typeparser_test.go b/typeparser_test.go new file mode 100644 index 0000000..209e0de --- /dev/null +++ b/typeparser_test.go @@ -0,0 +1,124 @@ +package clickhouse + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseTypeDesc(t *testing.T) { + type testCase struct { + name string + input string + output *TypeDesc + fail bool + } + testCases := []*testCase{ + { + name: "plain type", + input: "String", + output: &TypeDesc{Name: "String"}, + }, + { + name: "nullable type", + input: "Nullable(Nothing)", + output: &TypeDesc{ + Name: "Nullable", + Args: []*TypeDesc{{Name: "Nothing"}}, + }, + }, + { + name: "empty arg", + input: "DateTime()", + output: &TypeDesc{Name: "DateTime"}, + }, + { + name: "numeric arg", + input: "FixedString(42)", + output: &TypeDesc{ + Name: "FixedString", + Args: []*TypeDesc{{Name: "42"}}, + }, + }, + { + name: "args are ignored for Enum", + input: "Enum8(you can = put, 'whatever' here)", + output: &TypeDesc{Name: "Enum8"}, + }, + { + name: "quoted arg", + input: "DateTime('UTC')", + output: &TypeDesc{ + Name: "DateTime", + Args: []*TypeDesc{{Name: "UTC"}}, + }, + }, + { + name: "quoted escaped arg", + input: `DateTime('UTC\b\r\n\'\f\t\0')`, + output: &TypeDesc{ + Name: "DateTime", + Args: []*TypeDesc{{Name: "UTC\b\r\n'\f\t\x00"}}, + }, + }, + { + name: "nested args", + input: "Array(Tuple(Tuple(String, String), Tuple(String, UInt64)))", + output: &TypeDesc{ + Name: "Array", + Args: []*TypeDesc{ + { + Name: "Tuple", + Args: []*TypeDesc{ + { + Name: "Tuple", + Args: []*TypeDesc{{Name: "String"}, {Name: "String"}}, + }, + { + Name: "Tuple", + Args: []*TypeDesc{{Name: "String"}, {Name: "UInt64"}}, + }, + }, + }, + }, + }, + }, + { + name: "unfinished arg list", + input: "Array(Tuple(Tuple(String, String), Tuple(String, UInt64))", + fail: true, + }, + { + name: "left paren without name", + input: "(", + fail: true, + }, + { + name: "unfinished quote", + input: "Array(')", + fail: true, + }, + { + name: "unfinished escape", + input: `Array(\`, + fail: true, + }, + { + name: "stuff after end", + input: `Array() String`, + fail: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(tt *testing.T) { + output, err := ParseTypeDesc(tc.input) + if tc.fail { + assert.Error(tt, err) + } else { + assert.NoError(tt, err) + } + assert.Equal(tt, tc.output, output) + }) + } +}