In [None]:
import numpy as np
import keras
import matplotlib.pyplot as plt
import sys
import os
from keras.layers import Input, TimeDistributed, Lambda, Conv2D, MaxPooling2D, UpSampling2D, Concatenate
import keras.backend as K
from keras.models import Model
import tensorflow as tf
from keras.utils import Sequence
from keras.optimizers import Adam
import cv2
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from PIL import Image
from IPython.display import clear_output
import scipy.io
from copy import deepcopy
import tqdm 
import math
import random

sys.path.append('src')

from data_loading import load_datasets_multiduration
from util import get_model_by_name, create_losses

from eval import *
# from attentive_convlstm_new import AttentiveConvLSTM2D
# from dcn_resnet_new import dcn_resnet
# from gaussian_prior_new import LearningPrior
# from losses_keras2 import *
from sal_imp_utilities import *
from cb import InteractivePlot
from losses_keras2 import loss_wrapper
# #from multiduration_models import xception_3stream
# from multiduration_models import sam_xception_timedist, sam_resnet_timedist, xception_se_lstm
# from util import get_model_by_name

# from data_loading import load_datasets_multiduration, load_datasets_sal_imp

%load_ext autoreload
%autoreload 2

# Check GPU status

In [None]:
%%bash
nvidia-smi

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["CUDA_VISIBLE_DEVICES"]

# Load data

In [None]:
# FILL THESE IN 
dataset = "codecharts"
bp = "/mnt/localssd2/predimportance/predimportance_shared/datasets"

In [None]:
times = [500, 3000, 5000]

data = load_datasets_multiduration(dataset, times, bp=bp, verbose=True)

# Model and training params

In [None]:
# FILL THESE IN: set training parameters 
ckpt_savedir = "/data/vision/oliva/scratch/anelise/predimportance/models/ckpt/sammd_cc_clean/"

# train
load_weights = False
# md-sem trained on salicon-md. Fine tune this on new code charts w/ccm. 
weightspath = "/mnt/localssd2/predimportance/predimportance_shared/models/ckpt/xception_se_lstm/xception_se_lstm_salicon_md_fixations_3match10kl-5cc-1nss_ep04_valloss-2.2668.hdf5"
#weightspath = "/data/vision/oliva/scratch/anelise/predimportance/models/ckpt/sam_resnet_salmd/salicon_md_fixations_6bc4kl-3cc-10nss_ep01_valloss-20.3203.hdf5"

batch_size = 8
init_lr = 0.00001
lr_reduce_by = .1
reduce_at_epoch = 2
n_epochs = 15

opt = Adam(lr=init_lr) 

# losses is a dictionary mapping loss names to weights 
losses = {
     'kl': 10,
     'cc': -5,
     'nss': -1,
     'ccmatch': 3
}

model_name = "md-sem"
model_inp_size = (240, 320)
model_out_size = (480, 640)
n_timesteps = len(times)

In [None]:
# get model 
model_params = {
    'input_shape': model_inp_size + (3,),
    'n_outs': len(losses),
    'nb_timestep': n_timesteps
}
model_func, mode = get_model_by_name(model_name)
model = model_func(**model_params)

if load_weights: 
    model.load_weights(weightspath)
    print("Loaded")

In [None]:
# set up data generation and checkpoints
if not os.path.exists(ckpt_savedir): 
    os.makedirs(ckpt_savedir)
    
# sort the losses so that those that use a fixmap are last, by convention
l, lw, l_str, n_heatmaps = create_losses(losses, model_out_size)
n_fixmaps = len(l) - n_heatmaps
print("Loss string", l_str)
    
# Generators
gen_train = MultidurationGenerator(
                img_filenames=data['img_files_train'], 
                map_filenames=data['map_files_train'], 
                fix_filenames=data['fix_files_train'], 
                batch_size=batch_size, 
                mode=mode,
                img_size=model_inp_size, 
                map_size=model_out_size,
                shuffle=True, 
                augment=False, 
                n_output_maps=n_heatmaps,
                n_output_fixs=n_fixmaps,
                fix_as_mat=data.get('fix_as_mat', False),
                fix_key=data.get('fix_key', ''))

gen_val = MultidurationGenerator(
            img_filenames=data['img_files_val'], 
            map_filenames=data['map_files_val'], 
            fix_filenames=data['fix_files_val'], 
            batch_size=1, 
            mode=mode,
            img_size=model_inp_size, 
            map_size=model_out_size,
            shuffle=False, 
            augment=False, 
            n_output_maps=n_heatmaps,
            n_output_fixs=n_fixmaps,
            fix_as_mat=data.get('fix_as_mat', False),
            fix_key=data.get('fix_key', '')
        )

# Callbacks

# where to save checkpoints
filepath = os.path.join(ckpt_savedir, dataset + "_" + l_str + '_ep{epoch:02d}_valloss{val_loss:.4f}.hdf5')
print("Checkpoints will be saved with format %s" % filepath)

cb_chk = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_weights_only=True, period=1)
cb_plot = InteractivePlot()

def step_decay(epoch):
    lrate = init_lr * math.pow(lr_reduce_by, math.floor((1+epoch)/reduce_at_epoch))
    if epoch%reduce_at_epoch:
        print('Reducing lr. New lr is:', lrate)
    return lrate
cb_sched = LearningRateScheduler(step_decay)

cbs = [cb_chk, cb_sched, cb_plot]

In [None]:
l

In [None]:
#test the generator 
img, outs = gen_train.__getitem__(1)
print("batch size: %d. Num inputs: %d. Num outputs: %d." % (batch_size, len(img), len(outs)))
print(outs[0].shape)

# Train

In [None]:
lw

In [None]:
model.compile(optimizer=opt, loss=l, loss_weights=lw)

print('Ready to train')
model.fit_generator(gen_train, epochs=n_epochs, verbose=1, callbacks=cbs, validation_data=gen_val, max_queue_size=10,  workers=5)

# Evaluate

In [None]:
# FILL IN
load_ckpt = True
# md-sem on code charts (full split) 
#W_eval = "/mnt/localssd2/predimportance/predimportance_shared/models/ckpt/xception_se_lstm/xception_se_lstm_codecharts0_3match10kl-5cc-1nss_ep03_valloss-0.7587.hdf5"
#W_eval = "/data/vision/oliva/scratch/anelise/predimportance/models/ckpt/mdsem_cc_clean/codecharts_kl10cc-5nss-1ccmatch3_ep04_valloss-0.0789.hdf5"
# W_eval = "/data/vision/oliva/scratch/anelise/predimportance/models/ckpt/sammd_cc_clean/codecharts_kl10ccmatch3cc-5nss-1_ep08_valloss0.2516.hdf5"
W_eval = "/mnt/localssd2/predimportance/predimportance_shared/models/ckpt/xception_se_lstm/xception_se_lstm_codecharts0_3match10kl-5cc-1nss_ep04_valloss-0.7401.hdf5"
blur_sigma = 7

In [None]:
if load_ckpt: 
    model.load_weights(W_eval)
    print("Loaded weights")

In [None]:
gen = eval_generator(
    data['img_files_test'], 
    data['map_files_test'],
    data['fix_files_test'], 
    None, 
    inp_size=model_inp_size,
    fix_as_mat=data.get('fix_as_mat', False),
    fix_key=data.get('fix_key', ''), 
)
#get_stats(model, gen, blur=True, mode='simple', n=False, imsize=(480, 620))
get_stats_multiduration(model, gen, blur=blur_sigma, mode=mode, start_at=None, n_times=len(times), compare_across_times=False)
pass

In [None]:
#img_path_test = "/mnt/localssd2/predimportance/predimportance_shared/datasets/cat2000/testStimuli/"
savedir = "../models/pred/mdsem_mit1003/mit1003_6bc4kl-3cc-10nss_ep12_valloss-28.7681/"
predict_and_save(model, img_files_test, inp_size=(shape_r, shape_c), savedir=savedir, blur=7, test_img_base_path="", ext="jpeg")

In [None]:
#model.load_weights("/mnt/localssd2/predimportance/predimportance_shared/models/ckpt/xception_se_lstm/xception_se_lstm_codecharts0_3match10kl-5cc-1nss_ep03_valloss-0.7587.hdf5")



In [None]:
gen = eval_generator(
    data['img_files_test'], 
    data['map_files_test'],
    data['fix_files_test'], 
    None, 
    inp_size=model_inp_size,
    fix_as_mat=data.get('fix_as_mat', False),
    fix_key=data.get('fix_key', ''), 
    return_name=True
)
# idxs= [10,11] #list(range(len(gen_val)))
# gen_val.return_names=True


for i, elt in enumerate(gen):
    images, maps, _, _, name = elt
    #images, maps, names = gen_val.__getitem__(idd) #np.random.randint(len(gen_val)))
    preds = model.predict(images[0])
    
    fig = plt.figure(figsize=[13,10])
    for t in range(preds[0].shape[1]):
        plt.subplot(3,3,t*3+1)
        #print("images[t] shape", np.squeeze(images[t]).shape)
        im = reverse_preprocess(np.squeeze(images[t]))
        plt.imshow(im)
        plt.xticks([])
        plt.yticks([])
#             plt.title('Original')
        plt.subplot(3,3,t*3+2)
        plt.imshow(maps[t])
        plt.xticks([])
        plt.yticks([])
#             plt.title('Gr Truth Imp')
        plt.subplot(3,3,t*3+3)
        p = preds[0][0,t,:,:,0]
        plt.imshow(p)
        p_norm = (p-np.min(p))/(np.max(p)-np.min(p))
#             p_img = p_norm*255
#             pred_img = Image.fromarray(np.uint8(p_img), "L")
#             pred_img.save('../../predimportance/models/pred/',model_func.__name__,W.split('/')[-1][:-5],times[t],names[i].split('/')[-1])
        plt.xticks([])
        plt.yticks([])
#             plt.title('Predicted Imp')
        plt.tight_layout()
    #plt.show()
    name = name.split("/")[-1]
    print(name)
    fig.savefig("/data/vision/oliva/scratch/anelise/predimportance/models/fig/mdsem_cc0/" + name)
    plt.close(fig)

In [None]:
model_mdsem = get_model_by_name("md-sem")[0](**model_params)
model_sam_onedur = get_model_by_name("sam-resnet")[0](**model_params)
model_sam = get_model_by_name("sam-md")[0](**model_params)

In [None]:
def compare_mdsem_to_sam_md_and_samx3(ckpt_mdsem, ckpt_sam_md, ckpts_sam_x3, img_names, gr_truths, model_mdsem, model_sam, model_sam_onedur,times=[500,3000,5000]):
    preds_x3 = {}
    shape_r, shape_c = model_inp_size
    images = preprocess_images(img_names, shape_r, shape_c)
    gr_truth = preprocess_maps(gr_truths, shape_r, shape_c)
    c=0
    for t,ckpt in ckpts_sam_x3.items():
        print('CKPT for time %d' % t)
        model_sam_onedur.load_weights(ckpt)
        preds_x3[c] = model_sam_onedur.predict(images)
        c+=1
#         print('preds[t].shape (should be (4,r,c,1))',preds[t].shape)
    model_sam.load_weights(ckpt_sam_md)
    preds_sam_md = model_sam.predict(images)
    
    model_mdsem.load_weights(ckpt_mdsem)
    preds_mdsem = model_mdsem.predict(images)
    
    for i in range(len(img_names)):
        print(img_names[i])
        plt.figure(figsize = (14,8))
        # Plot maps for each timestep
        for t in range(3):
            plt.subplot(3,5,t*5+1)
            plt.imshow(reverse_preprocess(images[i]))
            plt.xticks([])
            plt.yticks([])
            plt.title('Original image')
            plt.subplot(3,5,t*5+2)
            plt.imshow(gr_truth[i,:,:,0])
            plt.xticks([])
            plt.yticks([])
            plt.title('Ground truth')
            plt.subplot(3,5,t*5+3)
#             print(len(preds_x3[t],preds_x3[t][0].shape)
            plt.imshow(preds_x3[t][0][i,:,:,0])
            plt.xticks([])
            plt.yticks([])
            plt.title('SAM x3')
            plt.subplot(3,5,t*5+4)
            plt.imshow(preds_sam_md[0][i,t,:,:,0])
            plt.xticks([])
            plt.yticks([])
            plt.title('SAM-MD')
            ax = plt.subplot(3,5,t*5+5)
            plt.imshow(preds_mdsem[0][i,t,:,:,0])
            plt.xticks([])
            plt.yticks([])
            plt.title('MD-SEM')
            plt.ylabel(str(times[t])+'ms')
            ax.yaxis.set_label_position("right")
            
ckpts_sam_x3 = {
    500:'/data/vision/oliva/scratch/anelise/predimportance/models/ckpt/sam_resnet_cc/codecharts500_nss-1kl10cc-5_ep11_valloss-2.3949.hdf5', 
    3000: '/data/vision/oliva/scratch/anelise/predimportance/models/ckpt/sam_resnet_cc/codecharts3000_nss-1kl10cc-5_ep06_valloss1.6970.hdf5',
    5000: '/data/vision/oliva/scratch/anelise/predimportance/models/ckpt/sam_resnet_cc/codecharts5000_nss-1kl10cc-5_ep13_valloss2.4631.hdf5',
}

ckpt_mdsem = '/mnt/localssd2/predimportance/predimportance_shared/models/ckpt/xception_se_lstm/xception_se_lstm_codecharts0_3match10kl-5cc-1nss_ep03_valloss-0.7587.hdf5'
ckpt_sam_md = "/data/vision/oliva/scratch/anelise/predimportance/models/ckpt/sammd_cc_clean/codecharts_kl10ccmatch3cc-5nss-1_ep08_valloss0.2516.hdf5"


imgs = ['/data/vision/oliva/scratch/code-chart-data/training_data/salicon_md/val/raw_img/COCO_val2014_000000043561.jpg',
       '/data/vision/oliva/scratch/code-chart-data/training_data/salicon_md/val/raw_img/COCO_val2014_000000036810.jpg',
       '/data/vision/oliva/scratch/code-chart-data/training_data/salicon_md/val/raw_img/COCO_val2014_000000043851.jpg',
       '/data/vision/oliva/scratch/code-chart-data/training_data/salicon_md/val/raw_img/COCO_val2014_000000023807.jpg']

bp = '/data/vision/oliva/scratch/code-chart-data/training_data/salicon_md_fixations/val/raw_img/'
bp_gr = '/data/vision/oliva/scratch/code-chart-data/training_data/salicon_md_fixations/val/heatmaps/'
imgs = sorted([os.path.join(bp, f) for f in os.listdir(bp)])[:10]
gr_truths = sorted([os.path.join(bp_gr, f) for f in os.listdir(bp_gr)])[:10]

compare_mdsem_to_sam_md_and_samx3(ckpt_mdsem, ckpt_sam_md, ckpts_sam_x3, imgs, gr_truths, model_mdsem, model_sam, model_sam_onedur)