In [None]:
import numpy as np
import keras
import matplotlib.pyplot as plt
import sys
import os
from keras.layers import Input, TimeDistributed, Lambda, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
import keras.backend as K
from keras.models import Model
import tensorflow as tf
from keras.utils import Sequence
from keras.optimizers import Adam
import cv2
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from PIL import Image
from IPython.display import clear_output
import scipy.io
from copy import deepcopy
import tqdm 
import math
import random

sys.path.append('src')

from data_loading import load_datasets_multiduration, load_datasets_sal_imp
from util import get_model_by_name, get_loss_by_name

# from eval import *
# from attentive_convlstm_new import AttentiveConvLSTM2D
# from dcn_resnet_new import dcn_resnet
# from gaussian_prior_new import LearningPrior
# from losses_keras2 import *
from sal_imp_utilities import *
from cb import InteractivePlot
from losses_keras2 import loss_wrapper
# #from multiduration_models import xception_3stream
# from multiduration_models import sam_xception_timedist, sam_resnet_timedist, xception_se_lstm
# from util import get_model_by_name

# from data_loading import load_datasets_multiduration, load_datasets_sal_imp

%load_ext autoreload
%autoreload 2

# Check GPU status

In [None]:
%%bash
nvidia-smi

In [None]:

os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["CUDA_VISIBLE_DEVICES"]

# Load data

In [None]:
# FILL THESE IN 
dataset = "mit1003"
bp = "/mnt/localssd2/predimportance/predimportance_shared/datasets"

In [None]:
times = [500, 3000, 5000]

data = load_datasets_multiduration(dataset, times, bp=bp, verbose=True)

# Model and training params

In [None]:
# FILL THESE IN: set training parameters 
ckpt_savedir = "ckpt"

load_weights = False
weightspath = ""

batch_size = 4
init_lr = 0.00001
lr_reduce_by = .1
reduce_at_epoch = 2
n_epochs = 15

opt = Adam(lr=init_lr) 

# losses is a dictionary mapping loss names to weights 
losses = {
    'binary_crossentropy': 6,
    'kl': 4,
    'cc': -3,
    'nss': -10
}

model_name = "md-sem"
model_inp_size = (240, 320)
model_out_size = (480, 640)
n_timesteps = 3

In [None]:
# get model 
n_outs_model = len(losses)
model_params = {
    'input_shape': model_inp_size + (3,),
    'n_outs': n_outs_model,
    'nb_timestep': n_timesteps
}
model_func, mode = get_model_by_name(model_name)
model = model_func(**model_params)

if load_weights: 
    model.load_weights(weightspath)

In [None]:
# set up data generation and checkpoints
if not os.path.exists(ckpt_savedir): 
    os.makedirs(ckpt_savedir)
    
l = []
lw = [] 
loss_str = ""
for lname, wt in losses.items():
    l.append(get_loss_by_name(lname, model_out_size))
    lw.append(wt)
    loss_str += lname + str(wt)
    
n_output_maps_gen = len(losses)-1

# Generators
gen_train = MultidurationGenerator(
                img_filenames=data['img_files_train'], 
                map_filenames=data['map_files_train'], 
                fix_filenames=data['fix_files_train'], 
                batch_size=batch_size, 
                mode=mode,
                img_size=model_inp_size, 
                map_size=model_out_size,
                shuffle=True, 
                augment=False, 
                n_output_maps=n_output_maps_gen,
                fix_as_mat=data.get('fix_as_mat', False),
                fix_key=data.get('fix_key', ''))

gen_val = MultidurationGenerator(
            img_filenames=data['img_files_val'], 
            map_filenames=data['map_files_val'], 
            fix_filenames=data['fix_files_val'], 
            batch_size=1, 
            mode=mode,
            img_size=model_inp_size, 
            map_size=model_out_size,
            shuffle=False, 
            augment=False, 
            n_output_maps=n_output_maps_gen,
            fix_as_mat=data.get('fix_as_mat', False),
            fix_key=data.get('fix_key', '')
        )

# Callbacks

# where to save checkpoints
filepath = os.path.join(ckpt_savedir, dataset + "_" + loss_str + '_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5')
print("Checkpoints will be saved with format %s" % filepath)

cb_chk = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_weights_only=True, period=1)
cb_plot = InteractivePlot()

def step_decay(epoch):
    lrate = init_lr * math.pow(lr_reduce_by, math.floor((1+epoch)/reduce_at_epoch))
    if epoch%reduce_at_epoch:
        print('Reducing lr. New lr is:', lrate)
    return lrate
cb_sched = LearningRateScheduler(step_decay)

cbs = [cb_chk, cb_sched, cb_plot]

In [None]:
#test the generator 
img, outs = gen_train.__getitem__(1)
print("batch size: %d. Num inputs: %d" % (batch_size, len(img)))
print(outs[0].shape)

# Train

In [None]:
model.compile(optimizer=opt, loss=l, loss_weights=lw)

print('Ready to train')
model.fit_generator(gen_train, epochs=n_epochs, verbose=1, callbacks=cbs, validation_data=gen_val, max_queue_size=10,  workers=5)

In [None]:
model.fit_generator(gen_val, verbose=1)

## Some visualizations

In [None]:
gen = eval_generator(
    img_files_val, 
    map_files_val, 
    
    fix_files_val, 
    None, 
    inp_size=(shape_r, shape_c))

examples = [next(gen) for _ in range(50)]
len(examples)

In [None]:
images, maps, fixmaps, fixcoords = random.choice(examples)

#images, maps = gen_val.__getitem__(np.random.randint(len(gen_val)))
print("maps size", len(maps), maps[0].shape)

batch = 0
preds = model.predict(images[0])[0][batch]

times = [500, 3000, 5000]
print("preds size", preds.shape)
n_times = len(preds)
assert len(times) == n_times
batch_sz = len(preds)
copy=0
# n_col, n_row = n_times + 2, batch_sz

# plt.figure(figsize=[16,10*batch_sz])

plt.imshow(reverse_preprocess(np.squeeze(images[0])))
plt.title("original image %d" % batch)
plt.show()

plt.figure(figsize=[16, 10])
n_row=n_times
n_col=2

for time in range(n_times): 

#         plt.subplot(n_row, n_col, batch*n_col+1)
#         plt.imshow(reverse_preprocess(images[batch]))
#         plt.title('Original')

    plt.subplot(n_row,n_col,time*n_col+1)
    plt.imshow(maps[time])
    plt.title('Gt %dms' % times[time])

    plt.subplot(n_row,n_col,time*n_col+2)
    # print("preds time sahpe", preds[time].shape)
    plt.imshow(np.squeeze(preds[time]))
    plt.title('Prediction %dms' % times[time])

plt.show()
    
# plt.show()

# Evaluate

In [None]:
if True: 
    W = "../models/ckpt/mdsem_mit1003/mit1003_6bc4kl-3cc-10nss_ep12_valloss-28.7681.hdf5"
    model.load_weights(W)

In [None]:
fix_as_mat

In [None]:
gen = eval_generator(
    img_files_val, 
    [map_files_val[0]], 
    [fix_files_val[0]], 
    None, 
    inp_size=(shape_r, shape_c),
    fix_as_mat=fix_as_mat,
    fix_key=fix_key, 
    fixcoord_filetype='mat'
)
#get_stats(model, gen, blur=True, mode='simple', n=False, imsize=(480, 620))
get_stats_oneduration(model, gen, blur=7, mode='singlestream', start_at=None)
pass

In [None]:
img_files_test, _, _, _, _, _, _, _ = load_datasets_multiduration('mit300', times=times, bp=bp)

In [None]:
#img_path_test = "/mnt/localssd2/predimportance/predimportance_shared/datasets/cat2000/testStimuli/"
savedir = "../models/pred/mdsem_mit1003/mit1003_6bc4kl-3cc-10nss_ep12_valloss-28.7681/"
predict_and_save(model, img_files_test, inp_size=(shape_r, shape_c), savedir=savedir, blur=7, test_img_base_path="", ext="jpeg")