# Application of Trained VGG-11 U-Net for Predicting Ring Boundaries on Newly Acquired Images

This script applies our previously optimized VGG-11 U-Net segmentation model, as presented in Doshi & Shaw et al. 2022, to generate ring boundary masks for our newly acquired *Proteus mirabilis* colony images that were not used in the original training, validation, and testing of the model<sup>1,2</sup>. Namely, we have used this model to segment rings on images of our copper-sensing strain pCopA-*flgM*, as well as one of our iptg-sensing strains, pLac-*flgM*, grown at different temperatures. As with that of our original implementation, this script utilizes various elements from "Segmentation Models: Python library with Neural Networks for Image Segmentation based on PyTorch" (SMP), including utility functions defined in the SMP car segmentation example<sup>3</sup>. 

[1] Doshi, A.\*\, M. Shaw\*\, R. Tonea, R. Minyety, S. Moon, A. Laine, J. Guo\^\, and T. Danino\^\. A deep learning pipeline for segmentation of *Proteus mirabilis* colony patterns. in *2022 IEEE 19th
International Symposium on Biomedical Imaging (ISBI)*. 2022. IEEE. doi: 10.1109/ISBI52829.2022.9761643

[2] daninolab. mirabilis-ringboundary-seg-minimal. 2022; Available from: https://github.com/daninolab/proteus-mirabilis.

[3] Iakubovskii, P. segmentation_models.pytorch (Version 0.2.0). 2021; Available from: https://github.com/qubvel/segmentation_models.pytorch.

# Imports

In [None]:
# Earlier PyPI version (0.2.0) that we have been using: 
!pip install segmentation-models-pytorch==0.2.0
# To get the latest version from source:
#!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import numpy as np
import cv2
import csv
import copy
import time
from tqdm import tqdm
import os
import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torchvision import models
from torch.utils.data import Dataset, DataLoader
import glob
import matplotlib.pyplot as plt
from matplotlib import pylab as pl
import segmentation_models_pytorch as smp
from segmentation_models_pytorch import losses
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import albumentations as albu
from skimage.morphology import skeletonize, thin
from skimage import data
from skimage.util import invert

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Dataset

In [None]:
# Set path to folder of new images to generate predicitions for
img_dir = '../input/101322metals' 
        # '../input/101522metals' 
        # '../input/101622metals' 
        # '../input/34C_gfp_flgm_chew_0i_10i'
        # '../input/36C_gfp_flgm_chew_0i_10i'
        # '../input/37C_gfp_flgm_chew_0i_10i'

In [None]:
# Get list of the files names (to print when displaying images later)
img_list = [idx for idx in os.listdir(img_dir) if idx.endswith('.tif')]
img_list.sort(key=lambda x: int(''.join(filter(str.isdigit, x))))

In [None]:
# Print how many images we're working with
num_imgs = len(img_list)
print(num_imgs)

In [None]:
# Dataset class
# note: this doesn't include ground truth masks anymore
class BacteriaDataset(Dataset):
    
    CLASSES = ['boundaries']
    
    def __init__(self, img_IDs, img_dir, classes=None, augmentation=None, preprocessing=None):
        self.img_IDs = img_IDs
        self.img_dir = img_dir
        self.augmentation = augmentation         # for augmentations
        self.preprocessing = preprocessing       # preprocessing to normalize images
        self.imgs_fps = [os.path.join(self.img_dir, img_id) for img_id in self.img_IDs]
        
         # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
    def __len__(self):
        return len(self.img_IDs)

    def __getitem__(self, i):
        
        # read data
        img = cv2.imread(self.imgs_fps[i])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=img)
            img = sample['image']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=img)
            img = sample['image']
            
        return img

In [None]:
# Define transformations
# For training set (not used here):
def get_training_augmentation():
    train_transform = [albu.PadIfNeeded(min_height=1024, min_width=1024, always_apply=True, border_mode=cv2.BORDER_REFLECT_101),
                       albu.Rotate(limit=(-10,10), border_mode=cv2.BORDER_REFLECT_101, p=0.5),
                       albu.HorizontalFlip(p=0.5),
                       albu.VerticalFlip(p=0.5),
                       albu.ShiftScaleRotate(shift_limit=0.05, scale_limit=0, rotate_limit=0,
                                          border_mode=cv2.BORDER_REFLECT_101, p=0.5), # translate
                       albu.ShiftScaleRotate(shift_limit=0, scale_limit=0.5, rotate_limit=0,
                                          border_mode=cv2.BORDER_REFLECT_101, p=0.5), # zoom
                      ]
    return albu.Compose(train_transform)

# For validation and test sets 
# (necessary for resizing images to feed into model):
def get_val_test_augmentation():
    val_test_transform = [
                       albu.PadIfNeeded(min_height=1024, min_width=1024, always_apply=True, border_mode=cv2.BORDER_REFLECT_101),
                      ]
    return albu.Compose(val_test_transform)

# Necessary for feeding images into model
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor),
    ]
    return albu.Compose(_transform)

In [None]:
# Helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(15, 10))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image,cmap='binary',vmin=0,vmax=1)
    plt.show()

In [None]:
# Function for generating predicted mask (cropped back down to the size of originak image: 1000x1000)
# & skeletonized version of cropped predicted mask
# ...given an index, a dataset, & a model

def generate_prediction_skel(n, dataset, model):
    # Get transformed (padded) + preprocessed image
    image = dataset[n] 
    img_tensor = torch.from_numpy(image).to(device).unsqueeze(0)
    
    # Generate prediction
    pr_mask = model.predict(img_tensor)
    pr_mask = (pr_mask.squeeze().cpu().numpy().round())
    cropped_pr_mask = pr_mask[12:1012, 12:1012]
    
    # Skeletonize the mask
    skeleton = skeletonize(cropped_pr_mask)
    skeleton = skeleton.astype(np.float32)
    
    return cropped_pr_mask, skeleton

# Load in Model

In [None]:
# Variables for initializing the previously trained model 
Encoder = 'vgg11'
Weights = 'imagenet'
ACTIVATION = 'sigmoid'
Attention = None 
CLASSES = ['boundaries']
preprocess_input = get_preprocessing_fn(Encoder, Weights)

In [None]:
# Function for loading in the previously trained model 
def load_model(checkpoint_path):
    # Initialize the model (optimizer not needed for inference)
    model = smp.Unet(
        encoder_name=Encoder, 
        encoder_weights=Weights, 
        decoder_attention_type=Attention,
        in_channels=3, 
        classes=len(CLASSES), 
        activation=ACTIVATION,
    )
    
    # Load in the checkpoint
    checkpoint = torch.load(checkpoint_path)
    
    # Load in the model's learned parameters
    model.load_state_dict(checkpoint['model_state_dict'])
    
    return model

In [None]:
# Set the path to the previously trained model & load it in
checkpoint_path = '../input/best-unets-earlystopping/final_models_to_test/vgg11_UNet_cp_noaug_300refined_082021_epoch_34.pth'
model = load_model(checkpoint_path)

# Generate predicted masks

In [None]:
# Dataset without transformations/preprocessing for image visualization
dataset_vis = BacteriaDataset(img_list, img_dir, classes=['boundaries'],)

In [None]:
# Preprocessed dataset (no augmentations) for feeding into model
dataset = BacteriaDataset(img_list, img_dir, classes=['boundaries'],
                          augmentation=get_val_test_augmentation(),
                          preprocessing=get_preprocessing(preprocess_input),)

# Preprocessed dataset loader
test_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

In [None]:
# Explore DataLoader
print('\nData Info:')
dataiter = iter(test_loader)
data = dataiter.next()
images = data
print("shape of images : {}".format(images.shape))

In [None]:
# Create output folder for storing cropped predicted masks
pred_folder = 'predictions'
if not os.path.exists(pred_folder):
    os.makedirs(pred_folder)

In [None]:
# Create output folder for storing skeletonized cropped predicted masks
skel_folder = 'skel_predictions'
if not os.path.exists(skel_folder):
    os.makedirs(skel_folder)

In [None]:
# Generate and save cropped predicted masks (& skeletonized versions) 
# (show every 10th image)
for n in range(num_imgs):
    
    # generate and save predicted mask & skeleton
    filename = img_list[n]
    filename_wo_ext = os.path.splitext(os.path.basename(filename))[0]
    
    cropped_pr_mask, skeleton = generate_prediction_skel(n, dataset, model)
    pred_filename = filename_wo_ext + '_pred.tif'
    pred_path = os.path.join(pred_folder, pred_filename)
    cv2.imwrite(pred_path, cropped_pr_mask)
    
    skel_filename = filename_wo_ext + '_skel.tif'
    skel_path = os.path.join(skel_folder, skel_filename)
    cv2.imwrite(skel_path, skeleton)
    
    # Show every 10th image
    if (n % 10 == 0):
        
        # So we know which image we're viewing
        print(filename) 
        image_vis = dataset_vis[n]

        # Visualize original image, cropped predicted mask, & skeletonized version
        visualize(original_pattern_image=image_vis/255,
                predicted_mask=cropped_pr_mask,
                skeletonized_predicted_mask=skeleton,)
