diff --git a/tf/examples/Bonsai/process_usps.py b/tf/examples/Bonsai/process_usps.py index 7ff763b00..252ba11e2 100644 --- a/tf/examples/Bonsai/process_usps.py +++ b/tf/examples/Bonsai/process_usps.py @@ -28,6 +28,19 @@ def loadLibSVMFile(file): assert os.path.isfile(tsf), 'File not found: %s' % tsf train = loadLibSVMFile(trf) test = loadLibSVMFile(tsf) + + # Convert the labels from 0 to numClasses-1 + y_train = train[:, 0] + y_test = test[:, 0] + + lab = y_train.astype('uint8') + lab = np.array(lab) - min(lab) + train[:, 0] = lab + + lab = y_test.astype('uint8') + lab = np.array(lab) - min(lab) + test[:, 0] = lab + np.save(path + '/train.npy', train) np.save(path + '/test.npy', test)