In [6]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [24]:
kernel_size = 5

sigma = 1.4
channels = 3

x_coord = torch.arange(kernel_size) #([0, 1, 2])
x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) # ([0,1,2], [0,1,2], [0,1,2])
y_grid = x_grid.t() # ([0,0,0], [1,1,1], [2,2,2])
xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
mean = (kernel_size - 1)/2. # 1 for kernel size = 3
variance = sigma**2. # 4.0, for sigma = 2

# Calculate the 2-dimensional gaussian kernel which is
#gaussian_kernel = (1./(2.*math.pi*variance)) * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / (2*variance))
#print(gaussian_kernel)

log_kernel = (-1./(math.pi*(sigma**4))) \
                        * (1-(torch.sum((xy_grid - mean)**2., dim=-1) / (2*(sigma**2)))) \
                        * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / (2*(sigma**2)))
log_kernel = log_kernel / torch.sum(log_kernel)
print(log_kernel)


tensor([[-0.0410, -0.0233, -0.0022, -0.0233, -0.0410],
        [-0.0233,  0.0891,  0.1750,  0.0891, -0.0233],
        [-0.0022,  0.1750,  0.3031,  0.1750, -0.0022],
        [-0.0233,  0.0891,  0.1750,  0.0891, -0.0233],
        [-0.0410, -0.0233, -0.0022, -0.0233, -0.0410]])


In [6]:

# kernel_size = 3
# sigma = 2
# channels = 3

# x_coord = torch.arange(kernel_size) #([0, 1, 2])
# x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size) # ([0,1,2], [0,1,2], [0,1,2])
# y_grid = x_grid.t() # ([0,0,0], [1,1,1], [2,2,2])
# xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
# mean = (kernel_size - 1)/2. # 1 for kernel size = 3
# variance = sigma**2. # 4.0, for sigma = 2

# # Calculate the 2-dimensional gaussian kernel which is
# # the product of two gaussian distributions for two different
# # variables (in this case called x and y)
# gaussian_kernel = (1./(2.*math.pi*variance)) * torch.exp(-torch.sum((xy_grid - mean)**2., dim=-1) / (2*variance))

# # Make sure sum of values in gaussian kernel equals 1.
#     # tensor([[0.1019, 0.1154, 0.1019],
#     #         [0.1154, 0.1308, 0.1154],
#     #         [0.1019, 0.1154, 0.1019]])
# gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

# # Reshape to 2d depthwise convolutional weight
#     # tensor([[[[0.1019, 0.1154, 0.1019],
#     #           [0.1154, 0.1308, 0.1154],
#     #           [0.1019, 0.1154, 0.1019]]]])
# gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)

# gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
# padding = 1 if kernel_size==3 else 2 if kernel_size == 5 else 0
# gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,
#                             kernel_size=kernel_size, groups=channels,
#                             bias=False, padding=padding)
# gaussian_filter.weight.data = gaussian_kernel
# gaussian_filter.weight.requires_grad = False 
# print(gaussian_filter)
# print(gaussian_filter.weight.data)

In [1]:
# Architecture: resnet18 with gaussian filter
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

'''Gaussian filter layer in PyTorch.
Reference:
[1] Curriculum by Smoothing. NeurIPS2020
'''

def get_gaussian_filter(kernel_size=3, sigma=2, channels=3):
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
    x_coord = torch.arange(kernel_size)
    x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()
    xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
    mean = (kernel_size - 1)/2.
    variance = sigma**2.

    # Calculate the 2-dimensional gaussian kernel which is
    # the product of two gaussian distributions for two different
    # variables (in this case called x and y)
    gaussian_kernel = (1./(2.*math.pi*variance)) * torch.exp(
                        -torch.sum((xy_grid - mean)**2., dim=-1) / (2*variance))

    # Make sure sum of values in gaussian kernel equals 1.
    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

    # Reshape to 2d depthwise convolutional weight
    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size, kernel_size)
    gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
    padding = 1 if kernel_size==3 else 2 if kernel_size == 5 else 0
    gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=kernel_size, 
                                groups=channels, bias=False, padding=padding)
    gaussian_filter.weight.data = gaussian_kernel
    gaussian_filter.weight.requires_grad = False 
    return gaussian_filter

'''ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
'''
class BasicBlock(nn.Module):
    expansion = 1

    def __init__( self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        
        self.planes = planes        
        
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut_kernel = True
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def get_new_kernels(self, kernel_size, std):
        self.kernel1 = get_gaussian_filter(kernel_size=kernel_size, sigma=std, channels=self.planes)
        self.kernel2 = get_gaussian_filter(kernel_size=kernel_size, sigma=std, channels=self.planes)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.bn1(self.kernel1(out)))
        out = self.conv2(out)
        out = self.bn2(self.kernel2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, args):
        super(ResNet, self).__init__()
        self.in_planes = 64
        self.std = args.std
        self.factor = args.std_factor
        self.epoch = args.epoch
        self.kernel_size = args.kernel_size

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, args.num_classes)

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None: nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = F.relu(self.bn1(self.kernel1(out)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

    def get_new_kernels(self, epoch_count):
        if epoch_count % self.epoch == 0 and epoch_count is not 0: self.std *= self.factor
            
        self.kernel1 = get_gaussian_filter(kernel_size=self.kernel_size, sigma=self.std, channels=64)

        for child in self.layer1.children(): child.get_new_kernels(self.kernel_size, self.std)
        for child in self.layer2.children(): child.get_new_kernels(self.kernel_size, self.std)
        for child in self.layer3.children(): child.get_new_kernels(self.kernel_size, self.std)
        for child in self.layer4.children(): child.get_new_kernels(self.kernel_size, self.std)

def ResNet18(args):
    return ResNet(BasicBlock, [2,2,2,2], args)

def ResNet34(args):
    return ResNet(BasicBlock, [3,4,6,3], args)

def test():
    net = ResNet18()
    y = net(torch.randn(1,3,32,32))
    print(y.size())

In [2]:
# Dateset preparation
import os
import copy

import math
import argparse
from sklearn.metrics import accuracy_score

import torch
import torch.optim as optim
import torch.utils.data as data
import torch.nn.functional as F
from torchvision import transforms, datasets

def seed_everything(seed=27):
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

#torch.backends.cudnn.enabled = False 

parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='datasets/curriculum_learning')
parser.add_argument('--dataset', type=str, default='cifar100')
parser.add_argument('--log_name', type=str, default='cbs_res_def')
parser.add_argument('--alg', type=str, default='res', choices=['normal', 'vgg', 'res', 'wrn'])
parser.add_argument('--log_path', type=str, default='log')
parser.add_argument('--no-cuda', action='store_true')
parser.add_argument('--batch_size', type=int, default=512)
parser.add_argument('--num_epochs', type=int, default=200)
parser.add_argument('--percentage', type=int, default=10)
parser.add_argument('--save_model', action='store_true')
parser.add_argument('--lr', type=float, default=1e-1)

# CBS ARGS
parser.add_argument('--std', default=1, type=float)
parser.add_argument('--std_factor', default=0.9, type=float)
parser.add_argument('--epoch', default=5, type=int)
parser.add_argument('--kernel_size', default=3, type=int)

args = parser.parse_args(args = [])

transform = transforms.Compose([transforms.Scale(32),transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5,), (0.5, 0.5, 0.5,))])    
train_data = datasets.CIFAR100(root=args.data,download=True,train=True,transform=transform)
test_data = datasets.CIFAR100(root=args.data,download=True,train=False,transform=transform)
train_loader = data.DataLoader(train_data, batch_size=args.batch_size,pin_memory=True,num_workers=int(4),shuffle=True,drop_last=True )
test_loader = data.DataLoader(test_data,batch_size=args.batch_size,pin_memory=True,num_workers=int(4),shuffle=False,drop_last=False)
args.num_classes = 100
args.in_dim = 3
#from arguments import get_args

  "please use transforms.Resize instead.")


Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Trainer
import os

os.environ["CUDA_VISIBLE_DEVICES" ]= "1"

seed_everything()
decay_epoch = 30
stop_decay_epoch = decay_epoch * 3 + 1
best_epoch, best_acc = 0, 0
num_iter = 0
model = ResNet18(args)##.cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=5e-4, momentum=0.9)
criterion = F.cross_entropy

def train(model, train_loader, optimizer, criterion):
    model.train()
    for num_iter, (images, labels) in enumerate(train_loader):
        images, labels = images.cuda(), labels.cuda()
        preds = model(images)
        loss = criterion(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # if num_iter % 200 == 0:
        #     print('iter num: {} \t loss: {:.2f}'.format(num_iter, loss.item())) 

def test(model, test_loader):
    model.eval()
    total, correct = 0, 0
    for images, labels in test_loader:
        images = images.cuda()
        with torch.no_grad():
            preds = model(images)
            preds = torch.argmax(preds, dim=1).cpu().numpy()
            correct += accuracy_score(labels, preds, normalize=False)
            total += images.size(0)
    model.train()
    return correct / total * 100

num_epoch = 50 # need to increase to reproduce the paper 
for epoch_count in range(num_epoch):
    model.get_new_kernels(epoch_count)
    model = model.cuda()
    if epoch_count is not 0 and epoch_count % decay_epoch == 0 and epoch_count < stop_decay_epoch:
        for param in optim.param_groups:
            param['lr'] = param['lr'] / 10
        
    train(model, train_loader, optimizer, criterion)
    accuracy = test(model, test_loader)
    if accuracy > best_acc:
        best_acc = accuracy
        best_epoch = epoch_count
        best_model = copy.deepcopy(model)
        torch.save(best_model.state_dict(), 'best_model.pth.tar')

    print('current epoch: {}  current acc: {:.2f}  best epoch: {}  best acc: {:.2f}'.format(
            epoch_count, accuracy, best_epoch, best_acc))

current epoch: 0  current acc: 13.91  best epoch: 0  best acc: 13.91
current epoch: 1  current acc: 17.19  best epoch: 1  best acc: 17.19
current epoch: 2  current acc: 25.45  best epoch: 2  best acc: 25.45
current epoch: 3  current acc: 28.02  best epoch: 3  best acc: 28.02
current epoch: 4  current acc: 31.85  best epoch: 4  best acc: 31.85
current epoch: 5  current acc: 36.36  best epoch: 5  best acc: 36.36
current epoch: 6  current acc: 40.15  best epoch: 6  best acc: 40.15
current epoch: 7  current acc: 36.41  best epoch: 6  best acc: 40.15
current epoch: 8  current acc: 39.68  best epoch: 6  best acc: 40.15
current epoch: 9  current acc: 42.89  best epoch: 9  best acc: 42.89
current epoch: 10  current acc: 42.81  best epoch: 9  best acc: 42.89
current epoch: 11  current acc: 39.69  best epoch: 9  best acc: 42.89
current epoch: 12  current acc: 45.54  best epoch: 12  best acc: 45.54
current epoch: 13  current acc: 45.60  best epoch: 13  best acc: 45.60
current epoch: 14  current a

AttributeError: module 'torch.optim' has no attribute 'param_groups'

In [None]:
# conv2d
import math
import torch
import torch.nn as nn

def get_gaussian_filter_2D(kernel_sizex=3,kernel_sizey=1, sigma=2, channels=3):
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
    kernel_size = max(kernel_sizex, kernel_sizey)
    x_coord = torch.arange(kernel_size)
    x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()
    xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
    mean = (kernel_size - 1)/2.
    variance = sigma**2.

    # Calculate the 2-dimensional gaussian kernel which is
    # the product of two gaussian distributions for two different
    # variables (in this case called x and y)
    xy_grid = xy_grid[:kernel_sizex,:kernel_sizey,:]
    gaussian_kernel = (1./(2.*math.pi*variance)) * torch.exp(
                        -torch.sum((xy_grid - mean)**2., dim=-1) / (2*variance))

    # Make sure sum of values in gaussian kernel equals 1.
    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)

    # Reshape to 2d depthwise convolutional weight
    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_sizex, kernel_sizey)
    gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1, 1)
    padding = 1 if kernel_size==3 else 2 if kernel_size == 5 else 0
    gaussian_filter = nn.Conv2d(in_channels=channels, out_channels=channels,
                                kernel_size=(kernel_sizex,kernel_sizey), groups=channels,
                                bias=False, padding=padding)
    gaussian_filter.weight.data = gaussian_kernel
    gaussian_filter.weight.requires_grad = False 
    return gaussian_filter
    
get_gaussian_filter()

In [1]:
def get_gaussian_filter_1D(kernel_size=3, sigma=2, channels=3):
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
    
    x_coord = torch.arange(kernel_size)
    x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()

    xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
    mean = (kernel_size - 1)/2.
    variance = sigma**2.
    xy_grid = torch.sum((xy_grid[:kernel_size,:kernel_size,:] - mean)**2., dim=-1)

    # Calculate the 1-dimensional gaussian kernel
    gaussian_kernel = (1./((math.sqrt(2.*math.pi)*sigma))) * \
                        torch.exp(-1* (xy_grid[int(kernel_size/2)]) / (2*variance))

    gaussian_kernel = gaussian_kernel / torch.sum(gaussian_kernel)
    gaussian_kernel = gaussian_kernel.view(1, 1, kernel_size)
    gaussian_kernel = gaussian_kernel.repeat(channels, 1, 1)

    padding = 1 if kernel_size==3 else 2 if kernel_size == 5 else 0
    gaussian_filter = nn.Conv1d(in_channels=channels, out_channels=channels,
                                kernel_size=kernel_size, groups=channels,
                                bias=False, padding=padding)
    gaussian_filter.weight.data = gaussian_kernel
    gaussian_filter.weight.requires_grad = False 
    return gaussian_filter

def get_laplaceOfGaussian_filter_1D(kernel_size=3, sigma=2, channels=3):
    # Create a x, y coordinate grid of shape (kernel_size, kernel_size, 2)
    
    x_coord = torch.arange(kernel_size)
    x_grid = x_coord.repeat(kernel_size).view(kernel_size, kernel_size)
    y_grid = x_grid.t()
    
    xy_grid = torch.stack([x_grid, y_grid], dim=-1).float()
    mean = (kernel_size - 1)/2.
    xy_grid = torch.sum((xy_grid[:kernel_size,:kernel_size,:] - mean)**2., dim=-1)
    
    # Calculate the 2-dimensional gaussian kernel
    used_sigma = sigma
    log_kernel = (-1./(math.pi*(used_sigma**4))) \
                    * (1-(torch.sum((xy_grid[int(3/2)] - mean)**2., dim=-1) / (2*(used_sigma**2)))) \
                    * torch.exp(-torch.sum((xy_grid[int(3/2)] - mean)**2., dim=-1) / (2*(used_sigma**2)))
    
    # Make sure sum of values in gaussian kernel equals 1.
    log_kernel = log_kernel / torch.sum(log_kernel)
    log_kernel = log_kernel.view(1, 1, kernel_size)
    log_kernel = log_kernel.repeat(channels, 1, 1)

    padding = 1 if kernel_size==3 else 2 if kernel_size == 5 else 0
    log_filter = nn.Conv1d(in_channels=channels, out_channels=channels,
                                kernel_size=kernel_size, groups=channels,
                                bias=False, padding=padding)
    log_filter.weight.data = log_kernel
    log_filter.weight.requires_grad = False
    return log_filter

In [None]:
train_dataset = SurgicalSceneDataset(seq_set = [2,3,4,6,7,9,10,11,12,14,15], dataconst = data_const, feature_extractor = args.feature_extractor)
val_dataset = SurgicalSceneDataset(seq_set= [1,5,16], dataconst = data_const, feature_extractor = args.feature_extractor)
dataset = {'train': train_dataset, 'val': val_dataset}

In [None]:
import sys
import random

import h5py
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset

import os
from glob import glob
    
class SurgicalSceneDataset(Dataset):
    '''
    '''
    def __init__(self, seq_set, data_dir, img_dir, dset, dataconst, feature_extractor, reduce_size = False):
        
        
        self.data_size = 200
        self.dataconst = dataconst
        self.img_dir = img_dir
        self.feature_extractor = feature_extractor
        self.reduce_size = reduce_size
        
        self.xml_dir_list = []
        self.dset = []
        
        for domain in range(len(seq_set)):
            domain_dir_list = []
            for i in seq_set[domain]:
                xml_dir_temp = data_dir[domain] + str(i) + '/xml/'
                domain_dir_list = domain_dir_list + glob(xml_dir_temp + '/*.xml')
            if self.reduce_size:
                indices = np.random.permutation(len(domain_dir_list))
                domain_dir_list = [domain_dir_list[j] for j in indices[0:self.data_size]]
            for file in domain_dir_list: 
                self.xml_dir_list.append(file)
                self.dset.append(dset[domain])
        self.word2vec = h5py.File('datasets/surgicalscene_word2vec.hdf5', 'r')
    
    # word2vec
    def _get_word2vec(self,node_ids, sgh = 0):
        word2vec = np.empty((0,300))
        for node_id in node_ids:
            if sgh == 1 and node_id == 0:
                vec = self.word2vec['tissue']
            else:
                vec = self.word2vec[self.dataconst.instrument_classes[node_id]]
            word2vec = np.vstack((word2vec, vec))
        return word2vec

    def __len__(self):
        return len(self.xml_dir_list)

    def __getitem__(self, idx):
    
        file_name = os.path.splitext(os.path.basename(self.xml_dir_list[idx]))[0]
        file_root = os.path.dirname(os.path.dirname(self.xml_dir_list[idx]))
        
        _img_loc = os.path.join(file_root+self.img_dir[self.dset[idx]]+ file_name + '.png')
        
        frame_data = h5py.File(os.path.join(file_root+'/vsgat/'+self.feature_extractor+'/'+ file_name + '_features.hdf5'), 'r')    
        data = {}
        data['img_name'] = frame_data['img_name'].value[:] + '.jpg'
        data['img_loc'] = _img_loc
        
        data['node_num'] = frame_data['node_num'].value
        data['roi_labels'] = frame_data['classes'][:]
        data['det_boxes'] = frame_data['boxes'][:]
        
        
        data['edge_labels'] = frame_data['edge_labels'][:]
        data['edge_num'] = data['edge_labels'].shape[0]
        
        data['features'] = frame_data['node_features'][:]
        data['spatial_feat'] = frame_data['spatial_features'][:]
        
        
        data['word2vec'] = self._get_word2vec(data['roi_labels'], self.dset[idx])
        return data

# for DatasetLoader
def collate_fn(batch):
    '''
        Default collate_fn(): https://github.com/pytorch/pytorch/blob/1d53d0756668ce641e4f109200d9c65b003d05fa/torch/utils/data/_utils/collate.py#L43
    '''
    batch_data = {}
    batch_data['img_name'] = []
    batch_data['img_loc'] = []
    batch_data['node_num'] = []
    batch_data['roi_labels'] = []
    batch_data['det_boxes'] = []
    batch_data['edge_labels'] = []
    batch_data['edge_num'] = []
    batch_data['features'] = []
    batch_data['spatial_feat'] = []
    batch_data['word2vec'] = []
    
    for data in batch:
        batch_data['img_name'].append(data['img_name'])
        batch_data['img_loc'].append(data['img_loc'])
        batch_data['node_num'].append(data['node_num'])
        batch_data['roi_labels'].append(data['roi_labels'])
        batch_data['det_boxes'].append(data['det_boxes'])
        batch_data['edge_labels'].append(data['edge_labels'])
        batch_data['edge_num'].append(data['edge_num'])
        batch_data['features'].append(data['features'])
        batch_data['spatial_feat'].append(data['spatial_feat'])
        batch_data['word2vec'].append(data['word2vec'])
        
    batch_data['edge_labels'] = torch.FloatTensor(np.concatenate(batch_data['edge_labels'], axis=0))
    batch_data['features'] = torch.FloatTensor(np.concatenate(batch_data['features'], axis=0))
    batch_data['spatial_feat'] = torch.FloatTensor(np.concatenate(batch_data['spatial_feat'], axis=0))
    batch_data['word2vec'] = torch.FloatTensor(np.concatenate(batch_data['word2vec'], axis=0))
    
    return batch_data

In [1]:
from __future__ import print_function

import os
import copy
import time

import numpy as np
from tqdm import tqdm
from PIL import Image
import utils.io as io
#from utils.vis_tool import vis_img

import torch
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter

def run_model(args, data_const):
    '''
    
    '''

    # use cpu or cuda
    device = torch.device('cuda' if torch.cuda.is_available() and args.gpu else 'cpu')
    print('training on {}...'.format(device))

    # model
    model = AGRNN(bias=args.bias, bn=args.bn, dropout=args.drop_prob, multi_attn=args.multi_attn, layer=args.layers, diff_edge=args.diff_edge, use_cbs = args.use_cbs)
    if args.use_cbs: model.grnn1.gnn.apply_h_h_edge.get_new_kernels(0)
    
    # calculate the amount of all the learned parameters
    parameter_num = 0
    for param in model.parameters(): parameter_num += param.numel()
    print(f'The parameters number of the model is {parameter_num / 1e6} million')

    # load pretrained model
    if args.pretrained:
        print(f"loading pretrained model {args.pretrained}")
        checkpoints = torch.load(args.pretrained, map_location=device)
        model.load_state_dict(checkpoints['state_dict'])
    model.to(device)
    
    #scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.3) #the scheduler divides the lr by 10 every 150 epochs

    # get the configuration of the model and save some key configurations
    io.mkdir_if_not_exists(os.path.join(args.save_dir, args.exp_ver), recursive=True)
    for i in range(args.layers):
        if i==0:
            model_config = model.CONFIG1.save_config()
            model_config['lr'] = args.lr
            model_config['bs'] = args.batch_size
            model_config['layers'] = args.layers
            model_config['multi_attn'] = args.multi_attn
            model_config['data_aug'] = args.data_aug
            model_config['drop_out'] = args.drop_prob
            model_config['optimizer'] = args.optim
            model_config['diff_edge'] = args.diff_edge
            model_config['model_parameters'] = parameter_num
            io.dump_json_object(model_config, os.path.join(args.save_dir, args.exp_ver, 'l1_config.json'))
    print('save key configurations successfully...')

    # domain 1
    train_seq = [[2,3,4,6,7,9,10,11,12,14,15]]
    val_seq = [[1,5,16]]
    data_dir = ['datasets/instruments18/seq_']
    img_dir = ['/left_frames/']
    dset = [0] # 0 for ISC, 1 for SGH
    seq = {'train_seq': train_seq, 'val_seq': val_seq, 'data_dir': data_dir, 'img_dir':img_dir, 'dataset': dset}
    epoch_train(args, model, seq, device)
    
    # domain 2
    train_seq = [[2,3,4,6,7,9,10,11,12,14,15], [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]]
    val_seq = [[1,5,16],[16,17,18,19,20,21,22]]
    data_dir = ['datasets/instruments18/seq_', 'datasets/SGH_dataset_2020/']
    img_dir = ['/left_frames/', '/resized_frames/']
    dset = [0, 1]
    seq = {'train_seq': train_seq, 'val_seq': val_seq, 'data_dir': data_dir, 'img_dir':img_dir, 'dset': dset}
    epoch_train(args, model, seq, device)
    epoch_train(args, model, seq, device, finetune = True)
    

def epoch_train(args, model, seq, device, finetune = False):
    '''
    input: model, dataloader, dataset, criterain, optimizer, scheduler, device, data_const
    data: 
        img_name, node_num, roi_labels, det_boxes, edge_labels,
        edge_num, features, spatial_features, word2vec
    '''
    
    new_domain = False
    stop_epoch = args.epoch
    if finetune:
        stop_epoch = 30
        train_dataset = SurgicalSceneDataset(seq_set = seq['train_seq'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = True)
        val_dataset = SurgicalSceneDataset(seq_set = seq['val_seq'], dset = seq['dset'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = True)
        dataset = {'train': train_dataset, 'val': val_dataset}
        print('finetune dataset set')
    
    # train and test dataset for one domain
    elif (len(seq['train_seq']) == 1):
        # set up dataset variable
        train_dataset = SurgicalSceneDataset(seq_set = seq['train_seq'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = False)
        val_dataset = SurgicalSceneDataset(seq_set = seq['val_seq'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = False)
        dataset = {'train': train_dataset, 'val': val_dataset}
        print('domain 1 dataset')
   
    # train and test for multiple domain
    elif (len(seq['train_seq']) > 1):
        # set up dataset variable
        new_domain = True
        curr_tr_seq = seq['train_seq'][len(seq['train_seq'])-1:]
        curr_tr_data_dir = seq['data_dir'][len(seq['data_dir'])-1:]
        curr_tr_img_dir = seq['img_dir'][len(seq['img_dir'])-1:]
        curr_dset = seq['dset'][len(seq['dset'])-1:]
        train_dataset = SurgicalSceneDataset(seq_set = curr_tr_seq, data_dir = curr_tr_data_dir, \
                            img_dir = curr_tr_img_dir, dset = curr_dset, dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = False)
        val_dataset = SurgicalSceneDataset(seq_set = seq['val_seq'], data_dir = seq['data_dir'], \
                            img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                            feature_extractor = args.feature_extractor, reduce_size = False)
        dataset = {'train': train_dataset, 'val': val_dataset}
        print('domain 2 dataset')
    
    # use default DataLoader() to load the data. 
    train_dataloader = DataLoader(dataset=dataset['train'], batch_size=args.batch_size, shuffle= True, collate_fn=collate_fn)
    val_dataloader = DataLoader(dataset=dataset['val'], batch_size=args.batch_size, shuffle= True, collate_fn=collate_fn)
    dataloader = {'train': train_dataloader, 'val': val_dataloader}
    
    # criterion and scheduler
    criterion = nn.MultiLabelSoftMarginLoss()
    # criterion = nn.BCEWithLogitsLoss()
    
    # set visualization and create folder to save checkpoints
    writer = SummaryWriter(log_dir=args.log_dir + '/' + args.exp_ver + '/' + 'epoch_train')
    io.mkdir_if_not_exists(os.path.join(args.save_dir, args.exp_ver, 'epoch_train'), recursive=True)

    for epoch in range(args.start_epoch, stop_epoch):
        
        # each epoch has a training and validation step
        epoch_acc = 0
        epoch_loss = 0
        
        # finetune
        if finetune:
            train_dataset = SurgicalSceneDataset(seq_set = seq['train_seq'], data_dir = seq['data_dir'], \
                                img_dir = seq['img_dir'], dset = seq['dset'], dataconst = data_const, \
                                feature_extractor = args.feature_extractor, reduce_size = True)
            val_dataset = SurgicalSceneDataset(seq_set = seq['val_seq'], dset = seq['dset'], data_dir = seq['data_dir'], \
                                img_dir = seq['img_dir'], dataconst = data_const, \
                                feature_extractor = args.feature_extractor, reduce_size = True)
            dataset = {'train': train_dataset, 'val': val_dataset}
            
            train_dataloader = DataLoader(dataset=dataset['train'], batch_size=args.batch_size, shuffle= True, collate_fn=collate_fn)
            val_dataloader = DataLoader(dataset=dataset['val'], batch_size=args.batch_size, shuffle= True, collate_fn=collate_fn)
            dataloader = {'train': train_dataloader, 'val': val_dataloader}
        
        # new domain
        if new_domain:
            model_old = copy.deepcopy(model)
            
            # distillation loss activation
            dist_loss_act = nn.Softmax(dim=1)
            dist_loss_act = dist_loss_act.to(device)
            
            dis_seq = seq['train_seq'][:-1]
            dis_data_dir = seq['data_dir'][:-1]
            dis_img_dir = seq['img_dir'][:-1]
            dis_dset = seq['dset'][:-1]
            dis_train_dataset = SurgicalSceneDataset(seq_set =  dis_seq, data_dir = dis_data_dir, \
                                    img_dir = dis_img_dir, dset = dis_dset, dataconst = data_const, \
                                    feature_extractor = args.feature_extractor, reduce_size = False)
            dis_train_dataloader = DataLoader(dataset=dis_train_dataset, batch_size=args.batch_size, shuffle= True, collate_fn=collate_fn)
        
        # build optimizer  
        if finetune: lrc = args.lr / 10
        else: lrc = args.lr
        if args.optim == 'sgd': 
            optimizer = optim.SGD(model.parameters(), lr= lrc, momentum=0.9, weight_decay=0)
        else: 
            optimizer = optim.Adam(model.parameters(), lr= lrc, weight_decay=0)
        
        for phase in ['train', 'val']:
            
            start_time = time.time()
            
            idx = 0
            running_acc = 0.0
            running_loss = 0.0
            running_edge_count = 0
            
            if phase == 'train' and args.use_cbs:
                model.grnn1.gnn.apply_h_h_edge.get_new_kernels(epoch)
                model.to(device)
            
            #for data in tqdm(dataloader[phase]):
            for data in dataloader[phase]:
                train_data = data
                img_name = train_data['img_name']
                img_loc = train_data['img_loc']
                node_num = train_data['node_num']
                roi_labels = train_data['roi_labels']
                det_boxes = train_data['det_boxes']
                edge_labels = train_data['edge_labels']
                edge_num = train_data['edge_num']
                features = train_data['features']
                spatial_feat = train_data['spatial_feat']
                word2vec = train_data['word2vec']
                features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
                
                if phase == 'train':
                    model.train()
                    model.zero_grad()
                    outputs = model(node_num, features, spatial_feat, word2vec, roi_labels)
                    
                    # loss and accuracy
                    loss = criterion(outputs, edge_labels.float())
                    loss.backward()
                    optimizer.step()
                    acc = np.sum(np.equal(np.argmax(outputs.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))

                else:
                    model.eval()
                    # turn off the gradients for validation, save memory and computations
                    with torch.no_grad():
                        outputs = model(node_num, features, spatial_feat, word2vec, roi_labels, validation=True)
                        
                        # loss and accuracy
                        loss = criterion(outputs, edge_labels.float())
                        acc = np.sum(np.equal(np.argmax(outputs.cpu().data.numpy(), axis=-1), np.argmax(edge_labels.cpu().data.numpy(), axis=-1)))
                    
                        # print result every 1000 iteration during validation
                        if idx == 10:
                            #print(img_loc[0])
                            io.mkdir_if_not_exists(os.path.join(args.output_img_dir, ('epoch_'+str(epoch))), recursive=True)
                            image = Image.open(img_loc[0]).convert('RGB')
                            det_actions = nn.Sigmoid()(outputs[0:int(edge_num[0])])
                            det_actions = det_actions.cpu().detach().numpy()
                            action_img = vis_img(image, roi_labels[0], det_boxes[0],  det_actions, score_thresh = 0.7)
                            image = image.save(os.path.join(args.output_img_dir, ('epoch_'+str(epoch)),img_name[0]))

                idx+=1
                # accumulate loss of each batch
                running_loss += loss.item() * edge_labels.shape[0]
                running_acc += acc
                running_edge_count += edge_labels.shape[0]
            
            # distillation learning
            if phase == 'train' and new_domain:
                
                #for data in tqdm(dataloader[phase]):
                for data in dis_train_dataloader:
                    train_data = data
                    img_name = train_data['img_name']
                    img_loc = train_data['img_loc']
                    node_num = train_data['node_num']
                    roi_labels = train_data['roi_labels']
                    det_boxes = train_data['det_boxes']
                    edge_labels = train_data['edge_labels']
                    edge_num = train_data['edge_num']
                    features = train_data['features']
                    spatial_feat = train_data['spatial_feat']
                    word2vec = train_data['word2vec']
                    features, spatial_feat, word2vec, edge_labels = features.to(device), spatial_feat.to(device), word2vec.to(device), edge_labels.to(device)    
                    
                    model.train()
                    model.zero_grad()
                    outputs = model(node_num, features, spatial_feat, word2vec, roi_labels)
                    
                    with torch.no_grad():
                        # old network output
                        output_old = model_old(data)
                        output_old = Variable(output_old)
                    
                    loss = 0.5* F.binary_cross_entropy(dist_loss_act(outputs), dist_loss_act(output_old))
                    # loss and accuracy
                    loss.backward()
                    optimizer.step()
            
            # calculate the loss and accuracy of each epoch
            epoch_loss = running_loss / len(dataset[phase])
            epoch_acc = running_acc / running_edge_count
            
            # import ipdb; ipdb.set_trace()
            # log trainval datas, and visualize them in the same graph
            if phase == 'train':
                train_loss = epoch_loss 
            else:
                writer.add_scalars('trainval_loss_epoch', {'train': train_loss, 'val': epoch_loss}, epoch)
            
            # print data
            if (epoch % args.print_every) == 0:
                end_time = time.time()
                print("[{}] Epoch: {}/{} Acc: {:0.6f} Loss: {:0.6f} Execution time: {:0.6f}".format(\
                        phase, epoch+1, args.epoch, epoch_acc, epoch_loss, (end_time-start_time)))
                        
        # scheduler.step()
        # save model
        if epoch_loss<0.0405 or epoch % args.save_every == (args.save_every - 1) and epoch >= (50-1):
            checkpoint = { 
                            'lr': args.lr,
                           'b_s': args.batch_size,
                          'bias': args.bias, 
                            'bn': args.bn, 
                       'dropout': args.drop_prob,
                        'layers': args.layers,
                    'multi_head': args.multi_attn,
                     'diff_edge': args.diff_edge,
                    'state_dict': model.state_dict()
            }
            save_name = "checkpoint_" + str(epoch+1) + '_epoch.pth'
            torch.save(checkpoint, os.path.join(args.save_dir, args.exp_ver, 'epoch_train', save_name))

    writer.close()
    print('Finishing training!')


SyntaxError: invalid syntax (<ipython-input-1-06fe35434ea8>, line 63)

In [None]:
def seed_everything(seed=27):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
if __name__ == "__main__":
    
    seed_everything()
    args = arguments()
    print(args.feature_extractor)
    data_const = SurgicalSceneConstants()
    run_model(args, data_const)

In [9]:
a = [[2,3,4,6,7,9,10,11,12,14,15], [30,40,35]]
a[:-1]

[[2, 3, 4, 6, 7, 9, 10, 11, 12, 14, 15]]

In [10]:
a = [1,2,3]
a.append(4)

In [11]:
a

[1, 2, 3, 4]