<a href="https://colab.research.google.com/github/mobarakol/AP-MTL/blob/main/AP_MTL_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# AP-MTL: Attention Pruned Multi-task Learning Model for Real-time Instrument Detection and Segmentation in Robot-assisted Surgery

The demo is reproducing partial result of the paper  excluding<br> 
- Detection evaluation and post-processing using detection prediction.<br>
- Global Attention Dynamic Pruning (GADP) pruning algorithm


Clone the git

In [1]:
!git clone https://github.com/mobarakol/AP-MTL.git 
%cd AP-MTL

Cloning into 'AP-MTL'...
remote: Enumerating objects: 43, done.[K
remote: Counting objects: 100% (43/43), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 43 (delta 5), reused 40 (delta 5), pack-reused 0[K
Unpacking objects: 100% (43/43), done.
/content/AP-MTL


Downloading trained model and dataset

In [2]:
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [3]:
ids = ['1KHYdqSf7lMJgig_vJ5DRFkUhk4V3X4q4', '1LHMzbQpkwnqcQ1gqAQxJjWA6pMUPV4e1']
zip_files = ['Instrument_17.zip','ap_mtl.pth.tar']
for id, zip_file in zip(ids, zip_files):
    downloaded = drive.CreateFile({'id':id}) 
    downloaded.GetContentFile(zip_file)
    if zip_file[-3:] == 'zip':
        !unzip -q $zip_file

Validation without post-processing:

In [4]:
import argparse
import torch
from torch.utils.data import DataLoader
import warnings
warnings.filterwarnings('ignore')
import numpy as np

from dataset import SurgicalDataset, detection_collate
from ap_mtl import AP_MTL
from utils import  calculate_dice, calculate_confusion_matrix_from_arrays

def validate(valid_loader, model, args):
    confusion_matrix = np.zeros(
            (args.num_classes, args.num_classes), dtype=np.uint32)
    model.eval()
    with torch.no_grad():
        for idx, data in enumerate(valid_loader):
            inputs, targets, labels_seg, _ = data
            inputs, labels_seg = inputs.cuda(), np.array(labels_seg)
            pred_seg, pred_bbox = model(inputs)
            pred_seg = pred_seg.data.max(1)[1].squeeze_(1).cpu().numpy()
            confusion_matrix += calculate_confusion_matrix_from_arrays(
                pred_seg, labels_seg, args.num_classes)    

    confusion_matrix = confusion_matrix[1:, 1:]  # exclude background
    dices = {'dice_{}'.format(cls + 1): dice
                for cls, dice in enumerate(calculate_dice(confusion_matrix))}
    dices_per_class = np.array(list(dices.values()))          

    return dices_per_class

if __name__ == '__main__':
    ssd_conf = {
        'num_classes': 8,
        'lr_steps': (80000, 100000, 120000),
        'max_iter': 120000,
        'feature_maps': [128, 64, 32, 16, 14, 12],
        'min_dim': 1024,
        'steps': [8, 16, 32, 64, 74, 84], #1024/128=8
        'min_sizes': [100, 198, 356, 518, 640, 810],
        'max_sizes': [198, 356, 518, 640, 810, 1024],
        'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
        'variance': [0.1, 0.2],
        'clip': True,
        'name': 'ssd_conf',
    }
    parser = argparse.ArgumentParser(description='AP-MTL')
    parser.add_argument('--num_classes', default=8, type=int, help="num of classes")
    parser.add_argument('--data_root', default='Instrument_17', help="data root dir")
    parser.add_argument('--batch_size', default=1, type=int, help="num of classes")
    parser.add_argument('--img_size', default=1024, type=int, help="num of classes")
    args = parser.parse_args(args=[])
    dataset_test = SurgicalDataset(data_root=args.data_root, seq_set=[4,7], is_train=False)
    test_loader = DataLoader(dataset_test, args.batch_size, num_workers=2,
                                    shuffle=False, collate_fn=detection_collate, pin_memory=True, drop_last=True)
    model = AP_MTL(num_classes=args.num_classes, size=args.img_size, ssd_conf=ssd_conf)
    ckpt_dir = 'ap_mtl.pth.tar'
    model.load_state_dict(torch.load(ckpt_dir, map_location=torch.device('cpu')))
    model = model.cuda()
    dices_per_class = validate(test_loader, model, args)
    print('Mean Avg Dice:%.4f [Bipolar Forceps:%.4f, Prograsp Forceps:%.4f, Large Needle Driver:%.4f, Vessel Sealer:%.4f]'
        %(dices_per_class[:4].mean(),dices_per_class[0], dices_per_class[1],dices_per_class[2],dices_per_class[3]))

Mean Avg Dice:0.3560 [Bipolar Forceps:0.5658, Prograsp Forceps:0.2088, Large Needle Driver:0.0005, Vessel Sealer:0.6490]
