Skip to content
Permalink
Browse files

Convert column types to ColumnType (#25)

  • Loading branch information...
deliahu committed Feb 28, 2019
1 parent 515e1b9 commit 23f4bda671b66405bd17cdc2a971ff51fd1ec49f
@@ -414,8 +414,7 @@ func describeAPI(name string, resourcesRes *schema.GetResourcesResponse) (string
var samplePlaceholderFields []string
for _, colName := range ctx.RawColumnInputNames(model) {
column := ctx.GetColumn(colName)
columnType := userconfig.ColumnTypeFromString(column.GetType())
fieldStr := `"` + colName + `": ` + columnType.JSONPlaceholder()
fieldStr := `"` + colName + `": ` + column.GetType().JSONPlaceholder()
samplePlaceholderFields = append(samplePlaceholderFields, fieldStr)
}
samplesPlaceholderStr := `{ "samples": [ { ` + strings.Join(samplePlaceholderFields, ", ") + " } ] }"
@@ -30,7 +30,7 @@ type Columns map[string]Column

type Column interface {
ComputedResource
GetType() string
GetType() userconfig.ColumnType
IsRaw() bool
GetInputRawColumnNames() []string
}
@@ -105,7 +105,7 @@ func GetColumnRuntimeTypes(
}

if rawColumnNames, ok := cast.InterfaceToStrSlice(columnInputValue); ok {
rawColumnTypes := make([]string, len(rawColumnNames))
rawColumnTypes := make([]userconfig.ColumnType, len(rawColumnNames))
for i, rawColumnName := range rawColumnNames {
rawColumn, ok := rawColumns[rawColumnName]
if !ok {
@@ -33,39 +33,39 @@ func TestGetColumnRuntimeTypes(t *testing.T) {
rawColumns := context.RawColumns{
"rfInt": &context.RawIntColumn{
RawIntColumn: &userconfig.RawIntColumn{
Type: "INT_COLUMN",
Type: userconfig.IntegerColumnType,
},
},
"rfFloat": &context.RawFloatColumn{
RawFloatColumn: &userconfig.RawFloatColumn{
Type: "FLOAT_COLUMN",
Type: userconfig.FloatColumnType,
},
},
"rfStr": &context.RawStringColumn{
RawStringColumn: &userconfig.RawStringColumn{
Type: "STRING_COLUMN",
Type: userconfig.StringColumnType,
},
},
}

columnInputValues = cr.MustReadYAMLStrMap("in: rfInt")
expected = map[string]interface{}{"in": "INT_COLUMN"}
expected = map[string]interface{}{"in": userconfig.IntegerColumnType}
checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t)

columnInputValues = cr.MustReadYAMLStrMap("in: rfStr")
expected = map[string]interface{}{"in": "STRING_COLUMN"}
expected = map[string]interface{}{"in": userconfig.StringColumnType}
checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t)

columnInputValues = cr.MustReadYAMLStrMap("in: [rfFloat]")
expected = map[string]interface{}{"in": []string{"FLOAT_COLUMN"}}
expected = map[string]interface{}{"in": []userconfig.ColumnType{userconfig.FloatColumnType}}
checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t)

columnInputValues = cr.MustReadYAMLStrMap("in: [rfInt, rfFloat, rfStr, rfInt]")
expected = map[string]interface{}{"in": []string{"INT_COLUMN", "FLOAT_COLUMN", "STRING_COLUMN", "INT_COLUMN"}}
expected = map[string]interface{}{"in": []userconfig.ColumnType{userconfig.IntegerColumnType, userconfig.FloatColumnType, userconfig.StringColumnType, userconfig.IntegerColumnType}}
checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t)

columnInputValues = cr.MustReadYAMLStrMap("in1: [rfInt, rfFloat]\nin2: rfStr")
expected = map[string]interface{}{"in1": []string{"INT_COLUMN", "FLOAT_COLUMN"}, "in2": "STRING_COLUMN"}
expected = map[string]interface{}{"in1": []userconfig.ColumnType{userconfig.IntegerColumnType, userconfig.FloatColumnType}, "in2": userconfig.StringColumnType}
checkTestGetColumnRuntimeTypes(columnInputValues, rawColumns, expected, t)

columnInputValues = cr.MustReadYAMLStrMap("in: 1")
@@ -77,8 +77,7 @@ func (models Models) GetTrainingDatasets() TrainingDatasets {
return trainingDatasets
}

func ValidateModelTargetType(targetDataTypeStr string, modelType string) error {
targetType := userconfig.ColumnTypeFromString(targetDataTypeStr)
func ValidateModelTargetType(targetType userconfig.ColumnType, modelType string) error {
switch modelType {
case "classification":
if targetType != userconfig.IntegerColumnType {
@@ -27,10 +27,10 @@ type TransformedColumns map[string]*TransformedColumn
type TransformedColumn struct {
*userconfig.TransformedColumn
*ComputedResourceFields
Type string `json:"type"`
Type userconfig.ColumnType `json:"type"`
}

func (column *TransformedColumn) GetType() string {
func (column *TransformedColumn) GetType() userconfig.ColumnType {
return column.Type
}

@@ -112,6 +112,9 @@ func ErrInvalidFloat32(provided float32, allowed ...float32) string {
func ErrInvalidFloat64(provided float64, allowed ...float64) string {
return fmt.Sprintf("invalid value (got %s, must be %s)", UserStr(provided), UserStrsOr(allowed))
}
func ErrInvalidInterface(provided interface{}, allowed ...interface{}) string {
return fmt.Sprintf("invalid value (got %s, must be %s)", UserStr(provided), UserStrsOr(allowed))
}

func ErrMustHavePrefix(provided string, prefix string) string {
return fmt.Sprintf("%s must start with %s", UserStr(provided), UserStr(prefix))
@@ -130,7 +130,7 @@ func strIndent(val interface{}, indent string, currentIndent string, newlineChar
if funcVal.IsValid() {
t := funcVal.Type()
if t.NumIn() == 0 && t.NumOut() == 1 && t.Out(0).Kind() == reflect.String {
return funcVal.Call(nil)[0].Interface().(string)
return quoteStr + funcVal.Call(nil)[0].Interface().(string) + quoteStr
}
}
if _, ok := reflect.PtrTo(valueType).MethodByName("String"); ok {
@@ -140,7 +140,7 @@ func strIndent(val interface{}, indent string, currentIndent string, newlineChar
if funcVal.IsValid() {
t := funcVal.Type()
if t.NumIn() == 0 && t.NumOut() == 1 && t.Out(0).Kind() == reflect.String {
return funcVal.Call(nil)[0].Interface().(string)
return quoteStr + funcVal.Call(nil)[0].Interface().(string) + quoteStr
}
}
}
@@ -16,7 +16,9 @@ limitations under the License.

package userconfig

import "strings"
import (
"strings"
)

type ColumnType int
type ColumnTypes []ColumnType
@@ -309,10 +309,10 @@ func ErrorInvalidColumnInputType(provided interface{}) error {
}
}

func ErrorInvalidColumnRuntimeType(provided interface{}) error {
func ErrorInvalidColumnRuntimeType() error {
return Error{
Kind: ErrInvalidColumnRuntimeType,
message: fmt.Sprintf("invalid column type (got %s, expected %s)", s.DataTypeStr(provided), s.StrsOr(ColumnTypeStrings())),
message: fmt.Sprintf("invalid column runtime type (expected %s)", s.StrsOr(ColumnTypeStrings())),
}
}

@@ -23,7 +23,7 @@ import (

type RawColumn interface {
Column
GetType() string
GetType() ColumnType
GetCompute() *SparkCompute
GetUserConfig() Resource
}
@@ -33,25 +33,28 @@ type RawColumns []RawColumn
var rawColumnValidation = &cr.InterfaceStructValidation{
TypeKey: "type",
TypeStructField: "Type",
InterfaceStructTypes: map[string]*cr.InterfaceStructType{
"STRING_COLUMN": {
ParsedInterfaceStructTypes: map[interface{}]*cr.InterfaceStructType{
StringColumnType: {
Type: (*RawStringColumn)(nil),
StructFieldValidations: rawStringColumnFieldValidations,
},
"INT_COLUMN": {
IntegerColumnType: {
Type: (*RawIntColumn)(nil),
StructFieldValidations: rawIntColumnFieldValidations,
},
"FLOAT_COLUMN": {
FloatColumnType: {
Type: (*RawFloatColumn)(nil),
StructFieldValidations: rawFloatColumnFieldValidations,
},
},
Parser: func(str string) (interface{}, error) {
return ColumnTypeFromString(str), nil
},
}

type RawIntColumn struct {
ResourceConfigFields
Type string `json:"type" yaml:"type"`
Type ColumnType `json:"type" yaml:"type"`
Required bool `json:"required" yaml:"required"`
Min *int64 `json:"min" yaml:"min"`
Max *int64 `json:"max" yaml:"max"`
@@ -100,7 +103,7 @@ var rawIntColumnFieldValidations = []*cr.StructFieldValidation{

type RawFloatColumn struct {
ResourceConfigFields
Type string `json:"type" yaml:"type"`
Type ColumnType `json:"type" yaml:"type"`
Required bool `json:"required" yaml:"required"`
Min *float32 `json:"min" yaml:"min"`
Max *float32 `json:"max" yaml:"max"`
@@ -149,7 +152,7 @@ var rawFloatColumnFieldValidations = []*cr.StructFieldValidation{

type RawStringColumn struct {
ResourceConfigFields
Type string `json:"type" yaml:"type"`
Type ColumnType `json:"type" yaml:"type"`
Required bool `json:"required" yaml:"required"`
Values []string `json:"values" yaml:"values"`
Compute *SparkCompute `json:"compute" yaml:"compute"`
@@ -215,15 +218,15 @@ func (rawColumns RawColumns) Get(name string) RawColumn {
return nil
}

func (column *RawIntColumn) GetType() string {
func (column *RawIntColumn) GetType() ColumnType {
return column.Type
}

func (column *RawFloatColumn) GetType() string {
func (column *RawFloatColumn) GetType() ColumnType {
return column.Type
}

func (column *RawStringColumn) GetType() string {
func (column *RawStringColumn) GetType() ColumnType {
return column.Type
}

@@ -25,9 +25,9 @@ type Transformers []*Transformer

type Transformer struct {
ResourceConfigFields
Inputs *Inputs `json:"inputs" yaml:"inputs"`
OutputType string `json:"output_type" yaml:"output_type"`
Path string `json:"path" yaml:"path"`
Inputs *Inputs `json:"inputs" yaml:"inputs"`
OutputType ColumnType `json:"output_type" yaml:"output_type"`
Path string `json:"path" yaml:"path"`
}

var transformerValidation = &cr.StructValidation{
@@ -53,6 +53,9 @@ var transformerValidation = &cr.StructValidation{
Required: true,
AllowedValues: ColumnTypeStrings(),
},
Parser: func(str string) (interface{}, error) {
return ColumnTypeFromString(str), nil
},
},
inputTypesFieldValidation,
typeFieldValidation,
@@ -26,10 +26,6 @@ import (
"github.com/cortexlabs/cortex/pkg/lib/slices"
)

func isValidColumnOutputType(columnTypeStr string) bool {
return slices.HasString(columnTypeStr, ColumnTypeStrings())
}

func isValidColumnInputType(columnTypeStr string) bool {
for _, columnTypeStrItem := range strings.Split(columnTypeStr, "|") {
if !slices.HasString(columnTypeStrItem, ColumnTypeStrings()) {
@@ -91,27 +87,28 @@ func ValidateColumnInputValues(columnInputValues map[string]interface{}) error {
}

func ValidateColumnRuntimeTypes(columnRuntimeTypes map[string]interface{}) error {
for columnInputName, columnType := range columnRuntimeTypes {
if columnTypeStr, ok := columnType.(string); ok {
if !isValidColumnOutputType(columnTypeStr) {
return errors.Wrap(ErrorInvalidColumnRuntimeType(columnTypeStr), columnInputName)
for columnInputName, columnTypeInter := range columnRuntimeTypes {
if columnType, ok := columnTypeInter.(ColumnType); ok {
if columnType == UnknownColumnType {
return errors.Wrap(ErrorInvalidColumnRuntimeType(), columnInputName) // unexpected
}
continue
}
if columnTypeStrs, ok := cast.InterfaceToStrSlice(columnType); ok {
for i, columnTypeStr := range columnTypeStrs {
if !isValidColumnOutputType(columnTypeStr) {
return errors.Wrap(ErrorInvalidColumnRuntimeType(columnTypeStr), columnInputName, s.Index(i))
if columnTypes, ok := columnTypeInter.([]ColumnType); ok {
for i, columnType := range columnTypes {
if columnType == UnknownColumnType {
return errors.Wrap(ErrorInvalidColumnRuntimeType(), columnInputName, s.Index(i)) // unexpected
}
}
continue
}
return errors.Wrap(ErrorInvalidColumnRuntimeType(columnType), columnInputName)
return errors.Wrap(ErrorInvalidColumnRuntimeType(), columnInputName) // unexpected
}

return nil
}

// columnRuntimeTypes is {string -> ColumnType or []ColumnType}, columnSchemaTypes is {string -> string or []string}
func CheckColumnRuntimeTypesMatch(columnRuntimeTypes map[string]interface{}, columnSchemaTypes map[string]interface{}) error {
err := ValidateColumnInputTypes(columnSchemaTypes)
if err != nil {
@@ -127,32 +124,32 @@ func CheckColumnRuntimeTypesMatch(columnRuntimeTypes map[string]interface{}, col
return errors.New(s.MapMustBeDefined(maps.InterfaceMapKeys(columnSchemaTypes)...))
}

columnRuntimeType, ok := columnRuntimeTypes[columnInputName]
columnRuntimeTypeInter, ok := columnRuntimeTypes[columnInputName]
if !ok {
return errors.New(columnInputName, s.ErrMustBeDefined)
}

if columnSchemaTypeStr, ok := columnSchemaType.(string); ok {
validTypes := strings.Split(columnSchemaTypeStr, "|")
columnRuntimeTypeStr, ok := columnRuntimeType.(string)
columnRuntimeType, ok := columnRuntimeTypeInter.(ColumnType)
if !ok {
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeType, validTypes), columnInputName)
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeInter, validTypes), columnInputName)
}
if !slices.HasString(columnRuntimeTypeStr, validTypes) {
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeStr, validTypes), columnInputName)
if !slices.HasString(columnRuntimeType.String(), validTypes) {
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeType, validTypes), columnInputName)
}
continue
}

if columnSchemaTypeStrs, ok := cast.InterfaceToStrSlice(columnSchemaType); ok {
validTypes := strings.Split(columnSchemaTypeStrs[0], "|")
columnRuntimeTypeStrs, ok := cast.InterfaceToStrSlice(columnRuntimeType)
columnRuntimeTypeSlice, ok := columnRuntimeTypeInter.([]ColumnType)
if !ok {
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeType, columnSchemaTypeStrs), columnInputName)
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeInter, columnSchemaTypeStrs), columnInputName)
}
for i, columnRuntimeTypeStr := range columnRuntimeTypeStrs {
if !slices.HasString(columnRuntimeTypeStr, validTypes) {
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeTypeStr, validTypes), columnInputName, s.Index(i))
for i, columnRuntimeType := range columnRuntimeTypeSlice {
if !slices.HasString(columnRuntimeType.String(), validTypes) {
return errors.Wrap(ErrorUnsupportedColumnType(columnRuntimeType, validTypes), columnInputName, s.Index(i))
}
}
continue
Oops, something went wrong.

0 comments on commit 23f4bda

Please sign in to comment.
You can’t perform that action at this time.