Skip to content

Commit

Permalink
Merge pull request #117 from mailru/add-ip-types
Browse files Browse the repository at this point in the history
Support of IPv4/IPv6 and additional testing
  • Loading branch information
DoubleDi committed Apr 14, 2021
2 parents c301c6f + 1d6a58f commit 5228ee5
Show file tree
Hide file tree
Showing 10 changed files with 173 additions and 92 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -47,6 +47,7 @@ It is recommended use type `UInt64` which is provided by driver for such kind of
type `[]byte` are used as raw string (without quoting)
for passing value of type `[]uint8` to driver as array - please use the wrapper `clickhouse.Array`
for passing decimal value please use the wrappers `clickhouse.Decimal*`
for passing IPv4/IPv6 types use `clickhouse.IP`

## Supported request params

Expand Down
12 changes: 8 additions & 4 deletions clickhouse_test.go
Expand Up @@ -31,12 +31,16 @@ var ddls = []string{
d32 Decimal32(4),
d64 Decimal64(4),
d128 Decimal128(4),
d10 Decimal(10, 4)
d10 Decimal(10, 4),
ipv4 IPv4,
ipv6 IPv6,
fs FixedString(8),
lc LowCardinality(String)
) ENGINE = Memory`,
`INSERT INTO data VALUES
(-1, 1, 1.0, '1', '1', [1], [10], '2011-03-06', '2011-03-06 06:20:00', 'one', '10', '100', '1000', '1'),
(-2, 2, 2.0, '2', '2', [2], [20], '2012-05-31', '2012-05-31 11:20:00', 'two', '30', '300', '2000', '2'),
(-3, 3, 3.0, '3', '2', [3], [30], '2016-04-04', '2016-04-04 11:30:00', 'three', '40', '400', '3000', '3')
(-1, 1, 1.0, '1', '1', [1], [10], '2011-03-06', '2011-03-06 06:20:00', 'one', '10', '100', '1000', '1', '127.0.0.1', '2001:db8:3333:4444:5555:6666:7777:8888', '12345678', 'one'),
(-2, 2, 2.0, '2', '2', [2], [20], '2012-05-31', '2012-05-31 11:20:00', 'two', '30', '300', '2000', '2', '8.8.8.8', '2001:db8:3333:4444:CCCC:DDDD:EEEE:FFFF', '88888888', 'two'),
(-3, 3, 3.0, '3', '2', [3], [30], '2016-04-04', '2016-04-04 11:30:00', 'three', '40', '400', '3000', '3', '255.255.255.255', '::1234:5678', '87654321', 'three')
`,
}

Expand Down
2 changes: 1 addition & 1 deletion conn_go18_test.go
Expand Up @@ -46,7 +46,7 @@ func (s *connSuite) TestColumnTypes() {
expected := []string{
"Int64", "UInt64", "Float64", "String", "String", "Array(Int16)", "Array(UInt8)", "Date", "DateTime",
"Enum8('one' = 1, 'two' = 2, 'three' = 3)",
"Decimal(9, 4)", "Decimal(18, 4)", "Decimal(38, 4)", "Decimal(10, 4)",
"Decimal(9, 4)", "Decimal(18, 4)", "Decimal(38, 4)", "Decimal(10, 4)", "IPv4", "IPv6", "FixedString(8)", "LowCardinality(String)",
}
s.Require().Equal(len(expected), len(types))
for i, e := range expected {
Expand Down
22 changes: 22 additions & 0 deletions conn_go19_test.go
Expand Up @@ -43,3 +43,25 @@ func (s *connSuite) TestExecBuild19() {
s.NoError(rows.Close())
}
}

func (s *connSuite) TestQuotedStrings() {
testCases := []struct {
query, expected1, expected2 string
}{
{
`SELECT '"foo" foo', 'bar'`, `"foo" foo`, "bar",
},
{
`SELECT 'bar', '"foo" foo'`, "bar", `"foo" foo`,
},
}
for _, tc := range testCases {
var actual1, actual2 string
err := s.conn.QueryRow(tc.query).Scan(&actual1, &actual2)
if !s.NoError(err) {
continue
}
s.Equal(tc.expected1, actual1)
s.Equal(tc.expected2, actual2)
}
}
2 changes: 1 addition & 1 deletion conn_test.go
Expand Up @@ -44,7 +44,7 @@ func (s *connSuite) TestQuery() {
[]interface{}{1},
[][]interface{}{{int64(-1), uint64(1), float64(1), "1", "1", []int16{1}, []uint8{10},
parseDate("2011-03-06"), parseDateTime("2011-03-06 06:20:00"), "one",
"10.0000", "100.0000", "1000.0000", "1.0000"}},
"10.0000", "100.0000", "1000.0000", "1.0000", "127.0.0.1", "2001:db8:3333:4444:5555:6666:7777:8888", "12345678", "one"}},
},
{
"SELECT i64, count() FROM data WHERE i64<0 GROUP BY i64 WITH TOTALS ORDER BY i64",
Expand Down
170 changes: 85 additions & 85 deletions dataparser.go
Expand Up @@ -26,21 +26,85 @@ var (
reflectTypeFloat64 = reflect.TypeOf(float64(0))
)

// DataParser implements parsing of a driver value and reporting its type.
type DataParser interface {
Parse(io.RuneScanner) (driver.Value, error)
Type() reflect.Type
func readNumber(s io.RuneScanner) (string, error) {
var builder bytes.Buffer

loop:
for {
r := read(s)

switch r {
case eof:
break loop
case ',', ']', ')':
s.UnreadRune()
break loop
}

builder.WriteRune(r)
}

return builder.String(), nil
}

type stringParser struct {
unquote bool
length int
func readUnquoted(s io.RuneScanner, length int) (string, error) {
var builder bytes.Buffer

runesRead := 0
loop:
for length == 0 || runesRead < length {
r := read(s)

switch r {
case eof:
break loop
case '\\':
escaped, err := readEscaped(s)
if err != nil {
return "", fmt.Errorf("incorrect escaping in string: %v", err)
}
r = escaped
case '\'':
s.UnreadRune()
break loop
}

builder.WriteRune(r)
runesRead++
}

if length != 0 && runesRead != length {
return "", fmt.Errorf("unexpected string length %d, expected %d", runesRead, length)
}

return builder.String(), nil
}

type dateTimeParser struct {
unquote bool
format string
location *time.Location
func readString(s io.RuneScanner, length int, unquote bool) (string, error) {
if unquote {
if r := read(s); r != '\'' {
return "", fmt.Errorf("unexpected character instead of a quote")
}
}

str, err := readUnquoted(s, length)
if err != nil {
return "", fmt.Errorf("failed to read string")
}

if unquote {
if r := read(s); r != '\'' {
return "", fmt.Errorf("unexpected character instead of a quote")
}
}

return str, nil
}

// DataParser implements parsing of a driver value and reporting its type.
type DataParser interface {
Parse(io.RuneScanner) (driver.Value, error)
Type() reflect.Type
}

type nullableParser struct {
Expand Down Expand Up @@ -155,79 +219,9 @@ func (p *nullableParser) Parse(s io.RuneScanner) (driver.Value, error) {
return p.DataParser.Parse(dB)
}

func readNumber(s io.RuneScanner) (string, error) {
var builder bytes.Buffer

loop:
for {
r := read(s)

switch r {
case eof:
break loop
case ',', ']', ')':
s.UnreadRune()
break loop
}

builder.WriteRune(r)
}

return builder.String(), nil
}

func readUnquoted(s io.RuneScanner, length int) (string, error) {
var builder bytes.Buffer

runesRead := 0
loop:
for length == 0 || runesRead < length {
r := read(s)

switch r {
case eof:
break loop
case '\\':
escaped, err := readEscaped(s)
if err != nil {
return "", fmt.Errorf("incorrect escaping in string: %v", err)
}
r = escaped
case '\'':
s.UnreadRune()
break loop
}

builder.WriteRune(r)
runesRead++
}

if length != 0 && runesRead != length {
return "", fmt.Errorf("unexpected string length %d, expected %d", runesRead, length)
}

return builder.String(), nil
}

func readString(s io.RuneScanner, length int, unquote bool) (string, error) {
if unquote {
if r := read(s); r != '\'' {
return "", fmt.Errorf("unexpected character instead of a quote")
}
}

str, err := readUnquoted(s, length)
if err != nil {
return "", fmt.Errorf("failed to read string")
}

if unquote {
if r := read(s); r != '\'' {
return "", fmt.Errorf("unexpected character instead of a quote")
}
}

return str, nil
type stringParser struct {
unquote bool
length int
}

func (p *stringParser) Parse(s io.RuneScanner) (driver.Value, error) {
Expand All @@ -238,6 +232,12 @@ func (p *stringParser) Type() reflect.Type {
return reflectTypeString
}

type dateTimeParser struct {
unquote bool
format string
location *time.Location
}

func (p *dateTimeParser) Parse(s io.RuneScanner) (driver.Value, error) {
str, err := readString(s, len(p.format), p.unquote)
if err != nil {
Expand Down Expand Up @@ -549,7 +549,7 @@ func newDataParser(t *TypeDesc, unquote bool, opt *DataParserOptions) (DataParse
return &floatParser{32}, nil
case "Float64":
return &floatParser{64}, nil
case "Decimal", "String", "Enum8", "Enum16", "UUID":
case "Decimal", "String", "Enum8", "Enum16", "UUID", "IPv4", "IPv6":
return &stringParser{unquote: unquote}, nil
case "FixedString":
if len(t.Args) != 1 {
Expand Down
12 changes: 12 additions & 0 deletions dataparser_test.go
Expand Up @@ -334,6 +334,18 @@ func TestParseData(t *testing.T) {
inputdata: "123",
output: uint64(123),
},
{
name: "ipv4",
inputtype: "IPv4",
inputdata: "127.0.0.1",
output: "127.0.0.1",
},
{
name: "ipv6",
inputtype: "IPv6",
inputdata: "2a02:aa08:e000:3100::2",
output: "2a02:aa08:e000:3100::2",
},
}

for _, tc := range testCases {
Expand Down
15 changes: 15 additions & 0 deletions rows_test.go
Expand Up @@ -131,3 +131,18 @@ func TestTextRowsWithEmptyLine(t *testing.T) {
}
assert.Equal(t, []driver.Value{int32(2), ""}, dest)
}

func TestTextRowsWithEmptyQuotes(t *testing.T) {
buf := bytes.NewReader([]byte("text\nString\n\"\"\n"))
rows, err := newTextRows(&conn{}, &bufReadCloser{buf}, time.Local, false)
if !assert.NoError(t, err) {
return
}
assert.Equal(t, []string{"text"}, rows.Columns())
assert.Equal(t, []string{"String"}, rows.types)
dest := make([]driver.Value, 1)
if !assert.NoError(t, rows.Next(dest)) {
return
}
assert.Equal(t, []driver.Value{`""`}, dest)
}
15 changes: 14 additions & 1 deletion types.go
Expand Up @@ -3,6 +3,7 @@ package clickhouse
import (
"database/sql/driver"
"fmt"
"net"
"strconv"
"time"
)
Expand All @@ -17,7 +18,7 @@ func Date(t time.Time) driver.Valuer {
return date(t)
}

// UInt64 returns date for t
// UInt64 returns uint64
func UInt64(u uint64) driver.Valuer {
return bigUint64(u)
}
Expand All @@ -40,6 +41,11 @@ func Decimal128(v interface{}, s int32) driver.Valuer {
return decimal{128, s, v}
}

// IP returns compatible database format for net.IP
func IP(i net.IP) driver.Valuer {
return ip(i)
}

type array struct {
v interface{}
}
Expand Down Expand Up @@ -73,3 +79,10 @@ type decimal struct {
func (d decimal) Value() (driver.Value, error) {
return []byte(fmt.Sprintf("toDecimal%d(%v, %d)", d.p, d.v, d.s)), nil
}

type ip net.IP

// Value implements driver.Valuer
func (i ip) Value() (driver.Value, error) {
return net.IP(i).String(), nil
}
14 changes: 14 additions & 0 deletions types_test.go
Expand Up @@ -2,6 +2,7 @@ package clickhouse

import (
"database/sql/driver"
"net"
"testing"
"time"

Expand Down Expand Up @@ -61,3 +62,16 @@ func TestDecimal(t *testing.T) {
assert.Equal(t, []byte("toDecimal128(100.01, 1)"), dv)
}
}

func TestIP(t *testing.T) {
ipv4 := net.ParseIP("127.0.0.1")
assert.NotNil(t, ipv4)
ipv6 := net.ParseIP("2001:44c8:129:2632:33:0:252:2")
assert.NotNil(t, ipv6)
dv, err := IP(ipv4).Value()
assert.NoError(t, err)
assert.Equal(t, "127.0.0.1", dv)
dv, err = IP(ipv6).Value()
assert.NoError(t, err)
assert.Equal(t, "2001:44c8:129:2632:33:0:252:2", dv)
}

0 comments on commit 5228ee5

Please sign in to comment.