In [None]:
## Library imports
import torch, torchvision
import torch.nn as nn
import shutil
import numpy as np
import os, json, cv2, random
import matplotlib.pyplot as plt
from tqdm import tqdm
import pickle

import detectron2
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
import detectron2.data.transforms as T
from detectron2.data.transforms import Transform
from detectron2 import model_zoo
from detectron2.config import get_cfg
from detectron2.modeling import build_model,build_resnet_backbone,build_backbone
from detectron2.structures import ImageList, Instances
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
from segmentation_models_pytorch.encoders import get_encoder
from skimage import io
from shapely.geometry import box

## Local Imports
from transformations import *
from helpers import *

## Setup logger
from detectron2.utils.logger import setup_logger
setup_logger()

## Suppress warnings
import warnings
warnings.filterwarnings("ignore")

## Transformation Functions

- Illumination: We will try to simluate the absorbtion of light wave (amplitude + phase shift) by simulating 12 microcope LED's at different angle of illumination. This would be a custom layer with optimizable weights
- Contrast normalization: The myeloma cells have varying levels of contrast as compared to other cells and tissues, so normalizing the contrast can help them be more visible
- Morphological operations: Erosion or dilation operations can smooth the edges of the cells can help us detect the cancer cells better
- Gradient Filters: Sobel filter can help identify the boundaries of the cells better
- Color Channels: Manipulate the different color channels (RGB) of the image by enhancing the effect of either red, green, or blue channel
- Blur Filter: Try the Gaussian blur filters and Median blur filter to smooth the image and reduce the noise

In [None]:
## Filter Options:

# 1). IlluminationSimulation
# 2). ContrastNormalization
# 3). Dilation, Erosion
# 4). SobelFilter
# 5). EnhanceRedColor, EnhanceGreenColor, EnhanceBlueColor,
# 6). MedianFilter, GaussianBlur

TRANSFORM_TYPE = 'EnhanceBlueColor'

In [None]:
## Setup configuration

#
os.environ['transform_type'] = TRANSFORM_TYPE+".zip"
final_weights = f"../../../sg623/BME548L-ML-and-Imaging-Final-Project/outputs/{TRANSFORM_TYPE}/model_final.pth"


img_root = "../TCIA_SegPC_dataset/test/x/"
pred_root = f"./{TRANSFORM_TYPE}_preds/"
final_pred_root = f"./{TRANSFORM_TYPE}_final_preds/"
os.environ['model_final_results'] = final_pred_root

# Delete predictions directory if exists
if os.path.exists(pred_root):
    shutil.rmtree(pred_root)
os.makedirs(pred_root, exist_ok=False)
if os.path.exists(final_pred_root):
    shutil.rmtree(final_pred_root)
os.makedirs(final_pred_root, exist_ok=False)


names = os.listdir(img_root)
thresh = 0.5
res_size=(1080,1440)

## Model Architecture

In [None]:
if TRANSFORM_TYPE == "IlluminationSimulation":
    class NonNegativeConv2d(nn.Conv2d):
        def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
            super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        
        def forward(self, x):
            self.weight.data.clamp_min_(0.0)
            return nn.functional.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

    # Define netowrk architecture
    @BACKBONE_REGISTRY.register()
    class Effb5(Backbone):
        def __init__(self, cfg, input_shape):
            super().__init__()
            
            # Set up the illumination layer
            self.illumination = NonNegativeConv2d(in_channels=12, out_channels=12, kernel_size=1, stride=1, padding=0, bias=False)
            torch.nn.init.normal_(self.illumination.weight, mean=0.0, std=0.05)
            
            in_channels = 1
            encoder_name = 'timm-efficientnet-b5'
            encoder_depth = 5
            encoder_weights = 'noisy-student'
            self.encoder = get_encoder(encoder_name,
                    in_channels=in_channels,
                    depth=encoder_depth,
                    weights=encoder_weights)
            self.channels = self.encoder.out_channels
            self.conv = nn.ModuleList(
                [nn.Conv2d(self.channels[i],256,3,stride = 2, padding = 1) for i in range(len(self.channels))]
            )
            self.names = ["p"+str(i+1) for i in range(6)]
            
        def forward(self, image):

            illuminated_image = torch.sum(self.illumination(image), dim=1, keepdim=True)
            features = self.encoder(illuminated_image)
            out = {self.names[i]: self.conv[i](features[i]) for i in range(1, len(features))}
            return out

        def output_shape(self):
            out_shape = {self.names[i]: ShapeSpec(channels =256, stride = 2**(i+1)) for i in range(1, len(self.names))}
            return out_shape
        
else:
    # Define netowrk architecture
    @BACKBONE_REGISTRY.register()
    class Effb5(Backbone):
        def __init__(self, cfg, input_shape):
            super().__init__()
            
            in_channels = 3
            encoder_name = 'timm-efficientnet-b5'
            encoder_depth = 5
            encoder_weights = 'noisy-student'
            self.encoder = get_encoder(encoder_name,
                    in_channels=in_channels,
                    depth=encoder_depth,
                    weights=encoder_weights)
            self.channels = self.encoder.out_channels
            self.conv = nn.ModuleList(
                [nn.Conv2d(self.channels[i],256,3,stride = 2, padding = 1) for i in range(len(self.channels))]
            )
            self.names = ["p"+str(i+1) for i in range(6)]
            
        def forward(self, image):

            features = self.encoder(image)
            out = {self.names[i]: self.conv[i](features[i]) for i in range(1, len(features))}
            return out

        def output_shape(self):
            out_shape = {self.names[i]: ShapeSpec(channels =256, stride = 2**(i+1)) for i in range(1, len(self.names))}
            return out_shape

## Predictor Class

In [None]:
class DefaultPredictor:
    """
    
    """

    def __init__(self, cfg, transformation):
        """
        
        """
        
        # Model conifguration object
        self.cfg = cfg.clone()
        self.model = build_model(self.cfg)
        
        # Set model to evaluation mode
        self.model.eval()
    
        # Load model from checkpoint
        checkpointer = DetectionCheckpointer(self.model)
        checkpointer.load(cfg.MODEL.WEIGHTS)
        
        # Initialize transformation
        self.aug = transformation[0]
        
        # Get inpur format
        self.input_format = cfg.INPUT.FORMAT
        assert self.input_format in ["RGB", "BGR"], self.input_format

    def __call__(self, original_image):
        """
        Args:
            original_image (np.ndarray): an image of shape (H, W, C) (in BGR order).
        Returns:
            predictions (dict):
                the output of the model for one image only.
                See :doc:`/tutorials/models` for details about the format.
        """
        with torch.no_grad():
            # Apply pre-processing to image.
            if self.input_format == "RGB":
                # whether the model expects BGR inputs or RGB
                original_image = original_image[:, :, ::-1]
            
            height, width = original_image.shape[:2]
            if not isinstance(transformation[0], type(NoOpTransform())):
                image = self.aug.get_transform(original_image).apply_image(original_image)
            else:
                image = original_image
            image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1))

            inputs = {"image": image, "height": height, "width": width}
            predictions = self.model([inputs])[0]
            
            return predictions

## Configuration Setup

In [None]:
# Initialize model configuration
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"))
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
cfg.MODEL.BACKBONE.NAME = "Effb5"
# Load model weights
cfg.MODEL.WEIGHTS = final_weights

if TRANSFORM_TYPE == "IlluminationSimulation":
    # Additional parameters for "IlluminationSimulation Transform"
    cfg.MODEL.PIXEL_MEAN = [0.5]
    cfg.MODEL.PIXEL_STD = [1.0]

# Inisitalize the transformation    
transformation = select_transformation(TRANSFORM_TYPE)

# Create predictor object
predictor = DefaultPredictor(cfg, transformation)

## Generate Predictions on Test Data

In [None]:
for name in tqdm(names):
    # Get index number of file
    index = name[:-4]
    
    # read the image
    im = cv2.imread(img_root+name)
    
    # Get dimensions of image
    orig_shape = im.shape[0:2]
    
    # Resize image
    im = cv2.resize(im, res_size[::-1], interpolation=cv2.INTER_NEAREST)
    
    # Get predictions
    outputs = predictor(im)
    
    # Get scores and predicted masks
    scores = outputs['instances'].to('cpu').scores.numpy()
    pred_masks = outputs['instances'].to('cpu').pred_masks.numpy()
    pred_boxes = outputs['instances'].to('cpu').pred_boxes.tensor.numpy()
    pred_classes = outputs['instances'].to('cpu').pred_classes.numpy()
    
    # Initialize counter
    count = 1
    bbox_list = []
    
    # Iterate through the detections
    for i in range(len(scores)):
        # check if detection probability is more than threshold
        if scores[i] >= thresh:
            # Extract and save the mask
            tmp_mask = pred_masks[i].astype('uint8')
            tmp_mask = 255 * tmp_mask
            tmp_mask = cv2.resize(tmp_mask, orig_shape[::-1],interpolation=cv2.INTER_NEAREST)
            cv2.imwrite(pred_root+index+'_'+str(count)+'.bmp', tmp_mask)
            count += 1
            
            # Save the filtered Bounding Boxes
            bbox_list.append((pred_boxes[i], pred_classes[i]))
    
    # Save the detected bounding boxes for each image
    with open(f'{pred_root}{index}.pkl', 'wb') as f:
        pickle.dump(bbox_list, f)

In [None]:
## Combine results for cyptoplasm and nucleus detections
for name in tqdm(names):
    
    # Get index name of file
    index = name[:-4]

    # Load the bounding boxes from the pickle file
    with open(pred_root+index+'.pkl', 'rb') as f:
        pickle_file = pickle.load(f)
    
    # Get all bboxes and class predictions
    bboxes = [item[0] for item in pickle_file]
    bbox_classes = [item[1] for item in pickle_file]
    
    # Calculate the overlap ratio between all pairs of bounding boxes
    overlap_matrix = np.zeros((len(bboxes), len(bboxes)))
    for i in range(len(bboxes)):
        for j in range(i + 1, len(bboxes)):
            bbox_i = bboxes[i]
            bbox_j = bboxes[j]
            box_i = box(*bbox_i)
            box_j = box(*bbox_j)
            intersection = box_i.intersection(box_j).area
            union = box_i.union(box_j).area
            overlap_ratio = intersection / union
            overlap_matrix[i, j] = overlap_ratio
            overlap_matrix[j, i] = overlap_ratio

    # Loop through the overlap matrix and find the overlapping bounding boxes
    overlappig_bbs = []
    for i in range(len(bboxes)):
        for j in range(i + 1, len(bboxes)):
            if overlap_matrix[i, j] > 0.2:
                overlappig_bbs.append((i,j)) 
                
    # Find the non overlap bb's
    all_overlappig_bbs = [item for sublist in overlappig_bbs for item in sublist]
    left_bbs = [bb for bb in range(len(bbox_classes)) if bb not in all_overlappig_bbs]
    
    counter = 1
    # Create cytoplasm and nucleus masks
    for bb1, bb2 in overlappig_bbs:
        # Read the two masks
        img1 = io.imread(pred_root+index+'_'+str(bb1+1)+'.bmp', as_gray=True).copy()
        img2 = io.imread(pred_root+index+'_'+str(bb2+1)+'.bmp', as_gray=True).copy()
        
        # get the classes
        img1_class = bbox_classes[bb1]
        img2_class = bbox_classes[bb2]

        # if cytoplasm
        if img1_class:
            # keep cytoplasm pixels to 20
            img_1_index = np.where(img1 == 255)
            img1[img_1_index] = 20
            
            # keep nuclues pixels to 40
            img_2_index = np.where(img2 == 255)
            img1[img_2_index] = 40
            
            # save the mask
            io.imsave(final_pred_root+index+'_'+str(counter)+'.bmp', img1)
        # if nucleus
        else:
            # keep cytoplasm pixels to 20
            img_2_index = np.where(img2 == 255)
            img2[img_2_index] = 20
            
            # keep nuclues pixels to 40
            img_1_index = np.where(img1 == 255)
            img2[img_1_index] = 40
            # save the mask
            io.imsave(final_pred_root+index+'_'+str(counter)+'.bmp', img2)
        
        # Update counter
        counter = counter + 1
    
    # process the single detections of cytoplasm and nuclues
    for bb_left in left_bbs:
        # read the mask image
        img1 = io.imread(pred_root+index+'_'+str(bb_left+1)+'.bmp', as_gray=True).copy()
        
        # get the predicted class
        img1_class = bbox_classes[bb_left]

        # if cytoplasm
        if img1_class:
            # keep cytoplasm pixels to 20
            img_1_index = np.where(img1 == 255)
            img1[img_1_index] = 20
            # save the mask
            io.imsave(final_pred_root+index+'_'+str(counter)+'.bmp', img1)
        # if nucleus
        else:
            # keep nuclues pixels to 40
            img_1_index = np.where(img1 == 255)
            img1[img_1_index] = 40
            # save the mask
            io.imsave(final_pred_root+index+'_'+str(counter)+'.bmp', img1)
        counter = counter + 1

In [None]:
## Create file to submit for SEG-PC Challenge 2021
file_path = './submission.txt'

# Check if file exists
if os.path.exists(file_path):
    # Delete file
    os.remove(file_path)
    print(f"File '{file_path}' deleted.")
else:
    print(f"File '{file_path}' does not exist.")

In [None]:
## Create submissions.txt
! /hpc/group/aipi540-s23/sl808/miniconda3/envs/bme548_fp/bin/python submission.py -s $model_final_results -d ./

In [None]:
## Zip the submissions file
! zip --verbose $transform_type submission.txt