<a href="https://colab.research.google.com/github/mobarakol/ST-MTL/blob/main/ST_MTL_Segmentation_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ST-MTL: Spatio-Temporal multitask learning model to predict scanpath while tracking instruments in robotic surgery

Representation learning of the task-oriented attention while tracking instrument holds vast potential in image-guided robotic surgery. Incorporating cognitive ability to automate the camera control enables the surgeon to concentrate more on dealing with surgical instruments. The objective is to reduce the operation time and facilitate the surgery for both surgeons and patients. We propose an end-to-end trainable Spatio-Temporal Multi-Task Learning (ST-MTL) model with a shared encoder and spatio-temporal decoders for the real-time surgical instrument segmentation and task-oriented saliency detection. In the MTL model of shared-parameters, optimizing multiple loss functions into a convergence point is still an open challenge. We tackle the problem with a novel asynchronous spatio-temporal optimization (ASTO) technique by calculating independent gradients for each decoder. We also design a competitive squeeze and excitation unit by casting a skip connection that retains weak features, excites strong features, and performs dynamic spatial and channel-wise feature recalibration. To capture better long term spatio-temporal dependencies, we enhance the long-short term memory (LSTM) module by concatenating high-level encoder features of consecutive frames. We also introduce Sinkhorn regularized loss to enhance task-oriented saliency detection by preserving computational efficiency. We generate the task-aware saliency maps and scanpath of the instruments on the dataset of the MICCAI 2017 robotic instrument segmentation challenge. Compared to the state-of-the-art segmentation and saliency methods, our model outperforms most of the evaluation metrics and produces an outstanding performance in the challenge.

Paper: [ST-MTL: Spatio-Temporal multitask learning model to predict scanpath while tracking instruments in robotic surgery](https://www.sciencedirect.com/science/article/pii/S1361841520302012) <br>
Code Architecture only: https://github.com/mobarakol/ST-MTL<br>

Instrument Classes:"Bipolar Forceps": 1, "Prograsp Forceps": 2, "Large Needle Driver": 3, "Vessel Sealer": 4, "Grasping Retractor": 5, "Monopolar Curve, Scissors": 6, "Other": 7<br>


## Citation
If you use this code for your research, please cite our paper.

```
@article{islam2021st,
  title={ST-MTL: Spatio-Temporal multitask learning model to predict scanpath while tracking instruments in robotic surgery},
  author={Islam, Mobarakol and Vibashan, VS and Lim, Chwee Ming and Ren, Hongliang},
  journal={Medical Image Analysis},
  volume={67},
  pages={101837},
  year={2021},
  publisher={Elsevier}
}
```




Download Code, Data and Trained Model

Download Code from Github

In [2]:
!rm -rf ST-MTL
!git clone https://github.com/mobarakol/ST-MTL.git
%cd ST-MTL

Cloning into 'ST-MTL'...
remote: Enumerating objects: 15, done.[K
remote: Counting objects: 100% (15/15), done.[K
remote: Compressing objects: 100% (12/12), done.[K
remote: Total 15 (delta 1), reused 9 (delta 1), pack-reused 0[K
Unpacking objects: 100% (15/15), done.
/content/ST-MTL


Download Validation Data and Trained Model

In [3]:
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [4]:
ids = ['1UqoSVJLpF6W9F5PfCitDibieZPbd4lJh', '1rx0oVv8eDoK3bXNT362Kvl6b1y4knxb-']
zip_files = ['Instrument_17.zip','best_epoch_st-mtl.pth.tar']
for id, zip_file in zip(ids, zip_files):
    downloaded = drive.CreateFile({'id':id}) 
    downloaded.GetContentFile(zip_file)
    if zip_file[-3:] == 'zip':
        !unzip -q $zip_file

## Demo

In [5]:
import math
import numpy as np
import argparse
import torch
from torch.utils.data import DataLoader
from model import ST_MTL_SEG
from dataset import SurgicalDataset
from utils import seed_everything, calculate_dice, calculate_confusion_matrix_from_arrays
import warnings
warnings.filterwarnings("ignore")

def validate(valid_loader, model, args):
    confusion_matrix = np.zeros(
            (args.num_classes, args.num_classes), dtype=np.uint32)
    model.eval()
    with torch.no_grad():
        for batch_idx, (inputs, labels_seg, _,_) in enumerate(valid_loader):
            inputs, labels_seg = inputs.to(device), np.array(labels_seg)
            pred_seg = model(inputs)
            pred_seg = pred_seg.data.max(1)[1].squeeze_(1).cpu().numpy()
            confusion_matrix += calculate_confusion_matrix_from_arrays(
                pred_seg, labels_seg, args.num_classes)    

    confusion_matrix = confusion_matrix[1:, 1:]  # exclude background
    dices = {'dice_{}'.format(cls + 1): dice
                for cls, dice in enumerate(calculate_dice(confusion_matrix))}
    dices_per_class = np.array(list(dices.values()))          

    return dices_per_class

def main():
    parser = argparse.ArgumentParser(description='Instrument Segmentation')
    parser.add_argument('--num_classes', default=8, type=int, help="num of classes")
    parser.add_argument('--data_root', default='Instrument_17', help="data root dir")
    parser.add_argument('--batch_size', default=2, type=int, help="num of classes")
    args = parser.parse_args(args=[])
    dataset_test = SurgicalDataset(data_root=args.data_root, seq_set=[4,7], is_train=False)
    test_loader = DataLoader(dataset=dataset_test, batch_size=args.batch_size, shuffle=False, num_workers=2,
                              drop_last=True)
    
    print('Sample size of test dataset:', dataset_test.__len__())
    model = ST_MTL_SEG(num_classes=args.num_classes).to(device)
    model.load_state_dict(torch.load('best_epoch_st-mtl.pth.tar'))
    model.eval()
    dices_per_class = validate(test_loader, model, args)
    print('Mean Avg Dice:%.4f [Bipolar Forceps:%.4f, Prograsp Forceps:%.4f, Large Needle Driver:%.4f, Vessel Sealer:%.4f]'
        %(dices_per_class[:4].mean(),dices_per_class[0], dices_per_class[1],dices_per_class[2],dices_per_class[3]))
    
    
if __name__ == '__main__':
    class_names = ["Bipolar Forceps", "Prograsp Forceps", "Large Needle Driver", "Vessel Sealer", "Grasping Retractor", "Monopolar Curve, Scissors", "Other"]
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    seed_everything()
    main()

Sample size of test dataset: 448


Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))


Mean Avg Dice:0.7560 [Bipolar Forceps:0.7344, Prograsp Forceps:0.6825, Large Needle Driver:0.7814, Vessel Sealer:0.8255]
