In [9]:
import h5py
from sklearn.cross_validation import train_test_split
from keras.models import Sequential, load_model
from keras.layers import Dense, Dropout, Activation, Convolution2D, MaxPooling2D, Flatten
from keras.optimizers import SGD

In [2]:
# Load data
with h5py.File("data/data.hdf5", "r") as f:
    X = f["X"][:]
    X.shape = (377414, 1, 200, 12)   # Reshaped to fit model
    artist = f["artist"][:]
    song = f["song"][:]
    Y = f["Y"][:]

In [3]:
# Train test split
X_train, X_test, y_train, y_test = train_test_split(X,Y)

In [4]:
# Neural network
model = Sequential()
model.add(Convolution2D(20, 3, 1, init='uniform', border_mode='valid', input_shape=(1, 200, 12)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2,1),border_mode='valid'))
model.add(Convolution2D(20, 3, 1, init='uniform', border_mode='valid'))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2,1),border_mode='valid'))

model.add(Flatten())
model.add(Dense(1000, init='uniform'))
model.add(Activation('tanh'))
# model.add(Dropout(0.5))
model.add(Dense(1000, init='uniform'))
model.add(Activation('tanh'))
# model.add(Dropout(0.5))
model.add(Dense(10, init='uniform'))
model.add(Activation('tanh'))

sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss="mean_squared_error",
              optimizer=sgd)

hist = model.fit(X_train, y_train, nb_epoch=2, batch_size=30, validation_data=(X_test,y_test))

Train on 283060 samples, validate on 94354 samples
Epoch 1/2
Epoch 2/2


In [6]:
# Save model
model.save("model.h5")

In [7]:
# Load model
del model
model = load_model('model.h5')

In [13]:
model.predict(X_test[0:4])

array([[ -2.09658942e-03,  -5.18888654e-03,  -6.89221197e-04,
         -2.41895061e-04,   2.99720210e-03,   2.02476326e-03,
          4.73593734e-03,  -6.40691305e-03,  -1.09251449e-03,
          1.09982828e-03],
       [ -2.98991171e-03,  -9.90920162e-05,  -3.45434924e-03,
          3.88612063e-03,   4.43011941e-03,   5.08217420e-03,
         -3.28706391e-03,   2.92978319e-03,  -1.98385445e-03,
          3.78746423e-03],
       [  2.33648135e-03,   4.53853514e-04,   1.60155934e-03,
          7.22355582e-03,  -1.31497975e-03,  -1.38273940e-03,
          3.62270721e-03,   4.86129103e-03,   4.57095943e-04,
          2.45059584e-03],
       [ -1.03534505e-04,  -6.59966492e-04,  -1.35125557e-03,
          2.11345195e-03,   7.69098056e-04,   2.98950588e-03,
          8.58568994e-04,  -8.20380519e-04,  -2.33601127e-03,
          1.67452183e-03]], dtype=float32)