## **Deep-Learning-Cell-Seg Pipeline**
Welcome! Enclosed is a single lightweight Jupyter notebook for the segmentation of cells and organelles in volume EM datasets.

This notebook can be run locally, though your institution likely has a high-performance computing (HPC) cluster that we recommend using. We found it convienient to run the augmentation/post-processing side locally and the neural net side on an HPC. 

This notebook is broken up into 4 sections: image preparation, network training, network predictions, and post-processing. 

## **1: Image Preparation before Training**
To start you will need a small stack of raw images and a corresponding stack of labeled features. Our example training data is on OSF.

https://osf.io/mpysc/

We will also rotate the full raw data stack for multi-axis segmentation.

### **1a:** Training Augmentation
We use Albumentations for augmentation, performing elastic deformations, rotations, brightness shifts, contrast shifts, and adding Gaussian noise.

https://albumentations.ai/docs/

#### Augmentation Imports

In [None]:
import imageio
import albumentations as A
import numpy as np
from matplotlib import pyplot as plt
import os
from tifffile import imread, imwrite
import torchvision.transforms as T
from tqdm import tqdm
import random

#### Initialize + Load Data

In [None]:
#Load Small Stack of Raw and Labeled Images
input_image_file = r"C:\Users\baenencm\Desktop\Platelet Example Dataset\Training Images\RawPlatelets.tif"
input_mask_file = r"C:\Users\baenencm\Desktop\Platelet Example Dataset\Training Images\LabeledPlatelets.tif"

def visualize(image):
    plt.figure(figsize=(10,10))
    plt.axis('off')
    plt.imshow(image)
    plt.show()

#Note: for optimal resnet18 performance, this should be a multiple of 32
patch_size = 512

geometry = A.Compose([
    A.ShiftScaleRotate(shift_limit=(-0.05, 0.05), scale_limit=(-0.1, 0.1), rotate_limit=(-15, 15), interpolation=1, border_mode=4, value=0, mask_value=0, shift_limit_x=None, shift_limit_y=None, rotate_method="largest_box", always_apply=None, p=0.9),
    A.RandomCrop(width=patch_size,height=patch_size,p=1.0),
])

contrast = A.Compose([
    A.ColorJitter(brightness=(0.8, 1.2), contrast=(0.8, 1.2), saturation=(0.8, 1), hue=(-0.5, 0.5), always_apply=None, p=0.90),
])

images_stack = imageio.volread(input_image_file)
masks_stack = imageio.volread(input_mask_file)

#### Loop for Random Augmentation

Takes a paired image and mask in the stack and creates x randomly augmented versions.

ImageJ can be very helpful for viewing and managing these .tif stacks. 

In [None]:
def save_output_as_tiff(augmented_output, output_path):
    output_stack = np.stack(augmented_output, axis=0)
    imwrite(output_path, output_stack.astype(np.uint8))

validation_masks = []
validation_images = []

training_masks = []
training_images = []

skip = 1

for i in range(len(images_stack) // skip):
    for j in range(50):  # To adjust the number of iterations per image
        image = images_stack[i * skip]
        mask = masks_stack[i * skip]
        
        # First apply brightness and contrast... we don't need to alter constrast of the mask
        image = contrast(image=image)['image']
        
        # Apply geometric transforms on both image and mask
        transformed = geometry(image=image, mask=mask)
        transformed_mask = transformed['mask']
        transformed_image = transformed['image']

        if j % 10 == 0:  # Add every 10th augmentation to validation set
            validation_images.append(transformed_image)
            validation_masks.append(transformed_mask)
        else:
            training_images.append(transformed_image)
            training_masks.append(transformed_mask)

# Save predictions to TIFF files
save_output_as_tiff(training_masks, r'C:\Users\baenencm\Desktop\Platelet Example Dataset\Augmented Training Images\Platelet_Masks.tif')
save_output_as_tiff(training_images, r'C:\Users\baenencm\Desktop\Platelet Example Dataset\Augmented Training Images\Platelet_Images.tif')
save_output_as_tiff(validation_masks, r'C:\Users\baenencm\Desktop\Platelet Example Dataset\Augmented Training Images\Platelet_Masks_validation.tif')
save_output_as_tiff(validation_images, r'C:\Users\baenencm\Desktop\Platelet Example Dataset\Augmented Training Images\Platelet_Images_validation.tif')

### **1b:** Raw Stack Rotation
Rotating raw data orientation for multi-axis analysis.

#### Imports + Loading Raw Data Stack

In [None]:
import numpy as np
import tifffile as tf

input_tiff = r"C:\Users\baenencm\Desktop\Platelet Example Dataset\Raw Data\CellSegPracticeData_ROI1.tif"
output_tiff = r"C:\Users\baenencm\Desktop\Platelet Example Dataset\Raw Data\CellSegPracticeData_ROI1_yz.tif"

#### Rotating Data
To rotate the xy stack to yz, use (1,2,0).

To rotate the xy stack to xz, use (2,0,1).

In [None]:
with tf.TiffFile(input_tiff) as tif:
    volume = tif.asarray()  

rotated_slices = np.transpose(volume, (1,2,0))

# Save to a new .tif as specified above
tf.imwrite(output_tiff, rotated_slices)

## **2: Neural Network Training**
Training using the augmented images.

### Neural Net Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tifffile import imread, imwrite
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from tqdm import tqdm
from torchvision.models import ResNet18_Weights

### Network Initialization and Training

We define a UNet with Resnet18 backbone

In [None]:
# Handle TIFF stacks
class TiffDataset(Dataset):
    def __init__(self, image_path, mask_path, transform=None):
        self.images = imread(image_path)
        self.masks = imread(mask_path)
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        mask = self.masks[idx]
        
        # Convert single-channel image to three channels
        image = np.stack([image] * 3, axis=-1)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Ensure the mask is a single channel
        mask = mask[0, :, :].unsqueeze(0)
        
        return image, mask

# U-Net model with ResNet18 backbone
class UNetResNet18(nn.Module):
    def __init__(self, n_classes):
        super(UNetResNet18, self).__init__()
        self.base_model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        self.base_layers = list(self.base_model.children())

        self.encoder1 = nn.Sequential(*self.base_layers[:3])
        self.encoder2 = nn.Sequential(*self.base_layers[3:5])
        self.encoder3 = self.base_layers[5]
        self.encoder4 = self.base_layers[6]
        self.encoder5 = self.base_layers[7]

        self.center = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

        self.decoder5 = self._decoder_block(512 + 512, 512)
        self.decoder4 = self._decoder_block(512 + 256, 256)
        self.decoder3 = self._decoder_block(256 + 128, 128)
        self.decoder2 = self._decoder_block(128 + 64, 64)
        self.decoder1 = self._decoder_block(64 + 64, 64)

        self.final_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def _decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(out_channels, out_channels, kernel_size=2, stride=2)
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(enc1)
        enc3 = self.encoder3(enc2)
        enc4 = self.encoder4(enc3)
        enc5 = self.encoder5(enc4)

        center = self.center(enc5)

        dec5 = self.decoder5(center)
        dec4 = self.decoder4(torch.cat([dec5, self._crop(enc4, dec5)], 1))
        dec3 = self.decoder3(torch.cat([dec4, self._crop(enc3, dec4)], 1))
        dec2 = self.decoder2(torch.cat([dec3, self._crop(enc2, dec3)], 1))
        dec1 = self.decoder1(dec2)

        final = self.final_conv(dec1)
        return self._resize(final, x.size()[2:])

    def _crop(self, enc, dec):
        _, _, H, W = dec.size()
        enc = transforms.CenterCrop([H, W])(enc)
        return enc

    def _resize(self, input, size):
        return nn.functional.interpolate(input, size=size, mode='bilinear', align_corners=True)

# Transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512 , 512))  # Again, ensure the size is a multiple of 32
])

### GPU Availability
Ensure you are utilizing a GPU, confirmed by 'cuda'

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = UNetResNet18(n_classes=1).to(device)

### Network Training

In [None]:
dataset = TiffDataset(
    '/gpfs/gsfs12/users/baenencm/Platelet_Images.tif', 
    '/gpfs/gsfs12/users/baenencm/Platelet_Masks.tif', 
    transform=transform
)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

validation_data = TiffDataset(
    '/gpfs/gsfs12/users/baenencm/Platelet_Images_validation.tif', 
    '/gpfs/gsfs12/users/baenencm/Platelet_Masks_validation.tif', 
    transform=transform
)
validation_dataloader = DataLoader(validation_data, batch_size=1, shuffle=True)


# Initialize the model, loss function, and optimizer
model = UNetResNet18(n_classes=1).to(device)  # Move model to GPU if available
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)


# Choose # of training epochs
num_epochs = 7

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    val_loss = 0

    for images, masks in tqdm(dataloader):
        images, masks = images.to(device), masks.to(device)  # Move data to GPU if available

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    

    for val_images, val_masks in tqdm(validation_dataloader):
        val_images, val_masks = val_images.to(device), val_masks.to(device)  # Move data to GPU if available
        val_outputs = model(val_images)
        loss = criterion(val_outputs, val_masks)
        val_loss += loss.item()
    


    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {epoch_loss/len(dataloader)}')
    print(f'Epoch {epoch+1}/{num_epochs}, Validation Loss: {val_loss/len(validation_dataloader)}')
    print()

### Save the Network Output Parameters

In [None]:
torch.save(model.state_dict(), '/gpfs/gsfs12/users/baenencm/Practice_Platelet_Membranes.pth')

### Visualize some Predictions

In [None]:
model.eval()
with torch.no_grad():
    for i in range(30): # set to however many original training images you labeled
        image, mask = validation_data[i*10]
        image = image.unsqueeze(0).to(device)  # move data to GPU if available
        output = model(image)
        output = torch.sigmoid(output).cpu().squeeze().numpy()

        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        axes[0].imshow(image.cpu().squeeze().permute(1, 2, 0).numpy(), cmap='gray')
        axes[0].set_title('Input Image')
        axes[1].imshow(mask.cpu().squeeze().numpy(), cmap='gray')
        axes[1].set_title('Ground Truth')
        axes[2].imshow(output*100000, cmap='gray')
        axes[2].set_title('Prediction')
        plt.show()

## **3: Network Predictions**
### Function to segment small patches of the image, then stitch back together
We recommend to use the same patch size as the image size the network was trained on in the previous section. We found that padding = patch_size/2 and stride = (patch_size/2)-remove_edge worked well for us.

In [None]:
def segment_image(image, model, device, patch_size=512, stride=244):
    model.eval()
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    
    # Ensure image is NumPy array
    if isinstance(image, torch.Tensor):
        image = image.cpu().numpy()
        
    # Ensure image has 3 channels by stacking
    if image.ndim == 2:
        image = np.stack([image] * 3, axis=-1)

    # Initialize   empty array for final mask
    padding = 256
    remove_edge = 12
    
    # Pad the image
    image = np.pad(image, ((padding, padding*2), (padding, padding*2), (0, 0)), mode='constant',constant_values=100)
    #plt.imshow(image)

    image_tensor = transform(image).unsqueeze(0)
    image_tensor = image_tensor.to(device)
    _, _, H, W = image_tensor.shape

    mask = np.zeros((H, W))

    # Divide image into overlapping patches
    for i in range(0, H - patch_size + 1, stride):
        for j in range(0, W - patch_size + 1, stride):
            patch = image_tensor[:, :, i:i + patch_size, j:j + patch_size]
            with torch.no_grad():
                output = model(patch)
                output_mask = torch.sigmoid(output).cpu().squeeze().numpy()
                mask[i+remove_edge:i + patch_size-remove_edge, j+remove_edge:j + patch_size-remove_edge] += output_mask[remove_edge:-remove_edge,remove_edge:-remove_edge]

    return mask[padding:H-2*padding,padding: W-2*padding]

### Running the Trained Network

If doing multi-plane analysis, run the xy, yz, and xz stacks here.

In [None]:
model = UNetResNet18(n_classes=1).to(device)  # Move model to GPU if available
model.load_state_dict(torch.load('/gpfs/gsfs12/users/baenencm/Practice_Platelet_Membranes.pth'))
model.eval()


stack = imread("/gpfs/gsfs12/users/baenencm/CellSegPracticeData_ROI1.tif")
predictions = []

for i in tqdm(range(0,len(stack)), desc="Processing Images", unit="image"):
    images = stack[i]
    outputs = segment_image(images, model, device)
    predictions.append(outputs)

### Visualizing some predictions

In [None]:
for i in range(0,len(predictions),50): 
    fig, axes = plt.subplots(1, 2, figsize=(15, 5))
    axes[0].imshow(stack[i], cmap='gray')
    axes[0].set_title(f'Input: Slice {i}')
    cax = axes[1].imshow(predictions[i], cmap='magma')
    axes[1].set_title('Network Prediction')
    cbar = fig.colorbar(cax, ax=axes[1])
    cbar.set_label('Probability')

### Function to save predictions as TIFF file

In [None]:
def save_predictions_as_tiff(predictions, output_path):
    predictions = [np.array(p, dtype=np.float16) for p in tqdm(predictions, desc="Processing predictions")]
    predictions_stack = np.stack(predictions, axis=0)
    imwrite(output_path, predictions_stack)

save_predictions_as_tiff(predictions, '/gpfs/gsfs12/users/baenencm/Practice_Platelet_Membranes_xy.tif')

# **4: Post-Processing**

#### Rotating Data back to xy.
To rotate the yz prediction stack back to xy, use (2,0,1).

To rotate the xz prediction stack back to xy, use (1,2,0).

In [None]:
input_tiff = r"C:\Users\baenencm\Desktop\Platelet Example Dataset\Raw Data\Practice_Platelet_Membranes_yz.tif"
output_tiff = r"C:\Users\baenencm\Desktop\Platelet Example Dataset\Raw Data\Practice_Platelet_Membranes_yz_rotated_back.tif"

with tf.TiffFile(input_tiff) as tif:
    volume = tif.asarray()  

rotated_slices = np.transpose(volume, (1,2,0))

# Save to a new .tif as specified above
tf.imwrite(output_tiff, rotated_slices)

#### From this point, we used Amira to merge the three masks and perform the rest of the analysis. More information on this process can be found in the full paper.