In [7]:
"""
Created on Wed Feb 26 17:01:56 2020

"""

""" IMPORTS """
import sys
import numpy as np
sys.path.append("../")
np.random.seed(1337)  # for reproducibility
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import *
from tensorflow.keras.layers import Flatten
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import *
from tensorflow.keras.datasets import mnist
from tensorflow.keras.constraints import *
from sklearn.model_selection import train_test_split
# from keras.utils import np_utils

from binary_ops import binary_tanh as binary_tanh_op
from binary_layers import BinaryDense, BinaryConv2D

import h5py
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
from lambda_layers import *
from binary_ops import *
%matplotlib qt

In [8]:
""" FUNCTIION AND VARIABLE DEFINITIONS """
def binary_tanh(x):
    return binary_tanh_op(x)

H = 1.
kernel_lr_multiplier = 'Glorot'

# # nn
batch_size = 50
epochs = 20
channels = 1
img_rows = 30
img_cols = 30
filters = 32
kernel_size = (32, 32)
pool_size = (2, 2)
hidden_units = 128
classes = 10
use_bias = False

# # learning rate schedule
lr_start = 1e-3
lr_end = 1e-4
lr_decay = (lr_end / lr_start)**(1. / epochs)

# # BN
epsilon = 1e-6
momentum = 0.9

# # dropout
p1 = 0.25
p2 = 0.5

hdf5_dir = Path("../../data/hdf5/")

def read_many_hdf5(num_images):
    """ Reads image from HDF5.
        Parameters:
        ---------------
        num_images   number of images to read
        Returns:
        ----------
        images      images array, (N, 32, 32, 3) to be stored
        labels      associated meta data, int label (N, 1)
    """
    images= []

    # Open the HDF5 file
    file = h5py.File(hdf5_dir / f"{num_images}_vids_32.h5", "r+")

    images = np.array(file["/images"]).astype("float32")

    return images

def np_streak(x):
    input_dims = np.shape(x)
    output_shape = (input_dims[0],input_dims[1],input_dims[1]+input_dims[2],input_dims[3],input_dims[4])
    streak_tensor = np.zeros(output_shape)
    for i in range(output_shape[0]):
        for j in range(output_shape[1]):
            streak_tensor[i,j,j:(output_shape[3]+j),:,:] = x[i,j,:,:,:]
    #return streak_tensor
    return np.sum(streak_tensor,axis=1)

def mask(val,ims,mask):
    for i in range(np.shape(val)[0]):
        for j in range(np.shape(val)[1]):
            val[i,j,:,:] = ims[i,j,:,:] * mask
    return val


long_vids = read_many_hdf5("long_yt")
russ = read_many_hdf5("russian")
sea = read_many_hdf5("sea")
#jelly = read_many_hdf5("jelly")
ims = np.concatenate((long_vids,russ,sea))
#ims = read_many_hdf5("all")
ims = np.reshape(ims,(-1,32,32,32,1))

validate = ims

validate = validate / 255
#ims2 = ims2 /255
ims = ims/255
#X_train, X_test, y_train, y_test = train_test_split(ims, validate, test_size=(1/3), random_state=42)
#X_train, X_test, y_train, y_test = train_test_split(ims2, validate, test_size=(1/3), random_state=42)

MX_train, MX_test, My_train, My_test = train_test_split(ims,ims, test_size = 0.3, random_state = 42)

print(np.shape(MX_test))
print(np.shape(MX_train))


reduce_lr = ReduceLROnPlateau(monitor='val_loss',verbose=1, factor=0.5,
                              patience=30, min_lr=0.000001)
early_stopping = EarlyStopping(patience=40,verbose=1,restore_best_weights=True)   


def custom_loss(y_true, y_pred):

  ssim_loss = (1.0-tf.image.ssim(y_true,y_pred,1))/2.0
  mse_loss = K.mean(K.square(y_pred-y_true))
  #mse_loss = tf.keras.losses.mean_squared_error(y_true,y_pred)

  ssim_loss = 0.5*ssim_loss
  mse_loss = 0.5*mse_loss

  return ssim_loss + mse_loss

def ssim_loss(y_true,y_pred):  
    return (1.0-tf.image.ssim(y_true,y_pred,1))/2.0

def mse_loss(y_true,y_pred):
    return K.mean(K.square(y_pred-y_true))

(438, 32, 32, 32, 1)
(1022, 32, 32, 32, 1)


In [9]:
""" VIDEO FUNCTIONS FOR CHECKING POST TRAINING"""

def get_mask(model,l=0, save = False, filename = "mask"):
    b = binarize(model.layers[l].weights[0])
    figb,axb = plt.subplots(1,1)
    axb.imshow(np.reshape(b,(32,32)),cmap="gray")
    
    if save:
        b = np.reshape(b,(32,32))
        np.save(filename,b)
                            
def show_video(y_pred,y_true,num,save=False):
    yp = np.reshape(y_pred,(-1,32,32,32))
    yt = np.reshape(y_true,(-1,32,32,32))
    split = np.zeros((5,32,32))
    yp_tensor = tf.convert_to_tensor(y_pred)
    yt_tensor = tf.convert_to_tensor(y_true)
    ssim = np.mean(ssim_loss(yt_tensor[num],yp_tensor[num]))
    mse = np.mean(mse_loss(yt_tensor[num],yp_tensor[num]))
    fig,ax = plt.subplots(nrows=5,ncols=13,figsize=(21,10),sharex=True,sharey=True)
    fig.suptitle(f'Movie: {num} MSE: {mse:.3} SSIM: {ssim:.3}')
    for row in range(5):
        for col in range(13):
            if col < 6:
                ax[row,col].imshow(yp[num][5*row+col],cmap="gray")
            elif col == 6:
                pass
            else:
                ax[row,col].imshow(yt[num][5*row + (col %6)],cmap="gray")
    if save:
        fig.savefig('test.png')

def show_all_videos(videos,rows,cols):
    yp = np.reshape(videos, (-1,32,32,32))
    fix3,ax3 = plt.subplots(nrows=rows, ncols = cols)
    for row in range(rows):
        for col in range(cols):
            ax3[row,col].imshow(yp[rows*row+col][3],cmap="gray")

In [4]:
plt.imshow(np.reshape(MX_test[10][3],(32,32)))

<matplotlib.image.AxesImage at 0x2876d3fffc8>

In [11]:
""" FORWARD MODEL """

forward_model = Sequential()
forward_model.add(Input(shape=(32,32,32,1)))
forward_model.add(TimeDistributed(BinaryConv2D(1, kernel_size=(32,32), input_shape=(32,32,32,1),
                   data_format='channels_last',
                   H=H, kernel_lr_multiplier=kernel_lr_multiplier,
                   padding='same', use_bias=use_bias, name='bin_conv_1')))
forward_model.add(Reshape((32,32,32)))
forward_model.add(Lambda(streak,output_shape=streak_output_shape))
forward_model.add(Lambda(integrate_ims, output_shape = integrate_ims_output_shape))
forward_model.add(Flatten())
forward_model.add(Dense(32768, activation = 'relu'))
forward_model.add(Reshape((32,32,32,1)))
forward_model.compile(optimizer = Nadam(0.0001), loss = custom_loss, metrics = [ssim_loss,mse_loss])
forward_model.summary()
forward_history = forward_model.fit(MX_train, My_train,
      batch_size = 32,epochs= 100,
      verbose=2,validation_data=(MX_test,My_test),callbacks=[reduce_lr, early_stopping])
    

    
    

ResourceExhaustedError: OOM when allocating tensor with shape[2048,32768] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:RandomUniform]

In [None]:
orward_model.save_weights("../../data/model_stuff/forward_model_weights_4_23_3d.h5")
binary_weights = forward_model.layers[0].get_weights()
inverse_weights = forward_model.layers[5].get_weights()


In [None]:
""" 
3D-UNET MODEL
Fixing the weights for the bin_conv1 layer as well as the dense1 layer, ie NON TRAINABLE
Feeding in weights from the forward_model above to see if that improves the results from previous session

"""
inputs = Input(shape=(32,32,32,1))
bin_conv1 = TimeDistributed(BinaryConv2D(1, kernel_size=(32,32), input_shape=(32,32,32,1),
                       data_format='channels_last',
                       H=H, kernel_lr_multiplier=kernel_lr_multiplier,
                       padding='same', use_bias=use_bias, name='bin_conv_1',trainable=False))(inputs)
resh1 = Reshape((32,32,32))(bin_conv1)
s = Lambda(streak, output_shape = streak_output_shape)(resh1)
i = Lambda(integrate_ims, output_shape = integrate_ims_output_shape) (s)
f = Flatten()(i)
dense1 = Dense(32768, activation = 'relu',trainable=False)(f)
resh2 = Reshape((32,32,32,1))(dense1)
c1 = Conv3D(1, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (resh2)
c1 = Dropout(0.1) (c1)
c1 = Conv3D(16, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c1)
p1 = MaxPooling3D((2, 2, 2),padding='same')(c1)

c2 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p1)
c2 = Dropout(0.1) (c2)
c2 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c2)
p2 = MaxPooling3D((2, 2, 2),padding='same')(c2)

c3 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p2)
c3 = Dropout(0.2) (c3)
c3 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c3)
p3 = MaxPooling3D((2, 2, 2),padding='same')(c3)
    
c4 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(p3)
c4 = Dropout(0.2) (c4)
c4 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same')(c4)
p4 = MaxPooling3D(pool_size=(2, 2, 2),padding='same') (c4)

c5 = Conv3D(256, (2, 2, 2), activation='relu', kernel_initializer='he_normal', padding='same')(p4)
c5 = Dropout(0.3) (c5)
c5 = Conv3D(256, (2, 2, 2), activation='relu', kernel_initializer='he_normal', padding='same') (c5)

u6 = Conv3DTranspose(128, (3, 3, 3), strides=(2, 2, 2), padding='same')(c5)
u6 = concatenate([u6, c4])
c6 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u6)
c6 = Dropout(0.2) (c6)
c6 = Conv3D(128, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c6)

u7 = Conv3DTranspose(64, (3, 3, 3), strides=(2, 2, 2), padding='same') (c6)
u7 = concatenate([u7, c3])
c7 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u7)
c7 = Dropout(0.2) (c7)
c7 = Conv3D(64, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c7)
    
u8 = Conv3DTranspose(32, (3, 3, 3), strides=(2, 2, 2), padding='same') (c7)
u8 = concatenate([u8, c2])
c8 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u8)
c8 = Dropout(0.1) (c8)
c8 = Conv3D(32, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c8)

u9 = Conv3DTranspose(16, (3, 3, 3), strides=(2, 2, 2), padding='same') (c8)
c9 = Conv3D(16, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (u9)
c9 = Dropout(0.1) (c9)
c9 = Conv3D(16, (3, 3, 3), activation='relu', kernel_initializer='he_normal', padding='same') (c9)
    
outputs = Conv3D(1, (1, 1, 1), activation='sigmoid') (c9)

CUPNET3D = Model(inputs = [inputs], outputs = [outputs])
CUPNET3D.layers[1].set_weights(binary_weights)
CUPNET3D.layers[6].set_weights(inverse_weights)
    
CUPNET3D.compile(optimizer = Nadam(), loss = custom_loss, metrics = ['mean_squared_error'],callbacks=[reduce_lr, early_stopping])
CUPNET3D.summary()

In [None]:
CUPNET3D_history = CUPNET3D.fit(MX_train, My_train,
          batch_size = 32,epochs= 100,
          verbose=2,validation_data=(MX_test,My_test),callbacks=[reduce_lr])