# Full Implementation of VGG-11 U-Net Segmentation Model
Traning, validation, and testing of a pre-trained VGG-11 U-Net on *Proteus mirabilis* pLac-*flgM* colony images for segmentating ring boundaries where changes in environmental conditions occurred.

**References/Acknowledgments:** This implementation follows the ring boundary segmentation model we previously presented in Doshi & Shaw et al. 2022<sup>1,2</sup>. Whereas the task of the initial work was to segment all ring boundaries within a *P. mirabilis* colony image, the model is re-trained and evaluated here for the new task of segmenting only the boundaries that delineate changes in environmental conditions.  

The architecture and pretrained VGG-11 encoder of the U-Net, as well as several utility functions, are imported from "Segmentation Models: Python library with Neural Networks for Image Segmentation based on PyTorch" (SMP)<sup>3</sup>. For both ring boundary segmentation tasks, our scripts were adapted from the SMP car segmentation example, benefiting from its specified functions for data loading, augmentation, and model training.  

[1] Doshi, A.\*\, M. Shaw\*\, R. Tonea, R. Minyety, S. Moon, A. Laine, J. Guo\^\, and T. Danino\^\. A deep learning pipeline for segmentation of *Proteus mirabilis* colony patterns. in *2022 IEEE 19th
International Symposium on Biomedical Imaging (ISBI)*. 2022. IEEE. doi: 10.1109/ISBI52829.2022.9761643

[2] daninolab. mirabilis-ringboundary-seg-minimal. 2022; Available from: https://github.com/daninolab/proteus-mirabilis.

[3] Iakubovskii, P. segmentation_models.pytorch (Version 0.2.0). 2021; Available from: https://github.com/qubvel/segmentation_models.pytorch.

# Imports

In [None]:
# Earlier PyPI version (0.2.0) that we have been using: 
!pip install segmentation-models-pytorch==0.2.0
# To get the latest version from source:
#!pip install git+https://github.com/qubvel/segmentation_models.pytorch
    
!pip install openpyxl # for reading in excel file

In [None]:
import numpy as np
import pandas as pd
import cv2
import csv
import copy
import time
from tqdm import tqdm
import os
import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torchvision import models
from torch.utils.data import Dataset, DataLoader
import glob
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import losses
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import albumentations as albu
import math
from skimage.morphology import skeletonize, thin
from skimage import data
from skimage.util import invert

# Dataset

In [None]:
# Specify train, val, test sizes
num_train = 13
num_val = 4
num_test = 4

In [None]:
# Set the paths to the input image and mask folders
img_dir = '../input/flgm-updated-110721/x_preProc'
mask_dir = '../input/flgm-updated-110721/y'

In [None]:
# Load in train-val-test split list (excel file)
train_val_test_path = '../input/flgm-temp-changes-trainvaltestlist-082021/flgM_TrainValTest_list.xlsx'
train_val_test_df = pd.read_excel(train_val_test_path)

In [None]:
# Function to extract img IDs for train, val, & test sets
def get_train_val_test_IDs(train_val_test_df, num_train, num_val, num_test):
    
    train_col = train_val_test_df.loc[0:num_train-1,["Train"]].values
    val_col = train_val_test_df.loc[0:num_val-1:,["Val"]].values
    test_col = train_val_test_df.loc[0:num_test-1:,["Test"]].values
    
    train_IDs = [str(train_img)[2:-2] for train_img in train_col]
    val_IDs = [str(val_img)[2:-2] for val_img in val_col]
    test_IDs = [str(test_img)[2:-2] for test_img in test_col]
    
    return train_IDs, val_IDs, test_IDs

In [None]:
# Extract the img IDs
train_IDs, val_IDs, test_IDs = get_train_val_test_IDs(train_val_test_df, num_train, num_val, num_test)

# Confirm dataset sizes
print(len(train_IDs))
print(len(val_IDs))
print(len(test_IDs))

In [None]:
# Dataset class
class BacteriaDataset(Dataset):
    
    CLASSES = ['boundaries']
    
    def __init__(self, img_IDs, img_dir, mask_dir, 
                 classes=None,augmentation=None, preprocessing=None):
        self.img_IDs = img_IDs
        self.img_dir = img_dir
        self.mask_dir = mask_dir  
        self.augmentation = augmentation         # for augmentations
        self.preprocessing = preprocessing       # preprocessing to normalize images
        
         # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
    def __len__(self):
        return len(self.img_IDs)

    def __getitem__(self, i):
        
        # read data
        img_path = os.path.join(self.img_dir, self.img_IDs[i])
        mask_path = os.path.join(self.mask_dir, self.img_IDs[i].replace(".tif","_testim_boundarymask_uncrop.tif"))
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask = (mask >= 1).astype('float32')
        mask = np.expand_dims(mask, axis=2) 
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=img, mask=mask)
            img, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=img, mask=mask)
            img, mask = sample['image'], sample['mask']
            
        return img, mask

In [None]:
# Transformations definitions 
# For training set:
def get_training_augmentation():
    train_transform = [albu.PadIfNeeded(min_height=1024, min_width=1024, always_apply=True, border_mode=cv2.BORDER_REFLECT_101),
                       albu.Rotate(limit=(-10,10), border_mode=cv2.BORDER_REFLECT_101, p=0.5),
                       albu.HorizontalFlip(p=0.5),
                       albu.VerticalFlip(p=0.5),
                       albu.ShiftScaleRotate(shift_limit=0.05, scale_limit=0, rotate_limit=0,
                                          border_mode=cv2.BORDER_REFLECT_101, p=0.5), # translate
                       albu.ShiftScaleRotate(shift_limit=0, scale_limit=0.5, rotate_limit=0,
                                          border_mode=cv2.BORDER_REFLECT_101, p=0.5), # zoom
                      ]
    return albu.Compose(train_transform)

# For validation and test sets:
def get_val_test_augmentation():
    val_test_transform = [
                       albu.PadIfNeeded(min_height=1024, min_width=1024, always_apply=True, border_mode=cv2.BORDER_REFLECT_101),
                      ]
    return albu.Compose(val_test_transform)

# Necessary for feeding images into model
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
# Helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(15, 10))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image,cmap='binary',vmin=0,vmax=1)
    plt.show()

In [None]:
# Let's look at a randon image+mask pair from our train dataset
orig_train_set = BacteriaDataset(train_IDs, img_dir, mask_dir,classes=['boundaries'])
rand_num = np.random.choice(len(orig_train_set))
img, mask = orig_train_set[rand_num] 
filename = train_IDs[rand_num]
print(filename)
visualize(original_pattern_image=img/255, 
          ground_truth_mask=mask.squeeze()
         )

In [None]:
# Visualize 5 transformed images+masks from our train dataset
# (no training augmentationd used in this run)
train_dataset = BacteriaDataset(train_IDs, img_dir, mask_dir, classes=['boundaries'],
                                augmentation=get_val_test_augmentation(),
                                # comment out this preprocessing line, as it's only needed for loading images into model:
                                #preprocessing=get_preprocessing(preprocess_input), 
                                )

for i in range(5):
    n = np.random.choice(len(train_dataset))
    img, mask = train_dataset[n]
    filename = train_IDs[n]
    print(filename)
    visualize(transformed_pattern_image=img/255, 
              transformed_ground_truth_mask=mask)

# Model architecture and hyperparameters

In [None]:
# Set some variables 

# For saving results
date_started = '110721'
architecture_name = 'vgg11_UNet_flgm' 
model_name = architecture_name + '_' + date_started
print(model_name)

# If loading in model from previous round of training, to continute training here
load_model = False

# Specific to model implementation
Encoder = 'vgg11'
Attention = None 
Weights = 'imagenet'
ACTIVATION = 'sigmoid'
CLASSES = ['boundaries']
preprocess_input = get_preprocessing_fn(Encoder, Weights)
more_epochs = 60 # How many epochs/more epochs to train for
patience = 3 # For early stopping
train_batch_size = 3
val_batch_size = 1
test_batch_size = 1

In [None]:
# Create segmentation model with pretrained encoder
# https://github.com/qubvel/segmentation_models.pytorch
model = smp.Unet(
    encoder_name=Encoder, 
    encoder_weights=Weights, 
    decoder_attention_type=Attention,
    in_channels=3, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

In [None]:
# initialize loss, metrics, % optimizer:
loss = smp.utils.losses.DiceLoss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
    smp.utils.metrics.Fscore(),
    smp.utils.metrics.Accuracy(),
    smp.utils.metrics.Recall(),
    smp.utils.metrics.Precision()
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

# Training & Validation

In [None]:
# File where previous history would have been stored, if resuming model training
trainvalCSVname = model_name + '_TrainValcsv.csv' 

# Upload previous model & scores if resuming training 
if load_model is True:
    model_name_dir = model_name.replace("_", "-").lower()
    model_dir = '../input/' + model_name_dir + '-models-logs'
    dfTrainVal_path = os.path.join(model_dir, trainvalCSVname)
    dfTrainVal = pd.read_csv(dfTrainVal_path)
    
    resume_epoch = len(dfTrainVal.index)
    last_epoch = resume_epoch - 1
    
    last_checkpoint_name = model_name + '_epoch_' + str(last_epoch) + '.pth'
    last_checkpoint_path = os.path.join(model_dir, last_checkpoint_name)
    last_checkpoint = torch.load(last_checkpoint_path)
    
    model.load_state_dict(last_checkpoint['model_state_dict'])
    optimizer.load_state_dict(last_checkpoint['optimizer_state_dict'])
# Otherwise, create new dataframe for storing metrics    
else: 
    dfTrainVal = pd.DataFrame(columns=['Epoch', 
                                   'Train Loss','Val Loss', 
                                   'Train Accuracy','Val Accuracy', 
                                   'Train Precision','Val Precision', 
                                   'Train Recall','Val Recall', 
                                   'Train IoU','Val IoU', 
                                   'Train Fscore','Val Fscore'])
    resume_epoch = len(dfTrainVal.index)

In [None]:
# view the dataframe, whether it's empty or filled with previous history 
dfTrainVal

In [None]:
# Create transformed & preprocessed datasets
train_dataset = BacteriaDataset(train_IDs, img_dir, mask_dir, classes=['boundaries'],
                                # since we're not using train augmentations here,... 
                                # we can just use the standard transformations needed for all images here 
                                augmentation=get_val_test_augmentation(),
                                preprocessing=get_preprocessing(preprocess_input),
                               )

val_dataset = BacteriaDataset(val_IDs, img_dir, mask_dir,classes=['boundaries'],
                              augmentation=get_val_test_augmentation(),
                              preprocessing=get_preprocessing(preprocess_input),
                             )

test_dataset = BacteriaDataset(test_IDs, img_dir, mask_dir,classes=['boundaries'],
                              augmentation=get_val_test_augmentation(),
                              preprocessing=get_preprocessing(preprocess_input),
                              )

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=12)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=4)

In [None]:
# Explore DataLoader
print('Training data Info:')
dataiter = iter(train_loader)
data = dataiter.next()
images,labels = data
print("shape of images : {}".format(images.shape))
print("shape of labels : {}".format(labels.shape))

print('\nValidation data Info:')
dataiter = iter(val_loader)
data = dataiter.next()
images,labels = data
print("shape of images : {}".format(images.shape))
print("shape of labels : {}".format(labels.shape))

print('\nTest data Info:')
dataiter = iter(test_loader)
data = dataiter.next()
images,labels = data
print("shape of images : {}".format(images.shape))
print("shape of labels : {}".format(labels.shape))

In [None]:
# Create epoch runners, as done in https://github.com/qubvel/segmentation_models.pytorch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

val_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [None]:
# Start (or continue) train & validating model

EPOCHS = resume_epoch + more_epochs
es = 0 # initiliaze early stopping counter

for epoch in range(resume_epoch, EPOCHS):
    
    print('\nEpoch: {}'.format(epoch))
    train_logs = train_epoch.run(train_loader)
    val_logs = val_epoch.run(val_loader)
    
    # Determine what the previous min val loss was 
    if epoch == 0:
        min_val_loss = 1
    else:
        min_val_loss = dfTrainVal['Val Loss'].min()

    # Update the dataframe with scores from this epoch
    dfTrainVal.loc[epoch, ['Epoch']] = epoch
    dfTrainVal.loc[epoch, ['Train Loss']] = train_logs['dice_loss']
    dfTrainVal.loc[epoch, ['Val Loss']] = val_logs['dice_loss']
    dfTrainVal.loc[epoch, ['Train Accuracy']] = train_logs['accuracy']
    dfTrainVal.loc[epoch, ['Val Accuracy']] = val_logs['accuracy']
    dfTrainVal.loc[epoch, ['Train Precision']] = train_logs['precision']
    dfTrainVal.loc[epoch, ['Val Precision']] = val_logs['precision']
    dfTrainVal.loc[epoch, ['Train Recall']] = train_logs['recall']
    dfTrainVal.loc[epoch, ['Val Recall']] = val_logs['recall']
    dfTrainVal.loc[epoch, ['Train IoU']] = train_logs['iou_score']
    dfTrainVal.loc[epoch, ['Val IoU']] = val_logs['iou_score']
    dfTrainVal.loc[epoch, ['Train Fscore']] = train_logs['fscore']
    dfTrainVal.loc[epoch, ['Val Fscore']] = val_logs['fscore']
    
    # Save the dataframe
    dfTrainVal.to_csv(trainvalCSVname,index=False)
    
    # Save model checkpoints
    checkpoint = {'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': loss}
    checkpoint_path = './'+model_name+'_epoch_'+str(epoch)+'.pth'
    torch.save(checkpoint, checkpoint_path)
    
    # Early stopping: check if val loss has decreased/increased from the previous min_val_loss
    val_loss = val_logs['dice_loss']
    if val_loss < min_val_loss:
        es = 0 # Early stopping not considered
    else: 
        es += 1 # Start counting
        print("EarlyStopping Counter {} of {}".format(es,patience))
        
        if es >= patience:
            print("Early stopping with min_val_loss: ", min_val_loss, "and val_loss for this epoch: ", val_loss, "...")
            break

In [None]:
# Find early stopping point
stop_pt =  dfTrainVal[['Val Loss']].astype(float).idxmin()[0]
print(stop_pt)

# Determine how many epochs were completed
epochs_completed = dfTrainVal.shape[0]
print(epochs_completed)

# Test the best saved model

In [None]:
# First, initialize the model & the optimizer 
best_model = smp.Unet(
    encoder_name=Encoder, 
    encoder_weights=Weights, 
    decoder_attention_type=Attention,
    in_channels=3,
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

best_optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [None]:
# Find the epoch with min val loss & load in checkpoints for that epoch
best_epoch = dfTrainVal[['Val Loss']].astype(float).idxmin()[0]

if best_epoch < resume_epoch:
    # previously uploaded checkpoint
    best_checkpoint_name = model_name + '_epoch_' + str(best_epoch) + '.pth'
    best_checkpoint_path = os.path.join(model_dir, best_checkpoint_name)
else:
    # newly saved checkpoint
    best_checkpoint_path = './' + model_name + '_epoch_' + str(best_epoch) + '.pth'

best_checkpoint = torch.load(best_checkpoint_path)
best_model.load_state_dict(best_checkpoint['model_state_dict'])

In [None]:
# Evaluate model on test set
test_epoch = smp.utils.train.ValidEpoch(
    model=best_model,
    loss=loss,
    metrics=metrics,
    device=device,
)

test_logs = test_epoch.run(test_loader)

In [None]:
# Set up & save dataframe for storing test scores 
dfTest = pd.DataFrame(columns=[
                           'Test Loss',
                           'Test Accuracy',
                           'Test Precision',
                           'Test Recall',
                           'Test IoU',
                           'Test Fscore'])

dfTest.loc[0, ['Test Loss']] = test_logs['dice_loss']
dfTest.loc[0, ['Test Accuracy']] = test_logs['accuracy']
dfTest.loc[0, ['Test Precision']] = test_logs['precision']
dfTest.loc[0, ['Test Recall']] = test_logs['recall']
dfTest.loc[0, ['Test IoU']] = test_logs['iou_score']
dfTest.loc[0, ['Test Fscore']] = test_logs['fscore']

testCSVname = model_name + '_epoch_'+str(best_epoch)+ '_Testcsv.csv'
dfTest.to_csv(testCSVname,index=False)

In [None]:
# Show test results
dfTest

# Visualize predictions

In [None]:
# Function for generating predicted mask (cropped back down to the size of originak image: 1000x1000)
# & skeletonized version of cropped predicted mask
# ...given an index, a dataset, & a model

def generate_prediction_skel(n, dataset, model):
    # Get transformed (padded) + preprocessed image
    image = dataset[n][0]  
    img_tensor = torch.from_numpy(image).to(device).unsqueeze(0)
    
    # Generate prediction
    pr_mask = model.predict(img_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    cropped_pr_mask = pr_mask[12:1012, 12:1012]
    
    # Skeletonize the mask
    skeleton = skeletonize(cropped_pr_mask)
    skeleton = skeleton.astype(np.float32)
    
    return cropped_pr_mask, skeleton

In [None]:
# Test dataset without transformations for image visualization
test_dataset_vis = BacteriaDataset(test_IDs, img_dir, mask_dir, classes=['boundaries'],)

In [None]:
# Create output folder for storing cropped predicted masks
pred_folder = 'predictions'
if not os.path.exists(pred_folder):
    os.makedirs(pred_folder)

In [None]:
# Create output folder for storing skeletonized cropped predicted masks
skel_folder = 'skel_predictions'
if not os.path.exists(skel_folder):
    os.makedirs(skel_folder)

In [None]:
# visualize all predictions on test set
test_size = len(test_IDs)

for i in range(test_size):
    
    # Get the image filename 
    filename = test_IDs[i]
    print(filename) # so I know which images I'm viewing
    filename_wo_ext = os.path.splitext(os.path.basename(filename))[0]
    
    # Visualize untransformed+unpreprocessed input image + ground truth mask
    image_vis, gt_vis = test_dataset_vis[i]
    gt_vis = gt_vis.squeeze()
    visualize(original_pattern_image=image_vis/255,
              ground_truth_mask=gt_vis,)
    
    # Generate and save cropped predicted mask & skeletonized version 
    cropped_pr_mask, skeleton = generate_prediction_skel(i, test_dataset, best_model)
    pred_filename = filename_wo_ext + '_pred_ep' + str(best_epoch) + '.tif'
    pred_path = os.path.join(pred_folder, pred_filename)
    cv2.imwrite(pred_path, cropped_pr_mask)
    skel_filename = filename_wo_ext + '_skel_ep' + str(best_epoch) + '.tif'
    skel_path = os.path.join(skel_folder, skel_filename)
    cv2.imwrite(skel_path, skeleton)
    
    # Visualize cropped predicted mask & skeletonized version 
    visualize(predicted_mask=cropped_pr_mask,
              skeletonized_predicted_mask=skeleton,)