In [1]:
'''
arguments
    Hyperparameters, file location, optimizer, network, data_processing
'''
ver = 'd2g_ecbs_t_resnet18_09'
f_e = 'resnet18_09'

class arguments():
    def __init__(self):
        
        # hyper parameters
        self.lr = 0.00001
        self.epoch = 251
        self.ft_epoch = 81
        self.start_epoch = 0
        self.batch_size = 32
        self.gpu = True
        self.print_every = 10 
        self.train_model = 'epoch'
        self.exp_ver= ver

        # file locations
        self.log_dir = './log/' + ver
        self.save_dir = './checkpoints/' + ver
        self.output_img_dir = './results/' + ver
        self.save_every = 10
        self.pretrained = None                 

        # optimizer
        self.optim='adam' # choices=['sgd', 'adam']

        # network
        self.layers= 1
        self.bn = False
        self.drop_prob = 0.3
        self.bias = True
        self.multi_attn = False
        self.diff_edge = False

        # data_processing
        self.sampler = 0
        self.data_aug = False
        self.feature_extractor = f_e
        
        # CBS
        self.use_cbs = True
        
        # temperature_scaling
        self.use_t = True
        self.t_scale = 1.5
        

In [2]:
'''
configurations of the network
    
    readout: G_ER_L_S = [1024+300+16+300+1024,  1024, 117]

    node_func: G_N_L_S = [1024+1024, 1024]
    node_lang_func: G_N_L_S2 = [300+300+300]
    
    edge_func : G_E_L_S = [1024*2+16, 1024]
    edge_lang_func: [300*2, 1024]
    
    attn: [1024, 1]
    attn_lang: [1024, 1]
'''

class CONFIGURATION(object):
    '''
    Configuration arguments: feature type, layer, bias, batch normalization, dropout, multi-attn
    
    readout           : fc_size, activation, bias, bn, droupout
    gnn_node          : fc_size, activation, bias, bn, droupout
    gnn_node_for_lang : fc_size, activation, bias, bn, droupout
    gnn_edge          : fc_size, activation, bias, bn, droupout
    gnn_edge_for_lang : fc_size, activation, bias, bn, droupout
    gnn_attn          : fc_size, activation, bias, bn, droupout
    gnn_attn_for_lang : fc_size, activation, bias, bn, droupout
    '''
    def __init__(self, layer=1, bias=True, bn=False, dropout=0.2, multi_attn=False):
        
        # if multi_attn:
        if True:
            if layer==1:
                feature_size = 512
                # readout
                self.G_ER_L_S = [feature_size+300+16+300+feature_size, feature_size, 13]
                self.G_ER_A   = ['ReLU', 'Identity']
                self.G_ER_B   = bias    #true
                self.G_ER_BN  = bn      #false
                self.G_ER_D   = dropout #0.3
                # self.G_ER_GRU = feature_size

                # # gnn node function
                self.G_N_L_S = [feature_size+feature_size, feature_size]
                self.G_N_A   = ['ReLU']
                self.G_N_B   = bias #true
                self.G_N_BN  = bn      #false
                self.G_N_D   = dropout #0.3
                # self.G_N_GRU = feature_size

                # # gnn node function for language
                self.G_N_L_S2 = [300+300, 300]
                self.G_N_A2   = ['ReLU']
                self.G_N_B2   = bias    #true
                self.G_N_BN2  = bn      #false
                self.G_N_D2   = dropout #0.3
                # self.G_N_GRU2 = feature_size

                # gnn edge function1
                self.G_E_L_S           = [feature_size*2+16, feature_size]
                self.G_E_A             = ['ReLU']
                self.G_E_B             = bias     # true
                self.G_E_BN            = bn       # false
                self.G_E_D             = dropout  # 0.3
                self.G_E_c_std         = 1.0
                self.G_E_c_std_factor  = 0.9      # 0.985 (LOG), 0.95 (gau)
                self.G_E_c_epoch       = 20
                self.G_E_c_kernel_size = 3
                self.G_E_c_filter      = 'LOG' # 'gau', 'LOG'

                # gnn edge function2 for language
                self.G_E_L_S2 = [300*2, feature_size]
                self.G_E_A2   = ['ReLU']
                self.G_E_B2   = bias     #true
                self.G_E_BN2  = bn       #false
                self.G_E_D2   = dropout  #0.3

                # gnn attention mechanism
                self.G_A_L_S = [feature_size, 1]
                self.G_A_A   = ['LeakyReLU']
                self.G_A_B   = bias     #true
                self.G_A_BN  = bn       #false
                self.G_A_D   = dropout  #0.3

                # gnn attention mechanism2 for language
                self.G_A_L_S2 = [feature_size, 1]
                self.G_A_A2   = ['LeakyReLU']
                self.G_A_B2   = bias    #true
                self.G_A_BN2  = bn      #false
                self.G_A_D2   = dropout #0.3
                    
    def save_config(self):
        model_config = {'graph_head':{}, 'graph_node':{}, 'graph_edge':{}, 'graph_attn':{}}
        CONFIG=self.__dict__
        for k, v in CONFIG.items():
            if 'G_H' in k:
                model_config['graph_head'][k]=v
            elif 'G_N' in k:
                model_config['graph_node'][k]=v
            elif 'G_E' in k:
                model_config['graph_edge'][k]=v
            elif 'G_A' in k:
                model_config['graph_attn'][k]=v
            else:
                model_config[k]=v
        
        return model_config

In [3]:
import math
import torch
import torch.nn as nn


def get_gaussian_filter_1D(kernel_size=3, sigma=2, channels=3):
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
    
    x_coord = torch.arange(kernel_size)
    x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()

    xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
    mean = (kernel_size - 1)/2.
    variance = sigma**2.
    xy_grid = torch.sum((xy_grid[:kernel_size,:kernel_size,:] - mean)**2., dim=-1)

    # Calculate the 1-dimensional gaussian kernel
    gaussian_kernel = (1./((math.sqrt(2.*math.pi)*sigma))) * \
                        torch.exp(-1* (xy_grid[int(kernel_size/2)]) / (2*variance))

    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size)
    gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1)

    padding = 1 if kernel_size==3 else 2 if kernel_size == 5 else 0
    gaussian_filter = nn.Conv1d(in_channels=channels, out_channels=channels,
                                kernel_size=kernel_size, groups=channels,
                                bias=False, padding=padding)
    gaussian_filter.weight.data = gaussian_kernel
    gaussian_filter.weight.requires_grad = False 
    return gaussian_filter

def get_laplaceOfGaussian_filter_1D(kernel_size=3, sigma=2, channels=3):
    
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
    x_coord = torch.arange(kernel_size)
    x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()
    xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
    mean = (kernel_size - 1)/2.

    used_sigma = sigma
    # Calculate the 2-dimensional gaussian kernel which is
    log_kernel = (-1./(math.pi*(used_sigma**4))) \
                  * (1-(torch.sum((xy_grid[int(kernel_size/2)] - mean)**2., dim=-1) / (2*(used_sigma**2)))) \
                  * torch.exp(-torch.sum((xy_grid[int(kernel_size/2)] - mean)**2., dim=-1) / (2*(used_sigma**2)))
    
    # Make sure sum of values in gaussian kernel equals 1.
    log_kernel = log_kernel / torch.sum(log_kernel)
    log_kernel = log_kernel.view(1, 1, kernel_size)
    log_kernel = log_kernel.repeat(channels, 1, 1)

    padding = 1 if kernel_size==3 else 2 if kernel_size == 5 else 0
    log_filter = nn.Conv1d(in_channels=channels, out_channels=channels,
                                kernel_size=kernel_size, groups=channels,
                                bias=False, padding=padding)
    log_filter.weight.data = log_kernel
    log_filter.weight.requires_grad = False
    return log_filter

In [4]:
'''
Primary activation and MLP layer
acivation:
    Identity
    ReLU
    LeakyReLU
MLP:
    init: layer size, activation, bias, use_BN, dropout_probability
    forward: x
'''

import torch.nn as nn
from collections import OrderedDict

class Identity(nn.Module):
    '''
    Identity class activation layer
    x = x
    '''
    def __init__(self):
        super(Identity,self).__init__()

    def forward(self, x):
        return x

def get_activation(name):
    '''
    get_activation sub-function
    argument: activatoin name (eg. ReLU, Identity, LeakyReLU)
    '''
    if name=='ReLU': return nn.ReLU(inplace=True)
    elif name=='Identity': return Identity()
    elif name=='LeakyReLU': return nn.LeakyReLU(0.2,inplace=True)
    else: assert(False), 'Not Implemented'
    #elif name=='Tanh': return nn.Tanh()
    #elif name=='Sigmoid': return nn.Sigmoid()

class MLP(nn.Module):
    '''
    Args:
        layer_sizes: a list, [1024,1024,...]
        activation: a list, ['ReLU', 'Tanh',...]
        bias : bool
        use_bn: bool
        drop_prob: default is None, use drop out layer or not
    '''
    def __init__(self, layer_sizes, activation, bias=True, use_bn=False, drop_prob=None):
        super(MLP, self).__init__()
        self.bn = use_bn
        self.layers = nn.ModuleList()
        for i in range(len(layer_sizes)-1):
            layer = nn.Linear(layer_sizes[i], layer_sizes[i+1], bias=bias)
            activate = get_activation(activation[i])
            block = nn.Sequential(OrderedDict([(f'L{i}', layer), ]))
            
            # !NOTE:# Actually, it is inappropriate to use batch-normalization here
            if use_bn:                                  
                bn = nn.BatchNorm1d(layer_sizes[i+1])
                block.add_module(f'B{i}', bn)
            
            # batch normalization is put before activation function 
            block.add_module(f'A{i}', activate)

            # dropout probablility
            if drop_prob:
                block.add_module(f'D{i}', nn.Dropout(drop_prob))
            
            self.layers.append(block)
    
    def forward(self, x):
        for layer in self.layers:
            # !NOTE: sometime the shape of x will be [1,N], and we cannot use batch-normailzation in that situation
            if self.bn and x.shape[0]==1:
                x = layer[0](x)
                x = layer[:-1](x)
            else:
                x = layer(x)
        return x


In [5]:
'''
H_H_EdgeApplyModule
    init    : config, multi_attn 
    forward : edge
    
H_NodeApplyModule
    init    : config
    forward : node
    
E_AttentionModule1
    init    : config
    forward : edge
    
GNN
    init    : config, multi_attn, diff_edge
    forward : g, h_node, o_node, h_h_e_list, o_o_e_list, h_o_e_list, pop_features
    
GRNN
    init    : config, multi_attn, diff_edge
    forward : b_graph, b_h_node_list, b_o_node_list, b_h_h_e_list, b_o_o_e_list, b_h_o_e_list, features, spatial_features, word2vec, valid, pop_features, initial_features
'''

import ipdb

import torch
import torch.nn as nn
import torch.nn.functional as F

class H_H_EdgeApplyModule(nn.Module): #human to human edge
    '''
        init    : config, multi_attn 
        forward : edge
    '''
    def __init__(self, CONFIG, multi_attn=False, use_cbs = False):
        super(H_H_EdgeApplyModule, self).__init__()
        self.use_cbs = use_cbs
        if use_cbs:
            self.init_std = CONFIG.G_E_c_std 
            self.cbs_std = CONFIG.G_E_c_std
            self.cbs_std_factor = CONFIG.G_E_c_std_factor
            self.cbs_epoch = CONFIG.G_E_c_epoch
            self.cbs_kernel_size = CONFIG.G_E_c_kernel_size
            self.cbs_filter = CONFIG.G_E_c_filter
        
        self.edge_fc = MLP(CONFIG.G_E_L_S, CONFIG.G_E_A, CONFIG.G_E_B, CONFIG.G_E_BN, CONFIG.G_E_D)
        self.edge_fc_lang = MLP(CONFIG.G_E_L_S2, CONFIG.G_E_A2, CONFIG.G_E_B2, CONFIG.G_E_BN2, CONFIG.G_E_D2)
    
    def forward(self, edge):
        feat = torch.cat([edge.src['n_f'], edge.data['s_f'], edge.dst['n_f']], dim=1)
        feat_lang = torch.cat([edge.src['word2vec'], edge.dst['word2vec']], dim=1)
        if self.use_cbs:
            feat = self.kernel1(feat[:,None,:])
            feat = torch.squeeze(feat, 1)
        e_feat = self.edge_fc(feat)
        e_feat_lang = self.edge_fc_lang(feat_lang)
  
        return {'e_f': e_feat, 'e_f_lang': e_feat_lang}

    def get_new_kernels(self, epoch_count):
        if self.use_cbs:
            if epoch_count == 0:
                self.cbs_std = self.init_std
                
            if epoch_count % self.cbs_epoch == 0 and epoch_count is not 0:
                self.cbs_std *= self.cbs_std_factor
            
            if (self.cbs_filter == 'gau'): 
                self.kernel1 = get_gaussian_filter_1D(kernel_size=self.cbs_kernel_size, sigma= self.cbs_std, channels= 1)
            elif (self.cbs_filter == 'LOG'): 
                self.kernel1 = get_laplaceOfGaussian_filter_1D(kernel_size=self.cbs_kernel_size, sigma= self.cbs_std, channels= 1)

class H_NodeApplyModule(nn.Module): #human node
    '''
        init    : config
        forward : node
    '''
    def __init__(self, CONFIG):
        super(H_NodeApplyModule, self).__init__()
        self.node_fc = MLP(CONFIG.G_N_L_S, CONFIG.G_N_A, CONFIG.G_N_B, CONFIG.G_N_BN, CONFIG.G_N_D)
        self.node_fc_lang = MLP(CONFIG.G_N_L_S2, CONFIG.G_N_A2, CONFIG.G_N_B2, CONFIG.G_N_BN2, CONFIG.G_N_D2)
    
    def forward(self, node):
        # import ipdb; ipdb.set_trace()
        feat = torch.cat([node.data['n_f'], node.data['z_f']], dim=1)
        feat_lang = torch.cat([node.data['word2vec'], node.data['z_f_lang']], dim=1)
        n_feat = self.node_fc(feat)
        n_feat_lang = self.node_fc_lang(feat_lang)

        return {'new_n_f': n_feat, 'new_n_f_lang': n_feat_lang}

class E_AttentionModule1(nn.Module): #edge attention
    '''
        init    : config
        forward : edge
    '''
    def __init__(self, CONFIG):
        super(E_AttentionModule1, self).__init__()
        self.attn_fc = MLP(CONFIG.G_A_L_S, CONFIG.G_A_A, CONFIG.G_A_B, CONFIG.G_A_BN, CONFIG.G_A_D)
        self.attn_fc_lang = MLP(CONFIG.G_A_L_S2, CONFIG.G_A_A2, CONFIG.G_A_B2, CONFIG.G_A_BN2, CONFIG.G_A_D2)

    def forward(self, edge):
        a_feat = self.attn_fc(edge.data['e_f'])
        a_feat_lang = self.attn_fc_lang(edge.data['e_f_lang'])
        return {'a_feat': a_feat, 'a_feat_lang': a_feat_lang}

class GNN(nn.Module):
    '''
        init    : config, multi_attn, diff_edge
        forward : g, h_node, o_node, h_h_e_list, o_o_e_list, h_o_e_list, pop_features
    '''
    def __init__(self, CONFIG, multi_attn=False, diff_edge=True, use_cbs = False):
        super(GNN, self).__init__()
        self.diff_edge = diff_edge # false
        self.apply_h_h_edge = H_H_EdgeApplyModule(CONFIG, multi_attn, use_cbs)
        self.apply_edge_attn1 = E_AttentionModule1(CONFIG)  
        self.apply_h_node = H_NodeApplyModule(CONFIG)

    def _message_func(self, edges):
        return {'nei_n_f': edges.src['n_f'], 'nei_n_w': edges.src['word2vec'], 'e_f': edges.data['e_f'], 'e_f_lang': edges.data['e_f_lang'], 'a_feat': edges.data['a_feat'], 'a_feat_lang': edges.data['a_feat_lang']}

    def _reduce_func(self, nodes):
        alpha = F.softmax(nodes.mailbox['a_feat'], dim=1)
        alpha_lang = F.softmax(nodes.mailbox['a_feat_lang'], dim=1)

        z_raw_f = nodes.mailbox['nei_n_f']+nodes.mailbox['e_f']
        z_f = torch.sum( alpha * z_raw_f, dim=1)

        z_raw_f_lang = nodes.mailbox['nei_n_w']
        z_f_lang = torch.sum(alpha_lang * z_raw_f_lang, dim=1)
         
        # we cannot return 'alpha' for the different dimension 
        if self.training or validation: return {'z_f': z_f, 'z_f_lang': z_f_lang}
        else: return {'z_f': z_f, 'z_f_lang': z_f_lang, 'alpha': alpha, 'alpha_lang': alpha_lang}

    def forward(self, g, h_node, o_node, h_h_e_list, o_o_e_list, h_o_e_list, pop_feat=False):
        
        g.apply_edges(self.apply_h_h_edge, g.edges())
        g.apply_edges(self.apply_edge_attn1)
        g.update_all(self._message_func, self._reduce_func)
        g.apply_nodes(self.apply_h_node, h_node+o_node)

        # !NOTE:PAY ATTENTION WHEN ADDING MORE FEATURE
        g.ndata.pop('n_f')
        g.ndata.pop('word2vec')

        g.ndata.pop('z_f')
        g.edata.pop('e_f')
        g.edata.pop('a_feat')

        g.ndata.pop('z_f_lang')
        g.edata.pop('e_f_lang')
        g.edata.pop('a_feat_lang')

class GRNN(nn.Module):
    '''
    init: 
        config, multi_attn, diff_edge
    forward: 
        batch_graph, batch_h_node_list, batch_obj_node_list,
        batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list,
        features, spatial_features, word2vec,
        valid, pop_features, initial_features
    '''
    def __init__(self, CONFIG, multi_attn=False, diff_edge=True, use_cbs = False):
        super(GRNN, self).__init__()
        self.multi_attn = multi_attn #false
        self.gnn = GNN(CONFIG, multi_attn, diff_edge, use_cbs)

    def forward(self, batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec, valid=False, pop_feat=False, initial_feat=False):
        
        # !NOTE: if node_num==1, there will be something wrong to forward the attention mechanism
        global validation 
        validation = valid

        # initialize the graph with some datas
        batch_graph.ndata['n_f'] = feat           # node: features 
        batch_graph.ndata['word2vec'] = word2vec  # node: words
        batch_graph.edata['s_f'] = spatial_feat   # edge: spatial features

        try:
            self.gnn(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list)
        except Exception as e:
            print(e)
            ipdb.set_trace()
        

In [6]:
'''
Predictor 
    init    : config
    forward : edge

AGRNN
    init    : bias, bn, dropout, multi_attn, layer, diff_edge
    forward : node_num, feat, spatial_feat, word2vec, roi_label, validation, choose_nodes, remove_nodes
'''

import dgl
import ipdb
import numpy as np

import torch
import torch.nn as nn
#import torchvision

class Predictor(nn.Module):
    '''
    init    : config
    forward : edge
    '''
    def __init__(self, CONFIG):
        super(Predictor, self).__init__()
        self.classifier = MLP(CONFIG.G_ER_L_S, CONFIG.G_ER_A, CONFIG.G_ER_B, CONFIG.G_ER_BN, CONFIG.G_ER_D)
        self.sigmoid = nn.Sigmoid()

    def forward(self, edge):
        feat = torch.cat([edge.dst['new_n_f'], edge.dst['new_n_f_lang'], edge.data['s_f'], edge.src['new_n_f_lang'], edge.src['new_n_f']], dim=1)
        pred = self.classifier(feat)
        # if the criterion is BCELoss, you need to uncomment the following code
        # output = self.sigmoid(output)
        return {'pred': pred}

class AGRNN(nn.Module):
    '''
    init    : 
        feature_type, bias, bn, dropout, multi_attn, layer, diff_edge
        
    forward : 
        node_num, features, spatial_features, word2vec, roi_label,
        validation, choose_nodes, remove_nodes
    '''
    def __init__(self, bias=True, bn=True, dropout=None, multi_attn=False, layer=1, diff_edge=True, use_cbs = False):
        super(AGRNN, self).__init__()
 
        self.multi_attn = multi_attn # false
        self.layer = layer           # 1 layer
        self.diff_edge = diff_edge   # false
        
        self.CONFIG1 = CONFIGURATION(layer=1, bias=bias, bn=bn, dropout=dropout, multi_attn=multi_attn)

        self.grnn1 = GRNN(self.CONFIG1, multi_attn=multi_attn, diff_edge=diff_edge, use_cbs = use_cbs)
        self.edge_readout = Predictor(self.CONFIG1)
        
    def _collect_edge(self, node_num, roi_label, node_space, diff_edge):
        '''
        arguments: node_num, roi_label, node_space, diff_edge
        '''
        
        # get human nodes && object nodes
        h_node_list = np.where(roi_label == 0)[0]
        obj_node_list = np.where(roi_label != 0)[0]
        edge_list = []
        
        h_h_e_list = []
        o_o_e_list = []
        h_o_e_list = []
        
        readout_edge_list = []
        readout_h_h_e_list = []
        readout_h_o_e_list = []
        
        # get all edge in the fully-connected graph, edge_list, For node_num = 2, edge_list = [(0, 1), (1, 0)]
        for src in range(node_num):
            for dst in range(node_num):
                if src == dst:
                    continue
                else:
                    edge_list.append((src, dst))
        
        # readout_edge_list, get corresponding readout edge in the graph
        src_box_list = np.arange(roi_label.shape[0])
        for dst in h_node_list:
            # if dst == roi_label.shape[0]-1:
            #    continue
            # src_box_list = src_box_list[1:]
            for src in src_box_list:
                if src not in h_node_list:
                    readout_edge_list.append((src, dst))
        
        # readout h_h_e_list, get corresponding readout h_h edges && h_o edges
        temp_h_node_list = h_node_list[:]
        for dst in h_node_list:
            if dst == h_node_list.shape[0]-1:
                continue
            temp_h_node_list = temp_h_node_list[1:]
            for src in temp_h_node_list:
                if src == dst: continue
                readout_h_h_e_list.append((src, dst))

        # readout h_o_e_list
        readout_h_o_e_list = [x for x in readout_edge_list if x not in readout_h_h_e_list]

        # add node space to match the batch graph
        h_node_list = (np.array(h_node_list)+node_space).tolist()
        obj_node_list = (np.array(obj_node_list)+node_space).tolist()
        
        h_h_e_list = (np.array(h_h_e_list)+node_space).tolist() #empty no diff_edge
        o_o_e_list = (np.array(o_o_e_list)+node_space).tolist() #empty no diff_edge
        h_o_e_list = (np.array(h_o_e_list)+node_space).tolist() #empty no diff_edge

        readout_h_h_e_list = (np.array(readout_h_h_e_list)+node_space).tolist()
        readout_h_o_e_list = (np.array(readout_h_o_e_list)+node_space).tolist()   
        readout_edge_list = (np.array(readout_edge_list)+node_space).tolist()

        return edge_list, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list
    
    def _build_graph(self, node_num, roi_label, node_space, diff_edge):
        '''
        Declare graph, add_nodes, collect edges, add_edges
        '''
        graph = dgl.DGLGraph()
        graph.add_nodes(node_num)

        edge_list, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list = self._collect_edge(node_num, roi_label, node_space, diff_edge)
        src, dst = tuple(zip(*edge_list))
        graph.add_edges(src, dst)   # make the graph bi-directional

        return graph, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list

    def forward(self, node_num=None, feat=None, spatial_feat=None, word2vec=None, roi_label=None, validation=False, choose_nodes=None, remove_nodes=None):
        
        batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, batch_readout_edge_list, batch_readout_h_h_e_list, batch_readout_h_o_e_list = [], [], [], [], [], [], [], [], []
        node_num_cum = np.cumsum(node_num) # !IMPORTANT
        
        for i in range(len(node_num)):
            # set node space
            node_space = 0
            if i != 0:
                node_space = node_num_cum[i-1]
            graph, h_node_list, obj_node_list, h_h_e_list, o_o_e_list, h_o_e_list, readout_edge_list, readout_h_h_e_list, readout_h_o_e_list = self._build_graph(node_num[i], roi_label[i], node_space, diff_edge=self.diff_edge)
            
            # updata batch
            batch_graph.append(graph)
            batch_h_node_list += h_node_list
            batch_obj_node_list += obj_node_list
            
            batch_h_h_e_list += h_h_e_list
            batch_o_o_e_list += o_o_e_list
            batch_h_o_e_list += h_o_e_list
            
            batch_readout_edge_list += readout_edge_list
            batch_readout_h_h_e_list += readout_h_h_e_list
            batch_readout_h_o_e_list += readout_h_o_e_list
        
        batch_graph = dgl.batch(batch_graph)
        
        # GRNN
        self.grnn1(batch_graph, batch_h_node_list, batch_obj_node_list, batch_h_h_e_list, batch_o_o_e_list, batch_h_o_e_list, feat, spatial_feat, word2vec, validation, initial_feat=True)
        batch_graph.apply_edges(self.edge_readout, tuple(zip(*(batch_readout_h_o_e_list+batch_readout_h_h_e_list))))
        
        if self.training or validation:
            # !NOTE: cannot use "batch_readout_h_o_e_list+batch_readout_h_h_e_list" because of the wrong order
            return batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['pred']
        else:
            return batch_graph.edges[tuple(zip(*batch_readout_edge_list))].data['pred'], \
                   batch_graph.nodes[batch_h_node_list].data['alpha'], \
                   batch_graph.nodes[batch_h_node_list].data['alpha_lang'] 


In [7]:
import os
import utils.io as io

class SurgicalSceneConstants():
    def __init__( self):
        self.instrument_classes = ('kidney', 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
                      'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier',
                      'stapler', 'maryland_dissector', 'spatulated_monopolar_cautery')
        
        #self.instrument_classes = ( 'kidney', 'bipolar_forceps', 'fenestrated_bipolar', 
        #                             'prograsp_forceps', 'large_needle_driver', 'vessel_sealer',
        #                             'grasping_retractor', 'monopolar_curved_scissors', 
        #                             'ultrasound_probe', 'suction', 'clip_applier', 'stapler')
        
        self.action_classes = ( 'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation', 
                                'Tool_Manipulation', 'Cutting', 'Cauterization', 
                                'Suction', 'Looping', 'Suturing', 'Clipping', 'Staple', 
                                'Ultrasound_Sensing')
        self.xml_data_dir = 'datasets/instruments18/seq_'
        self.word2vec_loc = 'datasets/surgicalscene_word2vec.hdf5'

In [8]:
import sys
import random

import h5py
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset

import os
from glob import glob
    
class SurgicalSceneDataset(Dataset):
    '''
    '''
    def __init__(self, seq_set, data_dir, img_dir, dset, dataconst, feature_extractor, reduce_size = False):
        
        
        self.data_size = 143
        self.dataconst = dataconst
        self.img_dir = img_dir
        self.feature_extractor = feature_extractor
        self.reduce_size = reduce_size
        
        self.xml_dir_list = []
        self.dset = []
        
        for domain in range(len(seq_set)):
            domain_dir_list = []
            for i in seq_set[domain]:
                xml_dir_temp = data_dir[domain] + str(i) + '/xml/'
                domain_dir_list = domain_dir_list + glob(xml_dir_temp + '/*.xml')
            if self.reduce_size:
                indices = np.random.permutation(len(domain_dir_list))
                domain_dir_list = [domain_dir_list[j] for j in indices[0:self.data_size]]
            for file in domain_dir_list: 
                self.xml_dir_list.append(file)
                self.dset.append(dset[domain])
        self.word2vec = h5py.File('datasets/surgicalscene_word2vec.hdf5', 'r')
    
    # word2vec
    def _get_word2vec(self,node_ids, sgh = 0):
        word2vec = np.empty((0,300))
        for node_id in node_ids:
            if sgh == 1 and node_id == 0:
                vec = self.word2vec['tissue']
            else:
                vec = self.word2vec[self.dataconst.instrument_classes[node_id]]
            word2vec = np.vstack((word2vec, vec))
        return word2vec

    def __len__(self):
        return len(self.xml_dir_list)

    def __getitem__(self, idx):
    
        file_name = os.path.splitext(os.path.basename(self.xml_dir_list[idx]))[0]
        file_root = os.path.dirname(os.path.dirname(self.xml_dir_list[idx]))
        if len(self.img_dir) == 1:
            _img_loc = os.path.join(file_root+self.img_dir[0]+ file_name + '.png')
        else:
            _img_loc = os.path.join(file_root+self.img_dir[self.dset[idx]]+ file_name + '.png')
        frame_data = h5py.File(os.path.join(file_root+'/vsgat/'+self.feature_extractor+'/'+ file_name + '_features.hdf5'), 'r')    
        data = {}
        data['img_name'] = frame_data['img_name'].value[:] + '.jpg'
        data['img_loc'] = _img_loc
        
        data['node_num'] = frame_data['node_num'].value
        data['roi_labels'] = frame_data['classes'][:]
        data['det_boxes'] = frame_data['boxes'][:]
        
        
        data['edge_labels'] = frame_data['edge_labels'][:]
        data['edge_num'] = data['edge_labels'].shape[0]
        
        data['features'] = frame_data['node_features'][:]
        data['spatial_feat'] = frame_data['spatial_features'][:]
        
        
        data['word2vec'] = self._get_word2vec(data['roi_labels'], self.dset[idx])
        return data

# for DatasetLoader
def collate_fn(batch):
    '''
        Default collate_fn(): https://github.com/pytorch/pytorch/blob/1d53d0756668ce641e4f109200d9c65b003d05fa/torch/utils/data/_utils/collate.py#L43
    '''
    batch_data = {}
    batch_data['img_name'] = []
    batch_data['img_loc'] = []
    batch_data['node_num'] = []
    batch_data['roi_labels'] = []
    batch_data['det_boxes'] = []
    batch_data['edge_labels'] = []
    batch_data['edge_num'] = []
    batch_data['features'] = []
    batch_data['spatial_feat'] = []
    batch_data['word2vec'] = []
    
    for data in batch:
        batch_data['img_name'].append(data['img_name'])
        batch_data['img_loc'].append(data['img_loc'])
        batch_data['node_num'].append(data['node_num'])
        batch_data['roi_labels'].append(data['roi_labels'])
        batch_data['det_boxes'].append(data['det_boxes'])
        batch_data['edge_labels'].append(data['edge_labels'])
        batch_data['edge_num'].append(data['edge_num'])
        batch_data['features'].append(data['features'])
        batch_data['spatial_feat'].append(data['spatial_feat'])
        batch_data['word2vec'].append(data['word2vec'])
        
    batch_data['edge_labels'] = torch.FloatTensor(np.concatenate(batch_data['edge_labels'], axis=0))
    batch_data['features'] = torch.FloatTensor(np.concatenate(batch_data['features'], axis=0))
    batch_data['spatial_feat'] = torch.FloatTensor(np.concatenate(batch_data['spatial_feat'], axis=0))
    batch_data['word2vec'] = torch.FloatTensor(np.concatenate(batch_data['word2vec'], axis=0))
    
    return batch_data

In [9]:
import time

import random
import numpy as np
import matplotlib
import torch as t

matplotlib.use('Agg')
from matplotlib import pyplot as plot
from PIL import Image, ImageDraw, ImageFont


def vis_img(img, node_classes, bboxs,  det_action, score_thresh = 0.7):
    
    Drawer = ImageDraw.Draw(img)
    line_width = 3
    outline = '#FF0000'
    font = ImageFont.truetype(font='/usr/share/fonts/truetype/freefont/FreeMono.ttf', size=25)
    
    im_w,im_h = img.size
    node_num = len(node_classes)
    edge_num = len(det_action)
    tissue_num = len(np.where(node_classes == 1)[0])
    
    for node in range(node_num):
        
        r_color = random.choice(np.arange(256))
        g_color = random.choice(np.arange(256))
        b_color = random.choice(np.arange(256))
        
        text = data_const.instrument_classes[node_classes[node]]
        h, w = font.getsize(text)
        Drawer.rectangle(list(bboxs[node]), outline=outline, width=line_width)
        Drawer.text(xy=(bboxs[node][0], bboxs[node][1]-w-1), text=text, font=font, fill=(r_color,g_color,b_color))
  
    edge_idx = 0
    
    for tissue in range(tissue_num):
        for instrument in range(tissue+1, node_num):
            
            #action_idx = np.where(det_action[edge_idx] > score_thresh)[0]
            action_idx = np.argmax(det_action[edge_idx])
#             print('det_action', det_action[edge_idx])
#             print('action_idx',action_idx)
            
            text = data_const.action_classes[action_idx]
            r_color = random.choice(np.arange(256))
            g_color = random.choice(np.arange(256))
            b_color = random.choice(np.arange(256))
        
            x1,y1,x2,y2 = bboxs[tissue]
            x1_,y1_,x2_,y2_ = bboxs[instrument]
            
            c0 = int(0.5*x1)+int(0.5*x2)
            c0 = max(0,min(c0,im_w-1))
            r0 = int(0.5*y1)+int(0.5*y2)
            r0 = max(0,min(r0,im_h-1))
            c1 = int(0.5*x1_)+int(0.5*x2_)
            c1 = max(0,min(c1,im_w-1))
            r1 = int(0.5*y1_)+int(0.5*y2_)
            r1 = max(0,min(r1,im_h-1))
            Drawer.line(((c0,r0),(c1,r1)), fill=(r_color,g_color,b_color), width=3)
            Drawer.text(xy=(c1, r1), text=text, font=font, fill=(r_color,g_color,b_color))

            edge_idx +=1

    return img

In [10]:
from __future__ import print_function

import os
import copy
import time

import numpy as np
from tqdm import tqdm
from PIL import Image
import utils.io as io
#from utils.vis_tool import vis_img

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter

def run_model(args, data_const):
    '''
    
    '''

    # use cpu or cuda
    device = torch.device('cuda' if torch.cuda.is_available() and args.gpu else 'cpu')
    print('training on {}...'.format(device))

    # model
    model = AGRNN(bias=args.bias, bn=args.bn, dropout=args.drop_prob, multi_attn=args.multi_attn, layer=args.layers, diff_edge=args.diff_edge, use_cbs = args.use_cbs)
    if args.use_cbs: model.grnn1.gnn.apply_h_h_edge.get_new_kernels(0)
    
    # calculate the amount of all the learned parameters
    parameter_num = 0
    for param in model.parameters(): parameter_num += param.numel()
    print(f'The parameters number of the model is {parameter_num / 1e6} million')

    # load pretrained model
    if args.pretrained:
        print(f"loading pretrained model {args.pretrained}")
        checkpoints = torch.load(args.pretrained, map_location=device)
        model.load_state_dict(checkpoints['state_dict'])
    model.to(device)
    
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.3) #the scheduler divides the lr by 10 every 150 epochs

    # get the configuration of the model and save some key configurations
    io.mkdir_if_not_exists(os.path.join(args.save_dir, args.exp_ver), recursive=True)
    for i in range(args.layers):
        if i==0:
            model_config = model.CONFIG1.save_config()
            model_config['lr'] = args.lr
            model_config['bs'] = args.batch_size
            model_config['layers'] = args.layers
            model_config['multi_attn'] = args.multi_attn
            model_config['data_aug'] = args.data_aug
            model_config['drop_out'] = args.drop_prob
            model_config['optimizer'] = args.optim
            model_config['diff_edge'] = args.diff_edge
            model_config['model_parameters'] = parameter_num
            io.dump_json_object(model_config, os.path.join(args.save_dir, args.exp_ver, 'l1_config.json'))
    print('save key configurations successfully...')

    # domain 1
    train_seq = [[2,3,4,6,7,9,10,11,12,14,15]]
    val_seq = [[1,5,16]]
    data_dir = ['datasets/instruments18/seq_']
    img_dir = ['/left_frames/']
    dset = [0] # 0 for ISC, 1 for SGH
    seq = {'train_seq': train_seq, 'val_seq': val_seq, 'data_dir': data_dir, 'img_dir':img_dir, 'dset': dset}
    print('======================== Domain 1 ==============================')
    epoch_train(args, model,seq, device, "D1")
    
    # domain 2
    train_seq = [[2,3,4,6,7,9,10,11,12,14,15], [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]]
    val_seq = [[1,5,16],[16,17,18,19,20,21,22]]
    data_dir = ['datasets/instruments18/seq_', 'datasets/SGH_dataset_2020/']
    img_dir = ['/left_frames/', '/resized_frames/']
    dset = [0, 1]
    seq = {'train_seq': train_seq, 'val_seq': val_seq, 'data_dir': data_dir, 'img_dir':img_dir, 'dset': dset}
    print('======================== Domain 2 ==============================')
    epoch_train(args, model,seq, device, "D2")
    print('======================== Domain 1-2 FT =========================')
    epoch_train(args, model,seq, device, "D2F", finetune = True)
    

def epoch_train(args, model, seq, device, dname, finetune = False):
    '''
    input: model, dataloader, dataset, criterain, optimizer, scheduler, device, data_const
    data: 
        img_name, node_num, roi_labels, det_boxes, edge_labels,
        edge_num, features, spatial_features, word2vec
    '''
    
    new_domain = False
    stop_epoch = args.epoch
    
    if finetune:
        stop_epoch = args.ft_epoch
        train_dataset = SurgicalSceneDataset(seq_set = seq['train_seq'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = True)
        val_dataset = SurgicalSceneDataset(seq_set = seq['val_seq'], dset = seq['dset'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = False)
        dataset = {'train': train_dataset, 'val': val_dataset}
        model_old = None
    
    # train and test dataset for one domain
    elif (len(seq['train_seq']) == 1):
        # set up dataset variable
        train_dataset = SurgicalSceneDataset(seq_set = seq['train_seq'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = False)
        val_dataset = SurgicalSceneDataset(seq_set = seq['val_seq'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = False)
        dataset = {'train': train_dataset, 'val': val_dataset}
        model_old = None
   
    # train and test for multiple domain
    elif (len(seq['train_seq']) > 1):
        # set up dataset variable
        new_domain = True
        curr_tr_seq = seq['train_seq'][len(seq['train_seq'])-1:]
        curr_tr_data_dir = seq['data_dir'][len(seq['data_dir'])-1:]
        curr_tr_img_dir = seq['img_dir'][len(seq['img_dir'])-1:]
        curr_dset = seq['dset'][len(seq['dset'])-1:]
        #print(curr_tr_seq, curr_tr_data_dir, curr_tr_img_dir, curr_dset)
        train_dataset = SurgicalSceneDataset(seq_set = curr_tr_seq, data_dir = curr_tr_data_dir, \
                            img_dir = curr_tr_img_dir, dset = curr_dset, dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = False)
        val_dataset = SurgicalSceneDataset(seq_set = seq['val_seq'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = False)
        dataset = {'train': train_dataset, 'val': val_dataset}
        model_old = copy.deepcopy(model)
    
    # use default DataLoader() to load the data. 
    train_dataloader = DataLoader(dataset=dataset['train'], batch_size=args.batch_size, shuffle= True, collate_fn=collate_fn)
    val_dataloader = DataLoader(dataset=dataset['val'], batch_size=args.batch_size, shuffle= True, collate_fn=collate_fn)
    dataloader = {'train': train_dataloader, 'val': val_dataloader}
    
    # criterion and scheduler
    criterion = nn.MultiLabelSoftMarginLoss()
    # criterion = nn.BCEWithLogitsLoss()
    
    # set visualization and create folder to save checkpoints
    writer = SummaryWriter(log_dir=args.log_dir + '/' + args.exp_ver + '/' + 'epoch_train')
    io.mkdir_if_not_exists(os.path.join(args.save_dir, args.exp_ver, 'epoch_train'), recursive=True)

    for epoch in range(args.start_epoch, stop_epoch):
        
        # each epoch has a training and validation step
        epoch_acc = 0
        epoch_loss = 0
        
        # finetune
        if finetune:
            train_dataset = SurgicalSceneDataset(seq_set = seq['train_seq'], data_dir = seq['data_dir'], \
                                img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                                feature_extractor = args.feature_extractor, reduce_size = True)
            dataset['train'] = train_dataset
            train_dataloader = DataLoader(dataset=dataset['train'], batch_size=args.batch_size, shuffle= True, collate_fn=collate_fn)
            dataloader['train'] = train_dataloader

        # build optimizer  
        if finetune: lrc = args.lr / 10
        else: lrc = args.lr
        
        if args.optim == 'sgd': 
            optimizer = optim.SGD(model.parameters(), lr= lrc, momentum=0.9, weight_decay=0)
        else: 
            optimizer = optim.Adam(model.parameters(), lr= lrc, weight_decay=0)
        
        for phase in ['train', 'val']:
            
            start_time = time.time()
            
            idx = 0
            running_acc = 0.0
            running_loss = 0.0
            running_edge_count = 0
            
            if phase == 'train' and args.use_cbs:
                model.grnn1.gnn.apply_h_h_edge.get_new_kernels(epoch)
                model.to(device)
            
            #print(len(dataloader[phase]))
            #for data in tqdm(dataloader[phase]):
            for data in dataloader[phase]:
                train_data = data
                img_name = train_data['img_name']
                img_loc = train_data['img_loc']
                node_num = train_data['node_num']
                roi_labels = train_data['roi_labels']
                det_boxes = train_data['det_boxes']
                edge_labels = train_data['edge_labels']
                edge_num = train_data['edge_num']
                features = train_data['features']
                spatial_feat = train_data['spatial_feat']
                word2vec = train_data['word2vec']
                features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
                
                if phase == 'train':
                    model.train()
                    model.zero_grad()
                    outputs = model(node_num, features, spatial_feat, word2vec, roi_labels)
                    
                    # loss and accuracy
                    if args.use_t: outputs = outputs / args.t_scale
                    loss = criterion(outputs, edge_labels.float())
                    loss.backward()
                    optimizer.step()
                    acc = np.sum(np.equal(np.argmax(outputs.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))

                else:
                    model.eval()
                    # turn off the gradients for validation, save memory and computations
                    with torch.no_grad():
                        outputs = model(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
                        
                        # loss and accuracy
                        loss = criterion(outputs, edge_labels.float())
                        acc = np.sum(np.equal(np.argmax(outputs.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))
                    
                        # print result every 1000 iteration during validation
                        if idx == 10:
                            #print(img_loc[0])
                            io.mkdir_if_not_exists(os.path.join(args.output_img_dir, ('epoch_'+str(epoch))), recursive=True)
                            image = Image.open(img_loc[0]).convert('RGB')
                            det_actions = nn.Sigmoid()(outputs[0:int(edge_num[0])])
                            det_actions = det_actions.cpu().detach().numpy()
                            action_img = vis_img(image, roi_labels[0], det_boxes[0],  det_actions, score_thresh = 0.7)
                            image = image.save(os.path.join(args.output_img_dir, ('epoch_'+str(epoch)),img_name[0]))

                idx+=1
                # accumulate loss of each batch
                running_loss += loss.item() * edge_labels.shape[0]
                running_acc += acc
                running_edge_count += edge_labels.shape[0]
            
            # distillation learning
            if phase == 'train' and new_domain:
                
                # distillation loss activation
                dist_loss_act = nn.Softmax(dim=1)
                dist_loss_act = dist_loss_act.to(device)
            
                dis_seq = seq['train_seq'][:-1]
                dis_data_dir = seq['data_dir'][:-1]
                dis_img_dir = seq['img_dir'][:-1]
                dis_dset = seq['dset'][:-1]
                dis_train_dataset = SurgicalSceneDataset(seq_set =  dis_seq, data_dir = dis_data_dir, \
                                        img_dir = dis_img_dir, dset = dis_dset, dataconst = data_const, \
                                        feature_extractor = args.feature_extractor, reduce_size = True)
                dis_train_dataloader = DataLoader(dataset=dis_train_dataset, batch_size=args.batch_size, shuffle= True, collate_fn=collate_fn)
                
#                 if args.use_cbs:
#                     model_old.grnn1.gnn.apply_h_h_edge.get_new_kernels(epoch)
#                     model_old.to(device)
        
                #print(len(dis_train_dataloader))
                #for data in tqdm(dataloader[phase]):
                for data in dis_train_dataloader:
                    train_data = data
                    img_name = train_data['img_name']
                    img_loc = train_data['img_loc']
                    node_num = train_data['node_num']
                    roi_labels = train_data['roi_labels']
                    det_boxes = train_data['det_boxes']
                    edge_labels = train_data['edge_labels']
                    edge_num = train_data['edge_num']
                    features = train_data['features']
                    spatial_feat = train_data['spatial_feat']
                    word2vec = train_data['word2vec']
                    features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
                    
                    model.train()
                    model_old.train()
                    model.zero_grad()
                    outputs = model(node_num, features, spatial_feat, word2vec, roi_labels)
                    
                    with torch.no_grad():
                        # old network output
                        output_old = model_old(node_num, features, spatial_feat, word2vec, roi_labels)
                        output_old = Variable(output_old, requires_grad=False)
                    
                    if args.use_t:
                        outputs = outputs/args.t_scale
                        output_old = output_old/args.t_scale
                    d_loss = F.binary_cross_entropy(dist_loss_act(outputs), dist_loss_act(output_old))
                    loss = criterion(outputs, edge_labels.float()) + 0.5* d_loss
                    
                    # loss and accuracy
                    loss.backward()
                    optimizer.step()
            
            # calculate the loss and accuracy of each epoch
            epoch_loss = running_loss / len(dataset[phase])
            epoch_acc = running_acc / running_edge_count
            
            # import ipdb; ipdb.set_trace()
            # log trainval datas, and visualize them in the same graph
            if phase == 'train':
                train_loss = epoch_loss 
            else:
                writer.add_scalars('trainval_loss_epoch', {'train': train_loss, 'val': epoch_loss}, epoch)
            
            # print data
            if (epoch % args.print_every) == 0:
                end_time = time.time()
                print("[{}] Epoch: {}/{} Acc: {:0.6f} Loss: {:0.6f} Execution time: {:0.6f}".format(\
                        phase, epoch+1, args.epoch, epoch_acc, epoch_loss, (end_time-start_time)))
                        
        # scheduler.step()
        # save model
        if epoch_loss<0.0405 or epoch % args.save_every == (args.save_every - 1) and epoch >= (20-1):
            checkpoint = { 
                            'lr': args.lr,
                           'b_s': args.batch_size,
                          'bias': args.bias, 
                            'bn': args.bn, 
                       'dropout': args.drop_prob,
                        'layers': args.layers,
                    'multi_head': args.multi_attn,
                     'diff_edge': args.diff_edge,
                    'state_dict': model.state_dict()
            }
            save_name = "checkpoint_" + dname + str(epoch+1) + '_epoch.pth'
            torch.save(checkpoint, os.path.join(args.save_dir, args.exp_ver, 'epoch_train', save_name))

    writer.close()


In [11]:
def seed_everything(seed=27):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
if __name__ == "__main__":
    
    seed_everything()
    args = arguments()
    print(args.feature_extractor)
    data_const = SurgicalSceneConstants()
    run_model(args, data_const)



resnet18_09
training on cuda...
The parameters number of the model is 2.393694 million
save key configurations successfully...




[train] Epoch: 1/251 Acc: 0.057687 Loss: 3.139412 Execution time: 5.549406
[val] Epoch: 1/251 Acc: 0.003445 Loss: 1.609188 Execution time: 1.497931
[train] Epoch: 11/251 Acc: 0.291924 Loss: 0.988341 Execution time: 5.452029
[val] Epoch: 11/251 Acc: 0.412575 Loss: 0.609609 Execution time: 1.501824
[train] Epoch: 21/251 Acc: 0.305876 Loss: 0.878071 Execution time: 5.837291
[val] Epoch: 21/251 Acc: 0.431525 Loss: 0.491744 Execution time: 1.577329
[train] Epoch: 31/251 Acc: 0.310437 Loss: 0.824750 Execution time: 5.934913
[val] Epoch: 31/251 Acc: 0.475452 Loss: 0.456061 Execution time: 1.623235
[train] Epoch: 41/251 Acc: 0.345318 Loss: 0.799243 Execution time: 5.986010
[val] Epoch: 41/251 Acc: 0.541774 Loss: 0.436969 Execution time: 1.616955
[train] Epoch: 51/251 Acc: 0.344781 Loss: 0.786648 Execution time: 6.172334
[val] Epoch: 51/251 Acc: 0.549526 Loss: 0.436317 Execution time: 1.680262
[train] Epoch: 61/251 Acc: 0.378320 Loss: 0.773164 Execution time: 6.091772
[val] Epoch: 61/251 Acc: 0

[val] Epoch: 11/251 Acc: 0.485557 Loss: 0.447473 Execution time: 2.170462
[train] Epoch: 21/251 Acc: 0.437586 Loss: 0.782782 Execution time: 1.170640
[val] Epoch: 21/251 Acc: 0.493810 Loss: 0.443743 Execution time: 2.080318
[train] Epoch: 31/251 Acc: 0.411357 Loss: 0.772453 Execution time: 1.240849
[val] Epoch: 31/251 Acc: 0.502751 Loss: 0.440645 Execution time: 2.175696
[train] Epoch: 41/251 Acc: 0.417790 Loss: 0.779149 Execution time: 1.234529
[val] Epoch: 41/251 Acc: 0.517882 Loss: 0.438503 Execution time: 2.118100
[train] Epoch: 51/251 Acc: 0.408219 Loss: 0.764133 Execution time: 1.097041
[val] Epoch: 51/251 Acc: 0.522696 Loss: 0.436063 Execution time: 2.082355
[train] Epoch: 61/251 Acc: 0.447945 Loss: 0.770259 Execution time: 1.107146
[val] Epoch: 61/251 Acc: 0.525447 Loss: 0.436758 Execution time: 2.233952
[train] Epoch: 71/251 Acc: 0.444898 Loss: 0.766205 Execution time: 1.105726
[val] Epoch: 71/251 Acc: 0.524072 Loss: 0.435311 Execution time: 2.156014
[train] Epoch: 81/251 Acc:

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

def reliability_diagram_multi(conf_avg, acc_avg, legend=None, leg_idx=0, n_bins=10):
    plt.figure(2)
    plt.plot([0, 1], [0, 1], linestyle='--')
    plt.xlabel('Confidence')
    plt.ylabel('Accuracy')
    plt.xticks(np.arange(0, 1.1, 1/n_bins))
    #plt.title(title)
    plt.plot(conf_avg[acc_avg>0],acc_avg[acc_avg>0], marker='.', label = legend)
    plt.legend()
    plt.savefig('ece_rel_multi.png',dpi=300)

def calibration_metrics(logits_all, labels_all, model_name='Class-aware TS'):
    uce = uceloss( logits_all.cpu(), labels_all.cpu())
    logits = logits_all.detach().cpu().numpy()
    labels = labels_all.detach().cpu().numpy()
    ece, acc, conf, Bm = ece_eval(logits, labels, bg_cls=-1)
    sce = get_sce(logits, labels)
    tace = get_tace(logits, labels)
    brier = get_brier(logits, labels)
    print('%s:, ece:%0.4f, sce:%0.4f, tace:%0.4f, brier:%.4f, uce:%.4f' %(model_name, ece, sce, tace, brier, uce.item()) )
    reliability_diagram_multi(conf, acc, legend=model_name)

logits_all_sm = F.softmax(logits_all_cts, dim=1)
calibration_metrics(logits_all_sm, labels_all_cts, model_name='Class-aware TS')

logits_all_sm = F.softmax(logits_all_ts, dim=1)
calibration_metrics(logits_all_sm, labels_all_ts, model_name='TS')

logits_all_sm = F.softmax(logits_all, dim=1)
calibration_metrics(logits_all_sm, labels_all, model_name='CE')