In [1]:
"""
Created on Wed Feb 26 17:01:56 2020

"""

""" IMPORTS """
import sys
import numpy as np
np.random.seed(1337)  # for reproducibility
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import *
from tensorflow.keras.layers import Flatten
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import *
from tensorflow.keras.datasets import mnist
from tensorflow.keras.constraints import *
from sklearn.model_selection import train_test_split
# from keras.utils import np_utils

from binary_ops import binary_tanh as binary_tanh_op
from binary_layers import BinaryDense, BinaryConv2D

import h5py
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
from lambda_layers import *
from binary_ops import *
%matplotlib qt

In [2]:
""" FUNCTIION AND VARIABLE DEFINITIONS """
def binary_tanh(x):
    return binary_tanh_op(x)

H = 1.
kernel_lr_multiplier = 'Glorot'

# # nn
batch_size = 50
epochs = 20
channels = 1
img_rows = 30
img_cols = 30
filters = 32
kernel_size = (32, 32)
pool_size = (2, 2)
hidden_units = 128
classes = 10
use_bias = False

# # learning rate schedule
lr_start = 1e-3
lr_end = 1e-4
lr_decay = (lr_end / lr_start)**(1. / epochs)

# # BN
epsilon = 1e-6
momentum = 0.9

# # dropout
p1 = 0.25
p2 = 0.5

hdf5_dir = Path("../data/hdf5/")

def read_many_hdf5(num_images):
    """ Reads image from HDF5.
        Parameters:
        ---------------
        num_images   number of images to read
        Returns:
        ----------
        images      images array, (N, 32, 32, 3) to be stored
        labels      associated meta data, int label (N, 1)
    """
    images= []

    # Open the HDF5 file
    file = h5py.File(hdf5_dir / f"{num_images}_vids.h5", "r+")

    images = np.array(file["/images"]).astype("float32")

    return images

def np_streak(x):
    input_dims = np.shape(x)
    output_shape = (input_dims[0],input_dims[1],input_dims[1]+input_dims[2],input_dims[3],input_dims[4])
    streak_tensor = np.zeros(output_shape)
    for i in range(output_shape[0]):
        for j in range(output_shape[1]):
            streak_tensor[i,j,j:(output_shape[3]+j),:,:] = x[i,j,:,:,:]
    #return streak_tensor
    return np.sum(streak_tensor,axis=1)

def mask(val,ims,mask):
    for i in range(np.shape(val)[0]):
        for j in range(np.shape(val)[1]):
            val[i,j,:,:] = ims[i,j,:,:] * mask
    return val



ims = read_many_hdf5(784)
# ims = np.ones((3943,30,32,32,1))
ims = np.reshape(ims, (-1,30,32,32,1))
ims = ims[:750]
# temp = np.zeros((1,32,32,1))
validate2 = np.zeros((750,30,32,32,1))
bk_temp = np.random.randint(0,2,(1,32,32,1))
validate2 = mask(validate2,ims,bk_temp)
ims2 = np_streak(validate2)


validate = ims

validate = validate / 255
ims2 = ims2 /255
ims = ims/255
#X_train, X_test, y_train, y_test = train_test_split(ims, validate, test_size=(1/3), random_state=42)
X_train, X_test, y_train, y_test = train_test_split(ims2, validate, test_size=(1/3), random_state=42)

MX_train, MX_test, My_train, My_test = train_test_split(ims,ims, test_size = 1/3, random_state = 42)

print(np.shape(X_test))
print(np.shape(X_train))

reduce_lr = ReduceLROnPlateau(monitor='val_loss',verbose=1, factor=0.5,
                              patience=50, min_lr=0.000001)
early_stopping = EarlyStopping(patience=90,verbose=1,restore_best_weights=True)   


def custom_loss(y_true, y_pred):

  ssim_loss = (1.0-tf.image.ssim(y_true,y_pred,1))/2.0
  mse_loss = K.mean(K.square(y_pred-y_true))
  #mse_loss = tf.keras.losses.mean_squared_error(y_true,y_pred)

  ssim_loss = 0.5*ssim_loss
  mse_loss = 0.5*mse_loss

  return ssim_loss + mse_loss

def ssim_loss(y_true,y_pred):  
    return (1.0-tf.image.ssim(y_true,y_pred,1))/2.0

def mse_loss(y_true,y_pred):
    return K.mean(K.square(y_pred-y_true))



(250, 62, 32, 1)
(500, 62, 32, 1)


In [3]:
""" VIDEO FUNCTIONS FOR CHECKING POST TRAINING"""

def get_mask(model,l=0):
    b = binarize(model.layers[l].weights[0])
    figb,axb = plt.subplots(1,1)
    axb.imshow(np.reshape(b,(32,32)),cmap="gray")
    
def show_video(y_pred,y_true,num):
    yp = np.reshape(y_pred,(-1,30,32,32))
    yt = np.reshape(y_true,(-1,30,32,32))
    fig,ax = plt.subplots(nrows=5,ncols=6,sharex=True,sharey=True)
    fig2,ax2 = plt.subplots(nrows=5,ncols=6,sharex=True,sharey=True)
    for row in range(5):
        for col in range(6):
            ax[row,col].imshow(yp[num][5*row+col],cmap="gray")
            ax2[row,col].imshow(yt[num][5*row+col],cmap="gray")

def show_all_videos(videos,rows,cols):
    yp = np.reshape(videos, (-1,30,32,32))
    fix3,ax3 = plt.subplots(nrows=rows, ncols = cols)
    for row in range(rows):
        for col in range(cols):
            ax3[row,col].imshow(yp[rows*row+col][3],cmap="gray")


In [4]:
""" FORWARD MODEL """

forward_model = Sequential()
forward_model.add(Input(shape=(30,32,32,1),batch_size = 50))
forward_model.add(TimeDistributed(BinaryConv2D(1, kernel_size=(32,32), input_shape=(30,32,32,1),
                       data_format='channels_last',
                       H=H, kernel_lr_multiplier=kernel_lr_multiplier,
                       padding='same', use_bias=use_bias, name='bin_conv_1')))
forward_model.add(Lambda(streak,output_shape=streak_output_shape))
forward_model.add(Lambda(integrate_ims, output_shape = integrate_ims_output_shape))
forward_model.add(Flatten())
forward_model.add(Dense(30720, activation = 'relu'))
forward_model.add(Reshape((30,32,32,1)))
forward_model.compile(optimizer = Nadam(0.0001), loss = custom_loss, metrics = ['mean_squared_error'])
forward_model.summary()
forward_history = forward_model.fit(MX_train, My_train,
          batch_size = 50,epochs= 250,
          verbose=2,validation_data=(MX_test,My_test),callbacks=[reduce_lr, early_stopping])



Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
time_distributed (TimeDistri (50, 30, 32, 32, 1)       1024      
_________________________________________________________________
lambda (Lambda)              (50, 30, 62, 32, 1)       0         
_________________________________________________________________
lambda_1 (Lambda)            (50, 62, 32, 1)           0         
_________________________________________________________________
flatten (Flatten)            (50, 1984)                0         
_________________________________________________________________
dense (Dense)                (50, 30720)               60979200  
_________________________________________________________________
reshape (Reshape)            (50, 30, 32, 32, 1)       0         
Total params: 60,980,224
Trainable params: 60,980,224
Non-trainable params: 0
____________________________________________

Epoch 59/250
500/500 - 2s - loss: 0.2779 - mean_squared_error: 0.1055 - val_loss: 0.2954 - val_mean_squared_error: 0.1132
Epoch 60/250
500/500 - 2s - loss: 0.2787 - mean_squared_error: 0.1071 - val_loss: 0.2936 - val_mean_squared_error: 0.1105
Epoch 61/250
500/500 - 2s - loss: 0.2780 - mean_squared_error: 0.1062 - val_loss: 0.2941 - val_mean_squared_error: 0.1107
Epoch 62/250
500/500 - 2s - loss: 0.2771 - mean_squared_error: 0.1050 - val_loss: 0.2943 - val_mean_squared_error: 0.1111
Epoch 63/250
500/500 - 2s - loss: 0.2766 - mean_squared_error: 0.1043 - val_loss: 0.2952 - val_mean_squared_error: 0.1130
Epoch 64/250
500/500 - 2s - loss: 0.2770 - mean_squared_error: 0.1054 - val_loss: 0.2947 - val_mean_squared_error: 0.1117
Epoch 65/250
500/500 - 2s - loss: 0.2766 - mean_squared_error: 0.1050 - val_loss: 0.2915 - val_mean_squared_error: 0.1065
Epoch 66/250
500/500 - 2s - loss: 0.2754 - mean_squared_error: 0.1032 - val_loss: 0.2951 - val_mean_squared_error: 0.1129
Epoch 67/250
500/500 - 2

Epoch 126/250
500/500 - 2s - loss: 0.2643 - mean_squared_error: 0.0973 - val_loss: 0.2906 - val_mean_squared_error: 0.1061
Epoch 127/250
500/500 - 2s - loss: 0.2650 - mean_squared_error: 0.0979 - val_loss: 0.2891 - val_mean_squared_error: 0.1040
Epoch 128/250
500/500 - 2s - loss: 0.2643 - mean_squared_error: 0.0974 - val_loss: 0.2911 - val_mean_squared_error: 0.1070
Epoch 129/250
500/500 - 2s - loss: 0.2650 - mean_squared_error: 0.0982 - val_loss: 0.2900 - val_mean_squared_error: 0.1060
Epoch 130/250
500/500 - 2s - loss: 0.2638 - mean_squared_error: 0.0971 - val_loss: 0.2901 - val_mean_squared_error: 0.1053
Epoch 131/250
500/500 - 2s - loss: 0.2646 - mean_squared_error: 0.0980 - val_loss: 0.2909 - val_mean_squared_error: 0.1066
Epoch 132/250
500/500 - 2s - loss: 0.2646 - mean_squared_error: 0.0978 - val_loss: 0.2899 - val_mean_squared_error: 0.1053
Epoch 133/250
500/500 - 2s - loss: 0.2646 - mean_squared_error: 0.0981 - val_loss: 0.2890 - val_mean_squared_error: 0.1039
Epoch 134/250
50

Epoch 193/250
500/500 - 2s - loss: 0.2565 - mean_squared_error: 0.0931 - val_loss: 0.2874 - val_mean_squared_error: 0.1021
Epoch 194/250
500/500 - 2s - loss: 0.2561 - mean_squared_error: 0.0927 - val_loss: 0.2887 - val_mean_squared_error: 0.1043
Epoch 195/250
500/500 - 2s - loss: 0.2562 - mean_squared_error: 0.0932 - val_loss: 0.2865 - val_mean_squared_error: 0.1003
Epoch 196/250
500/500 - 2s - loss: 0.2562 - mean_squared_error: 0.0930 - val_loss: 0.2869 - val_mean_squared_error: 0.1009
Epoch 197/250
500/500 - 2s - loss: 0.2560 - mean_squared_error: 0.0926 - val_loss: 0.2869 - val_mean_squared_error: 0.1012
Epoch 198/250
500/500 - 2s - loss: 0.2552 - mean_squared_error: 0.0920 - val_loss: 0.2873 - val_mean_squared_error: 0.1019
Epoch 199/250
500/500 - 2s - loss: 0.2553 - mean_squared_error: 0.0923 - val_loss: 0.2873 - val_mean_squared_error: 0.1019
Epoch 200/250
500/500 - 2s - loss: 0.2555 - mean_squared_error: 0.0925 - val_loss: 0.2867 - val_mean_squared_error: 0.1008
Epoch 201/250
50

In [5]:
forward_model.save("../data/model_stuff/forward_model.h5")
binary_weights = forward_model.layers[0].get_weights()
inverse_weights = forward_model.layers[4].get_weights()

In [6]:
""" 
UNET MODEL
Fixing the weights for the bin_conv1 layer as well as the dense1 layer, ie NON TRAINABLE
Feeding in weights from the forward_model above to see if that improves the results from previous session

"""
inputs = Input(shape=(30,32,32,1),batch_size=50)
bin_conv1 = TimeDistributed(BinaryConv2D(1, kernel_size=(32,32), input_shape=(30,32,32,1),
                       data_format='channels_last',
                       H=H, kernel_lr_multiplier=kernel_lr_multiplier,
                       padding='same', use_bias=use_bias, name='bin_conv_1',trainable=False))(inputs)
s = Lambda(streak, output_shape = streak_output_shape)(bin_conv1)
i = Lambda(integrate_ims, output_shape = integrate_ims_output_shape) (s)
f = Flatten()(i)
dense1 = Dense(30720, activation = 'relu',trainable=False)(f)
resh = Reshape((30,32,32,1))(dense1)
c1 = TimeDistributed(Conv2D(1, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (resh)
c1 = Dropout(0.1) (c1)
c1 = TimeDistributed(Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(c1)
p1 = TimeDistributed(MaxPooling2D((2, 2)))(c1)

c2 = TimeDistributed(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same'))(p1)
c2 = Dropout(0.1) (c2)
c2 = TimeDistributed(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(c2)
p2 = TimeDistributed(MaxPooling2D((2, 2)) )(c2)

c3 = TimeDistributed(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(p2)
c3 = Dropout(0.2) (c3)
c3 = TimeDistributed(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(c3)
p3 = TimeDistributed(MaxPooling2D((2, 2)) )(c3)
    
c4 = TimeDistributed(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(p3)
c4 = Dropout(0.2) (c4)
c4 = TimeDistributed(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(c4)
p4 = TimeDistributed(MaxPooling2D(pool_size=(2, 2))) (c4)

c5 = TimeDistributed(Conv2D(256, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')) (p4)
c5 = Dropout(0.3) (c5)
c5 = TimeDistributed(Conv2D(256, (2, 2), activation='relu', kernel_initializer='he_normal', padding='same')) (c5)

u6 = TimeDistributed(Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same'))(c5)
u6 = concatenate([u6, c4])
c6 = TimeDistributed( Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same') )(u6)
c6 = Dropout(0.2) (c6)
c6 = TimeDistributed(Conv2D(128, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (c6)

u7 = TimeDistributed(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same') )(c6)
u7 = concatenate([u7, c3])
c7 = TimeDistributed(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (u7)
c7 = Dropout(0.2) (c7)
c7 = TimeDistributed(Conv2D(64, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (c7)
    
u8 = TimeDistributed(Conv2DTranspose(32, (3, 3), strides=(2, 2), padding='same') )(c7)
u8 = concatenate([u8, c2])
c8 = TimeDistributed(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (u8)
c8 = Dropout(0.1) (c8)
c8 = TimeDistributed(Conv2D(32, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (c8)

u9 = TimeDistributed(Conv2DTranspose(16, (3, 3), strides=(2, 2), padding='same')) (c8)
c9 = TimeDistributed(Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (u9)
c9 = Dropout(0.1) (c9)
c9 = TimeDistributed(Conv2D(16, (3, 3), activation='relu', kernel_initializer='he_normal', padding='same')) (c9)
    
outputs = TimeDistributed(Conv2D(1, (1, 1), activation='sigmoid')) (c9)

CUPNET = Model(inputs = [inputs], outputs = [outputs])
    
CUPNET.compile(optimizer = Nadam(), loss = custom_loss, metrics = ['mean_squared_error'])

In [7]:
""" SET THE PREFIXED WEIGHTS FROM THE FORWARD MODEL """
CUPNET.layers[1].set_weights(binary_weights)
CUPNET.layers[5].set_weights(inverse_weights)
CUPNET.summary()


Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(50, 30, 32, 32, 1) 0                                            
__________________________________________________________________________________________________
time_distributed_1 (TimeDistrib (50, 30, 32, 32, 1)  1024        input_2[0][0]                    
__________________________________________________________________________________________________
lambda_2 (Lambda)               (50, 30, 62, 32, 1)  0           time_distributed_1[0][0]         
__________________________________________________________________________________________________
lambda_3 (Lambda)               (50, 62, 32, 1)      0           lambda_2[0][0]                   
______________________________________________________________________________________________

In [None]:
""" CUPNET TRAINING """
CUPNET_history = CUPNET.fit(MX_train, My_train,
          batch_size = 50,epochs= 500,
          verbose=2,validation_data=(MX_test,My_test),callbacks=[reduce_lr])



Train on 500 samples, validate on 250 samples
Epoch 1/500
500/500 - 21s - loss: 0.2421 - mean_squared_error: 0.0706 - val_loss: 0.2423 - val_mean_squared_error: 0.0735
Epoch 2/500
500/500 - 7s - loss: 0.2272 - mean_squared_error: 0.0675 - val_loss: 0.2364 - val_mean_squared_error: 0.0652
Epoch 3/500
500/500 - 7s - loss: 0.2123 - mean_squared_error: 0.0575 - val_loss: 0.2256 - val_mean_squared_error: 0.0473
Epoch 4/500
500/500 - 7s - loss: 0.2014 - mean_squared_error: 0.0467 - val_loss: 0.2092 - val_mean_squared_error: 0.0346
Epoch 5/500
500/500 - 7s - loss: 0.1759 - mean_squared_error: 0.0281 - val_loss: 0.2004 - val_mean_squared_error: 0.0294
Epoch 6/500
500/500 - 7s - loss: 0.1701 - mean_squared_error: 0.0257 - val_loss: 0.1999 - val_mean_squared_error: 0.0292
Epoch 7/500
500/500 - 7s - loss: 0.1638 - mean_squared_error: 0.0221 - val_loss: 0.1965 - val_mean_squared_error: 0.0280
Epoch 8/500
500/500 - 7s - loss: 0.1637 - mean_squared_error: 0.0234 - val_loss: 0.1976 - val_mean_squared

Epoch 68/500
500/500 - 7s - loss: 0.1264 - mean_squared_error: 0.0152 - val_loss: 0.1820 - val_mean_squared_error: 0.0250
Epoch 69/500
500/500 - 7s - loss: 0.1276 - mean_squared_error: 0.0161 - val_loss: 0.1811 - val_mean_squared_error: 0.0232
Epoch 70/500
500/500 - 7s - loss: 0.1263 - mean_squared_error: 0.0149 - val_loss: 0.1804 - val_mean_squared_error: 0.0232
Epoch 71/500
500/500 - 7s - loss: 0.1261 - mean_squared_error: 0.0156 - val_loss: 0.1818 - val_mean_squared_error: 0.0241
Epoch 72/500
500/500 - 7s - loss: 0.1244 - mean_squared_error: 0.0149 - val_loss: 0.1800 - val_mean_squared_error: 0.0233
Epoch 73/500
500/500 - 7s - loss: 0.1263 - mean_squared_error: 0.0151 - val_loss: 0.1804 - val_mean_squared_error: 0.0236
Epoch 74/500
500/500 - 7s - loss: 0.1253 - mean_squared_error: 0.0152 - val_loss: 0.1810 - val_mean_squared_error: 0.0242
Epoch 75/500
500/500 - 7s - loss: 0.1261 - mean_squared_error: 0.0159 - val_loss: 0.1822 - val_mean_squared_error: 0.0245
Epoch 76/500
500/500 - 7

Epoch 135/500
500/500 - 7s - loss: 0.1130 - mean_squared_error: 0.0138 - val_loss: 0.1807 - val_mean_squared_error: 0.0232
Epoch 136/500
500/500 - 7s - loss: 0.1137 - mean_squared_error: 0.0138 - val_loss: 0.1819 - val_mean_squared_error: 0.0242
Epoch 137/500
500/500 - 7s - loss: 0.1109 - mean_squared_error: 0.0132 - val_loss: 0.1824 - val_mean_squared_error: 0.0242
Epoch 138/500
500/500 - 7s - loss: 0.1109 - mean_squared_error: 0.0132 - val_loss: 0.1803 - val_mean_squared_error: 0.0235
Epoch 139/500
500/500 - 7s - loss: 0.1110 - mean_squared_error: 0.0132 - val_loss: 0.1797 - val_mean_squared_error: 0.0232
Epoch 140/500
500/500 - 7s - loss: 0.1096 - mean_squared_error: 0.0129 - val_loss: 0.1802 - val_mean_squared_error: 0.0235
Epoch 141/500
500/500 - 7s - loss: 0.1108 - mean_squared_error: 0.0137 - val_loss: 0.1831 - val_mean_squared_error: 0.0253
Epoch 142/500
500/500 - 7s - loss: 0.1139 - mean_squared_error: 0.0145 - val_loss: 0.1816 - val_mean_squared_error: 0.0238
Epoch 143/500

E

In [None]:
CUPNET.save("../data/model_stuff/cupnet_model.h5")

In [None]:
y_pred = CUPNET.predict(MX_train)

In [None]:
show_video(y_pred,My_train,30)

In [None]:
fig,ax = plt.subplots(1,1)
ax.imshow(np.reshape(ims2,(-1,62,32))[30],cmap="gray")

In [None]:
fig1,ax1 = plt.subplots(1,1)
ax1.imshow(np.reshape(validate2,(-1,30,32,32))[50][10],cmap="gray")

In [None]:
fig2,ax2=plt.subplots(1,1)
ax2.imshow(np.reshape(validate,(-1,30,32,32))[50][10],cmap="gray")