forked from sjwhitworth/golearn
/
categorical.go
190 lines (167 loc) · 5.19 KB
/
categorical.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
package base
import (
"encoding/json"
"fmt"
)
// CategoricalAttribute is an Attribute implementation
// which stores discrete string values
// - useful for representing classes.
type CategoricalAttribute struct {
Name string
values []string
}
// MarshalJSON returns a JSON version of this Attribute.
func (Attr *CategoricalAttribute) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]interface{}{
"type": "categorical",
"name": Attr.Name,
"attr": map[string]interface{}{
"values": Attr.values,
},
})
}
// UnmarshalJSON returns a JSON version of this Attribute.
func (Attr *CategoricalAttribute) UnmarshalJSON(data []byte) error {
var d map[string]interface{}
err := json.Unmarshal(data, &d)
if err != nil {
return err
}
for _, v := range d["values"].([]interface{}) {
Attr.values = append(Attr.values, v.(string))
}
return nil
}
// NewCategoricalAttribute creates a blank CategoricalAttribute.
func NewCategoricalAttribute() *CategoricalAttribute {
return &CategoricalAttribute{
"",
make([]string, 0),
}
}
// GetValues returns all the values currently defined
func (Attr *CategoricalAttribute) GetValues() []string {
return Attr.values
}
// GetName returns the human-readable name assigned to this attribute.
func (Attr *CategoricalAttribute) GetName() string {
return Attr.Name
}
// SetName sets the human-readable name on this attribute.
func (Attr *CategoricalAttribute) SetName(name string) {
Attr.Name = name
}
// GetType returns CategoricalType to avoid casting overhead.
func (Attr *CategoricalAttribute) GetType() int {
return CategoricalType
}
// GetSysVal returns the system representation of userVal as an index into the Values slice
// If the userVal can't be found, it returns nothing.
func (Attr *CategoricalAttribute) GetSysVal(userVal string) []byte {
for idx, val := range Attr.values {
if val == userVal {
return PackU64ToBytes(uint64(idx))
}
}
return nil
}
// GetUsrVal returns a human-readable representation of the given sysVal.
//
// IMPORTANT: this function doesn't check the boundaries of the array.
func (Attr *CategoricalAttribute) GetUsrVal(sysVal []byte) string {
idx := UnpackBytesToU64(sysVal)
return Attr.values[idx]
}
// GetSysValFromString returns the system representation of rawVal
// as an index into the Values slice. If rawVal is not inside
// the Values slice, it is appended.
//
// IMPORTANT: If no system representation yet exists, this functions adds it.
// If you need to determine whether rawVal exists: use GetSysVal and check
// for a zero-length return value.
//
// Example: if the CategoricalAttribute contains the values ["iris-setosa",
// "iris-virginica"] and "iris-versicolor" is provided as the argument,
// the Values slide becomes ["iris-setosa", "iris-virginica", "iris-versicolor"]
// and 2.00 is returned as the system representation.
func (Attr *CategoricalAttribute) GetSysValFromString(rawVal string) []byte {
// Match in raw values
catIndex := -1
for i, s := range Attr.values {
if s == rawVal {
catIndex = i
break
}
}
if catIndex == -1 {
Attr.values = append(Attr.values, rawVal)
catIndex = len(Attr.values) - 1
}
ret := PackU64ToBytes(uint64(catIndex))
return ret
}
// String returns a human-readable summary of this Attribute.
//
// Returns a string containing the list of human-readable values this
// CategoricalAttribute can take.
func (Attr *CategoricalAttribute) String() string {
return fmt.Sprintf("CategoricalAttribute(\"%s\", %s)", Attr.Name, Attr.values)
}
// GetStringFromSysVal returns a human-readable value from the given system-representation
// value val.
//
// IMPORTANT: This function calls panic() if the value is greater than
// the length of the array.
// TODO: Return a user-configurable default instead.
func (Attr *CategoricalAttribute) GetStringFromSysVal(rawVal []byte) string {
convVal := int(UnpackBytesToU64(rawVal))
if convVal >= len(Attr.values) {
panic(fmt.Sprintf("Out of range: %d in %d (%s)", convVal, len(Attr.values), Attr))
}
return Attr.values[convVal]
}
// Equals checks equality against another Attribute.
//
// Two CategoricalAttributes are considered equal if they contain
// the same values and have the same name. Otherwise, this function
// returns false.
func (Attr *CategoricalAttribute) Equals(other Attribute) bool {
attribute, ok := other.(*CategoricalAttribute)
if !ok {
// Not the same type, so can't be equal
return false
}
if Attr.GetName() != attribute.GetName() {
return false
}
// Check that this CategoricalAttribute has the same
// values as the other, in the same order
if len(attribute.values) != len(Attr.values) {
return false
}
for i, a := range Attr.values {
if a != attribute.values[i] {
return false
}
}
return true
}
// Compatible checks that this CategoricalAttribute has the same
// values as another, in the same order.
func (Attr *CategoricalAttribute) Compatible(other Attribute) bool {
attribute, ok := other.(*CategoricalAttribute)
if !ok {
return false
}
// Check that this CategoricalAttribute has the same
// values as the other, in the same order
if len(attribute.values) != len(Attr.values) {
return false
}
for i, a := range Attr.values {
if a != attribute.values[i] {
return false
}
}
return true
}