In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc
import imageio
import skimage
import os
import time
import cv2

In [None]:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
tf.enable_eager_execution()

In [None]:
tf.set_random_seed(42)

In [None]:
num_images = 10
num_classes = 10
folder_path = '../data/filtered_train'
label_path = '../data/filtered_train.csv'
batch_size = 128
valid_folder = '../data/filtered_valid'
valid_label = '../data/filtered_valid.csv'
compat_fun = 'pc' #dp or pc
class_mode = 'indep' #concat or indep

In [None]:
logging = tf.logging
logging.set_verbosity(logging.INFO)

def log_msg(msg):
   logging.info(f'{time.ctime()}: {msg}')

In [None]:
def cifar_dataset(image_folder_path, label_file, b_size, num_images):
#     images
    all_images = []
    for i in range(num_images):
        image_path = image_folder_path + '/' + str(i) + '.png'
        img = imageio.imread(image_path)
        img = (img - img.mean()) / img.std()
        all_images.append(img)
        if i%1000 == 0:
            print("Processed " + str(i))
        
    all_images = np.array(all_images)
    all_images = all_images.reshape((-1,256,256,3))
    
    dataset = tf.data.Dataset.from_tensor_slices((all_images)).batch(b_size)
    
#     labels
    lf = open(label_file,'r')
    labels = lf.read().split('\n')
    labels.remove('')
    labels = list(map(int, labels))
    print(len(labels))
    labels = labels[:num_images]
    labels = np.array(labels)
    all_labels = tf.data.Dataset.from_tensor_slices((labels)).batch(b_size)
            
    return dataset, all_labels  

In [None]:
class Convolution(tf.keras.Model):
    def __init__(self, filters, size, stride, padding, activation, initializer ):
        super(Convolution, self).__init__()
        self.conv = tf.layers.Conv2D(filters=filters, kernel_size=size, strides=stride, padding=padding, activation=activation, kernel_initializer=initializer)
#         self.conv = tf.layers.Conv2D(filters=filters, kernel_size=size, strides=stride, padding=padding, activation=activation)
    
    def call(self, inp):
        return self.conv(inp)
        

In [None]:
class MaxPool(tf.keras.Model):
    def __init__(self, size, stride, padding):
        super(MaxPool, self).__init__()
        self.pool = tf.layers.MaxPooling2D(pool_size = size, strides =stride, padding = padding)
    
    def call(self, inp):
        return self.pool(inp)

In [None]:
class Dense(tf.keras.Model):
    def __init__(self, size, activation):
        super(Dense, self).__init__()
        self.fc = tf.layers.Dense(size, activation = activation)
    
    def call(self, inp):
        return self.fc(inp)

In [None]:
def get_compatibility(v1, v2, compat_fun, u):
    if compat_fun == 'dp':
        scores = tf.einsum('bse,be->bs', tf.cast(v1, tf.float32) , tf.cast(v2, tf.float32) )
    elif compat_fun == 'pc':
        scores = tf.add(tf.transpose(v1,perm = [0,2,1]), tf.reshape(v2, [-1, v2.shape[1], 1]) )
        scores = tf.transpose(scores, perm = [0,2,1])
        
    return scores
    

In [None]:
class Attention(tf.keras.Model):
    def __init__(self, size):
        super(Attention, self).__init__()
        self.fc = Dense(size, None) #linear transformation
        self.u = Dense(size, None) #none??
        
    def call(self, inp, g, compat_fun):
        vec = tf.reshape(inp, [-1, inp.shape[1]*inp.shape[2], inp.shape[3]])
        transformed = self.fc(vec)
        c = get_compatibility(transformed, g, compat_fun, self.u)
        if compat_fun == 'pc':
            temp = self.u(c)
            summed = tf.reduce_sum(temp, axis = 2)
            c_scores = summed
        else:
            c_scores = c
        soft = tf.nn.softmax(c_scores)
        
        
        soft = tf.reshape(soft, [-1, soft.shape[1], 1]) #<-- final scores
        
        weighted = tf.multiply(soft, vec)
        summed = tf.reduce_sum(weighted, axis = 1)
        return summed, c_scores
    

In [None]:
def lrn(x, radius, alpha, beta, bias=1.0):
#     return tf.nn.local_response_normalization(x, depth_radius = radius, alpha = alpha, beta = beta, bias = bias)
    return tf.nn.lrn(x, depth_radius = radius, alpha = alpha, beta = beta, bias = bias)

In [None]:
class AlexNet(tf.keras.Model):
    def __init__(self):
        super(AlexNet, self).__init__()
        self.conv1 = Convolution(96,11,4,'VALID',tf.nn.relu, tf.truncated_normal_initializer(stddev=0.01))
        self.pool1 = MaxPool(3,2, 'VALID')
        
        self.conv2 = Convolution(256,5,1,'SAME',tf.nn.relu, tf.truncated_normal_initializer(stddev=0.01))
        self.pool2 = MaxPool(3,2, 'VALID') 
        
        self.conv3 = Convolution(384,3,1,'SAME',tf.nn.relu, tf.truncated_normal_initializer(stddev=0.01))
        
        self.conv4 = Convolution(384,3,1,'SAME',tf.nn.relu, tf.truncated_normal_initializer(stddev=0.01))
        
        self.conv5 = Convolution(256,3,1,'SAME',tf.nn.relu, tf.truncated_normal_initializer(stddev=0.01))
        self.pool5 = MaxPool(3,2,'VALID')
        
        self.g = Dense(256, None)
        self.att1 = Attention(256)
        self.att2 = Attention(256)
        
        self.fc = Dense(num_classes, None)
        self.fc2 = Dense(num_classes, None)
        
        
#         to be replaced with attention
#         self.fc6 = tf.layers.Dense(4096, activation = tf.nn.relu)
#         self.fc7 = tf.layers.Dense(4096, activation = tf.nn.relu)
#         self.fc8 = tf.layers.Dense(num_classes)
    
    def call(self, image, compat_fun):
#         compat_fun : dp - dot product
#                      pc - parameterised compatibility
        
        conv1 = self.conv1(image)
        pool1 = self.pool1(conv1)
        norm1 = lrn( tf.cast(pool1, dtype = tf.float32), 2, 2e-05, 0.75)
#         norm1 = tf.layers.batch_normalization(tf.cast(pool1, dtype = tf.float32))
        
        conv2 = self.conv2(norm1)
        pool2 = self.pool2(conv2)
        norm2 = lrn( tf.cast(pool2, dtype = tf.float32), 2, 2e-05, 0.75)
#         norm2 = tf.layers.batch_normalization(tf.cast(pool1, dtype = tf.float32))
        
        conv3 = self.conv3(norm2)
        
        conv4 = self.conv4(conv3)
        
        conv5 = self.conv5(conv4)
        pool5 = self.pool5(conv5)
        
        g = self.g(tf.layers.flatten(pool5))
        
        att1, scores1 = self.att1(conv4,g,compat_fun)
        att2, scores2 = self.att2(conv5,g,compat_fun)
        
        if class_mode == 'concat':
            concat = tf.concat([tf.layers.flatten(att1) , tf.layers.flatten(att2)], axis = 1)
            fc = self.fc(concat)
        elif class_mode == 'indep':
            fc1 = self.fc(tf.layers.flatten(att1))
            fc2 = self.fc2(tf.layers.flatten(att2))
            fc = tf.add(fc1,fc2)/2
          
        return fc, scores1, scores2
        
    

In [None]:
def prediction_loss_fun(model, data, labels, compat_fun):
    logits, scores1, scores2 = model(data, compat_fun)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
#     return tf.reduce_sum(loss)/tf.cast(data.shape[0], dtype = tf.float32)
    return tf.reduce_mean(loss)

In [None]:
def get_accuracy(model, data, labels):
    logits, s1, s2 = model(data, compat_fun)
    pred = tf.nn.softmax(logits)
#     print("predicted..")
#     print(tf.argmax(pred, axis=1))
#     print("actual..")
#     print(labels)
    accuracy_val = tf.reduce_sum( tf.cast( tf.equal( tf.argmax(pred, axis=1), labels),dtype=tf.float32))/float(pred.shape[0].value)
    return accuracy_val

In [None]:
def shuffle_data(data, label):
    idx = np.random.permutation(data.shape[0].value)
    
    datum = np.array(data)[idx]
    datum = tf.convert_to_tensor(datum)
    
    lab = np.array(label)[idx]
    lab = tf.convert_to_tensor(lab)
    
    return datum, lab

In [None]:
def get_valid_acc(model, dataset, labels):
    valid_acc = 0
    valid_loss = 0
    count = 0
    for datum,lab in zip(dataset,labels):
        count += 1
        valid_loss += prediction_loss_fun(anet, datum, lab).numpy()
        valid_acc += get_accuracy(anet, datum, lab).numpy()
    return valid_loss/count, valid_acc/count

In [None]:
dataset, labels = cifar_dataset(folder_path, label_path, batch_size, num_images)
valid_data, val_labels = cifar_dataset(valid_folder, valid_label, 10, 10)
val_data = next(iter(valid_data))
val_lab = next(iter(val_labels))

In [None]:
opt = tf.train.AdamOptimizer(learning_rate = 1e-4)

In [None]:
anet = AlexNet()

In [None]:
loss_and_grads_fun = tfe.implicit_value_and_gradients(prediction_loss_fun)

In [None]:
checkpoint_dir = '..checkpoints/attention_alexnet_' + compat_fun + '_' + class_mode
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
root = tfe.Checkpoint(optimizer=opt, model=anet, optimizer_step=tf.train.get_or_create_global_step())

In [None]:
valid_loss = prediction_loss_fun(anet, val_data, val_lab, compat_fun).numpy()
acc = get_accuracy(anet, val_data, val_lab).numpy()*100
log_msg(f'Initial Valid loss: {valid_loss: 0.4f} accuracy: {acc: f}%')

In [None]:
root.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
NUM_EPOCHS = 15
STATS_STEPS = 1



for epoch_num in range(NUM_EPOCHS):
    print("Epoch: " + str(epoch_num))
    step_num = 0
    for data, label in zip(dataset, labels):
        step_num += 1
        datum, lab = shuffle_data(data, label)
        
        loss_value, gradients = loss_and_grads_fun(anet, datum, lab, compat_fun)
        
        if step_num % STATS_STEPS == 0:
            print("Stat step.. " + str(step_num))
            loss = prediction_loss_fun(anet, datum, lab, compat_fun).numpy()
            accuracy = get_accuracy(anet, datum, lab).numpy()*100
            log_msg(f'Epoch: {epoch_num} Step: {step_num} Train loss: {loss: 0.4f} accuracy: {accuracy: f}%')
            loss = prediction_loss_fun(anet, val_data, val_lab, compat_fun).numpy()
            accuracy = get_accuracy(anet, val_data, val_lab).numpy()*100
            log_msg(f'Epoch: {epoch_num} Step: {step_num} Valid loss: {loss: 0.4f} accuracy: {accuracy: f}%')

            if loss < valid_loss:
                print("Improvement in validation loss. Saving..")
                valid_loss = loss
                save_path = root.save(checkpoint_prefix)
            
            
        
        opt.apply_gradients(gradients, global_step=tf.train.get_or_create_global_step())
        
        

    print(f'Epoch{epoch_num} Done!')

In [None]:
# anet.save_weights('../checkpoint/my_checkpoint')

In [None]:
# visualisation part

In [None]:
def get_correct_predictions(logits, labels):
    pred = tf.nn.softmax(logits)
    predictions = tf.cast( tf.equal( tf.argmax(pred, axis=1), labels),dtype=tf.float32).numpy()
    correct = np.where(predictions)
    return correct

In [None]:
iterator = iter(dataset)
it2 = iter(labels)

In [None]:
datum = next(iterator)
lab = next(it2)

In [None]:
logits, scores1, scores2 = anet(datum, compat_fun)

In [None]:
correct_images = get_correct_predictions(logits,lab)

In [None]:
scores1 = tf.reshape(scores1, [-1,scores1.shape[1]])
scores2 = tf.reshape(scores2, [-1,scores2.shape[1]])

In [None]:
correct_images

In [None]:
# 9,14,16 -> validation train

In [None]:
batch_num = 0
image_num = 8

In [None]:
att_map1 = scores1[image_num].numpy().reshape([14,14])
att_map2 = scores2[image_num].numpy().reshape([14,14])

In [None]:
image_path = '../data/filtered_valid/' + str(batch_num*batch_size + image_num) + '.png'

In [None]:
heatmap1 = np.array(np.repeat(att_map1.reshape(14,14,1), 3, axis = 2))
heatmap2 = np.array(np.repeat(att_map2.reshape(14,14,1), 3, axis = 2))
imageio.imsave('./heatmap1.png', (heatmap1))
imageio.imsave('./heatmap2.png', (heatmap2))

In [None]:
# image = cv2.resize(cv2.imread(image_path), (14,14))
# heatmap1 = cv2.imread('./heatmap1.png')
# heatmap2 = cv2.imread('./heatmap2.png')
# heatmap1 = cv2.applyColorMap(heatmap1, cv2.COLORMAP_JET)
# heatmap2 = cv2.applyColorMap(heatmap2, cv2.COLORMAP_JET)
# fin1 = cv2.addWeighted(heatmap1, 0.4, image, 0.6, 0)
# fin2 = cv2.addWeighted(heatmap2, 0.4, image, 0.6, 0)
# plt.imshow(fin1)
# plt.imshow(fin2)

In [None]:
# method 2-> keep image original size, resize map

In [None]:
heatmap1_r = cv2.resize(cv2.imread('./heatmap1.png'), (256,256))
heatmap2_r = cv2.resize(cv2.imread('./heatmap2.png'), (256,256))

In [None]:
image_r = cv2.imread(image_path)

In [None]:
plt.imshow(image_r)

In [None]:
heatmap1_r = cv2.applyColorMap(heatmap1_r, cv2.COLORMAP_JET)
heatmap2_r = cv2.applyColorMap(heatmap2_r, cv2.COLORMAP_JET)
fin1_r = cv2.addWeighted(heatmap1_r, 0.4, image_r, 0.6, 0)
fin2_r = cv2.addWeighted(heatmap2_r, 0.4, image_r, 0.6, 0)

In [None]:
plt.imshow(fin1_r)

In [None]:
plt.imshow(fin2_r)

In [None]:
# check test accuracy
test_folder = '../data/filtered_test'
test_file = '../data/filtered_test.csv'
test_data, test_labels = cifar_dataset(test_folder, test_file, 500, 500)
test_data = next(iter(test_data))
test_lab = next(iter(test_labels))

In [None]:
loss = prediction_loss_fun(anet, test_data, test_lab, compat_fun).numpy()

In [None]:
accuracy = get_accuracy(anet, test_data, test_lab).numpy()*100

In [None]:
print(loss)
print(accuracy)

In [None]:
log_msg(f'Test loss: {loss: 0.4f} accuracy: {accuracy: f}%')

In [None]:
imageio.imsave('../images/image_r.png', (cv2.resize(image_r, (256,256) )))
imageio.imsave('../images/fin1_r.png', (cv2.resize(fin1_r, (256,256) )))
imageio.imsave('../images/fin2_r.png', (cv2.resize(fin2_r, (256,256) )))