# Baseline classifier for zebrafish classification

The baseline classifier is a shallow convolutional network which serves as a comparisson for the transfer learning classifier performance later on. The model contains 4 convolutional and maxpooling layers. 

In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Conv2D, MaxPooling2D, Flatten
from keras.preprocessing.image import ImageDataGenerator
from skimage.transform import resize
import seaborn as sns
sns.set_style('white')
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from itertools import islice
from sklearn import metrics
import glob
import os
from PIL import Image
import cv2
import matplotlib.image as mpimg
import shutil
from collections import Counter
from pathlib import Path
import albumentations
#check if GPU is visible
#from tensorflow.python.client import device_lib
#print(device_lib.list_local_devices())

### Getting the zebrafish data

Download and unzip the training data (images and classes) into the subfolder `data`
```
wget https://zenodo.org/record/6651752/files/data.zip?download=1  -O data.zip
unzip data.zip
```

Which should result in the following 

```
data
├── fish_part_labels
│   ├── DAPT
│   ├── her1;her7
│   ├── tbx6_fss
│   └── WT
├── training
│   ├── DAPT
│   ├── her1;her7
│   ├── tbx6_fss
│   └── WT
└── validation
    ├── DAPT
    ├── her1;her7
    ├── tbx6_fss
    └── WT
```

### Loading the zebrafish data

In [3]:
dataroot= Path('data')
target_size = (450,900)

In [None]:
# augmentations 

transform = albumentations.Compose([
    albumentations.HorizontalFlip(p=.5),
    albumentations.VerticalFlip(p=.5),    
    albumentations.GaussianBlur(p=.3),
    albumentations.Affine(scale=(0.8,1.2), shear=0, rotate=(-10,10), cval=(1,1,1), p=.5),
    albumentations.RandomBrightnessContrast(brightness_limit=.1, contrast_limit=.1, p=.5),
    albumentations.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=.1, val_shift_limit=0, p=.5)
])

def preprocessing_function(x):
    x = x/255
    x = transform(image=x)['image']
    return x

path_train = dataroot/"training"
path_val = dataroot/"validation"

train_datagen_augmented = ImageDataGenerator(preprocessing_function=preprocessing_function)
val_datagen = ImageDataGenerator( rescale = 1.0/255. ) 

data_train = train_datagen_augmented.flow_from_directory(path_train, batch_size = 16, class_mode = 'sparse', target_size = target_size)
data_test = val_datagen.flow_from_directory(path_val,  batch_size = 16, class_mode = 'sparse', target_size = target_size)

class_name_to_id = dict((k,v) for k,v in data_train.class_indices.items())
class_id_to_name = dict((v,k) for k,v in data_train.class_indices.items())

print(class_name_to_id)
print(class_id_to_name)

x,y = data_train.next()

# lets plot some example augmented images 
w,h=5,3
plt.figure(figsize=(20,10))
for i, (_x, _y) in enumerate(zip(x[:w*h],y[:w*h])):
    plt.subplot(h,w, i+1)
    plt.imshow(np.clip(_x, 0,1))
    c = int(_y)
    plt.title(f'class = {class_id_to_name[c]} ({c})')

### Building and compiling the model

In [None]:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(32,3,activation="relu", padding="same", input_shape = x.shape[1:]))
model.add(tf.keras.layers.Conv2D(32,3,activation="relu", padding="same"))
model.add(tf.keras.layers.MaxPooling2D((4,4)))
model.add(tf.keras.layers.Conv2D(32,3,activation="relu", padding="same"))
model.add(tf.keras.layers.Conv2D(32,3,activation="relu", padding="same"))
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(4,)) #from_logits = True so no need for softmax here

model.summary()
epochs = 200

learning_rate = 0.001

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
          metrics=['sparse_categorical_accuracy'])

#preparing the weights for balanced training
counts = Counter(data_train.classes)
counts_mean = np.mean(tuple(counts.values()))
class_weights = dict((k, np.sqrt(counts_mean/v)) for k,v in counts.items())
print(f'class weights: {class_weights}')

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.5,patience=2, min_lr=0.00025, verbose = 1)

checkpoint_folder = Path('checkpoints')
checkpoint_folder.mkdir(exist_ok=True, parents=True)

checkpoint_filepath = checkpoint_folder/"baseline_classifier.h5"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    save_best_only=True)

### Training the model

In [None]:
history = model.fit(data_train,
                epochs=epochs,
                validation_data = data_test,
                class_weight=class_weights, callbacks=[model_checkpoint_callback])
np.save(checkpoint_folder/f"baseline_classifier_lr_{learning_rate:.4f}_epochs_{epochs}.npy",history.history)

### Loading best weights

In [None]:
model.load_weights(checkpoint_filepath)

### Calculating the confusion matrix

In [None]:
data_test = val_datagen.flow_from_directory(path_val,  batch_size = 16, class_mode = 'sparse', shuffle=False, target_size = target_size)

y_true = data_test.classes
y_pred = np.argmax(model.predict(data_test), -1)

def calculate_confusion_matrix(y_true, y_pred):
    classes = tuple(class_id_to_name[i] for i in range(4))
    matrix = metrics.confusion_matrix(y_true, y_pred) #rows - true, columns - predicted
    matrix = matrix/np.sum(matrix, axis=-1, keepdims=True)
    
    df_cm = pd.DataFrame(matrix, index=classes, columns=classes)
    # plt.figure(figsize=(10,7))
    sns.set(font_scale=1.4) # for label size
    sns.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='1.3f')# font size
    plt.ylabel("True class")
    plt.xlabel("Predicted class") 
    plt.show()
    accuracy = np.sum(np.diag(matrix))/np.sum(matrix)
    print(f'Accuracy: {accuracy:.4f}') 
    return accuracy

accuracy = calculate_confusion_matrix(y_true, y_pred)

# Apply the model to new images from a different microscope (`kings` subset)

Put the `kings` images (in bmp format) in the subfolder `data/kings`, resulting in 

```
data
├── fish_part_labels
│   ├── DAPT
│   ├── her1;her7
│   ├── tbx6_fss
│   └── WT
├── kings
│   ├── DAPT
│   └── WT
├── training
│   ├── DAPT
│   ├── her1;her7
│   ├── tbx6_fss
│   └── WT
└── validation
    ├── DAPT
    ├── her1;her7
    ├── tbx6_fss
    └── WT
```

In [None]:
path_new = dataroot/'kings'
save_folder = Path("test_predictions") 
save_folder.mkdir(exist_ok=True, parents=True)

new_datagen = ImageDataGenerator( rescale = 1.0/255. ) 
data_new = new_datagen.flow_from_directory(path_new,  batch_size = 1, shuffle=False, class_mode = 'sparse', target_size = target_size)
class_id_to_name_test = dict((v,k) for k,v in data_new.class_indices.items())

img = tuple(x for x, _ in islice(data_new, len(data_new)))
y_true = np.array(tuple(class_id_to_name_test[c] for c in data_new.classes))
y_pred = np.array(tuple(class_id_to_name[c] for c in np.argmax(model.predict(data_new), -1)))

accuracy = np.mean(y_true==y_pred)
print(f'Accuracy: {accuracy:.5f}')

In [None]:
for i, (x, fname, y1, y2) in enumerate(zip(img, data_new.filenames, y_true, y_pred)):
    plt.figure()
    plt.imshow(x[0])
    plt.title(f'{"correct" if y1==y2 else "wrong"} ({y1} -> {y2})', fontsize=10)
    if(y1 == y2):
        plt.savefig(save_folder/f'{Path(fname).name}_{y1}_correct_as_{y2}.jpg', dpi = 200, transparent=True)
    else:
        plt.savefig(save_folder/f'{Path(fname).name}_{y1}_wrong_as_{y2}.jpg', dpi = 200, transparent=True)