## TODO
- Brett: invertible reshaping (spatial/temporal)
- Brett: refactor model code into Python library
- Josh: simple PCA clustering

In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import numpy as np
%matplotlib inline
from matplotlib import pyplot as plt
import seaborn as sns
sns.set_context('poster')
sns.set_style("whitegrid", {'axes.grid': False})
import h5py
import tensorflow as tf

from keras import backend as K
gpu_opts = tf.ConfigProto(gpu_options=tf.GPUOptions(per_process_gpu_memory_fraction=0.4))
K.set_session(tf.Session(config=gpu_opts))
#K.set_image_data_format('channels_last')

In [None]:
# Looks for files in current directory
folder_path = './data/'

# (User) Loads all the data, this file needs to be in the path defined by folder_path
data = h5py.File(os.path.join(folder_path, 'cleaned_data.mat'))
f = h5py.File(os.path.join(folder_path, 'cleaned_data.mat'), 'r')
for name, data in f.items():
    print(name)  # Name

In [None]:
%%time
loop_data = f['filt_AI_mat'][()]

In [None]:
loop_data.shape  # (n_T, n_x, n_y)

In [None]:
plt.plot(loop_data[:, 1, 1])

In [None]:
X = np.rollaxis(loop_data.reshape(loop_data.shape[0], -1), 1)
X = X.reshape((-1, 256))
X -= np.mean(X)
X /= np.std(X)
X = np.atleast_3d(X)
plt.plot(X[0, :, 0])

In [None]:
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten, RepeatVector, Reshape, GRU, LSTM, TimeDistributed

def dense_auto(size, n_step, drop_frac=0., activation='relu', **kwargs):
    model = Sequential()
    model.add(Flatten())
    model.add(Dense(size, input_shape=(n_step,), activation=activation))
    model.add(Dense(size, activation=activation))
    if drop_frac > 0.:
        model.add(Dropout(drop_frac))
    model.add(Dense(size, activation=activation))
    model.add(Dense(n_step, activation='linear'))
    model.add(Reshape((n_step, 1)))
    
    return model

def rnn_auto(size, n_step, drop_frac=0., embedding=None, **kwargs):
    model = Sequential()
#    model.add(Reshape((n_step, 1), input_shape=(n_step,)))
    model.add(GRU(size, return_sequences=True, input_shape=(n_step, 1)))
    model.add(GRU(embedding if embedding else size, return_sequences=False))
    if drop_frac > 0.:
        model.add(Dropout(drop_frac))

    model.add(RepeatVector(n_step))
    model.add(GRU(size, return_sequences=True))
    model.add(TimeDistributed(Dense(1, activation='linear')))
#    model.add(Flatten())
    
    return model

In [None]:
import shutil
from keras.optimizers import Adam, SGD
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras_tqdm import TQDMNotebookCallback

train = np.arange(len(X))
#valid = np.arange(1024) + len(train)

size = 96
embedding = 8
drop_frac = 0.25
lr = 1e-3
batch_size = 1024
model_fun = rnn_auto

run = (f"{model_fun.__name__}{size:03d}_emb{embedding:03d}_{lr:1.0e}_drop{int(100 * drop_frac)}"
       f"_batch{batch_size}").replace('e-', 'm')
    
log_dir = os.path.join('log', run)
print("Logging to {}".format(os.path.abspath(log_dir)))
shutil.rmtree(log_dir, ignore_errors=True)
weights_path = os.path.join(log_dir, 'weights.h5')

model = model_fun(size, n_step=X.shape[1], drop_frac=drop_frac, embedding=embedding)
model.compile(Adam(lr), loss='mse')

history = model.fit(X[train], X[train], epochs=25, batch_size=batch_size, #validation_data=(X[valid], X[valid]),
                    callbacks=[TQDMNotebookCallback(leave_outer=True, leave_inner=False),
                               TensorBoard(log_dir=log_dir, write_graph=False),
                               ModelCheckpoint(weights_path)],
                    verbose=False)

In [None]:
plt.plot(history.history['loss'], 'o')

In [None]:
i = 10

plt.plot(X[i])
plt.plot(model.predict(X[[i]])[0])
plt.plot(X[i] - model.predict(X[[i]])[0])
np.mean((X[i] - model.predict(X[[i]])[0]) ** 2)

In [None]:
%%time
from sklearn.decomposition import PCA

pca_model = PCA(16)
pca_model.fit(X[train])

In [None]:
i = 0

plt.plot(X[i])
plt.plot(pca_model.inverse_transform(pca_model.transform(X[[i]])[0]))
np.mean((X[i] - pca_model.inverse_transform(pca_model.transform(X[[i]])[0])) ** 2)