## This notebook trains the "baseline" classifier of the zebrafish embryo mutants. Baseline classifier is a shallow convolutional classifier which serves as a comparisson for the transfer learning classifier performance. The model contains 4 convolutional layers intermitted by a maxpooling layer. 

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten
from keras.preprocessing.image import ImageDataGenerator
from skimage.transform import resize
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn import metrics
import glob
import os
from PIL import Image
import cv2
import matplotlib.image as mpimg
import shutil
from collections import Counter
from pathlib import Path
import albumentations
from ImageDataAugmentor.image_data_augmentor import *
import random
#check if GPU is visible
#from tensorflow.python.client import device_lib
#print(device_lib.list_local_devices())

### Loading the zebrafish data

In [None]:
path_train = Path("training")
path_val = Path("validation")

AUGMENTATIONS = albumentations.Compose([albumentations.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=10, val_shift_limit=0, p=1),
                                        albumentations.RandomBrightnessContrast(brightness_limit = (-0.3,0.3)),
                                        albumentations.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.1, rotate_limit=45),
                                        albumentations.HorizontalFlip(),
                                        albumentations.VerticalFlip()])

train_datagen_augmented = ImageDataAugmentor(rescale = 1.0/255., augment = AUGMENTATIONS, preprocess_input=None)
val_datagen = ImageDataAugmentor( rescale = 1.0/255., preprocess_input=None) #no augmentation

data_train = train_datagen_augmented.flow_from_directory(path_train, batch_size = 16, class_mode = 'sparse', target_size = (450, 900))
data_test = val_datagen.flow_from_directory(path_val,  batch_size = 16, class_mode = 'sparse', target_size = (450, 900))

x,y = data_train.next()

In [None]:
data_train.class_indices

### Building and compiling the model

In [None]:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(32,3,activation="relu", padding="same", input_shape = x.shape[1:]))
model.add(tf.keras.layers.Conv2D(32,3,activation="relu", padding="same"))
model.add(tf.keras.layers.MaxPooling2D((4,4)))
model.add(tf.keras.layers.Conv2D(32,3,activation="relu", padding="same"))
model.add(tf.keras.layers.Conv2D(32,3,activation="relu", padding="same"))
model.add(tf.keras.layers.GlobalAveragePooling2D())
#another conv2D x 2 with 32
#another max pol
#global average pooling. kill flatt
model.add(tf.keras.layers.Dropout(0.2))
#model.add(tf.keras.layers.Flatten()) 
model.add(tf.keras.layers.Dense(4,)) #from_logits = True so no need for softmax here

model.summary()
epochs = 200

rate = 0.001

model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=rate),
          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
          metrics=['accuracy'])

#preparing the weights for balanced training
counts = Counter(data_train.classes)
counts_total = sum(counts.values())
class_weights = dict((k, counts_total/v) for k,v in counts.items())

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.5,patience=2, min_lr=0.00025, verbose = 1)

checkpoint_folder = Path('checkpoints')
checkpoint_filepath = checkpoint_folder/"baseline_classifier"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_accuracy',
    mode='max',
    save_best_only=True)

### Training the model

In [None]:
history = model.fit(data_train,
                epochs=epochs,
                validation_data = data_test,
                class_weight=class_weights, callbacks=[model_checkpoint_callback])
np.save(checkpoint_folder/f"baseline_classifier_lr_{rate:.4f}_epochs_{epochs}.npy",history.history)

### Loading best weights

In [None]:
model.load_weights(checkpoint_filepath)

### Calculating the confusion matrix

In [None]:
data_test = val_datagen.flow_from_directory(path_val,  batch_size = 16, class_mode = 'sparse', target_size = (450, 900))

true = list() #list of true labels
predicted = list() #list of predicted labels

for i in range(len(data_test)):
    x, y = data_test.next()
    for j in range(len(x)):
        y_pred = np.argmax(model.predict(x[j:j+1]), axis=-1)
        
        y_true = y[j]
        
        true.append(y_true)
        predicted.append(y_pred)
        
classes = []
for cl in data_test.class_indices:
    classes.append(cl)
    
def calculate_confusion_matrix(classes, true, predicted):
    matrix = metrics.confusion_matrix(true, predicted) #rows - true, columns - predicted
    matrix = matrix.astype(float)


    for i in range(len(matrix)): #scaling per row (per true label)
        matrix[:][i] = matrix[:][i] / sum(matrix[:][i])
    

    df_cm = pd.DataFrame(matrix, index=[classes[0], classes[1], classes[2], classes[3]], columns=[classes[0], classes[1], classes[2], classes[3]])
    # plt.figure(figsize=(10,7))
    sn.set(font_scale=1.4) # for label size
    sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='1.3f')# font size
    plt.ylabel("True class")
    plt.xlabel("Predicted class") 
    plt.show()
    accuracy = sum(np.diag(matrix))/sum(sum(matrix))
    print(accuracy) 
    return accuracy

accuracy = calculate_confusion_matrix(classes, true, predicted)

### Running predictions on the test (never-seen) dataset

In [None]:
path_val = Path("test") #Folder with the test dataset
batch_size = 16
test_datagen = ImageDataGenerator(rescale = 1.0/255.) #no augmentation
test_generator = test_datagen.flow_from_directory(path_val,  batch_size = batch_size, class_mode = 'categorical', target_size = (450,900), shuffle = False)

save_folder = Path("test_predictions") #Folder where images will be saved
if not os.path.exists(save_folder):
        os.mkdir(save_folder)
        
        
classes = []
for cl in data_train.class_indices:
    classes.append(cl)
    
filenames = test_generator.filenames

total = 0  
for i in range(len(test_generator)):
    x, y = test_generator.next()
    for j in range(len(x)):
        image = x[j]
        
        y_coded = y[j]
        y_true = y_coded.argmax()
        yhat_coded = model.predict(np.array([image,]))
        yhat = yhat_coded.argmax()
        
        fig = plt.figure()
        #plt.imshow(image)   

        if(yhat == y_true):
            plt.savefig(str(save_folder/filenames[i*batch_size+j].split('\\')[0]) + "correct_as" + classes[yhat] + "_" + filenames[i*batch_size+j].split('\\')[1], dpi = fig.dpi, transparent=True)
        elif(yhat != y_true):
            plt.savefig(str(save_folder/filenames[i*batch_size+j].split('\\')[0]) + "wrong_as" + classes[yhat] + "_"+ filenames[i*batch_size+j].split('\\')[1], dpi = fig.dpi, transparent=True)