# Introduction

Plankton are microorganisms that form the base of the aquatic food chain. Photosynthetic plankton also play a crucial role in the global carbon cycle, as well as producing much of the world's oxygen. Because of their importance to the marine ecosystem, monitoring the health of the plankton population is a key component of oceanic research. One method to monitor plankton populations over time is through in-situ imaging of individual planktonic organisms. In-situ monitoring allows for the collection of not just abundance data but also morphological information on various species of plankton. 

The Martha's Vineyard Coastal Observatory operates an imaging flow cytometer (the MVCO FlowCytobot) that operates continuously and takes photographs of individual plankton cells in the water column. The large amount of image data generated by this machine necessitates the development of automated methods to classify plankton images. The Woods Hole Oceanographic Institute which operates the FlowCytobot has released several datasets of expert labeled plankton images for the purpose of training machine learning models to classify plankton by species. The main dataset, WHOI-plankton contains >3.5 million labeled images of plankton from 103 classes. Unfortunately, the dataset is highly imbalanced, with 6 classes comprising up to 85% of the whole dataset. Imbalanced datasets can hamper the performance of machine learning models due to underfitting of training on rarer classes. One solution for imbalanced datasets is to augment the images of the rarer classes to balance the dataset.

# Method 

In this notebook, I explored the use of a generative adversarial network (GAN) to generate synthetic images of plankton from the WHOI-plankton dataset. The synthetic images can be used to augment the training data for the rarer classes in the dataset. The GAN model was trained on a balanced subset of the WHOI-plankton dataset containing 22 classes and 6598 images. This dataset is called WHOI22. In addition, I used traditional image augmentation techniques to gradually replace the images of the WHOI22 dataset to evaluate the impace of image augmentation on transfer learning performance.

# Results

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import keras
from keras import layers
from keras import ops
import tensorflow as tf
import datetime
import pickle

## Image examples

Here are some examples of the images in the WHOI22 dataset:

In [None]:
image_size = (256, 256)
batch_size = 32
def load_data(batch_size=batch_size):
    train_data = keras.utils.image_dataset_from_directory("datasets/padded_training",
                                                            labels="inferred",
                                                            label_mode="categorical",
                                                            image_size=image_size,
                                                            batch_size=batch_size)
    test_data = keras.utils.image_dataset_from_directory("datasets/padded_testing",
                                                            labels="inferred",
                                                            label_mode="categorical",
                                                            image_size=image_size,
                                                            batch_size=batch_size)
    return train_data, test_data
train_data, test_data = load_data()
combined_data = train_data.concatenate(test_data)

In [None]:
plt.figure(figsize=(10, 10))
for images, labels in combined_data.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(np.array(images[i]).astype("uint8"))
        plt.title(combined_data.class_names[tf.argmax(labels[i])])
        plt.axis("off")

## Baseline Model

To benchmark the performance of training on the balanced WHOI22 dataset, I used a EfficientNetB7 model pre-trained on the "imagenet" dataset and performed transfer learning. Code is based on and inspired by the keras tutorial on transfer learning: https://keras.io/guides/transfer_learning/ 

In [None]:
from keras.applications.efficientnet import EfficientNetB7

base_model = EfficientNetB7(
    weights='imagenet',
    input_shape=image_size + (3,),
    include_top=False
)
base_model.trainable = False
inputs = keras.Input(shape=(256, 256, 3))

x = base_model(inputs, training=False)
# Convert features of shape `base_model.output_shape[1:]` to vectors
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(22)(x)
model = keras.Model(inputs, outputs)
model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=[keras.metrics.CategoricalAccuracy()])
model.summary()

In [None]:
%%script false --no-raise-error
## Training the model

This baseline model achieved an overall accuracy of ***TBD***, with average precision of ***TBD*** and average recall of ***TBD***.

In [None]:
## Load saved model

## Load predictions

## Evaluate the model

## Image Augmentation

I created a function to augment a proportion of the images in the WHOI22 dataset and replace the original images with the augmented images. I evaluated the impact of image augmentation on the performance of the EfficientNetB7 model. Below is the code I used to create new augmented datasets, based on the keras tutorial (https://www.tensorflow.org/tutorials/images/data_augmentation). 

In [None]:
## Code to augment dataset 

In [None]:
# Display examples of augmented data

# Appendix of methods

## Data preprocessing

Images were downloaded from Olsen & Sosik 2007 (https://doi.org/10.4319/lom.2007.5.195) in tif format. I padded the images to a square aspect ratio with black pixels to 256x256 and saved as bmp. 

In [None]:
%%script false --no-raise-error
# Function that takes a tiff image and pads it to the desired size (square) and returns a numpy array
def pad_image(image, size):
    # load the image
    img = tif.imread(image)
    # get image dimensions
    x,y = img.shape
    # calculate the padding
    x_pad = size - x
    y_pad = size - y
    # check if padding is needed, if not, resize the image maintaining aspect ratio
    if x_pad < 0 or y_pad < 0:
        if x > y:
            new_x = size
            new_y = int(y * (size/x))
        else:
            new_y = size
            new_x = int(x * (size/y))
        img = np.array(Image.fromarray(img).resize((new_y, new_x)))
        x,y = img.shape
        x_pad = size - x
        y_pad = size - y
    # pad the image
    if x_pad > 0 or y_pad > 0:
        x_pad1, x_pad2 = x_pad//2, x_pad-x_pad//2
        y_pad1, y_pad2 = y_pad//2, y_pad-y_pad//2
        img = np.pad(img, ((x_pad1, x_pad2), (y_pad1, y_pad2)), 'constant')
    return img

# Function that takes a folder and pads each image in the folder. Inputs are the input path, output path.
def pad_folder(input_path, output_path, size):
    # make the output path
    os.makedirs(output_path, exist_ok=True)
    subfolders = [f.path for f in os.scandir(input_path) if f.is_dir()]
    for folder in subfolders:
        # use base folder name as class
        class_name = os.path.basename(folder)
        # make the output class folder
        os.makedirs(os.path.join(output_path, class_name), exist_ok=True)
        # get all the images in the folder
        images = [f.path for f in os.scandir(folder) if f.is_file()]
        for img in images:
            # input image path
            img_in_path = img
            # output image path
            img_out_path = os.path.join(output_path, class_name, os.path.basename(img))
            img_out_path = img_out_path.replace('.tif', '.bmp')
            # print statment
            print(f"Padding {img_in_path} to {img_out_path}")
            # pad the image and save it
            try:
                img_output = pad_image(image=img_in_path, size=size)
                # write img as bmp
                im = Image.fromarray(img_output)
                im.save(img_out_path)
            except Exception as e:
                print("failed to pad image", img_in_path, e)
pad_folder(input_path='datasets/testing', output_path="datasets/padded_testing", size=256)
pad_folder(input_path='datasets/training', output_path="datasets/padded_training", size=256)