# Train autoencoder

In [None]:
from tensorflow import keras
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
%matplotlib inline
from tensorflow.keras.layers import Input,Dense
from tensorflow.keras import optimizers
from neuralgregnet import tools
from neuralgregnet.training import training
from neuralgregnet.training import load_model
cmap = plt.get_cmap("Paired")
from tensorflow.keras import backend as K
import plotly.graph_objs as go
import plotly.offline as py
import commonFunctions as cf

In [None]:
#load data
genes=['Gt', 'Kni', 'Kr', 'Hb']
cutNaN=35 # to avoid all the nan values
offNaN=965
positions=np.linspace(0,100,1000)
positions=positions[cutNaN:offNaN]
#load data
SortData,sortAge=cf.loadGG(f="gap_data_raw_dorsal_wt_time_series.mat",path="DataPetkova/Data/Gap/",
                positions=positions,ageSort=False, cutPos=False)
#plot a random embryo
for i,g in enumerate(genes):
    plt.plot(positions[:], SortData[1,i,:].T,color=cmap(i/4), label=g)
plt.xlabel("Positions (% of AP axis)")
plt.ylabel("Concentration")
plt.legend()


In [None]:
nb_epochs=1
batch_size=64
#build architecture 4,2,4
layers = [keras.layers.Flatten(input_shape=(4,))]
layers += [keras.layers.Dense(2,activation="sigmoid")]
layers += [keras.layers.Dense(4,activation="sigmoid")]
autoMut = keras.Sequential(layers)
autoMut.compile(optimizer=optimizers.Adam(lr=0.01),loss="mse")

#separate train and test data
np.random.shuffle(SortData)
datMTest =SortData[:int(SortData.shape[0]*0.2),:,:]
datMTr = SortData[int(SortData.shape[0]*0.2):,:,:]
train_data = np.swapaxes(datMTr,1,2)
shape = train_data.shape
train_data = train_data.reshape(shape[0]*shape[1],shape[2])
np.random.shuffle(train_data)
print(train_data.shape)
    
#train network    
aeM= training(autoMut,train_data, train_data,model_name="newWTtest",clear=True,nb_epochs=nb_epochs,learning_rate=0.01,batch_size=batch_size,loss="mse")
    
epochsR = range(nb_epochs)
loss = aeM.history.history['loss']
plt.figure()
plt.plot(epochsR[0:], loss[0:], 'bo', label='Training loss')
plt.title('Training and validation loss')
plt.legend()

In [None]:
#visualize autoencoder and prediction
index = 15 # for time point 0
prediction = aeM.predict(datMTest[index].T)
plt.figure()
for i,g in enumerate(genes):
    plt.plot(positions,datMTest[index][i],color=cmap(i/4), label=g)
    plt.plot(positions,prediction[:,i],color=cmap(i/4), linestyle="--", label="pred"+g)

plt.legend()
plt.xlabel("Position (% of AP axis)",fontsize=15)
plt.ylabel("Concentration (au)",fontsize=15)

test_data = np.swapaxes(datMTest,1,2)
shape = test_data.shape
test_data = test_data.reshape(shape[0]*shape[1],shape[2])
eva=aeM.evaluate( x=test_data, y=test_data, verbose=1)
print("mse:%f"%eva)

inp=datMTest[13,:,371].reshape((1,4))
data,layout = tools.plot_model(aeM, inp)
fig = go.Figure(data=data,layout=layout)

py.iplot(fig)