diff --git a/bulkcopy.go b/bulkcopy.go index 3b319af8..4e01ed9b 100644 --- a/bulkcopy.go +++ b/bulkcopy.go @@ -323,7 +323,15 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) switch col.ti.TypeId { - case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN: + case typeInt1, typeInt2, typeInt4, typeInt8, typeIntN, typeMoney, typeMoneyN: + // Note: typeMoney is really int64 with a hard-coded fixed + // point convention (123456 is treated as 12.3456). In bulk + // insert it is treated as int64, and here we expect the + // caller to pass it as the underlying int64. This may be a + // bit inconsistent vs. the []byte that comes back from the + // driver on SELECT for money, but at least this solution + // allows for the possibility of doing a bulk insert. + var intvalue int64 switch val := val.(type) { @@ -334,12 +342,18 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) case int64: intvalue = val default: - err = fmt.Errorf("mssql: invalid type for int column") + if col.ti.TypeId == typeMoney || col.ti.TypeId == typeMoneyN { + err = fmt.Errorf("mssql: please pass money values as int64 for bulk copy (int64 of 12345 turns into money '1.2345')") + } else { + err = fmt.Errorf("mssql: invalid type for int column") + } return } res.buffer = make([]byte, res.ti.Size) - if col.ti.Size == 1 { + if col.ti.TypeId == typeMoney || col.ti.TypeId == typeMoneyN { + encodeMoney(res.buffer, intvalue) + } else if col.ti.Size == 1 { res.buffer[0] = byte(intvalue) } else if col.ti.Size == 2 { binary.LittleEndian.PutUint16(res.buffer, uint16(intvalue)) @@ -453,7 +467,6 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) err = fmt.Errorf("mssql: invalid type for datetime column: %s", val) } - // case typeMoney, typeMoney4, typeMoneyN: case typeDecimal, typeDecimalN, typeNumeric, typeNumericN: var value float64 switch v := val.(type) { @@ -547,6 +560,24 @@ func (b *Bulk) makeParam(val DataValue, col columnStruct) (res param, err error) } +// encodeMoney turns a 64-bit integer into the TDS wire format for the +// 'money' type in mssql. The byte ordering was deduced from +// decodeMoney in types.go; could not find it explicitly in the TDS +// documentation. The format has been tested on the wire against real +// SQL Server. +func encodeMoney(out []byte, value int64) { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(value)) + out[4] = buf[0] + out[5] = buf[1] + out[6] = buf[2] + out[7] = buf[3] + out[0] = buf[4] + out[1] = buf[5] + out[2] = buf[6] + out[3] = buf[7] +} + func (b *Bulk) dlogf(format string, v ...interface{}) { if b.Debug { b.cn.sess.log.Printf(format, v...) diff --git a/bulkcopy_test.go b/bulkcopy_test.go index c0c89885..b6852593 100644 --- a/bulkcopy_test.go +++ b/bulkcopy_test.go @@ -23,6 +23,11 @@ func TestBulkcopy(t *testing.T) { val interface{} } + type differentExpected struct { + input interface{} + expected interface{} + } + tableName := "#table_test" geom, _ := hex.DecodeString("E6100000010C00000000000034400000000000004440") bin, _ := hex.DecodeString("ba8b7782168d4033a299333aec17bd33") @@ -58,7 +63,6 @@ func TestBulkcopy(t *testing.T) { {"test_geom", geom}, {"test_uniqueidentifier", []byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}}, // {"test_smallmoney", 1234.56}, - // {"test_money", 1234.56}, {"test_decimal_18_0", 1234.0001}, {"test_decimal_9_2", 1234.560001}, {"test_decimal_20_0", 1234.0001}, @@ -68,6 +72,27 @@ func TestBulkcopy(t *testing.T) { {"test_varbinary_max", bin}, {"test_binary", []byte("1")}, {"test_binary_16", bin}, + + // money must be input as int64 to bulk insert, but scans back as a string on SELECT, so use `differentExpected` to provide + // different input and expected output + + // First test: We do some byte shuffling for the money type, so make sure every byte is unique in the test. + {"test_money_1", differentExpected{ + int64(-(0x01<<56 | 0x02<<48 | 0x03<<40 | 0x04<<32 | 0x05<<24 | 0x06<<16 | 0x07<<8 | 0x08)), // evaluates to 72623859790382856 + []byte("-7262385979038.2856")}}, + // maximum positive, minimum negative, and zero values + {"test_money_2", differentExpected{math.MaxInt64, []byte("922337203685477.5807")}}, + {"test_money_3", differentExpected{math.MinInt64, []byte("-922337203685477.5808")}}, + {"test_money_4", differentExpected{0, []byte("0.0000")}}, + + {"test_money_n_1", differentExpected{ + int64(-(0x01<<56 | 0x02<<48 | 0x03<<40 | 0x04<<32 | 0x05<<24 | 0x06<<16 | 0x07<<8 | 0x08)), // evaluates to 72623859790382856 + []byte("-7262385979038.2856")}}, + // maximum positive, minimum negative, and zero values + {"test_money_n_2", differentExpected{math.MaxInt64, []byte("922337203685477.5807")}}, + {"test_money_n_3", differentExpected{math.MinInt64, []byte("-922337203685477.5808")}}, + {"test_money_n_4", differentExpected{0, []byte("0.0000")}}, + {"test_money_n_5", nil}, } columns := make([]string, len(testValues)) @@ -77,7 +102,12 @@ func TestBulkcopy(t *testing.T) { values := make([]interface{}, len(testValues)) for i, val := range testValues { - values[i] = val.val + switch t := val.val.(type) { + case differentExpected: + values[i] = t.input + default: + values[i] = val.val + } } pool := open(t) @@ -149,8 +179,15 @@ func TestBulkcopy(t *testing.T) { t.Fatal(err) } for i, c := range testValues { - if !compareValue(container[i], c.val) { - t.Errorf("columns %s : expected: %v, got: %v\n", c.colname, c.val, container[i]) + var expected interface{} + switch t := c.val.(type) { + case differentExpected: + expected = t.expected + default: + expected = c.val + } + if !compareValue(container[i], expected) { + t.Errorf("columns %s : expected: %v, got: %v\n", c.colname, string(expected.([]byte)), string(container[i].([]byte))) } } } @@ -203,8 +240,6 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str [test_datetime2_3] [datetime2](3) NULL, [test_datetime2_7] [datetime2](7) NULL, [test_date] [date] NULL, - [test_smallmoney] [smallmoney] NULL, - [test_money] [money] NULL, [test_tinyint] [tinyint] NULL, [test_smallint] [smallint] NOT NULL, [test_smallintn] [smallint] NULL, @@ -224,6 +259,15 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str [test_varbinary_max] VARBINARY(max) NOT NULL, [test_binary] BINARY NOT NULL, [test_binary_16] BINARY(16) NOT NULL, + [test_money_1] MONEY NOT NULL, + [test_money_2] MONEY NOT NULL, + [test_money_3] MONEY NOT NULL, + [test_money_4] MONEY NOT NULL, + [test_money_n_1] MONEY NULL, + [test_money_n_2] MONEY NULL, + [test_money_n_3] MONEY NULL, + [test_money_n_4] MONEY NULL, + [test_money_n_5] MONEY NULL CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED ( [id] ASC