# Baseline classifier for zebrafish classification

The baseline classifier is a shallow convolutional network which serves as a comparisson for the transfer learning classifier performance later on. The model contains 4 convolutional and maxpooling layers. 

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from skimage.transform import resize
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from itertools import islice
from sklearn import metrics
from collections import Counter
from pathlib import Path
import tifffile
from tqdm.auto import tqdm 
import seaborn as sns
sns.set_style('white')
from utils import create_dataset
# set manual seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

### Getting the zebrafish data

Download and unzip the training data (images and classes) into the subfolder `data`
```
wget https://zenodo.org/record/6651752/files/data.zip?download=1  -O data.zip
unzip data.zip
```

Which should result in the following 

```
data
├── fish_part_labels
│   ├── DAPT
│   ├── her1;her7
│   ├── tbx6_fss
│   └── WT
├── training
│   ├── DAPT
│   ├── her1;her7
│   ├── tbx6_fss
│   └── WT
└── validation
    ├── DAPT
    ├── her1;her7
    ├── tbx6_fss
    └── WT
```

In [None]:
dataroot= Path('data')
target_size = (450,900)
epochs=50

### Loading the zebrafish data

In [None]:
data_train = create_dataset(dataroot/"training",   target_size=target_size, batchsize=16, shuffle=True, augment=True)
data_val   = create_dataset(dataroot/"validation", target_size=target_size, batchsize=16, shuffle=False, augment=False)

print(data_train.class_name_to_id)
print(data_train.class_id_to_name)

# plot some example training and validation images
n=5
x,y = next(iter(data_train.data))
plt.figure(figsize=(20,5))
for i, (_x, _y) in enumerate(zip(x[:2*n],y[:2*n])):
    plt.subplot(2,n, i+1)
    plt.imshow(np.clip(_x, 0,1))
    plt.title(f'Training image - class = {data_train.class_id_to_name[int(_y)]} ({int(_y)})')

x,y = next(iter(data_val.data))
plt.figure(figsize=(20,5))
for i, (_x, _y) in enumerate(zip(x[:2*n],y[:2*n])):
    plt.subplot(2,n, i+1)
    plt.imshow(np.clip(_x, 0,1))
    plt.title(f'validation image - class = {data_train.class_id_to_name[int(_y)]} ({int(_y)})')


### Building and compiling the model

In [None]:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(32,7,activation="relu", strides=4, padding="same", input_shape = x.shape[1:]))
for _ in range(3):
    model.add(tf.keras.layers.Conv2D(32,3,activation="relu", strides=4, padding="same"))
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(4,)) #from_logits = True so no need for softmax here

model.summary()

learning_rate = 0.0005

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
          loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
          metrics=['sparse_categorical_accuracy'])

#preparing the weights for balanced training
counts = Counter(data_train.labels)
counts_mean = np.mean(tuple(counts.values()))
class_weights = dict((k, np.sqrt(counts_mean/v)) for k,v in counts.items())
print(f'class weights: {class_weights}')

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.5,patience=2, min_lr=0.00025, verbose = 1)

checkpoint_folder = Path('checkpoints')
checkpoint_folder.mkdir(exist_ok=True, parents=True)

checkpoint_filepath = checkpoint_folder/"baseline_classifier.h5"

model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='val_loss',
    save_best_only=True)

### Training the model

In [None]:
history = model.fit(data_train.data.repeat(4),
                epochs=epochs,
                validation_data = data_val.data,
                class_weight=class_weights, callbacks=[model_checkpoint_callback])
np.save(checkpoint_folder/f"baseline_classifier_lr_{learning_rate:.4f}_epochs_{epochs}.npy",history.history)

### Loading best weights

In [None]:
model.load_weights(checkpoint_filepath)

### Calculating the confusion matrix

In [None]:
y_true = np.array(data_val.labels)
y_pred = np.argmax(model.predict(data_val.data), -1)

def calculate_confusion_matrix(y_true, y_pred):
    classes = tuple(data_val.class_id_to_name[i] for i in range(4))
    matrix = metrics.confusion_matrix(y_true, y_pred) #rows - true, columns - predicted
    matrix = matrix/np.sum(matrix, axis=-1, keepdims=True)
    
    df_cm = pd.DataFrame(matrix, index=classes, columns=classes)
    # plt.figure(figsize=(10,7))
    sns.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='1.3f')# font size
    plt.ylabel("True class")
    plt.xlabel("Predicted class") 
    plt.show()
    accuracy = np.sum(np.diag(matrix))/np.sum(matrix)
    print(f'Accuracy: {accuracy:.4f}') 
    return accuracy

accuracy = calculate_confusion_matrix(y_true, y_pred)

# Apply the model to new images from a different microscope (`kings` subset)

Put the `kings` images (in bmp format) in the subfolder `data/kings`, resulting in 

```
data
├── fish_part_labels
│   ├── DAPT
│   ├── her1;her7
│   ├── tbx6_fss
│   └── WT
├── kings
│   ├── DAPT
│   └── WT
├── training
│   ├── DAPT
│   ├── her1;her7
│   ├── tbx6_fss
│   └── WT
└── validation
    ├── DAPT
    ├── her1;her7
    ├── tbx6_fss
    └── WT
```

In [None]:
data_new = create_dataset(dataroot/'kings', augment=False, shuffle=False, target_size=target_size)

x,y = next(iter(data_new.data))
plt.figure(figsize=(20,5))
for i, (_x, _y) in enumerate(zip(x[:2*n],y[:2*n])):
    plt.subplot(2,n, i+1)
    plt.imshow(np.clip(_x, 0,1))
    plt.title(f'New image - class = {data_train.class_id_to_name[int(_y)]} ({int(_y)})')


In [None]:
y_true = np.array(tuple(data_new.class_id_to_name[c] for c in data_new.labels))
y_pred = np.array(tuple(data_train.class_id_to_name[c] for c in np.argmax(model.predict(data_new.data), -1)))

accuracy = np.mean(y_true==y_pred)
print(f'Accuracy: {accuracy:.5f}')

In [None]:
save_folder = Path("test_predictions") 
save_folder.mkdir(exist_ok=True, parents=True)

for i, (x, fname, y1, y2) in enumerate(zip(data_new.images, data_new.filenames, y_true, y_pred)):
    if i%4==0:
        plt.figure(figsize=(20,4))
    ax = plt.subplot(1,4,i%4+1)
    ax.imshow(x)
    correct = y1==y2
    ax.set_title(f'{"correct" if correct else "wrong"} ({y1} -> {y2})', fontsize=10)
    ax.title.set_color('green' if correct else 'red')
    ax.axis('off')
    plt.savefig(save_folder/f'{Path(fname).name}_{y1}_{"correct" if correct else "wrong"}_as_{y2}.jpg', dpi = 200, transparent=True)