<center><img src="https://drive.google.com/uc?id=1Z3JvAFmL2IkBnQmmt5f4uTcXVhO5f7cq"/></center>

------
<center>&copy; Research Group CAMMA, University of Strasbourg, <a href="http://camma.u-strasbg.fr">http://camma.u-strasbg.fr</a> 

<h2>Author: Vinkle Srivastav </h2>
</center>

------

# <center><font color=green> Lecture 6: Surgical Workflow Recognition using PyTorch </font></center>
<center><img src="https://drive.google.com/uc?id=1M9UPbnuvU8VTQ3_QGVpzdCE-5kUgv-2U"/></center>


### **Objectives**: 
  1. PyTorch `Dataset` and `Dataloader` for subset of cholec80
  2. Visualize sample cholec80 images using PyTorch dataloaders
  3. Develop the surgical phase classification model
  4. Extract the features for the "train" and "val" dataset
  5. Train the model on extracted features
  6. Perform online inference on a sample cholec80 surgical video

## Setup

In [1]:
# install dependencies
!pip install numpy
!pip install matplotlib
!pip install torch
!pip install torchvision
!pip install tqdm
!pip install ipywidgets

Defaulting to user installation because normal site-packages is not writeable
--- Logging error ---
Traceback (most recent call last):
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_internal/utils/logging.py", line 177, in emit
    self.console.print(renderable, overflow="ignore", crop=False, style=style)
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_vendor/rich/console.py", line 1752, in print
    extend(render(renderable, render_options))
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_vendor/rich/console.py", line 1390, in render
    for render_output in iter_render:
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_internal/utils/logging.py", line 134, in __rich_console__
    for line in lines:
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_vendor/rich/segment.py", line 245, in split_lines
    for segment in segments:
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_vendor/r

Defaulting to user installation because normal site-packages is not writeable
--- Logging error ---
Traceback (most recent call last):
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_internal/utils/logging.py", line 177, in emit
    self.console.print(renderable, overflow="ignore", crop=False, style=style)
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_vendor/rich/console.py", line 1752, in print
    extend(render(renderable, render_options))
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_vendor/rich/console.py", line 1390, in render
    for render_output in iter_render:
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_internal/utils/logging.py", line 134, in __rich_console__
    for line in lines:
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_vendor/rich/segment.py", line 245, in split_lines
    for segment in segments:
  File "/home/nezih-niegu/.local/lib/python3.8/site-packages/pip/_vendor/r

In [7]:
# download resources
DIR="./resources"
![ ! -d "$DIR" ] && wget https://s3.unistra.fr/camma_public/teaching/edu4sds_resources/lec6_surg-workflow/resources.zip && unzip -qq resources.zip

In [2]:
# imports
import os
import json
import torch
import torch.nn as nn
from torch import optim
import torchvision
import numpy as np
from torchvision import models
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from PIL import Image
import pickle
import ipywidgets as wd
import glob
import io

# paths
ROOT_DIR = "resources/cholec80"
TRAIN_JSON = os.path.join(ROOT_DIR, "ch80_5vids_train.json")
VAL_JSON = os.path.join(ROOT_DIR, "ch80_3vids_val.json")
FEATURES_PATH = os.path.join(ROOT_DIR, "ch80_train_val_features.pkl")
PRETRAINED_MODEL_PATH = os.path.join(ROOT_DIR, "resnet50_ch80_20p_ft.pth")
FINAL_MODEL_PATH = os.path.join(ROOT_DIR, "resnet50_ch80_surg-flow.pth")
FINAL_LOG_PATH = os.path.join(ROOT_DIR, "train_log.json")
DO_TRAINING = True

# learning parameters
NUM_EPOCHS = 200
BATCH_SIZE = 64
CHANNEL_DIMS = [2048, 128]
POS_WEIGHTS_PHASE = [1.92, 0.20, 0.99, 0.30, 1.94, 1.0, 2.18]
LEARNING_RATE = 0.003
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005
MILE_STONES = [150, 180]
LR_GAMMA = 0.33

# others
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
META_DATA = json.load(open(TRAIN_JSON))["metadata"]

# we use two transformations for our input: transforms.ToTensor() converts images loaded by Pillow into PyTorch tensors 
# and transforms.Normalize() adjusts the values of the tensor so that their average is zero and their standard deviation is 0.5. 
IM2TENSOR = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        ),
    ]
)

print(
    "using device={}, PyTorch version={}, torchvision version={}".format(
        DEVICE, torch.__version__, torchvision.__version__
    )
)
META_DATA["phases"]

FileNotFoundError: [Errno 2] No such file or directory: 'resources/cholec80/ch80_5vids_train.json'

In [9]:
# helper function to conver PIL image to byte array
def image_to_byte_array(image):
  imgByteArr = io.BytesIO()
  image.save(imgByteArr, format=image.format)
  imgByteArr = imgByteArr.getvalue()
  return imgByteArr

# helper function to compute the accuracy from the ground truth labels and model predictions
def accuracy(labels, predictions):
    t, predicted = torch.max(predictions, 1)
    acc = (predicted == labels).float()
    return acc.mean().item()

## Cholec80 dataset
Cholec80 dataset [1] consists of 80 videos of the cholecystectomy procedures. The videos are captured at 25 fps with a resolution of 854 × 480 or 1920 × 1080. This tutorial uses a subset of the cholec80 dataset containing 5 videos for training and 3 videos for validation downsampled at 1 fps with a spatial resolution of 399x224. The surgical phases consist of 7 labels annotated by expert surgeons, as shown below.

<center><img src="https://drive.google.com/uc?id=1BHfCt8Obh1iTaeMSkpt9yDq1_-WIfSbG"/></center>


_1. Twinanda, Andru P., Sherif Shehata, Didier Mutter, Jacques Marescaux, Michel De Mathelin, and Nicolas Padoy. "Endonet: a deep architecture for recognition tasks on laparoscopic videos." IEEE transactions on medical imaging 36, no. 1 (2016): 86-97._

## 1. PyTorch `Dataset` and `Dataloader` for Cholec80
PyTorch uses two basic primitives to handle the data: a dataset object using `torch.utils.data.Dataset` to store the samples and their corresponding labels, and a dataloader object using `torch.utils.data.DataLoader` to wrap the dataset object to easy access the samples of the data.

### PyTorch `Dataset`
Now, let's go through the details of how to set the dataset class by extending `torch.utils.data.Dataset`. First, we will write the initialization function (`__init__()`)  of the class to read the ground truth labels and/or features of the "train" and the "val" set. The `__getitem__()` function will be used to read the image/feature along with the ground-truth phase.

In [10]:
class CholecDataset(torch.utils.data.Dataset):
    """
    CholecDataset class to give batch of images from the subset of cholec80 dataset (http://camma.u-strasbg.fr/datasets)
    Arguments:
        gt_json {str} -- path to json file containing ground truth annotations
        data_split {str} -- "train" or "val" split
        features [{dict}]-- list of dictionary containing "file_name" and 2048 dimensional resnet layer5 features
    """ 
    def __init__(
        self,
        gt_json="",
        data_split="train",
        features=None,
        extract_features=False,
    ):
        self.gt_json = gt_json
        self.data_split = data_split
        self.root_dir = ROOT_DIR
        data = json.load(open(gt_json))
        self.meta_data = data["metadata"]
        self.anns = data["annotations"]
        self.extract_features = extract_features
        self.transform = IM2TENSOR
        if features is not None:
            features_dict = {f["file_name"]: f["features"] for f in features}
            self.anns = [
                dict(ann, features=features_dict[ann["file_name"]])
                for ann in self.anns
            ]
        print("=> Dataset loaded for {}".format(str(self), self.meta_data))
    def __len__(self):
        # give the length of the datasets
        return len(self.anns)

    def __repr__(self):
        # print the datasets
        return "CholecDataset(" + self.data_split + ")"

    def __getitem__(self, idx):
        # if the mode is feature extraction return the PIL image and the file_name
        if self.extract_features:
            img_path = os.path.join(
                self.root_dir, self.data_split, self.anns[idx]["file_name"]
            )
            image = Image.open(img_path)
            image = self.transform(image)
            return image, self.anns[idx]["file_name"], self.anns[idx]["phase"]
        # else return the features of size 1x2048 and the corresponding annotations for phase
        else:
            features = torch.from_numpy(np.array(self.anns[idx]["features"]))
            phase_gt = self.anns[idx]["phase"]
            return features, phase_gt

#### PyTorch `dataloaders` for "train" and "val" set
Now, we will use the `torch.utils.data.DataLoader`, which wraps around the dataset class and uses multi-processing to provide a batch of sampled data. The `torch.utils.data.DataLoader` takes dataset object `torch.utils.data.Dataset`, `batch_size`, and `shuffle` as input. We have defined the dataset object above. The `batch_size` and `shuffle` parameters are described below. 

1. `batch_size` denotes the number of samples contained in each generated batch.
2. `shuffle` - If set to True, we will get a random order of samples from the dataset at each pass. Shuffling the order of examples during training helps to make our model more robust.

In [11]:
def get_dataloaders(
    extract_features=False, features=None, shuffle=False, batch_size=BATCH_SIZE
):
    """
    get the Pytorch dataloaders for the the "train" and the "val" set
    Arguments:
        extract_features {bool} -- whether to use the dataset in the feature extraction mode
        features {dict} -- resnet last layer features for the "train" or "val" split
        shuffle {bool} -- whether to shuffle the training dataset
        batch_size {int} -- batch size for the training
    Return:
        train and val dataloasers
    """    
    train_dataset = CholecDataset(
        TRAIN_JSON,
        data_split="train",
        features=features["train"] if features is not None else None,
        extract_features=extract_features,
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=shuffle
    )

    val_dataset = CholecDataset(
        VAL_JSON,
        data_split="val",
        features=features["val"] if features is not None else None,
        extract_features=extract_features,
    )
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

## 2. Let's visualize some images by iterating through dataloaders

In [12]:
def visualize_cholec80_images(num_images = 4):
    train_loader_feats, _ = get_dataloaders(extract_features=True, shuffle=True)
    image_batch, file_names, phases_id = next(iter(train_loader_feats))
    print(image_batch.shape)
    plt.figure(figsize=(15, 12))
    for i in range(num_images):
        plt.subplot(1, num_images, i + 1)
        # take the ithe image, and convert from CxHxW to HxWxC
        img = image_batch[i].cpu().squeeze().numpy().transpose(1, 2, 0)
        # rescale the image in the range 0 to 1
        vmin, vmax = img.min(), img.max()
        img = (img - vmin) / (vmax - vmin)
        # show the image
        plt.imshow(img)
        title = META_DATA["phases"][phases_id[i]] 
        plt.title(title)
        plt.axis('off')
        plt.grid()
    plt.gcf().tight_layout()
    
# call the function to visualize the images (execute multiple times to see different images from the "train" set)
visualize_cholec80_images()    

FileNotFoundError: [Errno 2] No such file or directory: 'resources/cholec80/ch80_5vids_train.json'

### 3. Surgical phase classification model
In the following, we will develop a classification model for surgical workflow recognition. We first load the resnet50 model weights in the `feature_extractor` module that are trained on the subset of the cholec80 dataset. Then, we add a few fully connected layers in the `phase_fc` module to perform the classification. PyTorch's modular design allows independently using the `feature_extractor` and `phase_fc`.

<center><img src="https://drive.google.com/uc?id=1z_SpM23Ha2E1zDGzg9T9ImhBt82Yjh0K"/></center>

In [None]:
class CholecModel(nn.Module):
    """
    A simple classification model for the surgical workflow recognition.
    """    
    def __init__(self):
        super(CholecModel, self).__init__()
        if PRETRAINED_MODEL_PATH:
            model = models.resnet50()
            print("=> loading backbone weights")
            m, v = model.load_state_dict(
                torch.load(PRETRAINED_MODEL_PATH), strict=False
            )
            print("=> backbone weights loaded... \nmissing keys = {}  invalid keys {}".format(m, v))
        else:
            # load the model with imagenet weights
            model = models.resnet50(pretrained=True)
            
        # feature_extractor module to extract the feature from a given image
        self.feature_extractor = torch.nn.Sequential(
            *(list(model.children())[:-1])
        )
        
        # develop layers for the classification module
        last_dim = CHANNEL_DIMS[0]
        layers = []
        for i, dim in enumerate(CHANNEL_DIMS[1:]):
            layers.append(nn.Linear(last_dim, dim, bias=False))
            layers.append(nn.BatchNorm1d(dim, eps=1e-05, momentum=0.1))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Dropout(p=0.9))
            last_dim = dim
        # add the last layer for the phase classification
        layers.append(
            nn.Linear(CHANNEL_DIMS[-1], len(META_DATA["phases"]), bias=True)
        )
        # phase_fc module to perform the classification
        self.phase_fc = nn.Sequential(*layers)

    def forward(self, x):
        # get the features from the image
        features = self.feature_extractor(x).flatten(1)
        # get the classification logits
        phase_logits = self.phase_fc(features)
        return phase_logits
print(CholecModel())

## Two-stage training
We will do the training in two stages: in the first stage, we will extract the feature for the "train" and the "val" set using `feature_extractor` module, and in the stage, we will only train the `phase_fc` module using the extracted features. The two-stage model helps train the model faster and develop a robust model based on temporal information from the previous frames.

### 4. Extract the features for the "train" and "val" dataset

In [None]:
# stage 1: feature extraction
def extract_features(save=True):
    if os.path.isfile(FEATURES_PATH):
        print("loading features from : {}".format(FEATURES_PATH))
        features_all = pickle.load(open(FEATURES_PATH, "rb"))
    else:
        with torch.no_grad():
            model = CholecModel().to(DEVICE).eval()
            features_all = {}
            train_loader_feats, val_loader_feats = get_dataloaders(
                extract_features=True
            )
            loaders = [train_loader_feats, val_loader_feats]
            names = ["train", "val"]
            for loader, name in zip(loaders, names):
                print("\nextracting features for the ", str(loader.dataset))
                features = []
                for image_batch, filenames, _ in tqdm(loader):
                    image_batch = image_batch.to(DEVICE)
                    feats = (
                        model.feature_extractor(image_batch)
                        .squeeze()
                        .cpu()
                        .numpy()
                    )
                    features += [
                        {"file_name": name, "features": f}
                        for f, name in zip(feats, filenames)
                    ]
                features_all[name] = features
            with open(FEATURES_PATH, "wb") as f:
                pickle.dump(features_all, f)
                print("saving features at : {}".format(FEATURES_PATH))
    return features_all

# call the function to extract the features
features = extract_features()

## 5. Train the model on extracted features

### a) Get the train and val dataloaders b) define the loss function c) get the model object d) define optimizer and scheduler

In [None]:
# get the loaders initialized from the extracted features
train_loader, val_loader = get_dataloaders(features=features, shuffle=True)

# Define the cross entropy loss.
loss_phase_fn = nn.CrossEntropyLoss(weight=torch.FloatTensor(POS_WEIGHTS_PHASE).to(DEVICE), ignore_index=-1)

# define the optimizer. Since we only need to train phase_fc module. We will 
cholec_model = CholecModel().to(DEVICE)
cholec_model.feature_extractor.eval()
for params in cholec_model.feature_extractor.parameters():
    params.requires_grad = False
parameters = filter(lambda p: p.requires_grad, cholec_model.parameters())
optimizer = optim.SGD(
    parameters,
    lr=LEARNING_RATE,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY,
    nesterov=True,
)
scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, milestones=MILE_STONES, gamma=LR_GAMMA
)

### training

In [None]:
if DO_TRAINING:
    pbar = tqdm(range(NUM_EPOCHS))
    train_logs = []
    for epoch in pbar:
        train_loss, val_loss, acc_metric_train, acc_metric_val = 0.0, 0.0, 0.0, 0.0
        # training epoch
        cholec_model.phase_fc.train()
        for feats, phase_gt in train_loader:
            feats, phase_gt = feats.to(DEVICE), phase_gt.to(DEVICE)
            phase_logits = cholec_model.phase_fc(feats)
            loss = loss_phase_fn(phase_logits, phase_gt)
            train_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            scheduler.step()
            acc_metric_train += accuracy(phase_gt, phase_logits)    
        train_loss /= len(train_loader)
        acc_metric_train = 100 * acc_metric_train / len(train_loader)
        
        # validation epoch
        with torch.no_grad():
            cholec_model.phase_fc.eval()
            for feats, phase_gt in val_loader:
                feats, phase_gt = feats.to(DEVICE), phase_gt.to(DEVICE)
                phase_logits = cholec_model.phase_fc(feats)
                loss = loss_phase_fn(phase_logits, phase_gt)
                val_loss += loss.item()
                acc_metric_val += accuracy(phase_gt, phase_logits)
            val_loss /= len(val_loader)
            acc_metric_val = 100 * acc_metric_val / len(val_loader)
        
        # log and display
        train_logs.append({'train_loss':train_loss, 
                        'val_loss':val_loss,
                        'acc_metric_train':acc_metric_train,
                        'acc_metric_val':acc_metric_val})
        pbar.set_description(
            "train_loss: {:.3f} val_loss: {:.3f} train accuracy: {:.3f} val accuracy: {:.3f}".format(
                train_loss, val_loss, acc_metric_train, acc_metric_val
            )
        )
    # save the trained model and logs
    torch.save(cholec_model.state_dict(), FINAL_MODEL_PATH)
    json.dump(train_logs, open(FINAL_LOG_PATH, "w"))
else:
    if os.path.isfile(FINAL_MODEL_PATH):
        m,v = cholec_model.load_state_dict(torch.load(FINAL_MODEL_PATH, map_location=DEVICE))
        print("=> loaded model weights from {} \nmissing keys = {}  invalid keys {}".format(FINAL_MODEL_PATH, m, v))
        train_logs = json.load(open(FINAL_LOG_PATH))
    else:
        print("=> No model weights found")

### Plot training stats

In [None]:
loss_train = [h['train_loss'] for h in train_logs]
loss_val = [h['val_loss'] for h in train_logs]
accuracy_train = [h['acc_metric_train'] for h in train_logs]
accuracy_val = [h['acc_metric_val'] for h in train_logs]
num_epochs = len(loss_train)
print(len(train_logs))
plt.figure(figsize=(12, 6))
titles = ["Accuracy vs. Number of Training Epochs", "Loss vs. Number of Training Epochs"]
ylabels = ["Accuracy", "Loss"]
yplots = [(accuracy_train, accuracy_val), (loss_train, loss_val)]
for i, title, ylabel, yplot in zip(range(2), titles, ylabels, yplots):
    plt.subplot(1, 2, i + 1)
    plt.title(title)
    plt.xlabel("Training Epochs")
    plt.ylabel(ylabel)
    plt.plot(range(1,num_epochs+1),yplot[0],label="Train-"+ylabel)
    plt.plot(range(1,num_epochs+1),yplot[1],label="Validation-"+ylabel)
    #plt.ylim((0.3,1.))
    #plt.xticks(np.arange(1, num_epochs+1, 1.0))
    plt.legend()
plt.gcf().tight_layout()

## 6. Perform live inference on a sample cholec80 surgical video

In [None]:
# path to the video
VIDEO_PATH_INFERENCE = os.path.join(ROOT_DIR, "val/video41") # or video41 or video42
# read the paths of the video frames and sort them to make it sequential
video_frames = sorted(
    [
        int(os.path.basename(a).replace(".jpg", ""))
        for a in glob.glob(VIDEO_PATH_INFERENCE + "/*.jpg")
    ]
)
video_frames = [os.path.join(VIDEO_PATH_INFERENCE, str(i) + ".jpg") for i in video_frames]

# read the ground truth file for checking the corresponding ground truth labels
VAL_GT = {
    os.path.join(ROOT_DIR, "val", p["file_name"]): p["phase"]
    for p in json.load(open(VAL_JSON))["annotations"]
}
test_image = open(video_frames[0], "rb").read()

### GUI using ipywidgets

We will use `ipywidgets` to develop the graphical user interface (GUI) buttons to `play`, `pause`, `stop`, and `slide` the video. We will also add two text boxes to show the output of the model and the corresponding ground truth phase label.  We will use the model in the inference mode. 

We will define a callback function `slider_update` on the change of the slider value. Whenever slider value gets updated, we will read the corresponding frame, pass the image to the model, get the model prediction. Finally, we will display the image, model prediction, and the ground truth label.

In [None]:
# slider to scroll through the video
slider = wd.IntSlider(value=0, min=0, max=len(video_frames) - 1)
# play button to plat the video
play_button = wd.Play(
    value=0, min=0, max=len(video_frames) - 1, step=1, interval=1000
)
# text box to show the ground truth phase label
gt_label = wd.Textarea(
    value="ground truth: prepration",
    placeholder="",
    description="",
    disabled=False,
)
# text box to show the model prediction
pred_label = wd.Textarea(
    value="prediction: prepration",
    placeholder="",
    description="",
    disabled=False,
)
# image widget to show the image
image_wd = wd.Image(value=test_image, width=600, height=336)
# link the output of the play button to the slider
wd.jslink((play_button, "value"), (slider, "value"))

In [None]:
# use the model in the inference mode
cholec_model.eval()
def slider_update(change):
    file_name = video_frames[change.new]
    input_image = Image.open(file_name)
    image_wd.value = image_to_byte_array(input_image)
    with torch.no_grad():
        image = IM2TENSOR(input_image)[None].to(DEVICE)
        logits = cholec_model(image)
    predicted_phase = META_DATA["phases"][logits.argmax().item()]
    gt_phase = META_DATA["phases"][VAL_GT[file_name]]
    gt_label.value = "ground truth: " + gt_phase
    pred_label.value = "prediction: " + predicted_phase

In [None]:
# call the app
slider.observe(slider_update, "value")
out = wd.Output()
app = wd.HBox(
    [
        wd.VBox([image_wd, wd.HBox([play_button, slider])]),
        wd.VBox([gt_label, pred_label]),
    ]
)
display(app)