<center><img src="https://drive.google.com/uc?id=1Z3JvAFmL2IkBnQmmt5f4uTcXVhO5f7cq"/></center>

------
<center>&copy; Research Group CAMMA, University of Strasbourg, <a href="http://camma.u-strasbg.fr">http://camma.u-strasbg.fr</a> 

<h2>Author: Deepak Alapatt </h2>
</center>

------

## Setup

# Learning Objective
# <center><font color=green> Lecture 7: Surgical Semantic segmentation </font></center>
<center><img src="https://drive.google.com/uc?id=1MrPylzmD6QIWcpe5pcadee00eeEH8GyP"/></center>


### **Objectives**: 
  1. PyTorch `Dataset` and `Dataloader` for a segmentation dataset
  3. Develop the surgical segmentation model for model
  5. Train the model to segment tool and anatomy instance on laparoscopic cholecystectomy frames
  6. Perform online inference on a sample cholec80 surgical video


## Setup

In [2]:
# install dependencies
# !pip install numpy
# !pip install matplotlib
# !pip install torch
# !pip install torchvision
# !pip install tqdm
# !pip install ipywidgets

In [3]:
# download resources
DIR="./resources"
![ ! -d "$DIR" ] && wget https://s3.unistra.fr/camma_public/teaching/edu4sds_resources/lec7_dl-segm/resources.zip && unzip -qq resources.zip

--2022-07-21 19:54:35--  https://s3.unistra.fr/camma_public/teaching/edu4sds_resources/lec7_dl-segm/resources.zip
Resolving s3.unistra.fr (s3.unistra.fr)... 130.79.200.152
Connecting to s3.unistra.fr (s3.unistra.fr)|130.79.200.152|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3344266745 (3.1G) [application/zip]
Saving to: ‘resources.zip’


2022-07-21 20:54:16 (6.94 MB/s) - Read error at byte 2420508136/3344266745 (Success). Retrying.

--2022-07-21 20:54:17--  (try: 2)  https://s3.unistra.fr/camma_public/teaching/edu4sds_resources/lec7_dl-segm/resources.zip
Connecting to s3.unistra.fr (s3.unistra.fr)|130.79.200.152|:443... connected.
HTTP request sent, awaiting response... 206 Partial Content
Length: 3344266745 (3.1G), 923758609 (881M) remaining [application/zip]
Saving to: ‘resources.zip’

resources.zip       100%[++++++++++++++=====>]   3.11G   960KB/s    in 8m 23s  

2022-07-21 21:02:41 (1.75 MB/s) - ‘resources.zip’ saved [3344266745/3344266745]



In [5]:
import torch
import torch.nn as nn
import transforms as T
import cv2
import random
import torchvision
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import glob
import os
import ipywidgets as wd
from tqdm.notebook import tqdm
import json
import numpy as np
import shutil
from PIL import Image, ImageColor
import io
import torchvision.transforms as transforms
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_EPOCHS = 30
DO_TRAINING = False
FINAL_MODEL_PATH = "resources/model_segm_cholecseg8k.pth"
MODEL_NAME = "fcn_resnet50" 

# Defining a color used to depict each semantic class being segmented

META_DATA = [
    ("black_background", (0,0,0)),
    ("abdominal_wall", (33, 191, 197)),
    ("liver", (231, 126, 9)),
    ("gastrointestinal_tract", (209, 53, 84)),
    ("fat", (80, 155, 4)),
    ("grasper", (255, 207, 210)),
    ("connective_tissue", (169, 52, 199)),
    ("blood", (229, 18, 18)),
    ("cystic_duct", (149, 50, 18)),
    ("l-hook_electrocautery", (46, 43, 180)),
    ("gallbladder", (148, 55, 66)),
    ("hepatic_vein", (214, 51, 149)),
    ("liver_ligament", (240, 79, 10)),
]

COLORS = np.array([m[1] for m in META_DATA]).astype("uint8")

# Deining a Look up table used to map each class id representing a semantic class
# to it's respective color 
def get_lut():
    lut = np.zeros((256, 1, 3), dtype=np.uint8)
    for ii, p in enumerate(COLORS):
        lut[ii, 0] = np.array([p[2], p[1], p[0]]).astype(np.uint8)
    return lut
LUT = get_lut()

# Defining some basic transformations
IM2TENSOR = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)

# Optimizer parameters
learning_rate = 0.00125
momentum = 0.9
power = 0.9
weight_decay = 1e-4

ModuleNotFoundError: No module named 'torch'

## Helper functions and classes

Defining some reusable function that we will use throughout this notebook

In [6]:
# A function that converts an 2D matrix of class ids to a colored mask representing different
# semantic classes
def applyCustomColorMap(segmentation_mask):
    if len(segmentation_mask.shape) == 2:
        segmentation_mask = cv2.cvtColor(segmentation_mask, cv2.COLOR_GRAY2BGR)
    im_color = cv2.LUT(segmentation_mask, LUT)
    return im_color


# Helper function to convert a PIL image to byte array, two formats used by various libraries
# to represent images
def image_to_byte_array(image):
  imgByteArr = io.BytesIO()
  image.save(imgByteArr, format=image.format)
  imgByteArr = imgByteArr.getvalue()
  return imgByteArr


def cat_list(images, fill_value=0):
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    batch_shape = (len(images),) + max_size
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    for img, pad_img in zip(images, batched_imgs):
        pad_img[..., : img.shape[-2], : img.shape[-1]].copy_(img)
    return batched_imgs

def collate_fn(batch):
    images, targets = list(zip(*batch))
    batched_imgs = cat_list(images, fill_value=0)
    batched_targets = cat_list(targets, fill_value=255)
    return batched_imgs, batched_targets

# Helper function to do a cross entropy loss between the ground truth and predicted values
def criterion(inputs, target):
    losses = {}
    for name, x in inputs.items():
        losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255)
    if len(losses) == 1:
        return losses["out"]
    return losses["out"] + 0.5 * losses["aux"]


# Helper function to compute relevant metrics using a confusion matrix
# see: https://en.wikipedia.org/wiki/Confusion_matrix
class ConfusionMatrix:
    def __init__(self, num_classes):
        self.num_classes = num_classes
        self.mat = None

    def update(self, a, b):
        n = self.num_classes
        if self.mat is None:
            self.mat = torch.zeros((n, n), dtype=torch.int64, device=a.device)
        with torch.no_grad():
            k = (a >= 0) & (a < n)
            inds = n * a[k].to(torch.int64) + b[k]
            self.mat += torch.bincount(inds, minlength=n ** 2).reshape(n, n)

    def reset(self):
        self.mat.zero_()

    def compute(self):
        h = self.mat.float()
        acc_global = torch.diag(h).sum() / h.sum()
        acc = torch.diag(h) / h.sum(1)
        iou = torch.diag(h) / (h.sum(1) + h.sum(0) - torch.diag(h))
        return acc_global, acc, iou
    
    # Return overall accuracy, per-class accuracy, per-class Intersection over Union (IoU) and mean IoU
    def __str__(self):
        acc_global, acc, iou = self.compute()
        return ("global correct: {:.1f}\naverage row correct: {}\nIoU: {}\nmean IoU: {:.1f}").format(
            acc_global.item() * 100,
            [f"{i:.1f}" for i in (acc * 100).tolist()],
            [f"{i:.1f}" for i in (iou * 100).tolist()],
            iou.mean().item() * 100,
        )

## CholeSeg8k
The CholecSeg8k dataset [1] consists of subset of Cholec80 [2] annotated with semantic segmentation labels with 13 semantic classes for 17 video clips.

<center><img src="https://drive.google.com/uc?id=1kKJrO75QDINP18gQz8m2CzZ-cDhtuCFk"/></center>


1. _Hong, W-Y., C-L. Kao, Y-H. Kuo, J-R. Wang, W-L. Chang, and C-S. Shih. "CholecSeg8k: A Semantic Segmentation Dataset for Laparoscopic Cholecystectomy Based on Cholec80." arXiv preprint arXiv:2012.12453 (2020)._

2. _Twinanda, Andru P., Sherif Shehata, Didier Mutter, Jacques Marescaux, Michel De Mathelin, and Nicolas Padoy. "Endonet: a deep architecture for recognition tasks on laparoscopic videos." IEEE transactions on medical imaging 36, no. 1 (2016): 86-97._

## Dataset class

In [None]:
# We define a dataset class that delivers images and correponding ground truth segmentation masks
# from the CholecSeg8k. Please refer to Lecture 6 for more info on torch Datasets.

class CholecDatasetSegm(torch.utils.data.Dataset):
    def __init__(self, gt_json, meta_data, root_dir = "./resources/cholecseg8k", data_split = "train", transforms = None):
        self.gt_json = gt_json
        self.root_dir = root_dir
        self.data_split = data_split
        self.transforms = transforms
        gt_data = json.load(open(gt_json))
        self.images = [os.path.join(self.root_dir, g["file_name"]) for g in gt_data]
        self.targets = [os.path.join(self.root_dir, g["mask_name"]) for g in gt_data]
        self.metadata = meta_data
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index: int):
        img = Image.open(self.images[index]).convert("RGB")
        target = Image.open(self.targets[index]).convert("L")
        if self.transforms is not None:
            img, target = self.transforms(img, target)        
        return img, target

In [None]:
class SegmentationPresetTrain:
    def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        min_size = int(0.5 * base_size)
        max_size = int(2.0 * base_size)

        trans = [T.RandomResize(min_size, max_size)]
        if hflip_prob > 0:
            trans.append(T.RandomHorizontalFlip(hflip_prob))
        trans.extend(
            [
                T.RandomCrop(crop_size),
                T.PILToTensor(),
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=mean, std=std),
            ]
        )
        self.transforms = T.Compose(trans)

    def __call__(self, img, target):
        return self.transforms(img, target)


class SegmentationPresetEval:
    def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.transforms = T.Compose(
            [
                T.RandomResize(base_size, base_size),
                T.PILToTensor(),
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=mean, std=std),
            ]
        )

    def __call__(self, img, target):
        return self.transforms(img, target)

In [None]:
# Defining Data Loaders for the training and testing splits.
# Please refer to Lecture 6 for more info on torch Data Loaders.


def get_transform(train=True):
    if train:
        return SegmentationPresetTrain(base_size=512, crop_size=400)
    else:
        return SegmentationPresetEval(base_size=400)
    
# Train loader
dataset = CholecDatasetSegm("./resources/cholecseg8k/ch80_13vids_train.json", META_DATA, data_split="train", transforms=get_transform())
num_classes = len(META_DATA)
train_sampler = torch.utils.data.RandomSampler(dataset)
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    sampler=train_sampler,
    collate_fn=collate_fn,
    drop_last=True,
)

# Test loader
dataset_test = CholecDatasetSegm("./resources/cholecseg8k/ch80_4vids_val.json", META_DATA, data_split="val", transforms=get_transform(False))
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader_test = torch.utils.data.DataLoader(dataset_test, batch_size=1, sampler=test_sampler, collate_fn=collate_fn)

## Segmentation model

In [None]:
model = torchvision.models.segmentation.__dict__[MODEL_NAME](pretrained=True)
model.classifier[4] = nn.Conv2d(512, num_classes, 1)
model.aux_classifier [4] = nn.Conv2d(256, num_classes, 1)
model = model.to(DEVICE)

## Optimizer and learning rate scheduler

In [None]:
params_to_optimize = [
    {"params": [p for p in model.backbone.parameters() if p.requires_grad]},
    {"params": [p for p in model.classifier.parameters() if p.requires_grad]},
]
params = [p for p in model.aux_classifier.parameters() if p.requires_grad]
params_to_optimize.append({"params": params, "lr": learning_rate * 10})

iters_per_epoch = len(data_loader)
optimizer = torch.optim.SGD(params_to_optimize, lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / (iters_per_epoch * NUM_EPOCHS)) ** power)

## Helper function for training and validation for one epoch

In [None]:
# Helper function to train
def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device):
    model.train()
    train_loss  = 0.0
    pbar = tqdm(data_loader)
    for image, target in pbar:
        image, target = image.to(device), target.to(device)
        output = model(image)
        loss = criterion(output, target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        train_loss += loss.item()
        pbar.set_description("train_loss: {:.3f} lr: {:.3f}".format(loss.item(), 
                                                                    optimizer.param_groups[0]["lr"]))
    train_loss /= len(data_loader)
    return train_loss, optimizer.param_groups[0]["lr"]

# Helper function to evaluate
def evaluate(model, data_loader, device, num_classes):
    model.eval()
    confmat = ConfusionMatrix(num_classes)
    pbar = tqdm(data_loader)
    with torch.no_grad():
        for image, target in data_loader:
            image, target = image.to(device), target.to(device)
            output = model(image)
            output = output["out"]
            confmat.update(target.flatten(), output.argmax(1).flatten())
            pbar.set_description("eval")
    return confmat

In [None]:
if DO_TRAINING:
    pbar = tqdm(range(NUM_EPOCHS))
    # Train and evaluate after each epoch
    for epoch in pbar:
        train_loss, last_lr = train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, DEVICE)
        confmat = evaluate(model, data_loader_test, device=DEVICE, num_classes=num_classes)
        acc_global, acc, iu = confmat.compute()    
        pbar.set_description(
            "train_loss: {:.3f} last_lr: {:.3f} acc_global: {:.3f} iou: {:.3f}".format(
                train_loss, last_lr, acc_global.item() * 100, iu.mean().item() * 100
            )
        )
        print("confmat:", confmat)
        torch.save(model.state_dict(), "model_epoch_"+str(epoch)+".pth")
else:
    m,v = model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=DEVICE))
    print("=> loaded model weights from {} \nmissing keys = {}  invalid keys {}".format(FINAL_MODEL_PATH, m, v))
    

In [None]:
# path to the video
ROOT_DIR = "resources/cholec80_val_3vids"
VIDEO_PATH_INFERENCE = os.path.join(ROOT_DIR, "video41") # or video41 or video42
# read the paths of the video frames and sort them to make it sequential
video_frames = sorted(
    [
        int(os.path.basename(a).replace(".jpg", ""))
        for a in glob.glob(VIDEO_PATH_INFERENCE + "/*.jpg")
    ]
)
video_frames = [os.path.join(VIDEO_PATH_INFERENCE, str(i) + ".jpg") for i in video_frames]

In [None]:
test_image = open(video_frames[0], "rb").read()
lut_image = open("resources/lut.jpg", "rb").read()
# slider to scroll through the video
slider = wd.IntSlider(value=0, min=0, max=len(video_frames) - 1)
# play button to plat the video
play_button = wd.Play(
    value=0, min=0, max=len(video_frames) - 1, step=1, interval=1000
)
# input_label = wd.Text(value="input image", disabled=True,)
# text box to show the model prediction
# pred_label = wd.Text(value="output prediction", disabled=True)
# image widget to show the image
image_wd = wd.Image(value=test_image, width=600, height=336)
# image widget to show the output
image_wd_out = wd.Image(value=test_image, width=600, height=336)
# image widget to show the look up table
image_wd_lut = wd.Image(value=lut_image, width=140, height=336)
# link the output of the play button to the slider
wd.jslink((play_button, "value"), (slider, "value"))

In [None]:
# use the model in the inference mode
model.eval()
def slider_update(change):
    file_name = video_frames[change.new]
    inp_img = Image.open(file_name)
    inp_gt = cv2.imread
    inp_img_cv2 = cv2.cvtColor(np.array(inp_img), cv2.COLOR_BGR2RGB)
    image_wd.value = image_to_byte_array(inp_img)
    with torch.no_grad():
        image = IM2TENSOR(inp_img)[None].to(DEVICE)
        output = model(image)["out"][0].argmax(0).byte().cpu().numpy().astype("uint8")
        output_color = applyCustomColorMap(output)
        im_output = cv2.addWeighted(inp_img_cv2, 0.5, output_color, 0.5, 0)
        image_wd_out.value = cv2.imencode('.png', im_output)[1].tobytes() #image_to_byte_array(Image.fromarray(im_output))
#         image_wd_out.value = image_to_byte_array(Image.fromarray(im_output))

In [None]:
# call the app
slider.observe(slider_update, "value")
out = wd.Output()
app = wd.VBox(
            [
                wd.HBox([image_wd, image_wd_out, image_wd_lut]),
                wd.HBox([play_button, slider]),
            ]
        )    
display(app)