In [None]:
import torch, torchvision
import torch.nn as nn

import numpy as np
import os, json, cv2, random
import matplotlib.pyplot as plt

import detectron2
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec
import detectron2.data.transforms as T
from detectron2.data.transforms import Transform
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor, DefaultTrainer
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog, DatasetMapper, build_detection_train_loader, build_detection_test_loader
from detectron2.modeling import build_model,build_resnet_backbone,build_backbone
from detectron2.structures import ImageList, Instances
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN
from detectron2.evaluation import COCOEvaluator
from detectron2.data.datasets import register_coco_instances
from segmentation_models_pytorch.encoders import get_encoder

from transformations import *

from detectron2.utils.logger import setup_logger
setup_logger()

## Transformation Functions

- Contrast normalization: The myeloma cells have varying levels of contrast as compared to other cells and tissues, so normalizing the contrast can help them be more visible
- Morphological operations: Erosion or dilation operations can smooth the edges of the cells can help us detect the cancer cells better
- Gradient Filters: Sobel filter can help identify the boundaries of the cells better
- Color Channels: Manipulate the different color channels (RGB) of the image by suppressing or enhancing the effect of either red, green, or blue channel
- Blur Filter: Try the Gaussian blur filters to smooth the image and reduce the noise
- Resolution: Benchmark the accuracy of detecting myeloma cells by reducing the resolution of the image to see if we can get at par accuracy with a smaller dimension image

In [None]:
def select_transformation(transformation_type=None):
    """
    Define the type of augmentation to apply to the images.
    
    Args:
    
    """
    if transformation_type is None:
        transformation = NoOpTransform()
        
    else:
        transformation = eval(transformation_type + "()")
    
    return [transformation]

In [None]:
def test_transformation(img_path="../TCIA_SegPC_dataset/coco/x/106.bmp", transformation_type=None):
    """
    Display the transformation on the image
    
    """
    
    ## Read the image
    img = cv2.imread(img_path)
    
    ## Convert to RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    ## Get the image transformer
    transform_object = select_transformation(transformation_type)[0]
    
    ## APply transformation
    img_transformed = transform_object.apply_image(img)
    
    ## Create a figure with two subplots in a single row
    fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))

    ## Display the first image in the first subplot
    axs[0].imshow(img)
    axs[0].set_title('Original Image')

    ## Display the second image in the second subplot
    axs[1].imshow(img_transformed)
    axs[1].set_title('Transformed Image')

    ## Show the figure
    plt.show();

## Model Architecture

In [None]:
# Define netowrk architecture
@BACKBONE_REGISTRY.register()
class Effb5(Backbone):
    def __init__(self, cfg, input_shape):
        super().__init__()
        
        # Set up the illumination layer
#         self.illumination = nn.Conv2d(in_channels=3, out_channels=9, kernel_size=1, stride=1, padding=0, bias=False)
#         torch.nn.init.normal_(self.illumination.weight, mean=0.0, std=0.05)
#         in_channels = 1
        
        in_channels = 3
        encoder_name = 'timm-efficientnet-b5'
        encoder_depth = 5
        encoder_weights = 'noisy-student'
        self.encoder = get_encoder(encoder_name,
                in_channels=in_channels,
                depth=encoder_depth,
                weights=encoder_weights)
        self.channels = self.encoder.out_channels
        self.conv = nn.ModuleList(
            [nn.Conv2d(self.channels[i],256,3,stride = 2, padding = 1) for i in range(len(self.channels))]
        )

        self.names = ["p"+str(i+1) for i in range(6)]
        
    def forward(self, image):

#         illuminated_image = torch.sum(self.illumination(image), dim=1, keepdim=True)
#         features = self.encoder(illuminated_image)
        features = self.encoder(image)
        out = {self.names[i]: self.conv[i](features[i]) for i in range(1, len(features))}

        return out
    def output_shape(self):
        out_shape = {self.names[i]: ShapeSpec(channels =256, stride = 2**(i+1)) for i in range(1, len(self.names))}
        return out_shape

## Model Training Class

In [None]:
class CocoTrainer(DefaultTrainer):
    """
    Custom class for model training
    """
    
    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        if output_folder is None:
            os.makedirs("coco_eval", exist_ok=True)
            output_folder = "coco_eval"
            
        return COCOEvaluator(dataset_name, cfg, False, output_folder)
    
    @classmethod
    def build_train_loader(cls, cfg):
        
        mapper = DatasetMapper(cfg, is_train=True, augmentations=select_transformation(TRANSFORM_TYPE))
        
        return build_detection_train_loader(cfg, mapper=mapper)
    
    @classmethod
    def build_test_loader(cls, cfg):
        
        mapper = DatasetMapper(cfg, is_train=False, augmentations=select_transformation(TRANSFORM_TYPE))
        
        return build_detection_test_loader(cfg, mapper=mapper)

## Dataset Initialization

In [None]:
register_coco_instances("SegPC_train", {}, "../TCIA_SegPC_dataset/coco/COCO.json", "../TCIA_SegPC_dataset/coco/x/")
register_coco_instances("SegPC_val", {}, "../TCIA_SegPC_dataset/coco_val/COCO.json", "../TCIA_SegPC_dataset/coco_val/x/")

train_meta = MetadataCatalog.get('SegPC_train')
val_meta = MetadataCatalog.get('SegPC_val')

train_dicts = DatasetCatalog.get("SegPC_train")
val_dicts = DatasetCatalog.get("SegPC_val")

## Configuration Setup

In [None]:
## Initialize model configuration
cfg = get_cfg()

## Set parameters
cfg.merge_from_file(model_zoo.get_config_file("Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml"))
cfg.DATASETS.TRAIN = ("SegPC_train",)
cfg.DATASETS.TEST = ("SegPC_val",)

cfg.DATALOADER.NUM_WORKERS = 5
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv.yaml")  
cfg.SOLVER.IMS_PER_BATCH = 2
cfg.SOLVER.BASE_LR = 0.02/8
cfg.SOLVER.LR_SCHEDULER_NAME = 'WarmupCosineLR'

cfg.SOLVER.WARMUP_ITERS = 100
cfg.SOLVER.MAX_ITER = 3725
cfg.SOLVER.STEPS = (1000, 1500)
cfg.SOLVER.GAMMA = 0.05
cfg.SOLVER.CHECKPOINT_PERIOD = 500

cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 64
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1

cfg.TEST.EVAL_PERIOD = 250
cfg.MODEL.BACKBONE.NAME = "Effb5"

cfg.CUDNN_BENCHMARK = True
cfg.OUTPUT_DIR = "./output/"
os.makedirs(cfg.OUTPUT_DIR, exist_ok=True)

TRANSFORM_TYPE = "GaussianBlur" #ContrastNormalization, GaussianBlur, CorrectColor, Dilation, Erosion, SobelFilterX, SobelFilterY, EnhanceRedColor, EnhanceGreenColor, EnhanceBlueColor

## Testing Trasformation Results

In [None]:
test_transformation(transformation_type=TRANSFORM_TYPE)

## Train The Model

In [None]:
# ## Train the model
# trainer = CocoTrainer(cfg)
# trainer.resume_or_load(resume=False)
# trainer.train()