In [1]:
import os
import sys
import random

import json
import h5py
import itertools
import numpy as np
import argparse, pickle

import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from tqdm import tqdm
from torch import optim
from torch.utils.data import Dataset as torchDataset
from torch.utils.data import DataLoader as TorchDataLoader

# caption libraries
import evaluation
import collections
from data.example import Example
from data.utils import nostdout
from data.field import ImageDetectionsField, TextField, RawField
from models.transformer import Transformer, MemoryAugmentedEncoder, MeshedDecoder, ScaledDotProductAttentionMemory

# graph libraries
import utils.io as io
import matplotlib.pyplot as plt
from utils.g_vis_img import *
from models.graph_su import *
from evaluation.graph_eval import *


# Random seeds
seed = 1234
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

ModuleNotFoundError: No module named 'pycocotools'

# Dataset Constants

In [2]:
class SurgicalSceneConstants():
    '''
    Surgical Scene constants
    '''
    def __init__( self):
        self.instrument_classes = ('kidney', 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
                                'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier',
                                'stapler', 'maryland_dissector', 'spatulated_monopolar_cautery')
        self.action_classes = ( 'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation', 
                                'Tool_Manipulation', 'Cutting', 'Cauterization', 
                                'Suction', 'Looping', 'Suturing', 'Clipping', 'Staple', 
                                'Ultrasound_Sensing')        
        #self.file_dir = 'datasets/instruments18/'
        #self.word2vec_loc = 'datasets/surgicalscene_word2vec.hdf5'

# Cross-entropy loss with label smoothing

In [3]:
class CELossWithLS(torch.nn.Module):
    '''
    label smoothing cross-entropy loss for captioning
    '''
    def __init__(self, classes=None, smoothing=0.1, gamma=3.0, isCos=True, ignore_index=-1):
        super(CELossWithLS, self).__init__()
        self.complement = 1.0 - smoothing
        self.smoothing = smoothing
        self.cls = classes
        self.log_softmax = torch.nn.LogSoftmax(dim=1)
        self.gamma = gamma
        self.ignore_index = ignore_index

    def forward(self, logits, target):
        with torch.no_grad():
            oh_labels = F.one_hot(target.to(torch.int64), num_classes = self.cls).permute(0,1,2).contiguous()
            smoothen_ohlabel = oh_labels * self.complement + self.smoothing / self.cls

        logs = self.log_softmax(logits[target!=self.ignore_index])
        pt = torch.exp(logs)
        return -torch.sum((1-pt).pow(self.gamma)*logs * smoothen_ohlabel[target!=self.ignore_index], dim=1).mean()


# dataloader

In [4]:
''' ---------------------------------------- caption dataloader objects ----------------------------------------'''
class DataLoader(TorchDataLoader):
    def __init__(self, dataset, *args, **kwargs):
        super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs)

class Dataset(object):
    def __init__(self, examples, fields):
        self.examples = examples
        self.fields = dict(fields)  

    def collate_fn(self):
        def collate(batch):
            if len(self.fields) == 1:
                batch = [batch, ]
            else:
                batch = list(zip(*batch))

            tensors = []
            for field, data in zip(self.fields.values(), batch):
                tensor = field.process(data)
                if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
                    tensors.extend(tensor)
                else:
                    tensors.append(tensor)

            if len(tensors) > 1:
                return tensors
            else:
                return tensors[0]

        return collate

    def __getitem__(self, i):
        example = self.examples[i]
        data = []
        for field_name, field in self.fields.items():
            #if field_name == 'image': print(getattr(example, field_name))
            data.append(field.preprocess(getattr(example, field_name)))   

        if len(data) == 1:
            data = data[0]
        return data

    def __len__(self):
        return len(self.examples)

    def __getattr__(self, attr):
        if attr in self.fields:
            for x in self.examples:
                yield getattr(x, attr)

class ValueDataset(Dataset):
    def __init__(self, examples, fields, dictionary):
        self.dictionary = dictionary
        super(ValueDataset, self).__init__(examples, fields)

    def collate_fn(self):
        def collate(batch):
            value_batch_flattened = list(itertools.chain(*batch))
            value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened)

            lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch]))
            if isinstance(value_tensors_flattened, collections.Sequence) \
                    and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened):
                value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])] for vt in value_tensors_flattened]
            else:
                value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]

            return value_tensors
        return collate

    def __getitem__(self, i):
        if i not in self.dictionary:
            raise IndexError

        values_data = []
        for idx in self.dictionary[i]:
            value_data = super(ValueDataset, self).__getitem__(idx)
            values_data.append(value_data)
        return values_data

    def __len__(self):
        return len(self.dictionary)

''' ---------------------------------------- caption dataloader objects ----------------------------------------'''
''' ---------------------------------------- GraphSU dataloader objects ----------------------------------------'''    

class GraphSUDataset(torchDataset):
    
    ''' Data loader for graph Scene Understanding'''
    def __init__(self,examples, fields, file_dir, img_dir, w2v_loc, dataconst, feature_extractor):
        
        self.examples = examples
        self.fields = dict(fields)

        self.file_dir = file_dir
        self.img_dir = img_dir
        self.dataconst = dataconst
        self.feature_extractor = feature_extractor
        print('filename', w2v_loc)
        self.word2vec = h5py.File(w2v_loc, 'r')
    
    # word2vec
    def _get_word2vec(self,node_ids):
        word2vec = np.empty((0,300))
        for node_id in node_ids:
            vec = self.word2vec[self.dataconst.instrument_classes[node_id]]
            word2vec = np.vstack((word2vec, vec))
        return word2vec

    def __len__(self):
        return len(self.examples)

    def __getattr__(self, attr):
        if attr in self.fields:
            for x in self.examples:
                yield getattr(x, attr)

    def __getitem__(self, idx):
        example = self.examples[idx]
        frame_path = getattr(example, 'image')
        frame_path = frame_path.split("/")
        
        _img_loc = os.path.join(self.file_dir, frame_path[0],self.img_dir,frame_path[3].split("_")[0]+'.png')
        frame_data = h5py.File(os.path.join(self.file_dir, frame_path[0],'vsgat',self.feature_extractor, frame_path[3].split("_")[0]+'_features.hdf5'), 'r')    
        
        data = {}
        data['img_name'] = frame_data['img_name'].value[:] + '.jpg'
        data['img_loc'] = _img_loc
        
        data['node_num'] = frame_data['node_num'].value
        data['roi_labels'] = frame_data['classes'][:]
        data['det_boxes'] = frame_data['boxes'][:]
        
        data['edge_labels'] = frame_data['edge_labels'][:]
        data['edge_num'] = data['edge_labels'].shape[0]
        
        data['features'] = frame_data['node_features'][:]
        data['spatial_feat'] = frame_data['spatial_features'][:]
        
        
        data['word2vec'] = self._get_word2vec(data['roi_labels'])
        return data

    def collate_fn(self):
        def collate(batch):
            batch_data = {}
            batch_data['img_name'] = []
            batch_data['img_loc'] = []
            batch_data['node_num'] = []
            batch_data['roi_labels'] = []
            batch_data['det_boxes'] = []
            batch_data['edge_labels'] = []
            batch_data['edge_num'] = []
            batch_data['features'] = []
            batch_data['spatial_feat'] = []
            batch_data['word2vec'] = []

            for data in batch:
                batch_data['img_name'].append(data['img_name'])
                batch_data['img_loc'].append(data['img_loc'])
                batch_data['node_num'].append(data['node_num'])
                batch_data['roi_labels'].append(data['roi_labels'])
                batch_data['det_boxes'].append(data['det_boxes'])
                batch_data['edge_labels'].append(data['edge_labels'])
                batch_data['edge_num'].append(data['edge_num'])
                batch_data['features'].append(data['features'])
                batch_data['spatial_feat'].append(data['spatial_feat'])
                batch_data['word2vec'].append(data['word2vec'])

            batch_data['edge_labels'] = torch.FloatTensor(np.concatenate(batch_data['edge_labels'], axis=0))
            batch_data['features'] = torch.FloatTensor(np.concatenate(batch_data['features'], axis=0))
            batch_data['spatial_feat'] = torch.FloatTensor(np.concatenate(batch_data['spatial_feat'], axis=0))
            batch_data['word2vec'] = torch.FloatTensor(np.concatenate(batch_data['word2vec'], axis=0))
            return batch_data

        return collate
''' ---------------------------------------- GraphSU dataloader objects ----------------------------------------'''
class DictionaryDataset(Dataset):
    def __init__(self, examples, fields, key_fields, gsu_file_dir='', gsu_img_dir='', gsu_w2v_loc='', gsu_feat = ''):
        if not isinstance(key_fields, (tuple, list)):
            key_fields = (key_fields,)
        for field in key_fields:
            assert (field in fields)

        dictionary = collections.defaultdict(list)
        key_fields = {k: fields[k] for k in key_fields}
        value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields}
        key_examples = []
        key_dict = dict()
        value_examples = []

        for i, e in enumerate(examples):
            key_example = Example.fromdict({k: getattr(e, k) for k in key_fields})
            value_example = Example.fromdict({v: getattr(e, v) for v in value_fields})
            if key_example not in key_dict:
                key_dict[key_example] = len(key_examples)
                key_examples.append(key_example)

            value_examples.append(value_example)
            dictionary[key_dict[key_example]].append(i)

        self.key_dataset = Dataset(key_examples, key_fields)
        self.value_dataset = ValueDataset(value_examples, value_fields, dictionary)

        dataconst = SurgicalSceneConstants()
        #self.graph_su_dataset = GraphSUDataset(key_examples, key_fields, 'left_frames', dataconst, 'resnet18_11_cbs_ts')
        self.graph_su_dataset = GraphSUDataset(key_examples, key_fields, gsu_file_dir, gsu_img_dir, gsu_w2v_loc, dataconst, gsu_feat)
                                                                        
        
        super(DictionaryDataset, self).__init__(examples, fields)

    def collate_fn(self):
        def collate(batch):
            key_batch, value_batch, graph_su_batch = list(zip(*batch))
            key_tensors = self.key_dataset.collate_fn()(key_batch)
            value_tensors = self.value_dataset.collate_fn()(value_batch)
            graph_su_tensors = self.graph_su_dataset.collate_fn()(graph_su_batch)
            return key_tensors, value_tensors, graph_su_tensors
        return collate

    def __getitem__(self, i):
        return self.key_dataset[i], self.value_dataset[i], self.graph_su_dataset[i]

    def __len__(self):
        return len(self.key_dataset)

class PairedDataset(Dataset):
    def __init__(self, examples, fields, gsu_file_dir='', gsu_img_dir='', gsu_w2v_loc='', gsu_feat=''):
        assert ('image' in fields)
        assert ('text' in fields)
        super(PairedDataset, self).__init__(examples, fields)
        self.image_field = self.fields['image']
        self.text_field = self.fields['text']
        self.gsu_file_dir = gsu_file_dir
        self.gsu_img_dir = gsu_img_dir
        self.gsu_w2v_loc = gsu_w2v_loc
        self.gsu_feat = gsu_feat
        
    def image_dictionary(self, fields=None):
        if not fields:
            fields = self.fields
        dataset = DictionaryDataset(self.examples, fields, 'image', self.gsu_file_dir, self.gsu_img_dir, self.gsu_w2v_loc, self.gsu_feat)
        return dataset


class COCO(PairedDataset):
    def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, gsu_file_dir='', gsu_img_dir='', gsu_w2v_loc='', gsu_feat=''):
        roots = {}
        roots['train'] = {
            'img': img_root,
            'cap': os.path.join(ann_root, 'captions_train.json')  
        }
        roots['val'] = {
            'img': img_root,
            'cap': os.path.join(ann_root, 'captions_val.json')
        }
        
        self.gsu_file_dir = gsu_file_dir
        self.gsu_img_dir = gsu_img_dir
        self.gsu_feat = gsu_feat
        self.gsu_w2v_loc = gsu_w2v_loc
        
        if id_root is not None:
            ids = {}
            ids['train'] = json.load(open(os.path.join(id_root, 'WithCaption_id_path_train.json'), 'r'))
            ids['val'] = json.load(open(os.path.join(id_root, 'WithCaption_id_path_val.json'), 'r'))   
        else:
            ids = None
        
        with nostdout():
            self.train_examples, self.val_examples = self.get_samples(roots, ids)
        examples = self.train_examples + self.val_examples
        super(COCO, self).__init__(examples, {'image': image_field, 'text': text_field},self.gsu_file_dir, self.gsu_img_dir, self.gsu_w2v_loc, self.gsu_feat)   

    @property
    def splits(self):
        train_split = PairedDataset(self.train_examples, self.fields, self.gsu_file_dir, self.gsu_img_dir, self.gsu_w2v_loc, self.gsu_feat) 
        val_split = PairedDataset(self.val_examples, self.fields, self.gsu_file_dir, self.gsu_img_dir, self.gsu_w2v_loc, self.gsu_feat)
        return train_split, val_split

    @classmethod
    def get_samples(cls, roots, ids_dataset=None):
        train_samples = []
        val_samples = []
   
        for split in ['train', 'val']:
            anns = json.load(open(roots[split]['cap'], 'r'))
            if ids_dataset is not None:
                ids = ids_dataset[split]
                
            for index in range(len(ids)):              
                id_path = ids[index]  
                caption = anns[index]['caption']
                #example = Example.fromdict({'image': os.path.join(roots[split]['img'], id_path), 'text': caption})
                example = Example.fromdict({'image': os.path.join('', id_path), 'text': caption})
                if split == 'train':
                    train_samples.append(example)
                elif split == 'val':
                    val_samples.append(example)
                
        return train_samples, val_samples

# MTL Model (Graph Scene Understanding and Captioning)

In [5]:
class mtl_model(nn.Module):
    '''
    Multi-task model : Graph Scene Understanding and Captioning
    '''
    def __init__(self, caption, graph):
        super(mtl_model, self).__init__()
        self.caption = caption
        self.graph_su = graph

    def forward(self, detections, captions, batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec):
        caption_output = self.caption(detections, captions)
        interaction = self.graph_su(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec)
        return caption_output, interaction


# Evaluation Matrix : Graph (Loss, Acc, ECE), Caption (Brier, Cider)

In [6]:

def evaluate_metrics(model, device, dataloader, text_field, g_temp = 1.5, c_temp = None):
    import itertools
    
    model.caption.decoder.caption_ts = c_temp
    #model.to(device)
    
    model.eval()
    gen = {}
    gts = {}

    # graph
    # criterion and scheduler
    g_criterion = nn.MultiLabelSoftMarginLoss()                   
    g_edge_count = 0
    g_total_acc = 0.0
    g_total_loss = 0.0
    g_logits_list = []
    g_labels_list = []

    #print(model.caption.beam_search)
    
    with tqdm(desc='evaluation', unit='it', total=len(dataloader)) as pbar:
        for it, (images, caps_gt, graph_data) in enumerate(iter(dataloader)):
            
            # graph
            img_name = graph_data['img_name']
            img_loc = graph_data['img_loc']
            node_num = graph_data['node_num']
            roi_labels = graph_data['roi_labels']
            det_boxes = graph_data['det_boxes']
            edge_labels = graph_data['edge_labels']
            edge_num = graph_data['edge_num']
            features = graph_data['features']
            spatial_feat = graph_data['spatial_feat']
            word2vec = graph_data['word2vec']
            features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
        
            # caption
            images = images.to(device)

            with torch.no_grad():
                # caption
                caption_out, _ = model.caption.beam_search(images, 20, text_field.vocab.stoi['<eos>'], 5, out_size=1)
                #print(caption_out)
                # graph
                g_output = model.graph_su(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
                g_output = g_output/g_temp
                g_logits_list.append(g_output)
                g_labels_list.append(edge_labels)
                # loss and accuracy
                g_loss = g_criterion(g_output, edge_labels.float())
                g_acc = np.sum(np.equal(np.argmax(g_output.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))
            
            # accumulate loss and accuracy of the batch
            g_total_loss += g_loss.item() * edge_labels.shape[0]
            g_total_acc  += g_acc
            g_edge_count += edge_labels.shape[0]
            
            caps_gen = text_field.decode(caption_out, join_words=False)
            for i, (gts_i, gen_i) in enumerate(zip(caps_gt, caps_gen)):
                gen_i = ' '.join([k for k, g in itertools.groupby(gen_i)])
                gen['%d_%d' % (it, i)] = [gen_i, ]    
                gts['%d_%d' % (it, i)] = gts_i
            pbar.update()
    
    #graph loss
    g_logits_all = torch.cat(g_logits_list).cuda()
    g_labels_all = torch.cat(g_labels_list).cuda()
    g_total_acc = g_total_acc / g_edge_count
    g_total_loss = g_total_loss / len(dataloader)

    g_logits_all = F.softmax(g_logits_all, dim=1)
    g_map_value, g_ece, g_sce, g_tace, g_brier, g_uce = calibration_metrics(g_logits_all, g_labels_all, 'test')
    print('acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce.item()) )


    if not os.path.exists('results/c_results/predict_caption'):
        os.makedirs('results/c_results/predict_caption')
    json.dump(gen, open('results/c_results/predict_caption/predict_caption_val.json', 'w'))

    gts = evaluation.PTBTokenizer.tokenize(gts)
    gen = evaluation.PTBTokenizer.tokenize(gen)
    scores, _ = evaluation.compute_scores(gts, gen)
    return scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce.item()

# 1.0: MTL Model (CBS, Incremental Learning) Evaluation

In [7]:
#if __name__ == '__main__':

print('Validation')

# arguments
device = torch.device('cuda')
parser = argparse.ArgumentParser(description='Incremental domain adaptation for surgical report generation')
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--workers', type=int, default=0)

# caption
parser.add_argument('--exp_name', type=str, default='m2_transformer')
parser.add_argument('--m', type=int, default=40)   
parser.add_argument('--cp_cbs', type=str, default='True')
parser.add_argument('--cp_cbs_filter', default='LOG', type=str) # Potential choice: 'gau' and 'LOG'
parser.add_argument('--cp_kernel_sizex', default=3, type=int)
parser.add_argument('--cp_kernel_sizey', default=1, type=int)
parser.add_argument('--cp_decay_epoch', default=2, type=int) 
parser.add_argument('--cp_std_factor', default=0.9, type=float)
parser.add_argument('--cp_features_path', type=str, default='datasets/instruments18/') 
parser.add_argument('--cp_annotation_folder', type=str, default='datasets/annotations_new/annotations_SD_inc')
parser.add_argument('--cp_checkpoint', type=str, default='checkpoints/IDA_MICCAI2021_checkpoints/inc_SC_eCBS/')

# graph
parser.add_argument('--gsu_cbs',        type=bool, default=True)
parser.add_argument('--gsu_checkpoint', type=str,  default='checkpoints/g_checkpoints/d2g_ecbs_resnet18_11_SC_eCBS/d2g_ecbs_resnet18_11_SC_eCBS/epoch_train/checkpoint_D2F70_epoch.pth')
parser.add_argument('--gsu_file_dir', type=str,  default='datasets/instruments18/')
parser.add_argument('--gsu_img_dir', type=str,  default='left_frames')
parser.add_argument('--gsu_w2v_loc', type=str,  default='datasets/surgicalscene_word2vec.hdf5')
parser.add_argument('--gsu_feat', type=str,  default='resnet18_11_SC_CBS')

#parser.add_argument('--head', type=int, default=8)
#parser.add_argument('--warmup', type=int, default=10000)
#parser.add_argument('--features_path_DA', type=str)
#parser.add_argument('--annotation_folder_DA', type=str)
args = parser.parse_args(args=[])


print(args)
   
# Pipeline for image regions and text
image_field = ImageDetectionsField(detections_path=args.cp_features_path, max_detections=6, load_in_tmp=False)  
text_field = TextField(init_token='<bos>', eos_token='<eos>', lower=True, tokenize='spacy', remove_punctuation=True, nopoints=False)

# Create the dataset 
dataset = COCO(image_field, text_field, args.cp_features_path, args.cp_annotation_folder, args.cp_annotation_folder, args.gsu_file_dir, args.gsu_img_dir, args.gsu_w2v_loc, args.gsu_feat)
train_dataset, val_dataset = dataset.splits   
print('train:', len(train_dataset))
print('val:', len(val_dataset))
    
# caption model
if not os.path.isfile('datasets/vocab_%s.pkl' % args.exp_name):
    print("Building vocabulary")
    text_field.build_vocab(train_dataset, val_dataset, min_freq=2)  
    pickle.dump(text_field.vocab, open('datasets/vocab_%s.pkl' % args.exp_name, 'wb'))
else:
    text_field.vocab = pickle.load(open('datasets/vocab_%s.pkl' % args.exp_name, 'rb'))

print('vocabulary size is:', len(text_field.vocab))
print(text_field.vocab.stoi)

if args.cp_cbs == 'True':
    from models.transformer import MemoryAugmentedEncoder_CBS
    print("MemoryAugmentedEncoder_CBS")
    encoder = MemoryAugmentedEncoder_CBS(3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': args.m})
else:
    print("MemoryAugmentedEncoder")
    encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory, attention_module_kwargs={'m': args.m}) 

decoder = MeshedDecoder(len(text_field.vocab), 54, 3, text_field.vocab.stoi['<pad>'])
caption_model = Transformer(text_field.vocab.stoi['<bos>'], encoder, decoder).to(device)

if args.cp_cbs == 'True':
    caption_model.encoder.get_new_kernels(0, args.cp_kernel_sizex, args.cp_kernel_sizey, args.cp_decay_epoch, args.cp_std_factor, args.cp_cbs_filter) 
            

# graph model    
graph_su_model = AGRNN(bias= True, bn=False, dropout=0.3, multi_attn=False, layer=1, diff_edge=False, use_cbs = args.gsu_cbs)
if args.gsu_cbs:
    graph_su_model.grnn1.gnn.apply_h_h_edge.get_new_kernels(0)

# loading pre-trained model for graph and caption
pretrained_model = torch.load(args.cp_checkpoint+('%s_best.pth' % args.exp_name))
caption_model.load_state_dict(pretrained_model['state_dict']) 
pretrained_model = torch.load(args.gsu_checkpoint)
graph_su_model.load_state_dict(pretrained_model['state_dict'])
    
# model
model = mtl_model(caption_model, graph_su_model)
model = model.to(device)
    
# dataset
dict_dataset_val = val_dataset.image_dictionary({'image': image_field, 'text': RawField()})
dataloader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
dict_dataloader_val = DataLoader(dict_dataset_val, batch_size=args.batch_size) # for caption with word GT class number


''' 1.0 mutlitask model base evaluation ==========================================================================='''
scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model, device, dict_dataloader_val, text_field)
print('Initial Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("Initial Caption scores :", scores)

Validation
Namespace(batch_size=1, cp_annotation_folder='datasets/annotations_new/annotations_SD_inc', cp_cbs='True', cp_cbs_filter='LOG', cp_checkpoint='checkpoints/IDA_MICCAI2021_checkpoints/inc_SC_eCBS/', cp_decay_epoch=2, cp_features_path='datasets/instruments18/', cp_kernel_sizex=3, cp_kernel_sizey=1, cp_std_factor=0.9, exp_name='m2_transformer', gsu_cbs=True, gsu_checkpoint='checkpoints/g_checkpoints/d2g_ecbs_resnet18_11_SC_eCBS/d2g_ecbs_resnet18_11_SC_eCBS/epoch_train/checkpoint_D2F70_epoch.pth', gsu_feat='resnet18_11_SC_CBS', gsu_file_dir='datasets/instruments18/', gsu_img_dir='left_frames', gsu_w2v_loc='datasets/surgicalscene_word2vec.hdf5', m=40, workers=0)
train: 1560
val: 447
vocabulary size is: 41
defaultdict(<function _default_unk_index at 0x7f78addd9a60>, {'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3, 'is': 4, 'tissue': 5, 'forceps': 6, 'monopolar': 7, 'curved': 8, 'scissors': 9, 'bipolar': 10, 'manipulating': 11, 'and': 12, 'are': 13, 'prograsp': 14, 'cutting': 15, 'i

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

filename datasets/surgicalscene_word2vec.hdf5


evaluation: 100%|██████████| 447/447 [01:54<00:00,  3.90it/s]


acc: 0.571059 map: 0.321686 loss: 0.509024, ece:0.219497, sce:0.047773, tace:0.049699, brier:0.649657, uce:0.283210
Initial Graph SU: acc: 0.571059 map: 0.321686 loss: 0.509024, ece:0.219497, sce:0.047773, tace:0.049699, brier:0.649657, uce:0.283210
Initial Caption scores : {'BLEU': array([0.5498, 0.4714, 0.4238, 0.3801]), 'METEOR': 0.2861, 'ROUGE': 0.57, 'CIDEr': 2.7487}


# Temperature Scaling : Model Object

In [8]:
class ModelWithTemperature(nn.Module):
    '''
    Temperature scaling model for model
    '''
    def __init__(self, model):
        super(ModelWithTemperature, self).__init__()
        self.model = model
        self.use_ts = False
        self.graph_su_temperature = nn.Parameter(torch.ones(1) * 1.5)
        self.caption_temperature = nn.Parameter(torch.ones(1) * 1.5)

    def forward(self, detections, captions, batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec):
        
        # if self.use_ts:
            # g_su_temp = self.graph_su_temperature.unsqueeze(1).expand(interaction.size(0), interaction.size(1))
            # self.model.caption.decoder.caption_ts = self.caption_temperature
        
        caption_output, interaction = self.model(detections, captions, batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec)
        
        # if self.use_ts:
            # interaction = interaction / g_su_temp

        return caption_output, interaction

    def graph_su_set_temperature(self, valid_loader):
        '''
        Finding optimal temperature scale for graph scene understanding task
        '''
        self.cuda()
        g_logits_list = []
        g_labels_list = []
        g_criterion = nn.MultiLabelSoftMarginLoss()
        
        with torch.no_grad():
            for it, (images, caps_gt, graph_data) in enumerate(iter(valid_loader)):
                # graph
                img_name = graph_data['img_name']
                img_loc = graph_data['img_loc']
                node_num = graph_data['node_num']
                roi_labels = graph_data['roi_labels']
                det_boxes = graph_data['det_boxes']
                edge_labels = graph_data['edge_labels']
                edge_num = graph_data['edge_num']
                features = graph_data['features']
                spatial_feat = graph_data['spatial_feat']
                word2vec = graph_data['word2vec']
                features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
        
                g_output = self.model.graph_su(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
                    
                g_logits_list.append(g_output)
                g_labels_list.append(edge_labels)
                
            g_logits = torch.cat(g_logits_list).cuda()
            g_labels = torch.cat(g_labels_list).cuda()

        #init_temp = self.graph_su_temperature.clone()
        optimizer = optim.LBFGS([self.graph_su_temperature], lr=0.01, max_iter=50)

        def eval():
            g_su_temp = self.graph_su_temperature.unsqueeze(1).expand(g_logits.size(0), g_logits.size(1))
            g_logit_out = F.softmax(g_logits/g_su_temp, dim=1)
            g_loss = g_criterion(g_logit_out, g_labels.float())
            g_loss.backward()
            return g_loss
        
        optimizer.step(eval)
        return

    def caption_set_temperature(self, valid_loader):
        '''
        Finding optimal temperature scale for caption task
        '''
        self.cuda()
        c_logits_list = None
        c_labels_list = None
        
        with torch.no_grad():
            for it, (images, caps_gt) in enumerate(iter(valid_loader)):    
                images, caps_gt = images.to(device), caps_gt.to(device)
                caption_out = self.model.caption(images, caps_gt)

                if c_logits_list is not None:
                    c_logits_list = torch.cat([c_logits_list, caption_out], 1)
                    c_labels_list = torch.cat([c_labels_list, caps_gt],1)
                else:
                    c_logits_list = caption_out
                    c_labels_list = caps_gt

            c_logits = c_logits_list.cuda()
            c_labels = c_labels_list.cuda()
        
        init_temp = self.caption_temperature.clone()
        optimizer = optim.LBFGS([self.caption_temperature], lr=0.01, max_iter=50)

        def eval():
            caption_temp = self.caption_temperature.unsqueeze(1).expand(c_logits.size(1), c_logits.size(0))
            c_criterion = CELossWithLS(classes=len(text_field.vocab), smoothing=0.1, gamma=0.0, isCos=False, ignore_index=text_field.vocab.stoi['<pad>'])
            c_base = c_logits/caption_temp
            c_loss = c_criterion(c_base[:, :-1].contiguous(), c_labels[:, 1:].contiguous())
            c_loss.backward()
            return c_loss
        
        optimizer.step(eval)
        return

# Temperature Scaling : ECE Loss

In [9]:
class _ECELoss(nn.Module):
    '''
    Expected Calibration Error for Calibration
    '''
    def __init__(self, n_bins=15):
        """
        n_bins (int): number of confidence interval bins
        """
        super(_ECELoss, self).__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]

    def forward(self, logits, labels):
        if labels.size(1) == 13:
            gt_labels = torch.argmax(labels, dim=1)
        else: 
            logits = logits.squeeze()
            gt_labels = labels.squeeze()

        softmaxes = F.softmax(logits, dim=1)
        confidences, predictions = torch.max(softmaxes, 1)
        accuracies = predictions.eq(gt_labels)
        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
        return ece

# Temperature Scaling : Evaluation

In [10]:
def ts_evaluation(model_ts, g_c_val_dataloader, c_val_dataloader, g_temp=1.5, c_temp=1):
    '''
    '''
    model.eval()

    g_acc = 0.0
    c_acc = 0.0
    g_loss = 0.0
    c_loss = 0.0
    g_logits = []
    g_labels = []
    c_logits = []
    c_labels = []
    
    g_logits_list = []
    g_labels_list = []
    c_logits_list = None
    c_labels_list = None

    with torch.no_grad():
        for it, (images, caps_gt, graph_data) in enumerate(iter(g_c_val_dataloader)):
            # graph
            img_name = graph_data['img_name']
            img_loc = graph_data['img_loc']
            node_num = graph_data['node_num']
            roi_labels = graph_data['roi_labels']
            det_boxes = graph_data['det_boxes']
            edge_labels = graph_data['edge_labels']
            edge_num = graph_data['edge_num']
            features = graph_data['features']
            spatial_feat = graph_data['spatial_feat']
            word2vec = graph_data['word2vec']
            features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
        
            g_output = model_ts.model.graph_su(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
            
            g_output = g_output / g_temp
            g_logits_list.append(g_output)
            g_labels_list.append(edge_labels)
                
    g_logits = torch.cat(g_logits_list).cuda()
    g_labels = torch.cat(g_labels_list).cuda()

    g_logits = F.softmax(g_logits, dim=1)
    g_criterion = nn.MultiLabelSoftMarginLoss()
    g_loss = g_criterion(g_logits, g_labels.float())
    g_acc = np.sum(np.equal(np.argmax(g_logits.cpu().data.numpy(), axis=-1), np.argmax(g_labels.cpu().data.numpy(), axis=-1))) / g_labels.size(0)

    with torch.no_grad():
        for it, (images, caps_gt) in enumerate(iter(c_val_dataloader)):    
            images, caps_gt = images.to(device), caps_gt.to(device)
            caption_out = model_ts.model.caption(images, caps_gt)
            caption_out = caption_out/c_temp
            if c_logits_list is not None:
                c_logits_list = torch.cat([c_logits_list, caption_out], 1)
                c_labels_list = torch.cat([c_labels_list, caps_gt],1)
            else:
                c_logits_list = caption_out
                c_labels_list = caps_gt

    c_logits = c_logits_list.cuda()
    c_labels = c_labels_list.cuda()
    
    c_criterion = CELossWithLS(classes=len(text_field.vocab), smoothing=0.1, gamma=0.0, isCos=False, ignore_index=text_field.vocab.stoi['<pad>'])
    c_loss = c_criterion(c_logits[:, :-1].contiguous(), c_labels[:, 1:].contiguous())
    c_acc = np.sum(np.equal(np.argmax(c_logits.cpu().data.numpy(), axis=-1), c_labels.cpu().data.numpy())) / c_labels.size(1)
    
    return (g_loss.item(), c_loss.item(), g_acc, c_acc, g_logits, g_labels, c_logits, c_labels)

# 2.0: Temperature Scaling

In [11]:
model_ts = ModelWithTemperature(model)
ece_criterion = _ECELoss().to(device)
print('Initial Graph SU Temperature:%.4f'%model_ts.graph_su_temperature.item())
print('Initial Caption Temperature:%0.4f'%model_ts.caption_temperature.item())

Initial Graph SU Temperature:1.5000
Initial Caption Temperature:1.5000


# 2.1: Temperature Scaling : Find Optimal value

In [12]:
model_ts.graph_su_set_temperature(dict_dataloader_val)
model_ts.caption_set_temperature(dataloader_val)
print('-----------------------------------------------------------------------')
print('Optimal Graph SU Temperature:%.4f'%model_ts.graph_su_temperature.item())
print('Optimal Caption Temperature:%0.4f'%model_ts.caption_temperature.item())

-----------------------------------------------------------------------
Optimal Graph SU Temperature:1.3954
Optimal Caption Temperature:4.8063


# 2.2 Temperature Scaling : Without TS

In [13]:
g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val)
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('Before TS: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('Before TS: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field)
print('Before TS: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("Before TS: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
Before TS: G_loss:0.710, G_acc:0.571, G_ece:0.46277
Before TS: C_loss:2.724, C_acc:0.022, C_ece:0.84285


evaluation: 100%|██████████| 447/447 [01:56<00:00,  3.85it/s]


acc: 0.571059 map: 0.321686 loss: 0.509024, ece:0.219497, sce:0.047773, tace:0.049699, brier:0.649657, uce:0.283210
Before TS: Graph SU: acc: 0.571059 map: 0.321686 loss: 0.509024, ece:0.219497, sce:0.047773, tace:0.049699, brier:0.649657, uce:0.283210
Before TS: Caption scores : {'BLEU': array([0.5498, 0.4714, 0.4238, 0.3801]), 'METEOR': 0.2861, 'ROUGE': 0.57, 'CIDEr': 2.7487}


# 2.3 Temperature Scaling : With TS

In [14]:
g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val, g_temp=model_ts.graph_su_temperature.item(), c_temp=model_ts.caption_temperature.item())
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('Optimal TS: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('Optimal TS: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field, g_temp = model_ts.graph_su_temperature.item(), c_temp = model_ts.caption_temperature.item())
print('Optimal TS: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("Optimal TS: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
Optimal TS: G_loss:0.709, G_acc:0.571, G_ece:0.46025
Optimal TS: C_loss:1.894, C_acc:0.022, C_ece:0.29004


evaluation: 100%|██████████| 447/447 [01:57<00:00,  3.80it/s]


acc: 0.571059 map: 0.315094 loss: 0.487805, ece:0.218120, sce:0.045566, tace:0.046718, brier:0.644169, uce:0.260813
Optimal TS: Graph SU: acc: 0.571059 map: 0.315094 loss: 0.487805, ece:0.218120, sce:0.045566, tace:0.046718, brier:0.644169, uce:0.260813
Optimal TS: Caption scores : {'BLEU': array([0.3623, 0.3231, 0.3032, 0.2864]), 'METEOR': 0.2151, 'ROUGE': 0.4628, 'CIDEr': 2.3378}


# 2.4 Temperature Scaling : With CDA-TS

In [15]:
g_cls_freq = torch.zeros(13)
for i in torch.argmax(g_labels, dim=1): g_cls_freq[i] += 1
g_cls_freq_norm = g_cls_freq/torch.max(g_cls_freq)
g_temp = model_ts.graph_su_temperature.item() + g_cls_freq_norm*0.1
#g_temp = g_temp.to(device)
model_ts.graph_su_temperature = nn.Parameter(g_temp.to(device))

c_cls_freq = torch.zeros(41)
for i in c_labels.squeeze(): c_cls_freq[i] += 1
c_cls_freq_norm = c_cls_freq/torch.max(c_cls_freq)
c_temp = model_ts.caption_temperature.item() + c_cls_freq_norm*0.1
#c_temp = c_temp.to(device)
model_ts.caption_temperature = nn.Parameter(c_temp.to(device))

g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('CDA-TS: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('CDA-TS: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
print('CDA-TS: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("CDA-TS: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
CDA-TS: G_loss:0.709, G_acc:0.574, G_ece:0.46348
CDA-TS: C_loss:3.262, C_acc:0.022, C_ece:0.02099


evaluation: 100%|██████████| 447/447 [01:56<00:00,  3.85it/s]


acc: 0.573643 map: 0.319211 loss: 0.489725, ece:0.217880, sce:0.045660, tace:0.046676, brier:0.643576, uce:0.267471
CDA-TS: Graph SU: acc: 0.573643 map: 0.319211 loss: 0.489725, ece:0.217880, sce:0.045660, tace:0.046676, brier:0.643576, uce:0.267471
CDA-TS: Caption scores : {'BLEU': array([0.3623, 0.3231, 0.3032, 0.2864]), 'METEOR': 0.2151, 'ROUGE': 0.4628, 'CIDEr': 2.3378}


# Confidence Aware distribution

In [16]:
def conf_dist(model_ts, g_c_val_dataloader, c_val_dataloader):

    g_conf_list = np.zeros(13)
    c_conf_list = np.zeros(41)

    with torch.no_grad():
        for it, (images, caps_gt, graph_data) in enumerate(iter(g_c_val_dataloader)):
            # graph
            img_name = graph_data['img_name']
            img_loc = graph_data['img_loc']
            node_num = graph_data['node_num']
            roi_labels = graph_data['roi_labels']
            det_boxes = graph_data['det_boxes']
            edge_labels = graph_data['edge_labels']
            edge_num = graph_data['edge_num']
            features = graph_data['features']
            spatial_feat = graph_data['spatial_feat']
            word2vec = graph_data['word2vec']
            features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
        
            g_output = model_ts.model.graph_su(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
            
            g_output = F.softmax(g_output,1)
            g_confidences, g_predictions = torch.max(g_output, 1)
            g_accuracies = g_predictions.eq(torch.argmax(edge_labels, dim=1))
            for i in torch.argmax(edge_labels, dim=1).unique():
                g_conf_list[i] += g_confidences[torch.argmax(edge_labels, dim=1)==i].sum()

    
    with torch.no_grad():
        for it, (images, caps_gt) in enumerate(iter(c_val_dataloader)):    
            images, caps_gt = images.to(device), caps_gt.to(device)
            c_output = model_ts.model.caption(images, caps_gt)
            c_output = c_output.squeeze()
            c_output = F.softmax(c_output,1)
            c_confidences, c_predictions = torch.max(c_output, 1)
            c_accuracies = c_predictions.eq(caps_gt.squeeze())
            for i in caps_gt.squeeze().unique():
                c_conf_list[i] += c_confidences[caps_gt.squeeze()==i].sum()

    return g_conf_list, c_conf_list

# 2.5 Temperature Scaling : With CCA-TS

In [17]:
g_conf_list, c_conf_list = conf_dist(model_ts, dict_dataloader_val, dataloader_val)

plt.figure(1)
plt.title('Graph SU class distribution')
plt.bar(np.arange(len(g_cls_freq)),g_cls_freq)
plt.savefig('graph_class_dist.png')

plt.figure(2)
plt.title('Graph SU confidence score distribution')
plt.bar(np.arange(len(g_cls_freq)),g_conf_list/12)
plt.savefig('graph_conf_dist.png')
    
plt.figure(3)
plt.title('Caption Class Distribution')
plt.bar(np.arange(len(c_cls_freq)),c_cls_freq)
plt.savefig('caption_cls_dist.png')
    
plt.figure(4)
plt.title('Caption Confidence distribution')
plt.bar(np.arange(len(c_cls_freq)),c_conf_list/41)
plt.savefig('caption_conf_dist.png')

# CCA-TS calculation V2

In [18]:
g_cls_freq = g_conf_list/13
g_cls_freq_norm = g_cls_freq/np.max(g_cls_freq)
g_cls_freq_norm = torch.tensor(g_cls_freq_norm).float()
g_temp = model_ts.graph_su_temperature.cpu() + g_cls_freq_norm*0.1
# g_temp = g_temp.to(device)
model_ts.graph_su_temperature = nn.Parameter(g_temp.to(device))
    

c_cls_freq = c_conf_list/41#train_dataset.get_cls_num_list()
c_cls_freq_norm = c_cls_freq/np.max(c_cls_freq)
c_cls_freq_norm = torch.tensor(c_cls_freq_norm).float()
c_temp =  model_ts.caption_temperature.cpu() + c_cls_freq_norm*0.1
# c_temp = c_temp.to(device)
model_ts.caption_temperature = nn.Parameter(c_temp.to(device))
    
g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('CCA-TS-V1: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('CCA-TS-V1: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
print('CCA-TS-V1: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("CCA-TS-V1: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
CCA-TS-V1: G_loss:0.709, G_acc:0.575, G_ece:0.46582
CCA-TS-V1: C_loss:3.272, C_acc:0.022, C_ece:0.02043


evaluation: 100%|██████████| 447/447 [01:56<00:00,  3.85it/s]


acc: 0.575366 map: 0.320200 loss: 0.491720, ece:0.216305, sce:0.045280, tace:0.046805, brier:0.643167, uce:0.266481
CCA-TS-V1: Graph SU: acc: 0.575366 map: 0.320200 loss: 0.491720, ece:0.216305, sce:0.045280, tace:0.046805, brier:0.643167, uce:0.266481
CCA-TS-V1: Caption scores : {'BLEU': array([0.3623, 0.3231, 0.3032, 0.2864]), 'METEOR': 0.2151, 'ROUGE': 0.4628, 'CIDEr': 2.3378}


# CCA-TS calculation V2

In [19]:
g_cls_freq = g_conf_list/13
g_cls_freq_norm = g_cls_freq/np.max(g_cls_freq)
g_cls_freq_norm = torch.tensor(g_cls_freq_norm).float()
g_temp = model_ts.graph_su_temperature.cpu() + (g_cls_freq_norm.exp()-1.0)*.1
# g_temp = g_temp.to(device)
model_ts.graph_su_temperature.i = nn.Parameter(g_temp.to(device))
    
c_cls_freq = c_conf_list/41#train_dataset.get_cls_num_list()
c_cls_freq_norm = c_cls_freq/np.max(c_cls_freq)
c_cls_freq_norm = torch.tensor(c_cls_freq_norm).float()
c_temp =  model_ts.caption_temperature.cpu() + (c_cls_freq_norm.exp()-1.0)*.1
# c_temp = c_temp.to(device)
model_ts.caption_temperature = nn.Parameter(c_temp.to(device))
    
g_loss, c_loss, g_acc, c_acc, g_logits, g_labels, c_logits, c_labels = ts_evaluation(model_ts, dict_dataloader_val, 
                                    dataloader_val, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
g_temperature_ece = ece_criterion(g_logits, g_labels).item()
c_temperature_ece = ece_criterion(c_logits, c_labels).item()
print('-----------------------------------------------------------------------')
print('CCA-TS-V2: G_loss:%.3f, G_acc:%.3f, G_ece:%.5f'%(g_loss, g_acc, g_temperature_ece))
print('CCA-TS-V2: C_loss:%.3f, C_acc:%.3f, C_ece:%.5f'%(c_loss, c_acc, c_temperature_ece))

scores, g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce = evaluate_metrics(model_ts.model, device, dict_dataloader_val, text_field, g_temp = model_ts.graph_su_temperature, c_temp = model_ts.caption_temperature)
print('CCA-TS-V2: Graph SU: acc: %0.6f map: %0.6f loss: %0.6f, ece:%0.6f, sce:%0.6f, tace:%0.6f, brier:%.6f, uce:%.6f' %(g_total_acc, g_map_value, g_total_loss, g_ece, g_sce, g_tace, g_brier, g_uce))
print("CCA-TS-V2: Caption scores :", scores)

evaluation:   0%|          | 0/447 [00:00<?, ?it/s]

-----------------------------------------------------------------------
CCA-TS-V2: G_loss:0.709, G_acc:0.575, G_ece:0.46582
CCA-TS-V2: C_loss:3.283, C_acc:0.022, C_ece:0.01962


evaluation: 100%|██████████| 447/447 [01:55<00:00,  3.87it/s]


acc: 0.575366 map: 0.320200 loss: 0.491720, ece:0.216305, sce:0.045280, tace:0.046805, brier:0.643167, uce:0.266481
CCA-TS-V2: Graph SU: acc: 0.575366 map: 0.320200 loss: 0.491720, ece:0.216305, sce:0.045280, tace:0.046805, brier:0.643167, uce:0.266481
CCA-TS-V2: Caption scores : {'BLEU': array([0.3623, 0.3231, 0.3032, 0.2864]), 'METEOR': 0.2151, 'ROUGE': 0.4628, 'CIDEr': 2.3378}
