From fa5e700d2eb28c0b651213ae220d54020a92a67c Mon Sep 17 00:00:00 2001 From: erezrokah Date: Thu, 1 Jun 2023 18:37:31 +0200 Subject: [PATCH 1/2] feat(scalar): Support all int variations in decimal scalar --- scalar/decimal.go | 128 +++++++++++++++++++++++++++++++++++++++++ scalar/decimal_test.go | 118 +++++++++++++++++++++++++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 scalar/decimal_test.go diff --git a/scalar/decimal.go b/scalar/decimal.go index e9ad791a86..58aece4263 100644 --- a/scalar/decimal.go +++ b/scalar/decimal.go @@ -73,8 +73,24 @@ func (s *Decimal256) Set(val any) error { return nil } return s.Set(*value) + case int: + s.Value = decimal256.FromI64(int64(value)) + case int8: + s.Value = decimal256.FromI64(int64(value)) + case int16: + s.Value = decimal256.FromI64(int64(value)) + case int32: + s.Value = decimal256.FromI64(int64(value)) case int64: s.Value = decimal256.FromI64(value) + case uint: + s.Value = decimal256.FromU64(uint64(value)) + case uint8: + s.Value = decimal256.FromU64(uint64(value)) + case uint16: + s.Value = decimal256.FromU64(uint64(value)) + case uint32: + s.Value = decimal256.FromU64(uint64(value)) case uint64: s.Value = decimal256.FromU64(value) case string: @@ -83,12 +99,60 @@ func (s *Decimal256) Set(val any) error { return err } s.Value = v + case *int: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *int8: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *int16: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *int32: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) case *int64: if value == nil { s.Valid = false return nil } return s.Set(*value) + case *uint: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *uint8: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *uint16: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *uint32: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) case *uint64: if value == nil { s.Valid = false @@ -167,8 +231,24 @@ func (s *Decimal128) Set(val any) error { return nil } return s.Set(*value) + case int: + s.Value = decimal128.FromI64(int64(value)) + case int8: + s.Value = decimal128.FromI64(int64(value)) + case int16: + s.Value = decimal128.FromI64(int64(value)) + case int32: + s.Value = decimal128.FromI64(int64(value)) case int64: s.Value = decimal128.FromI64(value) + case uint: + s.Value = decimal128.FromU64(uint64(value)) + case uint8: + s.Value = decimal128.FromU64(uint64(value)) + case uint16: + s.Value = decimal128.FromU64(uint64(value)) + case uint32: + s.Value = decimal128.FromU64(uint64(value)) case uint64: s.Value = decimal128.FromU64(value) case string: @@ -177,12 +257,60 @@ func (s *Decimal128) Set(val any) error { return err } s.Value = v + case *int: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *int8: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *int16: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *int32: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) case *int64: if value == nil { s.Valid = false return nil } return s.Set(*value) + case *uint: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *uint8: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *uint16: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) + case *uint32: + if value == nil { + s.Valid = false + return nil + } + return s.Set(*value) case *uint64: if value == nil { s.Valid = false diff --git a/scalar/decimal_test.go b/scalar/decimal_test.go new file mode 100644 index 0000000000..2f37533e28 --- /dev/null +++ b/scalar/decimal_test.go @@ -0,0 +1,118 @@ +package scalar + +import ( + "testing" + + "github.com/apache/arrow/go/v13/arrow" + "github.com/apache/arrow/go/v13/arrow/decimal128" + "github.com/apache/arrow/go/v13/arrow/decimal256" + "github.com/stretchr/testify/require" +) + +func TestDecimal128Set(t *testing.T) { + str := "100.32" + decimalType := &arrow.Decimal128Type{Precision: 5, Scale: 2} + strDecimal, _ := decimal128.FromString(str, decimalType.Precision, decimalType.Scale) + + intVal := int(1) + int8Val := int8(1) + int16Val := int16(1) + int32Val := int32(1) + int64Val := int64(1) + uintVal := uint(1) + uint8Val := uint8(1) + uint16Val := uint16(1) + uint32Val := uint32(1) + uint64Val := uint64(1) + + successfulTests := []struct { + source any + decimalType *arrow.Decimal128Type + expect Decimal128 + }{ + {source: str, expect: Decimal128{Value: strDecimal, Valid: true, Type: decimalType}, decimalType: decimalType}, + {source: &str, expect: Decimal128{Value: strDecimal, Valid: true, Type: decimalType}, decimalType: decimalType}, + {source: intVal, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: int8Val, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: int16Val, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: int32Val, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: int64Val, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: uintVal, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + {source: uint8Val, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + {source: uint16Val, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + {source: uint32Val, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + {source: uint64Val, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + {source: &intVal, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: &int8Val, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: &int16Val, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: &int32Val, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: &int64Val, expect: Decimal128{Value: decimal128.FromI64(1), Valid: true}}, + {source: &uintVal, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + {source: &uint8Val, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + {source: &uint16Val, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + {source: &uint32Val, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + {source: &uint64Val, expect: Decimal128{Value: decimal128.FromU64(1), Valid: true}}, + } + + for i, tt := range successfulTests { + r := Decimal128{} + r.Type = tt.decimalType + err := r.Set(tt.source) + require.NoError(t, err, "No error expected for test %d", i) + require.Equal(t, tt.expect, r, "Unexpected result for test %d", i) + } +} + +func TestDecimal256Set(t *testing.T) { + str := "100.32" + decimalType := &arrow.Decimal256Type{Precision: 5, Scale: 2} + strDecimal, _ := decimal256.FromString(str, decimalType.Precision, decimalType.Scale) + + intVal := int(1) + int8Val := int8(1) + int16Val := int16(1) + int32Val := int32(1) + int64Val := int64(1) + uintVal := uint(1) + uint8Val := uint8(1) + uint16Val := uint16(1) + uint32Val := uint32(1) + uint64Val := uint64(1) + + successfulTests := []struct { + source any + decimalType *arrow.Decimal256Type + expect Decimal256 + }{ + {source: str, expect: Decimal256{Value: strDecimal, Valid: true, Type: decimalType}, decimalType: decimalType}, + {source: &str, expect: Decimal256{Value: strDecimal, Valid: true, Type: decimalType}, decimalType: decimalType}, + {source: intVal, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: int8Val, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: int16Val, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: int32Val, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: int64Val, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: uintVal, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + {source: uint8Val, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + {source: uint16Val, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + {source: uint32Val, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + {source: uint64Val, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + {source: &intVal, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: &int8Val, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: &int16Val, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: &int32Val, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: &int64Val, expect: Decimal256{Value: decimal256.FromI64(1), Valid: true}}, + {source: &uintVal, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + {source: &uint8Val, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + {source: &uint16Val, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + {source: &uint32Val, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + {source: &uint64Val, expect: Decimal256{Value: decimal256.FromU64(1), Valid: true}}, + } + + for i, tt := range successfulTests { + r := Decimal256{} + r.Type = tt.decimalType + err := r.Set(tt.source) + require.NoError(t, err, "No error expected for test %d", i) + require.Equal(t, tt.expect, r, "Unexpected result for test %d", i) + } +} From 4421ee0d3311e6e74e56117cfd387d96f66200b5 Mon Sep 17 00:00:00 2001 From: erezrokah Date: Thu, 1 Jun 2023 19:04:19 +0200 Subject: [PATCH 2/2] style: Fix linting --- scalar/decimal_test.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/scalar/decimal_test.go b/scalar/decimal_test.go index 2f37533e28..7f429ec901 100644 --- a/scalar/decimal_test.go +++ b/scalar/decimal_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/require" ) +// nolint:dupl func TestDecimal128Set(t *testing.T) { str := "100.32" decimalType := &arrow.Decimal128Type{Precision: 5, Scale: 2} @@ -63,6 +64,7 @@ func TestDecimal128Set(t *testing.T) { } } +// nolint:dupl func TestDecimal256Set(t *testing.T) { str := "100.32" decimalType := &arrow.Decimal256Type{Precision: 5, Scale: 2}