-
-
Notifications
You must be signed in to change notification settings - Fork 442
/
Copy pathmnist.go
135 lines (118 loc) · 2.94 KB
/
mnist.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
// package mnist handles the mnist data set
package mnist
import (
"os"
"path/filepath"
"github.com/pkg/errors"
"gorgonia.org/tensor"
)
// Image holds the pixel intensities of an image.
// 255 is foreground (black), 0 is background (white).
type RawImage []byte
// Label is a digit label in 0 to 9
type Label uint8
// Load loads the mnist data into two tensors
//
// typ can be "train", "test"
//
// loc represents where the mnist files are held
func Load(typ, loc string, as tensor.Dtype) (inputs, targets tensor.Tensor, err error) {
const (
trainLabel = "train-labels.idx1-ubyte"
trainData = "train-images.idx3-ubyte"
testLabel = "t10k-labels.idx1-ubyte"
testData = "t10k-images.idx3-ubyte"
)
var labelFile, dataFile string
switch typ {
case "train", "dev":
labelFile = filepath.Join(loc, trainLabel)
dataFile = filepath.Join(loc, trainData)
case "test":
labelFile = filepath.Join(loc, testLabel)
dataFile = filepath.Join(loc, testData)
}
var labelData []Label
var imageData []RawImage
if labelData, err = readLabelFile(os.Open(labelFile)); err != nil {
return nil, nil, errors.Wrap(err, "Unable to read Labels")
}
if imageData, err = readImageFile(os.Open(dataFile)); err != nil {
return nil, nil, errors.Wrap(err, "Unable to read image data")
}
inputs = prepareX(imageData, as)
targets = prepareY(labelData, as)
return
}
func pixelWeight(px byte) float64 {
retVal := float64(px)/pixelRange*0.9 + 0.1
if retVal == 1.0 {
return 0.999
}
return retVal
}
func reversePixelWeight(px float64) byte {
return byte((pixelRange*px - pixelRange) / 0.9)
}
func prepareX(M []RawImage, dt tensor.Dtype) (retVal tensor.Tensor) {
rows := len(M)
cols := len(M[0])
var backing interface{}
switch dt {
case tensor.Float64:
b := make([]float64, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < len(M[i]); j++ {
b = append(b, pixelWeight(M[i][j]))
}
}
backing = b
case tensor.Float32:
b := make([]float32, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < len(M[i]); j++ {
b = append(b, float32(pixelWeight(M[i][j])))
}
}
backing = b
}
retVal = tensor.New(tensor.WithShape(rows, cols), tensor.WithBacking(backing))
return
}
func prepareY(N []Label, dt tensor.Dtype) (retVal tensor.Tensor) {
rows := len(N)
cols := 10
var backing interface{}
switch dt {
case tensor.Float64:
b := make([]float64, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < 10; j++ {
if j == int(N[i]) {
b = append(b, 0.9)
} else {
b = append(b, 0.1)
}
}
}
backing = b
case tensor.Float32:
b := make([]float32, rows*cols, rows*cols)
b = b[:0]
for i := 0; i < rows; i++ {
for j := 0; j < 10; j++ {
if j == int(N[i]) {
b = append(b, 0.9)
} else {
b = append(b, 0.1)
}
}
}
backing = b
}
retVal = tensor.New(tensor.WithShape(rows, cols), tensor.WithBacking(backing))
return
}