In [1]:
"""
Created on Wed Feb 26 17:01:56 2020

"""

""" IMPORTS """
import sys
import numpy as np
sys.path.append("../")
np.random.seed(1337)  # for reproducibility
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import *
from tensorflow.keras.layers import Flatten
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import *
from tensorflow.keras.datasets import mnist
from tensorflow.keras.constraints import *
from sklearn.model_selection import train_test_split
# from keras.utils import np_utils

from binary_ops import binary_tanh as binary_tanh_op
from binary_layers import BinaryDense, BinaryConv2D

import h5py
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
from lambda_layers import *
from binary_ops import *
%matplotlib qt

In [2]:
""" FUNCTIION AND VARIABLE DEFINITIONS """
def binary_tanh(x):
    return binary_tanh_op(x)

H = 1.
kernel_lr_multiplier = 'Glorot'

# # nn
batch_size = 50
epochs = 20
channels = 1
img_rows = 30
img_cols = 30
filters = 32
kernel_size = (32, 32)
pool_size = (2, 2)
hidden_units = 128
classes = 10
use_bias = False

# # learning rate schedule
lr_start = 1e-3
lr_end = 1e-4
lr_decay = (lr_end / lr_start)**(1. / epochs)

# # BN
epsilon = 1e-6
momentum = 0.9

# # dropout
p1 = 0.25
p2 = 0.5

hdf5_dir = Path("../../data/hdf5/")

def read_many_hdf5(num_images):
    """ Reads image from HDF5.
        Parameters:
        ---------------
        num_images   number of images to read
        Returns:
        ----------
        images      images array, (N, 32, 32, 3) to be stored
        labels      associated meta data, int label (N, 1)
    """
    images= []

    # Open the HDF5 file
    file = h5py.File(hdf5_dir / f"{num_images}_vids.h5", "r+")

    images = np.array(file["/images"]).astype("float32")

    return images

def np_streak(x):
    input_dims = np.shape(x)
    output_shape = (input_dims[0],input_dims[1],input_dims[1]+input_dims[2],input_dims[3],input_dims[4])
    streak_tensor = np.zeros(output_shape)
    for i in range(output_shape[0]):
        for j in range(output_shape[1]):
            streak_tensor[i,j,j:(output_shape[3]+j),:,:] = x[i,j,:,:,:]
    #return streak_tensor
    return np.sum(streak_tensor,axis=1)

def mask(val,ims,mask):
    for i in range(np.shape(val)[0]):
        for j in range(np.shape(val)[1]):
            val[i,j,:,:] = ims[i,j,:,:] * mask
    return val



jelly1 = read_many_hdf5("jelly")
jelly2 = read_many_hdf5("jelly_2")
ims = np.concatenate((jelly1,jelly2))
ims = np.reshape(ims,(-1,30,32,32,1))

validate = ims

validate = validate / 255
#ims2 = ims2 /255
ims = ims/255
#X_train, X_test, y_train, y_test = train_test_split(ims, validate, test_size=(1/3), random_state=42)
#X_train, X_test, y_train, y_test = train_test_split(ims2, validate, test_size=(1/3), random_state=42)

MX_train, MX_test, My_train, My_test = train_test_split(ims,ims, test_size = 0.3, random_state = 42)

print(np.shape(MX_test))
print(np.shape(MX_train))


reduce_lr = ReduceLROnPlateau(monitor='val_loss',verbose=1, factor=0.5,
                              patience=25, min_lr=0.000001)
early_stopping = EarlyStopping(patience=50,verbose=1,restore_best_weights=True)   


def custom_loss(y_true, y_pred):

  ssim_loss = (1.0-tf.image.ssim(y_true,y_pred,1))/2.0
  mse_loss = K.mean(K.square(y_pred-y_true))
  #mse_loss = tf.keras.losses.mean_squared_error(y_true,y_pred)

  ssim_loss = 0.5*ssim_loss
  mse_loss = 0.5*mse_loss

  return ssim_loss + mse_loss

def ssim_loss(y_true,y_pred):  
    return (1.0-tf.image.ssim(y_true,y_pred,1))/2.0

def mse_loss(y_true,y_pred):
    return K.mean(K.square(y_pred-y_true))

(381, 30, 32, 32, 1)
(888, 30, 32, 32, 1)


In [3]:
""" VIDEO FUNCTIONS FOR CHECKING POST TRAINING"""

def get_mask(model,l=0, save = False, filename = "mask"):
    b = binarize(model.layers[l].weights[0])
    figb,axb = plt.subplots(1,1)
    axb.imshow(np.reshape(b,(32,32)),cmap="gray")
    
    if save:
        b = np.reshape(b,(32,32))
        np.save(filename,b)
                            
def show_video(y_pred,y_true,num,save=False):
    yp = np.reshape(y_pred,(-1,30,32,32))
    yt = np.reshape(y_true,(-1,30,32,32))
    split = np.zeros((5,32,32))
    yp_tensor = tf.convert_to_tensor(y_pred)
    yt_tensor = tf.convert_to_tensor(y_true)
    ssim = np.mean(ssim_loss(yt_tensor[num],yp_tensor[num]))
    mse = np.mean(mse_loss(yt_tensor[num],yp_tensor[num]))
    fig,ax = plt.subplots(nrows=5,ncols=13,figsize=(21,10),sharex=True,sharey=True)
    fig.suptitle(f'Movie: {num} MSE: {mse:.3} SSIM: {ssim:.3}')
    for row in range(5):
        for col in range(13):
            if col < 6:
                ax[row,col].imshow(yp[num][5*row+col],cmap="gray")
            elif col == 6:
                pass
            else:
                ax[row,col].imshow(yt[num][5*row + (col %6)],cmap="gray")
    if save:
        fig.savefig('test.png')

def show_all_videos(videos,rows,cols):
    yp = np.reshape(videos, (-1,30,32,32))
    fix3,ax3 = plt.subplots(nrows=rows, ncols = cols)
    for row in range(rows):
        for col in range(cols):
            ax3[row,col].imshow(yp[rows*row+col][3],cmap="gray")

In [4]:
""" FORWARD MODEL """

forward_model = Sequential()
forward_model.add(Input(shape=(30,32,32,1)))
forward_model.add(TimeDistributed(BinaryConv2D(1, kernel_size=(32,32), input_shape=(30,32,32,1),
                   data_format='channels_last',
                   H=H, kernel_lr_multiplier=kernel_lr_multiplier,
                   padding='same', use_bias=use_bias, name='bin_conv_1')))
forward_model.add(Reshape((30,32,32)))
forward_model.add(Lambda(streak,output_shape=streak_output_shape))
forward_model.add(Lambda(integrate_ims, output_shape = integrate_ims_output_shape))
forward_model.add(Flatten())
forward_model.add(Dense(30720, activation = 'relu'))
forward_model.add(Reshape((30,32,32,1)))
forward_model.compile(optimizer = Nadam(0.0001), loss = custom_loss, metrics = [ssim_loss,mse_loss])
forward_model.summary()
forward_history = forward_model.fit(MX_train, My_train,
      batch_size = 32,epochs= 100,
    verbose=2,validation_data=(MX_test,My_test),callbacks=[reduce_lr, early_stopping])


    

(None, 30, 62, 32)
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
time_distributed (TimeDistri (None, 30, 32, 32, 1)     1024      
_________________________________________________________________
reshape (Reshape)            (None, 30, 32, 32)        0         
_________________________________________________________________
lambda (Lambda)              (None, 30, 62, 32)        0         
_________________________________________________________________
lambda_1 (Lambda)            (None, 62, 32)            0         
_________________________________________________________________
flatten (Flatten)            (None, 1984)              0         
_________________________________________________________________
dense (Dense)                (None, 30720)             60979200  
_________________________________________________________________
reshape_1 (Reshape)          (None, 3

Epoch 48/100
888/888 - 5s - loss: 0.2860 - ssim_loss: 0.4813 - mse_loss: 0.0907 - val_loss: 0.2855 - val_ssim_loss: 0.4830 - val_mse_loss: 0.0878
Epoch 49/100
888/888 - 5s - loss: 0.2860 - ssim_loss: 0.4814 - mse_loss: 0.0906 - val_loss: 0.2852 - val_ssim_loss: 0.4827 - val_mse_loss: 0.0877
Epoch 50/100
888/888 - 5s - loss: 0.2859 - ssim_loss: 0.4812 - mse_loss: 0.0904 - val_loss: 0.2846 - val_ssim_loss: 0.4815 - val_mse_loss: 0.0877
Epoch 51/100
888/888 - 5s - loss: 0.2858 - ssim_loss: 0.4811 - mse_loss: 0.0906 - val_loss: 0.2850 - val_ssim_loss: 0.4826 - val_mse_loss: 0.0873
Epoch 52/100

Epoch 00052: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05.
Restoring model weights from the end of the best epoch.
888/888 - 5s - loss: 0.2858 - ssim_loss: 0.4811 - mse_loss: 0.0904 - val_loss: 0.2847 - val_ssim_loss: 0.4819 - val_mse_loss: 0.0875
Epoch 00052: early stopping


In [5]:
y_pred_forward = forward_model.predict(MX_test)
show_video(y_pred_forward,MX_test,200)

In [30]:
forward_history

<tensorflow.python.keras.callbacks.History at 0x2177c72ad08>

In [6]:
forward_model.save_weights("../../data/model_stuff/forward_jelly_weights_4_23.h5")
binary_weights = forward_model.layers[0].get_weights()
inverse_weights = forward_model.layers[5].get_weights()
get_mask(forward_model)

In [4]:
""" 
UNET MODEL
Fixing the weights for the bin_conv1 layer as well as the dense1 layer, ie NON TRAINABLE
Feeding in weights from the forward_model above to see if that improves the results from previous session

"""
inputs = Input(shape=(30,32,32,1))
bin_conv1 = TimeDistributed(BinaryConv2D(1, kernel_size=(32,32), input_shape=(30,32,32,1),
                       data_format='channels_last',
                       H=H, kernel_lr_multiplier=kernel_lr_multiplier,
                       padding='same', use_bias=use_bias, name='bin_conv_1',trainable=False))(inputs)
resh1 = Reshape((30,32,32))(bin_conv1)
s = Lambda(streak, output_shape = streak_output_shape)(resh1)
i = Lambda(integrate_ims, output_shape = integrate_ims_output_shape) (s)
f = Flatten()(i)
dense1 = Dense(30720, activation = 'relu',trainable=False)(f)
resh2 = Reshape((30,32,32,1))(dense1)
c1 = TimeDistributed(Conv2D(1, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (resh2)
c1 = Dropout(0.1) (c1)
c1 = TimeDistributed(Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(c1)
p1 = TimeDistributed(MaxPooling2D((2, 2)))(c1)

c2 = TimeDistributed(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same'))(p1)
c2 = Dropout(0.1) (c2)
c2 = TimeDistributed(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(c2)
p2 = TimeDistributed(MaxPooling2D((2, 2)) )(c2)

c3 = TimeDistributed(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(p2)
c3 = Dropout(0.2) (c3)
c3 = TimeDistributed(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(c3)
p3 = TimeDistributed(MaxPooling2D((2, 2)) )(c3)
    
c4 = TimeDistributed(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(p3)
c4 = Dropout(0.2) (c4)
c4 = TimeDistributed(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(c4)
p4 = TimeDistributed(MaxPooling2D(pool_size=(2, 2))) (c4)

c5 = TimeDistributed(Conv2D(256, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')) (p4)
c5 = Dropout(0.3) (c5)
c5 = TimeDistributed(Conv2D(256, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')) (c5)

u6 = TimeDistributed(Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same'))(c5)
u6 = concatenate([u6, c4])
c6 = TimeDistributed( Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(u6)
c6 = Dropout(0.2) (c6)
c6 = TimeDistributed(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (c6)

u7 = TimeDistributed(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same') )(c6)
u7 = concatenate([u7, c3])
c7 = TimeDistributed(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (u7)
c7 = Dropout(0.2) (c7)
c7 = TimeDistributed(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (c7)
    
u8 = TimeDistributed(Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same') )(c7)
u8 = concatenate([u8, c2])
c8 = TimeDistributed(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (u8)
c8 = Dropout(0.1) (c8)
c8 = TimeDistributed(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (c8)

u9 = TimeDistributed(Conv2DTranspose(16, (3, 3), strides=(2, 2), padding='same')) (c8)
c9 = TimeDistributed(Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (u9)
c9 = Dropout(0.1) (c9)
c9 = TimeDistributed(Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (c9)
    
outputs = TimeDistributed(Conv2D(1, (1, 1), activation='sigmoid')) (c9)

CUPNET = Model(inputs = [inputs], outputs = [outputs])
    
CUPNET.compile(optimizer = Nadam(), loss = custom_loss, metrics = ['mean_squared_error'],callbacks=[reduce_lr, early_stopping])
CUPNET.summary()

(None, 30, 62, 32)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 30, 32, 32,  0                                            
__________________________________________________________________________________________________
time_distributed (TimeDistribut (None, 30, 32, 32, 1 1024        input_1[0][0]                    
__________________________________________________________________________________________________
reshape (Reshape)               (None, 30, 32, 32)   0           time_distributed[0][0]           
__________________________________________________________________________________________________
lambda (Lambda)                 (None, 30, 62, 32)   0           reshape[0][0]                    
___________________________________________________________________________

In [5]:
CUPNET.load_weights("../../data/model_stuff/cupnet_jelly_weights_4_23.h5")

In [10]:
b = CUPNET.layers[1].get_weights()
b = binarize(tf.convert_to_tensor(b))
b = np.reshape(b,(32,32))
np.save('cupnet_jelly_mask',b)

In [20]:
CUPNET.layers[1].set_weights(binary_weights)
CUPNET.layers[6].set_weights(inverse_weights)
CUPNET.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 30, 32, 32,  0                                            
__________________________________________________________________________________________________
time_distributed_29 (TimeDistri (None, 30, 32, 32, 1 1024        input_3[0][0]                    
__________________________________________________________________________________________________
reshape_4 (Reshape)             (None, 30, 32, 32)   0           time_distributed_29[0][0]        
__________________________________________________________________________________________________
lambda_4 (Lambda)               (None, 30, 62, 32)   0           reshape_4[0][0]                  
____________________________________________________________________________________________

In [21]:
""" CUPNET TRAINING """
CUPNET_history = CUPNET.fit(MX_train, My_train,
          batch_size = 32,epochs= 300,
          verbose=2,validation_data=(MX_test,My_test),callbacks=[reduce_lr,early_stopping])

Train on 888 samples, validate on 381 samples
Epoch 1/300
888/888 - 22s - loss: 0.1172 - mean_squared_error: 0.0394 - val_loss: 0.1034 - val_mean_squared_error: 0.0371
Epoch 2/300
888/888 - 13s - loss: 0.0956 - mean_squared_error: 0.0292 - val_loss: 0.0834 - val_mean_squared_error: 0.0196
Epoch 3/300
888/888 - 13s - loss: 0.0821 - mean_squared_error: 0.0157 - val_loss: 0.0746 - val_mean_squared_error: 0.0136
Epoch 4/300
888/888 - 13s - loss: 0.0776 - mean_squared_error: 0.0137 - val_loss: 0.0733 - val_mean_squared_error: 0.0125
Epoch 5/300
888/888 - 13s - loss: 0.0745 - mean_squared_error: 0.0129 - val_loss: 0.0688 - val_mean_squared_error: 0.0111
Epoch 6/300
888/888 - 13s - loss: 0.0721 - mean_squared_error: 0.0120 - val_loss: 0.0679 - val_mean_squared_error: 0.0105
Epoch 7/300
888/888 - 13s - loss: 0.0702 - mean_squared_error: 0.0109 - val_loss: 0.0664 - val_mean_squared_error: 0.0096
Epoch 8/300
888/888 - 13s - loss: 0.0684 - mean_squared_error: 0.0096 - val_loss: 0.0663 - val_mean_

Epoch 68/300
888/888 - 13s - loss: 0.0529 - mean_squared_error: 0.0037 - val_loss: 0.0534 - val_mean_squared_error: 0.0036
Epoch 69/300
888/888 - 13s - loss: 0.0530 - mean_squared_error: 0.0037 - val_loss: 0.0543 - val_mean_squared_error: 0.0037
Epoch 70/300
888/888 - 13s - loss: 0.0529 - mean_squared_error: 0.0037 - val_loss: 0.0529 - val_mean_squared_error: 0.0036
Epoch 71/300
888/888 - 13s - loss: 0.0525 - mean_squared_error: 0.0036 - val_loss: 0.0528 - val_mean_squared_error: 0.0036
Epoch 72/300
888/888 - 13s - loss: 0.0525 - mean_squared_error: 0.0036 - val_loss: 0.0527 - val_mean_squared_error: 0.0035
Epoch 73/300
888/888 - 13s - loss: 0.0525 - mean_squared_error: 0.0036 - val_loss: 0.0540 - val_mean_squared_error: 0.0038
Epoch 74/300
888/888 - 13s - loss: 0.0521 - mean_squared_error: 0.0036 - val_loss: 0.0535 - val_mean_squared_error: 0.0035
Epoch 75/300
888/888 - 13s - loss: 0.0520 - mean_squared_error: 0.0036 - val_loss: 0.0527 - val_mean_squared_error: 0.0035
Epoch 76/300
888

Epoch 134/300
888/888 - 13s - loss: 0.0463 - mean_squared_error: 0.0030 - val_loss: 0.0505 - val_mean_squared_error: 0.0032
Epoch 135/300
888/888 - 13s - loss: 0.0461 - mean_squared_error: 0.0030 - val_loss: 0.0503 - val_mean_squared_error: 0.0032
Epoch 136/300
888/888 - 13s - loss: 0.0461 - mean_squared_error: 0.0029 - val_loss: 0.0505 - val_mean_squared_error: 0.0032
Epoch 137/300
888/888 - 13s - loss: 0.0460 - mean_squared_error: 0.0029 - val_loss: 0.0502 - val_mean_squared_error: 0.0032
Epoch 138/300
888/888 - 13s - loss: 0.0460 - mean_squared_error: 0.0029 - val_loss: 0.0501 - val_mean_squared_error: 0.0032
Epoch 139/300
888/888 - 13s - loss: 0.0458 - mean_squared_error: 0.0029 - val_loss: 0.0504 - val_mean_squared_error: 0.0032
Epoch 140/300
888/888 - 13s - loss: 0.0458 - mean_squared_error: 0.0029 - val_loss: 0.0506 - val_mean_squared_error: 0.0032
Epoch 141/300
888/888 - 13s - loss: 0.0457 - mean_squared_error: 0.0029 - val_loss: 0.0505 - val_mean_squared_error: 0.0032
Epoch 14

KeyboardInterrupt: 

In [22]:
CUPNET.save_weights("../../data/model_stuff/cupnet_jelly_weights_4_23.h5")

In [23]:
print(CUPNET_history.history.keys())

dict_keys(['loss', 'mean_squared_error', 'val_loss', 'val_mean_squared_error', 'lr'])


In [27]:
# summarize history for loss
plt.plot(CUPNET_history.history['loss'])
plt.plot(CUPNET_history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [32]:
plt.plot(forward_history.history['loss'])
plt.plot(forward_history.history['val_loss'])
plt.title('forward_model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper left')
plt.show()

In [33]:
y_train_pred = CUPNET.predict(MX_train)
y_test_pred = CUPNET.predict(MX_test)


In [40]:
show_video(y_test_pred,MX_test,200)

In [13]:
def store_many_hdf5(images,name):

    num_images = len(images)
    try:
        os.mkdir("../../data/hdf5")
    except: 
        pass

    # Create a new HDF5 file
    file = h5py.File(hdf5_dir / f"{name}_vids.h5", "w")

    # Create a dataset in the file
    dataset = file.create_dataset(
        "images", data=images
    )    
    
    file.close()

In [36]:
np.save("../../data/hdf5/u_net_jelly_test_ims",y_test_pred)

In [15]:
store_many_hdf5(y_test_pred,"u_net_predict")

In [16]:
import scipy.io as sio
twist_jelly = sio.loadmat('../../data/hdf5/jelly_resized.mat')
jelly = np.asarray(twist_jelly['sample'])
jelly = np.transpose(jelly,(2,0,1))
jelly = np.reshape(jelly[30:60], (1,30,32,32,1))
print(np.shape(jelly))

(1, 30, 32, 32, 1)


In [17]:
j_pred = CUPNET.predict(jelly)