In [2]:
# Imports
from __future__ import print_function

import os
import time

import dgl
import networkx as nx
import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from sklearn.model_selection import train_test_split

import ipdb
import h5py
import pickle
import argparse
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import random
import datetime

import utils.io as io
from model.cnn_model import HOCNN
from datasets import metadata
from utils.vis_tool import vis_img
from datasets.hico_constants import HicoConstants
from datasets.hico_dataset import HicoDataset, collate_fn

import json
import cv2
import numpy as np
from matplotlib import pyplot as plt

%matplotlib inline

In [3]:
# Set random seed for reproducibility
torch.manual_seed(21)
np.random.seed(21)

In [4]:
# Define data paths
TRAIN_IMG_PATH = "datasets/hico/images/train2015/"
TEST_IMG_PATH = "datasets/hico/images/test2015/"

In [5]:
# Setup training device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('training on {}'.format(device))

training on cuda


In [6]:
# Define arguments
batch_size = 32
epochs = 80
initial_lr = 0.00001

feat_type = 'fc7'
data_aug = False
exp_ver = 'v1_' + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
save_dir = './checkpoints/hico'
log_dir = './log/hico'
save_every = 5
print_batch_every = 100
print_epoch_every = 1

# set the cache size [0 means infinite]
max_img_cache_size = 5000

print('Running experiment ' + exp_ver)

Running experiment v1_2020-11-03_17-35


In [7]:
# Define dataloaders
data_const = HicoConstants(feat_type=feat_type)

train_dataset = HicoDataset(data_const=data_const, subset='train', data_aug=data_aug)
val_dataset = HicoDataset(data_const=data_const, subset='val', data_aug=False, test=True)
dataset = {'train': train_dataset, 'val': val_dataset}
print('set up dataset variable successfully')

train_dataloader = DataLoader(dataset=dataset['train'], batch_size=batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
val_dataloader = DataLoader(dataset=dataset['val'], batch_size=batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
dataloader = {'train': train_dataloader, 'val': val_dataloader}
print('set up dataloader successfully')

Using fc7 feature...
Using fc7 feature...
set up dataset variable successfully
set up dataloader successfully


In [8]:
# Define model
model = HOCNN().to(device)

In [9]:
# Display parameter information
parameter_num = 0
for param in model.parameters():
    parameter_num += param.numel()
print(f'The number of parameters in this model is {parameter_num / 1e6} million')

The number of parameters in this model is 49.470376 million


In [10]:
# Define loss function
criterion = nn.CrossEntropyLoss()

In [11]:
# Define optimizer
optimizer = optim.Adam(model.parameters(), lr=initial_lr, weight_decay=0)
#optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0)

In [12]:
# Setup visualization
writer = SummaryWriter(log_dir=log_dir + '/' + exp_ver + '/' + 'epoch_train')
io.mkdir_if_not_exists(os.path.join(save_dir, exp_ver, 'epoch_train'), recursive=True)

In [None]:
# Training loop
with open('datasets/processed/hico/anno_list.json') as f:
    anno_list = json.load(f)
    
with open('datasets/processed/hico/hoi_list.json') as f:
    hoi_list = json.load(f)
    
img_cache = {} # format {key: [human, object, pairwise]}
img_cache_counter = 0
    
print('Training has started!')
    
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    for phase in ['train', 'val']:
        start_time = time.time()
        running_loss = 0.0
        running_correct = 0
        idx = 0
        
        for data in dataloader[phase]:
            train_data = data
            img_name = train_data['img_name']
            
            labels = np.zeros((batch_size, 600))
            batch_correct = 0
            for i in range(batch_size):
                # Get image data
                parsed_img_name = img_name[i].split(".")[0]
                img = [x for x in anno_list if x['global_id'] == parsed_img_name][0]
                img = img['hois'][0]
                img_id = int(img['id']) - 1
                labels[i][img_id] = 1
                human_bboxes = img['human_bboxes']
                object_bboxes = img['object_bboxes']

                # Apply masks to images [with caching]
                src_img_path = TRAIN_IMG_PATH + parsed_img_name + '.jpg'
                if src_img_path in img_cache: # Use cache if available
                    human_bbox_img, obj_bbox_img, pairwise_bbox_img = img_cache[src_img_path]
                else:
                    src = cv2.imread(src_img_path)
                    human_mask = np.zeros_like(src)
                    for bbox in human_bboxes:
                        cv2.rectangle(human_mask, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 255, 255), thickness=-1)
                    human_bbox_img = cv2.bitwise_and(src, human_mask, mask=None)

                    obj_mask = np.zeros_like(src)
                    pairwise_mask = human_mask
                    for bbox in object_bboxes:
                        cv2.rectangle(obj_mask, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 255, 255), thickness=-1)
                        cv2.rectangle(pairwise_mask, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 255, 255), thickness=-1)
                    obj_bbox_img = cv2.bitwise_and(src, obj_mask, mask=None)
                    pairwise_bbox_img = cv2.bitwise_and(src, pairwise_mask, mask=None)
                    
                    # Add to cache if within limits
                    if not max_img_cache_size or img_cache_counter < max_img_cache_size:
                        img_cache[src_img_path] = [human_bbox_img, obj_bbox_img, pairwise_bbox_img]
                        img_cache_counter += 1

                '''
                # Visualization of masks
                f, axarr = plt.subplots(1,3)

                human_bbox_rgb = cv2.cvtColor(human_bbox_img, cv2.COLOR_BGR2RGB)
                axarr[0].imshow(human_bbox_rgb)

                object_mask_rgb = cv2.cvtColor(obj_bbox_img, cv2.COLOR_BGR2RGB)
                axarr[1].imshow(object_mask_rgb)

                pairwise_rgb = cv2.cvtColor(pairwise_bbox_img, cv2.COLOR_BGR2RGB)
                axarr[2].imshow(pairwise_rgb)

                plt.show()
                f.clf()
                '''

                human_bbox_img = cv2.resize(human_bbox_img, (64, 64), interpolation=cv2.INTER_AREA)
                obj_bbox_img = cv2.resize(obj_bbox_img, (64, 64), interpolation=cv2.INTER_AREA)
                pairwise_bbox_img = cv2.resize(pairwise_bbox_img, (64, 64), interpolation=cv2.INTER_AREA)

                human_bbox_img = torch.from_numpy(human_bbox_img).to(device)
                obj_bbox_img = torch.from_numpy(obj_bbox_img).to(device)
                pairwise_bbox_img = torch.from_numpy(pairwise_bbox_img).to(device)

                if i == 0:
                    res_human_input = human_bbox_img.unsqueeze(0)
                    res_obj_input = obj_bbox_img.unsqueeze(0)
                    res_pairwise_input = pairwise_bbox_img.unsqueeze(0)
                else:
                    res_human_input = torch.cat((res_human_input, human_bbox_img.unsqueeze(0)), dim=0)
                    res_obj_input = torch.cat((res_obj_input, obj_bbox_img.unsqueeze(0)), dim=0)
                    res_pairwise_input = torch.cat((res_pairwise_input, pairwise_bbox_img.unsqueeze(0)), dim=0)

            res_human_input = res_human_input.permute([0,3,1,2]).float().to(device)
            res_obj_input = res_obj_input.permute([0,3,1,2]).float().to(device)
            res_pairwise_input = res_pairwise_input.permute([0,3,1,2]).float().to(device)
            labels = torch.from_numpy(labels).long().to(device)
            
            if phase == 'train':
                # Initial train loop
                model.train()
                model.zero_grad()
                
                # Forward pass: human, objects, pairwise streams
                outputs = model.forward(res_human_input, res_obj_input, res_pairwise_input)
                loss = criterion(outputs, torch.max(labels, 1)[1])
                loss.backward()
                optimizer.step()
                
                preds = torch.argmax(outputs, dim=1)
                ground_labels = torch.max(labels, 1)[1]
                for accuracy_iterator in range(len(ground_labels)):
                    if preds[accuracy_iterator] == ground_labels[accuracy_iterator]:
                        batch_correct += 1
                
            else:
                # Evaluation after train loop
                model.eval()
                with torch.no_grad(): # Disable gradients for validation
                    outputs = model.forward(res_human_input, res_obj_input, res_pairwise_input)
                    loss = criterion(outputs, torch.max(labels, 1)[1])
                    
                    preds = torch.argmax(outputs, dim=1)
                    ground_labels = torch.max(labels, 1)[1]
                    for accuracy_iterator in range(len(ground_labels)):
                        if preds[accuracy_iterator] == ground_labels[accuracy_iterator]:
                            batch_correct += 1
                    
            # Accumulate loss of each batch (average * batch size)
            running_loss += loss.item() * batch_size
            running_correct += batch_correct
            
            # Print out status per print_batch_every
            idx += 1
            if (idx % print_batch_every) == 0:
                print("[{}] Epoch: {}/{} Batch: {}/{} Loss: {} Accuracy: {}".format(\
                        phase, epoch+1, epochs, idx, len(dataloader[phase]), \
                        loss.item(), 100 * batch_correct / batch_size))
            
        # Epoch loss and accuracy
        epoch_loss = running_loss / len(dataset[phase])
        epoch_accuracy = 100 * running_correct / len(dataset[phase])
        
        # Log trainval data for visualization
        if phase == 'train':
            train_loss = epoch_loss 
            train_accuracy = epoch_accuracy
        else:
            writer.add_scalars('trainval_loss_epoch', {'train': train_loss, 'val': epoch_loss}, epoch)
            writer.add_scalars('trainval_accuracy_epoch', {'train': train_accuracy, 'val': epoch_accuracy}, epoch)
            
        # Output data per print_epoch_every
        if (epoch % print_epoch_every) == 0:
            end_time = time.time()
            print("[{}] Epoch: {}/{} Loss: {} Accuracy: {} Execution time: {}".format(\
                    phase, epoch+1, epochs, epoch_loss, epoch_accuracy, (end_time-start_time)))
    
    # Save the model per save_every
    if epoch_loss<0.0405 or epoch % save_every == (save_every - 1) and epoch >= (10-1):
        checkpoint = { 
                        'lr': initial_lr,
                       'b_s': batch_size,
                 'feat_type': feat_type,
                'state_dict': model.state_dict()
        }
        save_name = "checkpoint_" + str(epoch+1) + '_epoch.pth'
        torch.save(checkpoint, os.path.join(save_dir, exp_ver, 'epoch_train', save_name))
        
print('Finishing training!')

Training has started!


  return F.log_softmax(summed_results)


[train] Epoch: 1/80 Batch: 100/952 Loss: 6.461737632751465 Accuracy: 6.25
[train] Epoch: 1/80 Batch: 200/952 Loss: 5.453140735626221 Accuracy: 3.125
[train] Epoch: 1/80 Batch: 300/952 Loss: 4.965389728546143 Accuracy: 3.125
[train] Epoch: 1/80 Batch: 400/952 Loss: 5.457126617431641 Accuracy: 6.25
[train] Epoch: 1/80 Batch: 500/952 Loss: 5.385848522186279 Accuracy: 0.0
[train] Epoch: 1/80 Batch: 600/952 Loss: 5.058994293212891 Accuracy: 0.0
[train] Epoch: 1/80 Batch: 700/952 Loss: 4.646044731140137 Accuracy: 12.5
[train] Epoch: 1/80 Batch: 800/952 Loss: 4.414399147033691 Accuracy: 6.25
[train] Epoch: 1/80 Batch: 900/952 Loss: 4.728868007659912 Accuracy: 15.625
[train] Epoch: 1/80 Loss: 5.7592109216082426 Accuracy: 5.361534677816035 Execution time: 1173.434383392334
[val] Epoch: 1/80 Batch: 100/238 Loss: 5.014231204986572 Accuracy: 6.25
[val] Epoch: 1/80 Batch: 200/238 Loss: 4.594542026519775 Accuracy: 12.5
[val] Epoch: 1/80 Loss: 4.8164817823069495 Accuracy: 8.985963531418077 Execution 

[train] Epoch: 9/80 Batch: 100/952 Loss: 3.124918222427368 Accuracy: 18.75
[train] Epoch: 9/80 Batch: 200/952 Loss: 3.582306385040283 Accuracy: 25.0
[train] Epoch: 9/80 Batch: 300/952 Loss: 3.370269536972046 Accuracy: 25.0
[train] Epoch: 9/80 Batch: 400/952 Loss: 3.2232918739318848 Accuracy: 28.125
[train] Epoch: 9/80 Batch: 500/952 Loss: 2.864689826965332 Accuracy: 34.375
[train] Epoch: 9/80 Batch: 600/952 Loss: 3.630892038345337 Accuracy: 28.125
[train] Epoch: 9/80 Batch: 700/952 Loss: 3.6716930866241455 Accuracy: 18.75
[train] Epoch: 9/80 Batch: 800/952 Loss: 5.0678839683532715 Accuracy: 6.25
[train] Epoch: 9/80 Batch: 900/952 Loss: 3.3424129486083984 Accuracy: 28.125
[train] Epoch: 9/80 Loss: 3.563912565072143 Accuracy: 23.21364158058698 Execution time: 748.080039024353
[val] Epoch: 9/80 Batch: 100/238 Loss: 3.8234221935272217 Accuracy: 21.875
[val] Epoch: 9/80 Batch: 200/238 Loss: 3.478285789489746 Accuracy: 28.125
[val] Epoch: 9/80 Loss: 4.157655231786591 Accuracy: 16.97494424767

[val] Epoch: 16/80 Batch: 200/238 Loss: 5.720288276672363 Accuracy: 12.5
[val] Epoch: 16/80 Loss: 4.522883385572317 Accuracy: 16.096025186934277 Execution time: 192.76015639305115
[train] Epoch: 17/80 Batch: 100/952 Loss: 2.54832124710083 Accuracy: 31.25
[train] Epoch: 17/80 Batch: 200/952 Loss: 2.7777369022369385 Accuracy: 37.5
[train] Epoch: 17/80 Batch: 300/952 Loss: 2.6741151809692383 Accuracy: 43.75
[train] Epoch: 17/80 Batch: 400/952 Loss: 2.316897392272949 Accuracy: 53.125
[train] Epoch: 17/80 Batch: 500/952 Loss: 2.1797444820404053 Accuracy: 50.0
[train] Epoch: 17/80 Batch: 600/952 Loss: 2.9349756240844727 Accuracy: 31.25
[train] Epoch: 17/80 Batch: 700/952 Loss: 3.059478759765625 Accuracy: 28.125
[train] Epoch: 17/80 Batch: 800/952 Loss: 2.764176607131958 Accuracy: 34.375
[train] Epoch: 17/80 Batch: 900/952 Loss: 2.2660984992980957 Accuracy: 46.875
[train] Epoch: 17/80 Loss: 2.5079845950728186 Accuracy: 40.370552549598294 Execution time: 757.4389278888702
[val] Epoch: 17/80 Ba

[train] Epoch: 24/80 Loss: 1.60745151467862 Accuracy: 59.350713231677325 Execution time: 751.4768993854523
[val] Epoch: 24/80 Batch: 100/238 Loss: 6.2420854568481445 Accuracy: 15.625
[val] Epoch: 24/80 Batch: 200/238 Loss: 4.485657691955566 Accuracy: 15.625
[val] Epoch: 24/80 Loss: 5.625930802146234 Accuracy: 14.088941361668635 Execution time: 192.47737288475037
[train] Epoch: 25/80 Batch: 100/952 Loss: 1.4848395586013794 Accuracy: 53.125
[train] Epoch: 25/80 Batch: 200/952 Loss: 0.7980847358703613 Accuracy: 87.5
[train] Epoch: 25/80 Batch: 300/952 Loss: 1.469244122505188 Accuracy: 65.625
[train] Epoch: 25/80 Batch: 400/952 Loss: 1.5138001441955566 Accuracy: 71.875
[train] Epoch: 25/80 Batch: 500/952 Loss: 1.4664453268051147 Accuracy: 62.5
[train] Epoch: 25/80 Batch: 600/952 Loss: 1.245239019393921 Accuracy: 68.75
[train] Epoch: 25/80 Batch: 700/952 Loss: 1.3772474527359009 Accuracy: 68.75
[train] Epoch: 25/80 Batch: 800/952 Loss: 1.3151147365570068 Accuracy: 56.25
[train] Epoch: 25/80

In [None]:
# Close visualization
writer.close()