In [1]:
import os
import cv2
import sys
import random

import json
import h5py
import itertools
import numpy as np
from PIL import Image
import argparse, pickle

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from tqdm import tqdm
from torch import optim
import torchvision.models
from torch.utils.data import Dataset as torchDataset
from torch.utils.data import DataLoader as TorchDataLoader
from torch.utils.data.dataloader import default_collate

# caption libraries
import evaluation
import collections
from data.example import Example
from data.utils import nostdout
from data.field import ImageDetectionsField, TextField, RawField
from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory

# graph libraries
import utils.io as io
import matplotlib.pyplot as plt
from utils.g_vis_img import *
from models.graph_su import *
from evaluation.graph_eval import *

# feature extractor
from models.feature_extractor import *

# Random seeds
seed = 27
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# Dataset Constants

In [2]:
class SurgicalSceneConstants():
    '''
    Surgical Scene constants
    '''
    def __init__( self):
        self.instrument_classes = ( 'kidney', 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
                                'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier',
                                'stapler', 'maryland_dissector', 'spatulated_monopolar_cautery')
        self.action_classes = ( 'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation', 
                                'Tool_Manipulation', 'Cutting', 'Cauterization', 
                                'Suction', 'Looping', 'Suturing', 'Clipping', 'Staple', 
                                'Ultrasound_Sensing')

# Cross-entropy loss with label smoothing

In [3]:
class CELossWithLS(torch.nn.Module):
    '''
    label smoothing cross-entropy loss for captioning
    '''
    def __init__(self, classes=None, smoothing=0.1, gamma=3.0, isCos=True, ignore_index=-1):
        super(CELossWithLS, self).__init__()
        self.complement = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = torch.nn.LogSoftmax(dim=1)
        self.gamma = gamma
        self.ignore_index = ignore_index

    def forward(self, logits, target):
        with torch.no_grad():
            oh_labels = F.one_hot(target.to(torch.int64), num_classes = self.cls).permute(0,1,2).contiguous()
            smoothen_ohlabel = oh_labels * self.complement + self.smoothing / self.cls

        logs = self.log_softmax(logits[target!=self.ignore_index])
        pt = torch.exp(logs)
        return -torch.sum((1-pt).pow(self.gamma)*logs * smoothen_ohlabel[target!=self.ignore_index], dim=1).mean()


# Dataloader

In [4]:
class DataLoader(TorchDataLoader):
    def __init__(self, dataset, *args, **kwargs):
        super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs)

class Dataset(object):
    def __init__(self, examples, fields, gsu_const):
        self.examples = examples
        self.fields = dict(fields)
        
        self.file_dir = gsu_const['file_dir']
        self.img_dir = gsu_const['img_dir']
        self.dataconst = gsu_const['dataconst']
        self.feature_extractor = gsu_const['feature_extractor']
        self.word2vec = h5py.File(gsu_const['w2v_loc'], 'r')
        
    # word2vec
    def _get_word2vec(self,node_ids):
        word2vec = np.empty((0,300))
        for node_id in node_ids:
            vec = self.word2vec[self.dataconst.instrument_classes[node_id]]
            word2vec = np.vstack((word2vec, vec))
        return word2vec

    def __getitem__(self, i):
        example = self.examples[i]
        frame_path = getattr(example, 'image')
        frame_path = frame_path.split("/")
        
        _img_loc = os.path.join(self.file_dir, frame_path[0],self.img_dir,frame_path[3].split("_")[0]+'.png')
        frame_data = h5py.File(os.path.join(self.file_dir, frame_path[0],'vsgat',self.feature_extractor, frame_path[3].split("_")[0]+'_features.hdf5'), 'r')    
        
        #print(_img_loc)
        
        # caption data
        cp_data = []
        for field_name, field in self.fields.items():
            if field_name == 'image' and field == None:
                cp_data.append(np.zeros((6,512), dtype = np.float32))
            else:
                cp_data.append(field.preprocess(getattr(example, field_name)))   
        if len(cp_data) == 1: cp_data = cp_data[0]
        
        # graph data
        gsu_data = {}
        gsu_data['img_name'] = frame_data['img_name'].value[:] + '.jpg'
        gsu_data['img_loc'] = _img_loc
        gsu_data['node_num'] = frame_data['node_num'].value
        gsu_data['roi_labels'] = frame_data['classes'][:]
        gsu_data['det_boxes'] = frame_data['boxes'][:]
        gsu_data['edge_labels'] = frame_data['edge_labels'][:]
        gsu_data['edge_num'] = gsu_data['edge_labels'].shape[0]
        if self.fields['image'] == None:
            gsu_data['features'] = np.zeros((gsu_data['node_num'],512), dtype = np.float32)
        else:
            gsu_data['features'] = frame_data['node_features'][:]
        gsu_data['spatial_feat'] = frame_data['spatial_features'][:]
        gsu_data['word2vec'] = self._get_word2vec(gsu_data['roi_labels'])
        
        data = {}
        data['cp_data'] = cp_data
        data['gsu_data'] = gsu_data
        return data

    def __len__(self):
        return len(self.examples)

    def __getattr__(self, attr):
        if attr in self.fields:
            for x in self.examples:
                yield getattr(x, attr)
                
    def collate_fn(self):
        def collate(batch):
            gsu_batch_data = {}
            gsu_batch_data['img_name'] = []
            gsu_batch_data['img_loc'] = []
            gsu_batch_data['node_num'] = []
            gsu_batch_data['roi_labels'] = []
            gsu_batch_data['det_boxes'] = []
            gsu_batch_data['edge_labels'] = []
            gsu_batch_data['edge_num'] = []
            gsu_batch_data['features'] = []
            gsu_batch_data['spatial_feat'] = []
            gsu_batch_data['word2vec'] = []

            for data in batch:
                gsu_batch_data['img_name'].append(data['gsu_data']['img_name'])
                gsu_batch_data['img_loc'].append(data['gsu_data']['img_loc'])
                gsu_batch_data['node_num'].append(data['gsu_data']['node_num'])
                gsu_batch_data['roi_labels'].append(data['gsu_data']['roi_labels'])
                gsu_batch_data['det_boxes'].append(data['gsu_data']['det_boxes'])
                gsu_batch_data['edge_labels'].append(data['gsu_data']['edge_labels'])
                gsu_batch_data['edge_num'].append(data['gsu_data']['edge_num'])
                gsu_batch_data['features'].append(data['gsu_data']['features'])
                gsu_batch_data['spatial_feat'].append(data['gsu_data']['spatial_feat'])
                gsu_batch_data['word2vec'].append(data['gsu_data']['word2vec'])

            gsu_batch_data['edge_labels'] = torch.FloatTensor(np.concatenate(gsu_batch_data['edge_labels'], axis=0))
            gsu_batch_data['features'] = torch.FloatTensor(np.concatenate(gsu_batch_data['features'], axis=0))
            gsu_batch_data['spatial_feat'] = torch.FloatTensor(np.concatenate(gsu_batch_data['spatial_feat'], axis=0))
            gsu_batch_data['word2vec'] = torch.FloatTensor(np.concatenate(gsu_batch_data['word2vec'], axis=0))
            
            cp_batch_data = []
            tensors = []
            
            for data in batch: cp_batch_data.append(data['cp_data'])
            if len(self.fields) == 1: cp_batch_data = [cp_batch_data, ]
            else: cp_batch_data = list(zip(*cp_batch_data))

            for field, data in zip(self.fields.values(), cp_batch_data):
                if field == None: tensor = default_collate(data)
                else: tensor = field.process(data)
                if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
                    tensors.extend(tensor)
                else: tensors.append(tensor)

            if len(tensors) > 1:cp_batch_data = tensors
            else: cp_batch_data = tensors[0]
            
            batch_data = {}
            batch_data['gsu'] = gsu_batch_data
            batch_data['cp'] = cp_batch_data
            
            return(batch_data)

        return collate


class PairedDataset(Dataset):
    def __init__(self, examples, fields, gsu_const):
        assert ('image' in fields)
        assert ('text' in fields)
        super(PairedDataset, self).__init__(examples, fields, gsu_const)
        self.image_field = self.fields['image']
        if self.image_field == None: print('no pre-extracted image featured')
        self.text_field = self.fields['text']
        
    def image_dictionary(self, fields=None):
        if not fields:
            fields = self.fields
        dataset = Dataset(self.examples, fields, gsu_const)
        #dataset = DictionaryDataset(self.examples, fields, gsu_const, 'image')
        return dataset
        
class COCO(PairedDataset):
    def __init__(self, image_field, text_field, gsu_const, img_root, ann_root, id_root=None):
        # setting training and val root
        roots = {}
        roots['train'] = { 'img': img_root, 'cap': os.path.join(ann_root, 'captions_train.json')}
        roots['val'] = {'img': img_root, 'cap': os.path.join(ann_root, 'captions_val.json')}

        # Getting the id: planning to remove this in future
        if id_root is not None:
            ids = {}
            ids['train'] = json.load(open(os.path.join(id_root, 'WithCaption_id_path_train.json'), 'r'))
            ids['val'] = json.load(open(os.path.join(id_root, 'WithCaption_id_path_val.json'), 'r'))   
        else: ids = None
        
        with nostdout():
            self.train_examples, self.val_examples = self.get_samples(roots, ids)
        examples = self.train_examples + self.val_examples
        super(COCO, self).__init__(examples, {'image': image_field, 'text': text_field}, gsu_const)   

    @property
    def splits(self):
        train_split = PairedDataset(self.train_examples, self.fields, gsu_const) 
        val_split = PairedDataset(self.val_examples, self.fields, gsu_const)
        return train_split, val_split

    @classmethod
    def get_samples(cls, roots, ids_dataset=None):
        train_samples = []
        val_samples = []
   
        for split in ['train', 'val']:
            anns = json.load(open(roots[split]['cap'], 'r'))
            if ids_dataset is not None: ids = ids_dataset[split]
                
            for index in range(len(ids)):              
                id_path = ids[index]
                caption = anns[index]['caption']
                example = Example.fromdict({'image': os.path.join('', id_path), 'text': caption})
                if split == 'train': train_samples.append(example)
                elif split == 'val': val_samples.append(example)
                    
        return train_samples, val_samples

# MTL Model (Graph Scene Understanding and Captioning)

In [5]:
class mtl_model(nn.Module):
    '''
    Multi-task model : Graph Scene Understanding and Captioning
    '''
    def __init__(self, feature_extractor, graph, caption):
        super(mtl_model, self).__init__()
        self.feature_extractor = feature_extractor
        self.graph_su = graph
        self.caption = caption
        self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
    
    def forward(self, img_dir, det_boxes_all, caps_gt, node_num, features, spatial_feat, word2vec, roi_labels, val = False, text_field = None):               
        
        gsu_node_feat = None
        cp_node_feat = None
        for index, img_loc in  enumerate(img_dir):
            #print(img_loc)
            _img = Image.open(img_loc).convert('RGB')
            _img = np.array(_img)
            
            img_stack = None
            for idx, bndbox in enumerate(det_boxes_all[index]):        
                roi = np.array(bndbox).astype(int)
                roi_image = _img[roi[1]:roi[3] + 1, roi[0]:roi[2] + 1, :]
                roi_image = self.transform(cv2.resize(roi_image, (224, 224), interpolation=cv2.INTER_LINEAR))
                roi_image = torch.autograd.Variable(roi_image.unsqueeze(0))
                # stack nodes images per image
                if img_stack is None: 
                    img_stack = roi_image
                else: img_stack = torch.cat((img_stack, roi_image))
            
            #print(img_stack.shape)
            # send the stack to feature extractor
            img_stack = img_stack.cuda()
            feature = feature_network(img_stack)
            feature = feature.view(feature.size(0), -1)
            
            if gsu_node_feat == None: gsu_node_feat = feature
            else: gsu_node_feat = torch.cat((gsu_node_feat,feature))
            
            feature = torch.unsqueeze(torch.cat((feature,torch.zeros((6-len(feature)),512).cuda())),0)
            if cp_node_feat == None: cp_node_feat = feature
            else: cp_node_feat = torch.cat((cp_node_feat,feature))

    
        if val == True:
            caption_output, _ = self.caption.beam_search(cp_node_feat, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1)
        else:
            caption_output = self.caption(cp_node_feat, caps_gt)
        interaction = self.graph_su(node_num, gsu_node_feat, spatial_feat, word2vec, roi_labels, validation=val)
        return interaction, caption_output

# Evaluation 

In [8]:
import itertools

def eval_mtl(model, dataloader, text_field):
    
    model.eval()
    gen = {}
    gts = {}

    # graph
    # criterion and scheduler
    g_criterion = nn.MultiLabelSoftMarginLoss()                   
    g_edge_count = 0
    g_total_acc = 0.0
    g_total_loss = 0.0
    g_logits_list = []
    g_labels_list = []
    
    for it, data in tqdm(enumerate(iter(dataloader))):
            
        graph_data = data['gsu']
        cp_data = data['cp']
            
        # graph
        img_name = graph_data['img_name']
        img_loc = graph_data['img_loc']
        node_num = graph_data['node_num']
        roi_labels = graph_data['roi_labels']
        det_boxes = graph_data['det_boxes']
        edge_labels = graph_data['edge_labels']
        edge_num = graph_data['edge_num']
        features = graph_data['features']
        spatial_feat = graph_data['spatial_feat']
        word2vec = graph_data['word2vec']
        features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)         
        
        #print(features, features.shape)
        # caption
        _, caps_gt = cp_data
            
        with torch.no_grad():
    
            g_output, caption_out = model(img_loc, det_boxes, caps_gt, node_num, features, spatial_feat, word2vec, roi_labels, val = True, text_field = text_field)
        
            g_logits_list.append(g_output)
            g_labels_list.append(edge_labels)
            # loss and accuracy
            g_loss = g_criterion(g_output, edge_labels.float())
            g_acc = np.sum(np.equal(np.argmax(g_output.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))
            
        # accumulate loss and accuracy of the batch
        g_total_loss += g_loss.item() * edge_labels.shape[0]
        g_total_acc  += g_acc
        g_edge_count += edge_labels.shape[0]
        
        caps_gen = text_field.decode(caption_out, join_words=False)
        
        for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
            gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
            gen['%d_%d' % (it, i)] = [gen_i, ]    
            gts['%d_%d' % (it, i)] = [gts_i,]
        
    #graph evaluation
    g_total_acc = g_total_acc / g_edge_count
    g_total_loss = g_total_loss / len(dataloader)

    g_logits_all = torch.cat(g_logits_list).cuda()
    g_labels_all = torch.cat(g_labels_list).cuda()
    g_logits_all = F.softmax(g_logits_all, dim=1)
    g_map_value, g_ece, g_sce, g_tace, g_brier, g_uce = calibration_metrics(g_logits_all, g_labels_all, 'test')
    
    # caption evaluation
    #if not os.path.exists('results/c_results/predict_caption'):
    #    os.makedirs('results/c_results/predict_caption')
    #json.dump(gen, open('results/c_results/predict_caption/predict_caption_val.json', 'w'))

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)

    scores, _ = evaluation.compute_scores(gts, gen)
    print('Graph : {acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f}' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce.item()) )
    print(print("Caption Scores :", scores))

# Train

In [9]:
def train(epoch, lrc, model, dataloader, dict_dataloader_val, text_field):
    '''
    Finding optimal temperature scale for graph scene understanding task
    '''
    
        
    #if args.optim == 'sgd': 
    optimizer = optim.SGD(model.feature_extractor.parameters(), lr= lrc, momentum=0.9, weight_decay=0)
    #else: 
    #    optimizer = optim.Adam(model.parameters(), lr= lrc, weight_decay=0)
       
    g_criterion = nn.MultiLabelSoftMarginLoss()
    
    for epoch_count in range(epoch):
        
        running_loss = 0.0
        running_g_acc = 0.0
        running_edge_count = 0
        iters = 0
        
        for it, data in tqdm(enumerate(iter(dataloader))):
            iters += 1
            
            graph_data = data['gsu']
            cp_data = data['cp']
            
            # graph
            img_name = graph_data['img_name']
            img_loc = graph_data['img_loc']
            node_num = graph_data['node_num']
            roi_labels = graph_data['roi_labels']
            det_boxes = graph_data['det_boxes']
            edge_labels = graph_data['edge_labels']
            edge_num = graph_data['edge_num']
            features = graph_data['features']
            spatial_feat = graph_data['spatial_feat']
            word2vec = graph_data['word2vec']
            features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
            
            # caption
            caption_nodes, caps_gt = cp_data
            caption_nodes, caps_gt = caption_nodes.to(device), caps_gt.to(device)
            
            model.zero_grad()
            interaction, caption_output = model( img_loc, det_boxes, caption_nodes, caps_gt, node_num, features, spatial_feat, word2vec, roi_labels)
            
            # graph loss and acc
            interaction = F.softmax(interaction, dim=1)
            g_loss = g_criterion(interaction, edge_labels.float())
            g_acc = np.sum(np.equal(np.argmax(interaction.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))
                    
            # caption loss
            c_criterion = CELossWithLS(classes=len(text_field.vocab), smoothing=0.1, gamma=0.0, isCos=False, ignore_index=text_field.vocab.stoi['<pad>'])
            c_loss = c_criterion(caption_output[:, :-1].contiguous(), caps_gt[:, 1:].contiguous())
            
            #uda:
            #loss = (0.5 * g_loss) + (0.5 * c_loss)
            #uda_graph:
            loss = g_loss
            #uda_caption:
            #loss = c_loss
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            running_g_acc += g_acc
            running_edge_count += edge_labels.shape[0]
            #break
        epoch_loss = running_loss/float(iters)
        epoch_g_acc = running_g_acc/float(running_edge_count)
        print("[{}] Epoch: {}/{} MTL_Loss: {:0.6f} Graph_Acc: {:0.6f}".format(\
                            'MTL-Train', epoch_count+1, epoch, epoch_loss, epoch_g_acc))
        
        checkpoint = {'state_dict': model.state_dict()}
        save_name = "checkpoints/mtl_train/UDA_Graph/checkpoint_" + str(epoch_count+1) + '_epoch.pth'
        torch.save(checkpoint, os.path.join(save_name))
        
        Print("=========== Evaluation ===============")
        eval_mtl(model, dict_dataloader_val, text_field)
  
    return

# Arguments, dataloader

In [10]:
# arguments
device = torch.device('cuda')
parser = argparse.ArgumentParser(description='Incremental domain adaptation for surgical report generation')
parser.add_argument('--batch_size', type=int, default=2)
parser.add_argument('--workers', type=int, default=0)

# caption
parser.add_argument('--exp_name', type=str, default='m2_transformer')
parser.add_argument('--m', type=int, default=40)   
parser.add_argument('--cp_cbs', type=str, default='True')
parser.add_argument('--cp_cbs_filter', default='LOG', type=str) # Potential choice: 'gau' and 'LOG'
parser.add_argument('--cp_kernel_sizex', default=3, type=int)
parser.add_argument('--cp_kernel_sizey', default=1, type=int)
parser.add_argument('--cp_decay_epoch', default=2, type=int) 
parser.add_argument('--cp_std_factor', default=0.9, type=float)

# graph
parser.add_argument('--gsu_cbs',        type=bool, default=True)
parser.add_argument('--gsu_feat', type=str,  default='resnet18_09_SC_CBS')
parser.add_argument('--gsu_w2v_loc', type=str,  default='datasets/surgicalscene_word2vec.hdf5')

# feature_extractor
parser.add_argument('--fe_use_cbs',            type=bool,      default=True,        help='use CBS')
parser.add_argument('--fe_std',                type=float,     default=1.0,         help='')
parser.add_argument('--fe_std_factor',         type=float,     default=0.9,         help='')
parser.add_argument('--fe_cbs_epoch',          type=int,       default=5,           help='')
parser.add_argument('--fe_kernel_size',        type=int,       default=3,           help='')
parser.add_argument('--fe_fil1',               type=str,       default='LOG',       help='gau, LOG')
parser.add_argument('--fe_fil2',               type=str,       default='gau',       help='gau, LOG')
parser.add_argument('--fe_fil3',               type=str,       default='gau',       help='gau, LOG')
parser.add_argument('--fe_num_classes',        type=int,       default=11,           help='11')
parser.add_argument('--fe_use_SC',             type=bool,      default=True,       help='use SuperCon')

# file dirs
print('Training check for DA_ECBS_ResNet18_09_SC_ECBS')

parser.add_argument('--gsu_img_dir', type=str,  default='left_frames')
parser.add_argument('--gsu_file_dir', type=str,  default='datasets/instruments18/')

parser.add_argument('--cp_features_path', type=str, default='datasets/instruments18/') 
parser.add_argument('--cp_annotation_folder', type=str, default='datasets/annotations_new/annotations_SD_inc')

# checkpoints dir
parser.add_argument('--fe_modelpath',          type=str,       default='feature_extractor/checkpoint/incremental/inc_ResNet18_SC_CBS_0_012345678.pkl')
parser.add_argument('--gsu_checkpoint', type=str,  default='checkpoints/g_checkpoints/da_ecbs_resnet18_09_SC_eCBS/da_ecbs_resnet18_09_SC_eCBS/epoch_train/checkpoint_D1230_epoch.pth')
parser.add_argument('--cp_checkpoint', type=str, default='checkpoints/IDA_MICCAI2021_checkpoints/SD_base_LOG/')


args = parser.parse_args(args=[])
print(args)

# graph scene understanding constants
gsu_const = {}
gsu_const['file_dir'] = args.gsu_file_dir
gsu_const['img_dir'] = args.gsu_img_dir
gsu_const['dataconst'] = SurgicalSceneConstants()
gsu_const['feature_extractor'] = args.gsu_feat
gsu_const['w2v_loc'] =args.gsu_w2v_loc


# Pipeline for image regions and text
#image_field = ImageDetectionsField(detections_path=args.cp_features_path, max_detections=6, load_in_tmp=False)  
image_field = None
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy', remove_punctuation=True, nopoints=False)

# Create the dataset 
dataset = COCO(image_field, text_field, gsu_const, args.cp_features_path, args.cp_annotation_folder, args.cp_annotation_folder)
train_dataset, val_dataset = dataset.splits   
print('train:', len(train_dataset))
print('val:', len(val_dataset))
    
# caption data
if not os.path.isfile('datasets/vocab_%s.pkl' % args.exp_name):
    print("Building vocabulary")
    text_field.build_vocab(train_dataset, val_dataset, min_freq=2)  
    pickle.dump(text_field.vocab, open('datasets/vocab_%s.pkl' % args.exp_name, 'wb'))
else:
    text_field.vocab = pickle.load(open('datasets/vocab_%s.pkl' % args.exp_name, 'rb'))

print('vocabulary size is:', len(text_field.vocab))
print(text_field.vocab.stoi)

# dataset
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()})
dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size) # for caption with word GT class number

Training check for DA_ECBS_ResNet18_09_SC_ECBS
Namespace(batch_size=2, cp_annotation_folder='datasets/annotations_new/annotations_SD_inc', cp_cbs='True', cp_cbs_filter='LOG', cp_checkpoint='checkpoints/IDA_MICCAI2021_checkpoints/SD_base_LOG/', cp_decay_epoch=2, cp_features_path='datasets/instruments18/', cp_kernel_sizex=3, cp_kernel_sizey=1, cp_std_factor=0.9, exp_name='m2_transformer', fe_cbs_epoch=5, fe_fil1='LOG', fe_fil2='gau', fe_fil3='gau', fe_kernel_size=3, fe_modelpath='feature_extractor/checkpoint/incremental/inc_ResNet18_SC_CBS_0_012345678.pkl', fe_num_classes=11, fe_std=1.0, fe_std_factor=0.9, fe_use_SC=True, fe_use_cbs=True, gsu_cbs=True, gsu_checkpoint='checkpoints/g_checkpoints/da_ecbs_resnet18_09_SC_eCBS/da_ecbs_resnet18_09_SC_eCBS/epoch_train/checkpoint_D1230_epoch.pth', gsu_feat='resnet18_09_SC_CBS', gsu_file_dir='datasets/instruments18/', gsu_img_dir='left_frames', gsu_w2v_loc='datasets/surgicalscene_word2vec.hdf5', m=40, workers=0)
no pre-extracted image featured
no 

# Feature Extractor

In [9]:
# net model
if args.fe_use_SC: feature_network = SupConResNet(args=args)
else: feature_network = ResNet18(args)

# CBS
if args.fe_use_cbs:
    if args.fe_use_SC: feature_network.encoder.get_new_kernels(0)
    else: feature_network.get_new_kernels(0)

# gpu
num_gpu = torch.cuda.device_count()
if num_gpu > 0:
    device_ids = np.arange(num_gpu).tolist()    
    if args.fe_use_SC:
        feature_network.encoder = torch.nn.DataParallel(feature_network.encoder)
        feature_network = feature_network.cuda()
    else:
        feature_network = nn.DataParallel(feature_network, device_ids=device_ids).cuda()

# Caption model

In [10]:
if args.cp_cbs == 'True':
    from models.transformer import MemoryAugmentedEncoder_CBS
    print("MemoryAugmentedEncoder_CBS")
    encoder = MemoryAugmentedEncoder_CBS(3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': args.m})
else:
    print("MemoryAugmentedEncoder")
    encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': args.m}) 

decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])
caption_model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)

if args.cp_cbs == 'True':
    caption_model.encoder.get_new_kernels(0, args.cp_kernel_sizex, args.cp_kernel_sizey, args.cp_decay_epoch, args.cp_std_factor, args.cp_cbs_filter) 

MemoryAugmentedEncoder_CBS


# Graph Model

In [11]:
graph_su_model = AGRNN(bias= True, bn=False, dropout=0.3, multi_attn=False, layer=1, diff_edge=False, use_cbs = args.gsu_cbs)
if args.gsu_cbs:
    graph_su_model.grnn1.gnn.apply_h_h_edge.get_new_kernels(0)

# Load Pre-trained_weights

In [12]:
# caption
pretrained_model = torch.load(args.cp_checkpoint+('%s_best.pth' % args.exp_name))
caption_model.load_state_dict(pretrained_model['state_dict']) 

# graph
pretrained_model = torch.load(args.gsu_checkpoint)
graph_su_model.load_state_dict(pretrained_model['state_dict'])

# feature network
feature_network.load_state_dict(torch.load(args.fe_modelpath))

<All keys matched successfully>

# feature extraction layers

In [13]:
# extract the encoder layer
if args.fe_use_SC:
    feature_network = feature_network.encoder
else:
    if args.fe_use_cbs: feature_network = nn.Sequential(*list(feature_network.module.children())[:-2])
    else: feature_network = nn.Sequential(*list(feature_network.module.children())[:-1])

feature_network = feature_network.cuda()

# Combined_model

In [14]:
model = mtl_model(feature_network, graph_su_model, caption_model)
model = model.to(device)

In [15]:
# initial network evaluation
eval_mtl(model, dict_dataloader_val, text_field)

# train for 100 epoch
# train(100, 0.001, model, train_dataloader, dict_dataloader_val, text_field)

0it [00:00, ?it/s]

['datasets/instruments18/seq_1/left_frames/frame047.png', 'datasets/instruments18/seq_1/left_frames/frame065.png']
node_num [4, 3]
node_feat torch.Size([7, 512])
spatial_deat torch.Size([18, 16])
word_2_vec torch.Size([7, 300])
roi_labels [array([0, 1, 2, 7]), array([0, 2, 4])]


1it [00:03,  3.71s/it]

interaction tensor([[  0.0000,  -8.2871,  -7.3069,  -5.1491,  -4.4096, -23.5156,   0.0000,
           0.0000,   0.0000, -32.4802, -24.8784, -35.0060, -13.8209],
        [ -1.9168,   0.0000,  -5.8999,   0.0000,  -4.8041, -13.6514,   0.0000,
          -4.8086,  -8.7836, -22.3640, -17.1114, -25.4892, -11.4277],
        [ -0.2609, -10.1361,  -7.8001,  -2.7731,  -3.8458,  -4.8875,   0.0000,
          -7.6525,  -5.9701,   0.0000,  -4.0276, -11.3565,   0.0000],
        [ -0.4975,  -4.6083,  -4.3027,  -0.5445, -10.7980,   0.0000, -26.4635,
         -11.1808,  -6.9873, -27.4969, -20.9212,   0.0000, -15.3676],
        [ -0.5307,  -5.8077,  -6.8055,  -3.4842,  -4.5793,   1.2136,  -6.6921,
          -9.9246,  -6.0392,   0.0000,   0.0000,  -7.1559,  -6.2763]],
       device='cuda:0')
['datasets/instruments18/seq_1/left_frames/frame094.png', 'datasets/instruments18/seq_1/left_frames/frame013.png']
node_num [4, 4]
node_feat torch.Size([8, 512])
spatial_deat torch.Size([24, 16])
word_2_vec torch.Size(

2it [00:04,  2.74s/it]

interaction tensor([[  1.7616, -10.8580,  -4.5352,  -0.7789,   0.0000, -13.2561, -21.3044,
         -23.0021, -12.5584, -23.3428, -25.0887,   0.0000, -19.4127],
        [  0.5599,  -7.7693,  -9.9580,   1.1444,   0.0000, -11.7384, -31.9997,
          -9.1883,  -5.9409, -29.3164, -15.2528,   0.0000, -16.7993],
        [  0.8450,  -9.3496,  -8.1437,  -5.5296,  -8.2990,   0.0000,   0.0000,
         -24.0040, -11.2224, -10.5505, -15.7250, -16.0700, -12.9930],
        [  0.6087,  -6.8864,  -4.3954,  -1.4028,  -8.0746,   0.0000, -24.9153,
         -17.3685,  -5.3791, -25.6285, -22.8299,   0.0000,   0.0000],
        [  0.0000,  -4.7851,  -4.7146,   0.4951,  -8.0289, -11.4949,   0.0000,
          -8.4013,  -7.4362, -19.6092, -16.5382, -17.3998, -10.8829],
        [  1.5607,  -9.5905,  -9.4825,  -3.1004,  -6.5390,   0.0000,   0.0000,
         -16.8339,  -9.1139, -11.1966,   0.0000,   0.0000, -11.1149]],
       device='cuda:0')
['datasets/instruments18/seq_1/left_frames/frame109.png', 'datasets/i

3it [00:04,  2.06s/it]

interaction tensor([[  0.0000,   0.0000,   0.0000,  -5.6596,  -9.1584, -22.0826,   0.0000,
         -15.3100,   0.0000, -33.0210, -32.8981, -34.6749, -21.1730],
        [  2.2521, -14.1995,   0.0000,  -8.0715,  -8.3112, -37.4496,   0.0000,
          -9.3564, -15.5919, -59.0969, -43.6544, -61.7319,   0.0000],
        [  1.9643, -10.2696,  -8.6664,   0.0000,  -5.8914,   0.0000,   0.0000,
           0.0000, -10.7342, -10.2574, -13.6837, -15.1018, -10.5569],
        [  0.7186,  -4.3310,   0.0000,  -2.6788,  -5.0853,   0.0000, -20.0343,
          -9.0862,  -5.6965, -23.7891,   0.0000, -21.1476, -11.9725],
        [ -0.2177,  -5.3122,  -5.7098,   0.0000,  -7.3874, -15.4011, -27.3393,
           0.0000,  -7.6885, -26.7382, -18.7629, -27.4717, -14.5518],
        [ -0.3362,  -9.5501,   0.0000,  -4.4658,  -6.1753,   1.0032,  -6.2940,
           0.0000, -10.5625,  -5.6543, -10.1116, -10.1616,  -7.4469]],
       device='cuda:0')
['datasets/instruments18/seq_1/left_frames/frame135.png', 'datasets/i

4it [00:05,  1.58s/it]

interaction tensor([[ 3.1906e-02, -9.4376e+00,  0.0000e+00, -1.0439e+00, -1.1433e+01,
         -2.3973e+01, -3.5077e+01, -1.5541e+01, -1.2406e+01, -3.4316e+01,
         -3.0824e+01,  0.0000e+00, -1.9082e+01],
        [-2.5102e+00, -6.0729e+00, -6.0301e+00,  4.1442e+00,  0.0000e+00,
         -1.6117e+01, -2.7929e+01,  0.0000e+00, -6.5825e+00,  0.0000e+00,
         -1.6959e+01, -2.4546e+01, -1.7679e+01],
        [-3.6221e-02, -6.7651e+00, -7.2599e+00, -3.8176e+00,  0.0000e+00,
         -3.6654e+00,  0.0000e+00, -3.4733e+00, -5.3455e+00,  0.0000e+00,
         -6.1087e+00, -1.0234e+01, -7.5311e+00],
        [ 0.0000e+00, -1.5780e+01, -1.3545e+01, -3.7561e+00, -6.7560e+00,
         -8.2817e+00,  0.0000e+00,  0.0000e+00, -9.8606e+00,  0.0000e+00,
          0.0000e+00, -1.6347e+01,  0.0000e+00],
        [-4.6151e-01, -4.9866e+00, -5.1990e+00,  9.4513e-01, -8.7835e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -7.6741e+00, -2.7490e+01,
         -1.8997e+01, -2.4490e+01, -1.4884e+01],
   

5it [00:05,  1.24s/it]

interaction tensor([[  0.1032,   0.0000,  -2.1698,  -2.4776,  -4.9631, -10.4636, -19.8205,
           0.0000,   0.0000, -19.9683,   0.0000, -19.4585,   0.0000],
        [ -0.7767,  -2.8128,  -3.5627,   0.0860,   0.0000,  -5.0281, -11.9112,
          -6.2590,   0.0000, -12.1248,  -8.9779,   0.0000,  -6.9479],
        [  0.8986, -10.4317,  -8.8533,  -6.6280,   0.0000,   0.0000, -15.0635,
           0.0000, -15.5472, -12.8866, -20.3637, -21.0653,   0.0000],
        [  0.0000,  -5.6923,   0.0000,  -1.2671,  -5.9448,   0.0000,   0.0000,
         -10.3988,  -8.6262,   0.0000, -19.9523, -24.6381, -15.4203],
        [  0.0000,  -7.8105,   0.0000,  -4.1904,  -8.4225,   0.0000,   0.0000,
          -9.1533,   0.0000,  -6.7394,   0.0000,   0.0000,  -7.3752]],
       device='cuda:0')
['datasets/instruments18/seq_1/left_frames/frame130.png', 'datasets/instruments18/seq_1/left_frames/frame083.png']
node_num [4, 4]
node_feat torch.Size([8, 512])
spatial_deat torch.Size([24, 16])
word_2_vec torch.Size(

5it [00:05,  1.19s/it]


KeyboardInterrupt: 