In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.transforms import Compose
import math
import copy
import random
import csv

import sys
sys.path.append('./backbones/asrf')
from libs import models
from libs.optimizer import get_optimizer
from libs.dataset import ActionSegmentationDataset, collate_fn
from libs.transformer import TempDownSamp, ToTensor

sys.path.append('./backbones/ms-tcn')
from model import MultiStageModel

sys.path.append('./backbones')
sys.path.append('./backbones/SSTDA')
from SSTDA.model import MultiStageModel as MSM_SSTDA

from src.utils import eval_txts, load_meta
from src.predict import predict_refiner
from src.refiner_train import refiner_train
from src.refiner_model import RefinerModel
from src.mgru import mGRU

import configs.refiner_config as cfg
import configs.sstda_config as sstda_cfg

In [2]:
np.random.seed(0)
random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic=True

In [3]:
device = 'cuda'

In [4]:
dataset = '50salads'     # choose from gtea, 50salads, breakfast
split = 3            # gtea : 1~4, 50salads : 1~5, breakfast : 1~4
pool_backbone_name = ['mstcn'] # 'asrf', 'mstcn', 'sstda', 'mgru'
main_backbone_name = 'mstcn'
model_name = 'refiner'+main_backbone_name.upper()+'-'+'-'.join(pool_backbone_name) 

In [5]:
actions_dict, \
num_actions, \
gt_path, \
features_path, \
vid_list_file, \
vid_list_file_tst, \
sample_rate,\
model_dir,\
result_dir, \
record_dir = load_meta(cfg.dataset_root, cfg.model_root, cfg.result_root, cfg.record_root, 
                       dataset, split, model_name)

Created :./model/refinerASRF-sstda-mstcn-mgru/50salads/split_3
Created :./result/refinerASRF-sstda-mstcn-mgru/50salads/split_3
Created :./record/refinerASRF-sstda-mstcn-mgru/50salads


In [6]:
train_data = ActionSegmentationDataset(
        dataset,
        transform=Compose([ToTensor(), TempDownSamp(sample_rate)]),
        mode="trainval",
        split=split,
        dataset_dir=cfg.dataset_root,
        csv_dir=cfg.csv_dir,
    )
train_loader = DataLoader(
        train_data,
        batch_size=cfg.batch_size,
        shuffle=True,
        drop_last=True if cfg.batch_size > 1 else False,
        collate_fn=collate_fn,
        pin_memory=True
    )

In [7]:
curr_split_dir = os.path.join(cfg.dataset_root, dataset, 'splits')
split_dict = {k+1:[] for k in range(cfg.num_splits[dataset])}
for i in range(eval('cfg.num_splits["{}"]'.format(dataset))):
    curr_fp = os.path.join(curr_split_dir, 'test.split{}.bundle'.format(i+1))
    f = open(curr_fp, 'r')
    lines = f.readlines()
    for l in lines:
        curr_name = l.split('.')[0]
        split_dict[i+1].append(curr_name)
    f.close()
print(split_dict)

{1: ['rgb-01-1', 'rgb-01-2', 'rgb-02-1', 'rgb-02-2', 'rgb-03-1', 'rgb-03-2', 'rgb-04-1', 'rgb-04-2', 'rgb-05-1', 'rgb-05-2'], 2: ['rgb-06-1', 'rgb-06-2', 'rgb-07-1', 'rgb-07-2', 'rgb-09-1', 'rgb-09-2', 'rgb-10-1', 'rgb-10-2', 'rgb-11-1', 'rgb-11-2'], 3: ['rgb-13-1', 'rgb-13-2', 'rgb-14-1', 'rgb-14-2', 'rgb-15-1', 'rgb-15-2', 'rgb-16-1', 'rgb-16-2', 'rgb-17-1', 'rgb-17-2'], 4: ['rgb-18-1', 'rgb-18-2', 'rgb-19-1', 'rgb-19-2', 'rgb-20-1', 'rgb-20-2', 'rgb-21-1', 'rgb-21-2', 'rgb-22-1', 'rgb-22-2'], 5: ['rgb-23-1', 'rgb-23-2', 'rgb-24-1', 'rgb-24-2', 'rgb-25-1', 'rgb-25-2', 'rgb-26-1', 'rgb-26-2', 'rgb-27-1', 'rgb-27-2']}


In [8]:
pool_backbones = {bn: {k+1:None for k in range(cfg.num_splits[dataset])} for bn in cfg.backbone_names}

for i in range(eval('cfg.num_splits["{}"]'.format(dataset))):
    if 'asrf' in cfg.backbone_names:
        curr_asrf = models.ActionSegmentRefinementFramework(
                        in_channel = cfg.in_channel,
                        n_features = cfg.n_features,
                        n_classes = num_actions,
                        n_stages = cfg.n_stages,
                        n_layers = cfg.n_layers,
                        n_stages_asb = cfg.n_stages_asb,
                        n_stages_brb = cfg.n_stages_brb
        )
        curr_asrf.load_state_dict(torch.load(os.path.join(cfg.model_root, 'asrf', dataset, 
                                                          'split_{}'.format(i+1), 
                                                          'epoch-{}.model'.format(cfg.best['asrf'][dataset][i]))))
        curr_asrf.to(device)
        pool_backbones['asrf'][i+1] = curr_asrf
        
    if 'mstcn' in cfg.backbone_names:
        curr_mstcn = MultiStageModel(cfg.num_stages,
                                     num_layers = cfg.num_layers,
                                     num_f_maps = cfg.num_f_maps,
                                     dim = cfg.features_dim,
                                     num_classes = num_actions)
        curr_mstcn.load_state_dict(torch.load(os.path.join(cfg.model_root, 'mstcn', dataset,
                                                          'split_{}'.format(i+1),
                                                          'epoch-{}.model'.format(cfg.best['mstcn'][dataset][i]))))
        curr_mstcn.to(device)
        pool_backbones['mstcn'][i+1] = curr_mstcn
        
    if 'sstda' in cfg.backbone_names:
        curr_sstda = MSM_SSTDA(sstda_cfg, num_actions)
        curr_sstda.load_state_dict(torch.load(os.path.join(cfg.model_root, 'sstda', dataset,
                                                          'split_{}'.format(i+1),
                                                          'epoch-{}.model'.format(cfg.best['sstda'][dataset][i]))))
        curr_sstda.to(device)
        pool_backbones['sstda'][i+1] = curr_sstda
        
    if 'mgru' in cfg.backbone_names:
        curr_mgru = mGRU(num_layers=cfg.gru_layers, 
                         feat_dim=cfg.gru_hidden_dim, 
                         inp_dim=cfg.in_channel, 
                         out_dim=num_actions)
        curr_mgru.load_state_dict(torch.load(os.path.join(cfg.model_root, 'mgru', dataset,
                                                          'split_{}'.format(i+1),
                                                          'epoch-{}.model'.format(cfg.best['mgru'][dataset][i]))))
        curr_mgru.to(device)
        pool_backbones['mgru'][i+1] = curr_mgru
        
main_backbones = copy.deepcopy(pool_backbones[main_backbone_name])

In [9]:
model = RefinerModel(num_actions = num_actions,
                    input_dim = cfg.features_dim,
                    feat_dim = cfg.hidden_dim,
                    num_highlevel_frames = cfg.num_highlevel_frames,
                    num_highlevel_samples = cfg.num_highlevel_samples,
                    device = device)
model.to(device)

RefinerModel(
  (key_embedding): Linear(in_features=2048, out_features=512, bias=False)
  (value_embedding): Linear(in_features=2048, out_features=512, bias=False)
  (query_embedding): Embedding(19, 512)
  (label_embedding): Embedding(19, 512)
  (video_embedding): SparseSampleEmbedder(
    (init_conv): Conv1d(1024, 512, kernel_size=(3,), stride=(1,), padding=(1,))
    (blocks): ModuleList(
      (0): ResidualBlock(
        (block): Sequential(
          (0): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
          (1): LeakyReLU(negative_slope=0.1)
          (2): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
        )
      )
      (1): ResidualBlock(
        (block): Sequential(
          (0): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
          (1): LeakyReLU(negative_slope=0.1)
          (2): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
        )
      )
      (2): ResidualBlock(
        (block): Sequential(
        

In [10]:
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

In [11]:
for epoch in range(cfg.max_epoch):
    train_loss = refiner_train(cfg, dataset, train_loader, model, pool_backbones, pool_backbone_name, optimizer, epoch, split_dict, device)
    torch.save(model.state_dict(), os.path.join(model_dir, "epoch-"+str(epoch+1)+".model"))
    print("epoch: {}\tlr: {:.5f}\ttrain loss: {:.4f}".format(epoch+1, optimizer.param_groups[0]["lr"], train_loss))

	nonzero()
Consider using one of the following signatures instead:
	nonzero(*, bool as_tuple) (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370156314/work/torch/csrc/utils/python_arg_parser.cpp:882.)
  label_d = label_d_onehot.nonzero()  # e.g. [0, 0, 1, 0, 0, 0] --> [[2]]


epoch: 1	lr: 0.00010	train loss: 2.2317
epoch: 2	lr: 0.00010	train loss: 1.4934
epoch: 3	lr: 0.00010	train loss: 1.3393
epoch: 4	lr: 0.00010	train loss: 1.3462
epoch: 5	lr: 0.00010	train loss: 1.1368
epoch: 6	lr: 0.00010	train loss: 1.1436
epoch: 7	lr: 0.00010	train loss: 1.1488
epoch: 8	lr: 0.00010	train loss: 1.1867
epoch: 9	lr: 0.00010	train loss: 1.0095
epoch: 10	lr: 0.00010	train loss: 0.9046
epoch: 11	lr: 0.00010	train loss: 1.0052
epoch: 12	lr: 0.00010	train loss: 0.9278
epoch: 13	lr: 0.00010	train loss: 0.8782
epoch: 14	lr: 0.00010	train loss: 0.8771
epoch: 15	lr: 0.00010	train loss: 0.9140
epoch: 16	lr: 0.00010	train loss: 0.8384
epoch: 17	lr: 0.00010	train loss: 0.9007
epoch: 18	lr: 0.00010	train loss: 0.7719
epoch: 19	lr: 0.00010	train loss: 0.8814
epoch: 20	lr: 0.00010	train loss: 0.8784
epoch: 21	lr: 0.00010	train loss: 0.7403
epoch: 22	lr: 0.00010	train loss: 0.7050
epoch: 23	lr: 0.00010	train loss: 0.6852
epoch: 24	lr: 0.00010	train loss: 0.6705
epoch: 25	lr: 0.00010	tra

In [12]:
max_epoch = -1
max_val = 0.0
max_results = dict()

f = open(os.path.join(record_dir, 'split_{}_all.csv'.format(split)), 'w')

writer = csv.writer(f, delimiter='\t')
writer.writerow(['epoch', 'accu', 'edit', 
                 'F1@{}'.format(cfg.iou_thresholds[0]),
                 'F1@{}'.format(cfg.iou_thresholds[1]), 
                 'F1@{}'.format(cfg.iou_thresholds[2])])

for epoch in range(1, cfg.max_epoch+1):
    print('======================EPOCH {}====================='.format(epoch))
    predict_refiner(model, main_backbone_name, main_backbones, 
                    split_dict, model_dir, result_dir, 
                    features_path, vid_list_file_tst,
                    epoch, actions_dict, device, sample_rate)    
    results = eval_txts(cfg.dataset_root, result_dir, dataset, split, model_name)
    
    writer.writerow([epoch, '%.4f'%(results['accu']), '%.4f'%(results['edit']),
                    '%.4f'%(results['F1@%0.2f'%(cfg.iou_thresholds[0])]),
                    '%.4f'%(results['F1@%0.2f'%(cfg.iou_thresholds[1])]),
                    '%.4f'%(results['F1@%0.2f'%(cfg.iou_thresholds[2])])])

    curr_val = sum([results[k] for k in results.keys()])
    max_val = max(max_val, curr_val)

    if curr_val == max_val:
        max_epoch = epoch
        max_results = results

print('EARNED MAXIMUM PERFORMANCE IN EPOCH {}'.format(max_epoch))
print(max_results)

f.close()

Acc: 75.3346
Edit: 71.2017
F1@0.10: 76.8000
F1@0.25: 75.2000
F1@0.50: 63.4667
Acc: 75.0800
Edit: 72.6440
F1@0.10: 77.5401
F1@0.25: 74.8663
F1@0.50: 64.7059
Acc: 74.6324
Edit: 70.5718
F1@0.10: 76.3441
F1@0.25: 74.1935
F1@0.50: 62.9032
Acc: 79.4462
Edit: 74.4747
F1@0.10: 80.6283
F1@0.25: 79.0576
F1@0.50: 70.1571
Acc: 80.2548
Edit: 75.5770
F1@0.10: 82.2281
F1@0.25: 80.6366
F1@0.50: 71.6180
Acc: 76.3139
Edit: 73.3886
F1@0.10: 78.4000
F1@0.25: 76.8000
F1@0.50: 66.6667
Acc: 80.1657
Edit: 76.5148
F1@0.10: 81.1370
F1@0.25: 79.5866
F1@0.50: 73.3850
Acc: 81.8077
Edit: 78.6762
F1@0.10: 83.2041
F1@0.25: 82.6873
F1@0.50: 75.4522
Acc: 76.5497
Edit: 73.2721
F1@0.10: 78.1915
F1@0.25: 77.1277
F1@0.50: 68.0851
Acc: 76.7658
Edit: 71.9766
F1@0.10: 79.0451
F1@0.25: 76.3926
F1@0.50: 68.4350
Acc: 78.9446
Edit: 74.5821
F1@0.10: 80.1047
F1@0.25: 78.5340
F1@0.50: 68.5864
Acc: 76.9570
Edit: 70.7644
F1@0.10: 76.7568
F1@0.25: 75.1351
F1@0.50: 63.7838
Acc: 80.3912
Edit: 74.9900
F1@0.10: 80.9399
F1@0.25: 80.4178
F1@

In [13]:
f = open(os.path.join(record_dir, 'split_{}_best.csv'.format(split)), 'w')
writer = csv.writer(f, delimiter='\t')
writer.writerow(['epoch', 'accu', 'edit', 
                 'F1@{}'.format(cfg.iou_thresholds[0]),
                 'F1@{}'.format(cfg.iou_thresholds[1]), 
                 'F1@{}'.format(cfg.iou_thresholds[2])])
writer.writerow([max_epoch, '%.4f'%(max_results['accu']), '%.4f'%(max_results['edit']),
                '%.4f'%(max_results['F1@%0.2f'%(cfg.iou_thresholds[0])]),
                '%.4f'%(max_results['F1@%0.2f'%(cfg.iou_thresholds[1])]),
                '%.4f'%(max_results['F1@%0.2f'%(cfg.iou_thresholds[2])])])
f.close()