-
Notifications
You must be signed in to change notification settings - Fork 0
/
MnistDataClass.cpp
85 lines (68 loc) · 1.81 KB
/
MnistDataClass.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
#include "MnistDataClass.h"
MnistDataClass::MnistDataClass(){
}
int MnistDataClass::getNumImages() {
return numImages;
}
vector<double> MnistDataClass::getPixelData(int imageNum) {
return inputData[imageNum];
}
vector<double> MnistDataClass::getImageNumber(int imageNum) {
return correctNumber[imageNum];
}
MnistDataClass::MnistDataClass(string dataFilename, string answerFilename) {
ifstream dataFile;
dataFile.open(dataFilename, ifstream::binary | ifstream::in);
if (dataFile.good()) {
readHeader(dataFile);
//numImages = 20;
//set vector size
inputData.resize(numImages);
//read in pixel data for each image
for (int i = 0; i < numImages; i++) {
for (int j = 0; j < imageRow*imageCol; j++) {
if (double(dataFile.get() & 0xff) > 0) inputData[i].push_back(1.0);
else inputData[i].push_back(0);
//inputData[i].push_back(double(uint8_t(dataFile.get() & 0xff)));
}
}
}
dataFile.close();
dataFile.open(answerFilename, ifstream::binary | ifstream::in);
if (dataFile.good()) {
//dump the first 8 bytes
for (int i = 0; i < 8; i++) {
dataFile.get();
}
correctNumber.resize(numImages);
//read in answer data for each image
for (int i = 0; i < numImages; i++) {
correctNumber[i].resize(10);
correctNumber[i][double(dataFile.get() & 0xff)] = 1;
}
}
dataFile.close();
}
void MnistDataClass::readHeader(ifstream &file) {
char c;
int magicNumber = 0;
for (int i = 0; i < 4; i++) {
c = file.get();
magicNumber = (magicNumber << 8) | (c & 0xff);
}
for (int i = 0; i < 4; i++) {
c = file.get();
numImages = (numImages << 8) | (c & 0xff);
}
for (int i = 0; i < 4; i++) {
c = file.get();
imageRow = (imageRow << 8) | (c & 0xff);
}
for (int i = 0; i < 4; i++) {
c = file.get();
imageCol = (imageCol << 8) | (c & 0xff);
}
}
MnistDataClass::~MnistDataClass()
{
}