## This notebook is used for training of transfer learning, MobileNetV2-based classifier of zebrafish embryo mutants. The training and evaluation of the classifier is followed by Class ctivation mapping (CAM) analysis which points to the predictiv feature of the zebrafish embryo images.

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten
from keras.preprocessing.image import ImageDataGenerator
from skimage.transform import resize
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn import metrics
import glob
import os
from PIL import Image
import cv2
import matplotlib.image as mpimg
import shutil
from collections import Counter
from pathlib import Path
import albumentations
from ImageDataAugmentor.image_data_augmentor import *
import random
#check if GPU is visible
#from tensorflow.python.client import device_lib
#print(device_lib.list_local_devices())

### Loading the zebrafish data

In [None]:
path_train = Path("training")
path_val = Path("validation")



AUGMENTATIONS = albumentations.Compose([albumentations.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=10, val_shift_limit=0, p=1),
                                        albumentations.RandomBrightnessContrast(brightness_limit = (-0.3,0.3)),
                                        albumentations.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45),
                                        albumentations.HorizontalFlip(),
                                        albumentations.VerticalFlip()])

train_datagen_augmented = ImageDataAugmentor(rescale = 1.0/255., augment = AUGMENTATIONS, preprocess_input=None)
val_datagen = ImageDataAugmentor( rescale = 1.0/255., preprocess_input=None) #no augmentation

N = 1 #N is the number of pieces the image is cut per each dimension. For training on whole images it should be 1

if(N < 15):
    train_generator = train_datagen_augmented.flow_from_directory(path_train, batch_size = 20, class_mode = 'categorical', target_size = (int(450/N),int(900/N)))
    validation_generator = val_datagen.flow_from_directory(path_val,  batch_size = 20, class_mode = 'categorical', target_size = (int(450/N),int(900/N)))

if(N >= 15):
#for N = 15 it should be (30,60) but MobileNetV2 cannot accept values for any axis smaller than 32 so I redefine the size to (32,64) manually
    train_generator = train_datagen_augmented.flow_from_directory(path_train, batch_size = 20, class_mode = 'categorical', target_size = (32,64))
    validation_generator = val_datagen.flow_from_directory(path_val,  batch_size = 20, class_mode = 'categorical', target_size = (32,64))

x,y = train_generator.next()

In [None]:
train_generator.class_indices

### Building and compiling transfer learning classifier

In [None]:
base_model = tf.keras.applications.MobileNetV2(input_shape = x.shape[1:],
                                                 include_top=False,
                                              weights='imagenet')
base_model.trainable = False #freeze the weights

#Construct the final model
model = tf.keras.Sequential([
                          base_model,
                          tf.keras.layers.GlobalAveragePooling2D(),
                          Dense(4, activation='softmax')])
epochs = 200

rate = 0.001

model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=rate),
          loss='categorical_crossentropy',
          metrics=['accuracy'])

#preparing the weights for balanced training
counts = Counter(train_generator.classes)
counts_total = sum(counts.values())
class_weights = dict((k, counts_total/v) for k,v in counts.items())

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.5,patience=10, min_lr=0.00025, verbose = 1)
checkpoint_folder = Path('checkpoints')
checkpoint_filepath = checkpoint_folder/"transfer_learning_classifier"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

### Training the classifier

In [None]:
history = model.fit(train_generator,
                epochs=epochs,
                validation_data=validation_generator,
                class_weight=class_weights, callbacks=[model_checkpoint_callback]) #reduce_lr
np.save(checkpoint_folder/f"transfer-learning_lr_{rate:.4f}_epochs_{epochs}.npy",history.history)

### Plotting the learning curves

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='upper right')
plt.ylabel('Accuracy')
plt.xlabel('Epochs')
plt.ylim([0,1.00])
plt.yticks([0, 0.25, 0.5, 0.75, 1])
plt.title('Training and Validation Accuracy')

### Loading weights from the best epoch

In [None]:
model.load_weights(checkpoint_filepath)

### Run predictions on the validation set and evaluate accuracy

In [None]:
path_val = Path("validation")

test_datagen = ImageDataGenerator( rescale = 1.0/255. ) #no augmentation
if(N < 15):
    test_generator = test_datagen.flow_from_directory(path_val,  batch_size = 20, class_mode = 'categorical', target_size = (int(450/N),int(900/N)), shuffle = False)
if(N >= 15): 
    test_generator = test_datagen.flow_from_directory(path_val,  batch_size = 20, class_mode = 'categorical', target_size = (32,64), shuffle = False)
#test_generator = test_datagen.flow_from_directory(path_val,  batch_size = 20, class_mode = 'categorical', target_size = (32,64), shuffle = False)

#note shuffle = False in the previous line

names = test_generator.filenames #list of names of files
true = list() #list of true labels
predicted = list() #list of predicted labels
misclassified = list() #list of indices of images that are misclassified

for i in range(len(test_generator)):
    x, y = test_generator.next()
    for j in range(len(x)):
        image = x[j]
        
        y_coded = y[j]
        y_true = y_coded.argmax()
        yhat_coded = model.predict(np.array([image,]))
        yhat = yhat_coded.argmax()

        true.append(y_true)
        predicted.append(yhat)
        if(y_true != yhat):
            misclassified.append(names[i*20+j]) #batch size = 20
classes = []
for cl in test_generator.class_indices:
    classes.append(cl)

#calculate confusion matrix
def calculate_confusion_matrix(classes, true, predicted):


    matrix = metrics.confusion_matrix(true, predicted) #rows - true class, columns - predicted class
    matrix = matrix.astype(float)
    for i in range(len(matrix)): #scaling per row (per true label)
        matrix[:][i] = matrix[:][i] / sum(matrix[:][i])

    df_cm = pd.DataFrame(matrix, index=[classes[0], classes[1], classes[2], classes[3]], columns=[classes[0], classes[1], classes[2], classes[3]])
    sns.set(font_scale=1.4) # for label size
    sns.heatmap(df_cm, annot=True, annot_kws={"size": 16}) # font size

    plt.show()
    accuracy = sum(np.diag(matrix))/sum(sum(matrix))
    print(accuracy) 
    return accuracy

accuracy = calculate_confusion_matrix(classes, true, predicted)

### Running predictions on the test (never-seen) dataset

In [None]:
path_val = Path("test") #Folder with the test dataset
batch_size = 20
test_datagen = ImageDataGenerator(rescale = 1.0/255.) #no augmentation
if(N<15):
    test_generator = test_datagen.flow_from_directory(path_val,  batch_size = batch_size, class_mode = 'categorical', target_size = (int(450/N),int(900/N)), shuffle = False)
if(N >= 15):
    test_generator = test_datagen.flow_from_directory(path_val,  batch_size = batch_size, class_mode = 'categorical', target_size = (int(32),int(64)), shuffle = False)

save_folder = Path("test_predictions") #Folder where images will be saved
if not os.path.exists(save_folder):
        os.mkdir(save_folder)
        
        
classes = []
for cl in train_generator.class_indices:
    classes.append(cl)
    
filenames = test_generator.filenames

total = 0  
for i in range(len(test_generator)):
    x, y = test_generator.next()
    for j in range(len(x)):
        image = x[j]
        
        y_coded = y[j]
        y_true = y_coded.argmax()
        yhat_coded = model.predict(np.array([image,]))
        yhat = yhat_coded.argmax()
        
        fig = plt.figure()
        #plt.imshow(image)   

        if(yhat == y_true):
            plt.savefig(str(save_folder/filenames[i*batch_size+j].split('\\')[0]) + "correct_as" + classes[yhat] + "_" + filenames[i*batch_size+j].split('\\')[1], dpi = fig.dpi, transparent=True)
        elif(yhat != y_true):
            plt.savefig(str(save_folder/filenames[i*batch_size+j].split('\\')[0]) + "wrong_as" + classes[yhat] + "_"+ filenames[i*batch_size+j].split('\\')[1], dpi = fig.dpi, transparent=True)

### CAM analysis

In [None]:
#adapted from: https://nbviewer.jupyter.org/github/vincent1bt/Machine-learning-tutorials-notebooks/blob/master/activationMaps/ActivationsMaps.ipynb
#deeper explanation here: https://vincentblog.xyz/posts/class-activation-maps
#grad-CAM, although better than CAM, could not be used due to structure of MobileNetV2
def get_activation_map(image, image_class): 
        #we assume image is given in a shape of [0, a, b, 3] - axb RGB image in a tupple of 1 file - this is needed so that it's compatible with predict
        #image_class is the true class

        class_weights = model.layers[-1].get_weights()[0]
        final_conv_layer = model.layers[0].layers[152]
        
        get_output = tf.keras.backend.function([model.layers[0].input], 
                                               [final_conv_layer.output])
        predictions = model.predict(image)
        [conv_outputs] = get_output(image)
        conv_outputs = conv_outputs[0, :, :, :]

        cam = np.zeros(dtype=np.float32, shape=conv_outputs.shape[0:2])

        for index, weight in enumerate(class_weights[:, image_class]):
            cam += weight * conv_outputs[:, :, index]
        
        class_predicted = np.argmax(predictions[0])
        predictions = f'Class predicted: {class_predicted} | Real class: {image_class}'
        
        cam /= np.max(cam)
        cam = cv2.resize(cam, (image.shape[2], image.shape[1]))
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heatmap[np.where(cam < 0.2)] = 0
        
        img = heatmap * 0.5 + image[0,:]
        cv2.imwrite("heatmap.jpg", img)
        #Why is this writing and reading necessary?
        heatmap = mpimg.imread("heatmap.jpg")
        
        scaled_image = (((img - img.min()) * 255) / (img.max() - img.min())).astype(np.uint8)
        
        #puting cam into 0-1 range
        scaled_cam = cam + np.abs(np.min(cam))
        scaled_cam = scaled_cam/np.sum(scaled_cam) 
        
        return scaled_cam#scaled_image

path_val = Path("validation")


test_datagen = ImageDataGenerator( rescale = 1.0/255. ) #no augmentation
test_generator = test_datagen.flow_from_directory(path_val,  batch_size = 20, class_mode = 'categorical', target_size = (int(450/N),int(900/N)), shuffle = False)

#saving of the predicted images
cam_folder = Path('CAM_results')
if not os.path.exists(cam_folder):
        os.mkdir(cam_folder)
        
classes = []
for cl in test_generator.class_indices:
    classes.append(cl)
    if os.path.exists(cam_folder/cl):
        shutil.rmtree(cam_folder/cl)
    os.mkdir(cam_folder/cl)

    
filenames = test_generator.filenames

#folder with parts of the fish
folder_parts = Path("fish_part_labels")

total = 0  
for i in range(len(test_generator)):
    x, y = test_generator.next()
    for j in range(len(x)):
        total = total + 1
        y_pred = np.argmax(model.predict(x[j:j+1]), axis=-1)
        image = x[j]
        y_true = y[j].argmax()
        #print(y_true)
        cam = get_activation_map(np.array([image,]), y_true)
        filename = filenames[i*20+j].split('.')[0] #batch_size = 20
        parts_file = glob.glob(str(folder_parts/f"{filename}*"))
        parts = Image.open(parts_file[0])
        resized_parts = parts.resize((900,450), Image.NEAREST) #everything is streched to the shape of the loaded image
        parts_array = np.array(resized_parts)
        parts_array_binarized = np.where(parts_array > 0, 0, 1)
        #print(filenames[i*20+j])
        if(y_pred == y_true):
            fig = plt.figure()
            plt.imshow(image[0,...])   
            plt.imshow(resize(cam,image.shape[0:2]), alpha=0.4, cmap = 'magma')
            plt.imshow(parts_array_binarized, alpha = 0.1)
            plt.box(False)
            plt.axis('off')
            #fig = resize(att[:,:,0], (450, 900))
            plt.savefig(cam_folder/filenames[i*20+j], dpi = fig.dpi, transparent=True)


### Calculating the per-class- and per-fish-part CAM attention for images from the validation set

In [None]:
def calculate_attention(image, mask, index, normalization = 0):
    
    #index is the value of the part in the mask
    #output is the sum of image pixels on positions mask == image
    #tested on simple examples and it works as expected
    #normalization is the parameter that define
    
    attention = 0

    if(image.shape != mask.shape):
        print(image.shape)
        print(mask.shape)
        raise Exception('Image and Mask should have the same shape')
    
    #I don't know if there is some trick as in Matlab to do this in two lines of code
    n_pixels = 0
    for i in range(image.shape[0]):
        for j in range(image.shape[1]):
            if(mask[i,j] == index):
                attention = attention + image[i,j]
                n_pixels = n_pixels + 1
                
    if normalization == 1:
        attention = attention / n_pixels 
    return attention

Attention = np.empty((4, 6, 20)) #4 rows are the classes, 5 colums are different zebrafish parts
#indices for zebrafish parts: (in the array as well the values on the zebrafish part paintings)
#0 - background
#1 - head
#2 - trunk
#3 - tail 
#4 - yolk
#5 - yolk extension

#indices for classes are the same as in the classes array

#folder with parts of the fish
folder_parts = "fish_part_labels"

#loading the validation set
path_val = "validation"
val_datagen = ImageDataGenerator( rescale = 1.0/255. ) #no augmentation
data_test = val_datagen.flow_from_directory(path_val,  batch_size = 20, class_mode = 'categorical', target_size = (450, 900), shuffle = False)
#note shuffle = False in the previous line
filenames = data_test.filenames

classes = []
for cl in data_test.class_indices:
    classes.append(cl)

class_counter = np.zeros(4) #keeping track of how many images from a certain class there were so far
normalization = 1
for i in range(len(data_test)):
    x, y = data_test.next()
    for j in range(len(x)):
        
        total = total + 1
        y_pred = np.argmax(model.predict(x[j:j+1]), axis=-1)
        image = x[j]
        
        y_true = y[j].argmax()
        cam = get_activation_map(np.array([image,]), y_true)
        fig = resize(cam, (450, 900))
        resized_cam = np.array(resize(cam, image.shape[0:2]))
        #resizing changes values so we scale attention back to 1
        resized_cam = resized_cam/np.sum(resized_cam)
        
        #fig = plt.figure()
        #plt.imshow(att[:,:,0])
        #fig = plt.figure()
        #plt.imshow(resized_att)
        
        for parts_file in glob.glob(folder_parts + '//' + filenames[i*20+j] + '*'): #batch_size = 20
            parts = Image.open(parts_file)
            resized_parts = parts.resize((900,450), Image.NEAREST) #everything is streched to the shape of the loaded image
            parts_array = np.array(resized_parts)
            for part_index in range(0,6):
                
                Attention[int(y_true), part_index, int(class_counter[int(y_true)])] = calculate_attention(resized_cam, parts_array, int(part_index), normalization)
                
        class_counter[int(y_true)] = class_counter[int(y_true)] + 1
        #print(class_counter)

### Ploting CAM results

In [None]:
#ploting per class
for i in range(4):
    df = pd.DataFrame({'BKGD': Attention[i, 0, :], 'Head': Attention[i, 1, :], 'Trunk': Attention[i, 2, :], 'Tail': Attention[i, 3, :], 'Yolk': Attention[i, 4, :], 'Y.E.': Attention[i, 5, :]})
    
    plt.figure()
    ax = sns.violinplot(data = df.iloc[:, 0:6]*100, color= np.array([185, 208, 229])/255, scale='width', bw = 'scott') #we scale Y axis so that it represents percentages
    #ax.set(ylabel="Attention [%]")
    #plt.yticks([-1, 0, 1, 2, 3, 4, 5, 6])
    plt.title(classes[i])
    plt.ylim([-0.0001, 0.0011])

parts = ['BKGD', 'Head', 'Trunk', 'Tail', 'Yolk', 'Yolk Extension']

#ploting per fish part
for i in range(6):
    df = pd.DataFrame({'WT': Attention[1, i, :], 'tbx6': Attention[3, i, :], 'DAPT': Attention[0, i, :], 'her1;her7': Attention[2, i, :]})
    plt.figure()
    ax = sns.violinplot(data = df.iloc[:, 0:6]*100, color= np.array([185, 208, 229])/255, scale='width', bw = 'scott') #we scale Y axis so that it represents percentages
    #ax.set(ylabel="Attention [%]")
    #plt.yticks([-1, 0, 1, 2, 3, 4, 5, 6])
    plt.title(parts[i])
    plt.ylim([-0.0001, 0.0011])