In [112]:
import os
import pickle
import math

import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import datasets as da

from optimizers import AdamOptimizer
from optimizers.lr_schedulers import InverseSquareRootSchedule

from models.transformer import DualTransformer
from models.loss import ivc_loss, cal_nll_loss, rec_loss

import collections
from utils_helper import TimeMeter, AverageMeter
from utils_helper import load_json, top_1_metric, top_n_metric, move_to_cuda

In [103]:
config_path = './config/charades/main.json'
tag = 'default_charades'
seed = 8
num_updates = 0 # For training

In [81]:
random.seed(seed)
np.random.seed(seed + 1)
torch.manual_seed(seed + 2)
torch.cuda.manual_seed(seed + 4)
torch.cuda.manual_seed_all(seed + 4)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [82]:
args = load_json(config_path)
args['train']['model_saved_path'] = os.path.join(args['train']['model_saved_path'], tag)

# Build Dataset

In [83]:
# Build Dataset

with open(args['dataset']['vocab_path'], 'rb') as fp:
    vocab = pickle.load(fp)
    
cls = getattr(da, args['dataset']['dataset'], None)
train_set = cls(data_path=args['dataset']['train_data'], vocab=vocab, args=args['dataset'], is_training=True, split='train')
test_set = cls(data_path=args['dataset']['test_data'], vocab=vocab, args=args['dataset'], split='test')
batch_size = args['train']['batch_size']

def worker_init_fn(worker_id):
    def set_seed(seed):
        random.seed(seed)
        np.random.seed(seed + 1)
        torch.manual_seed(seed + 3)
        torch.cuda.manual_seed(seed + 4)
        torch.cuda.manual_seed_all(seed + 4)

    set_seed(8 + worker_id)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=train_set.collate_data, num_workers=2, worker_init_fn=worker_init_fn)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, collate_fn=test_set.collate_data, num_workers=0)

args['model']['config']['vocab_size'] = train_set.vocab_size
args['model']['config']['max_epoch'] = args['train']['max_num_epochs']

# Build Model (CPL)

In [96]:
class SinusoidalPositionalEmbedding(nn.Module):
    """This module produces sinusoidal positional embeddings of any length.

    Padding symbols are ignored.
    """

    def __init__(self, embedding_dim, padding_idx, init_size=1024):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.padding_idx = padding_idx
        self.weights = SinusoidalPositionalEmbedding.get_embedding(
            init_size,
            embedding_dim,
            padding_idx,
        )

    @staticmethod
    def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
        """Build sinusoidal embeddings.

        This matches the implementation in tensor2tensor, but differs slightly
        from the description in Section 3.5 of "Attention Is All You Need".
        """
        half_dim = embedding_dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
        emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
        if embedding_dim % 2 == 1:
            # zero pad
            emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
        if padding_idx is not None:
            emb[padding_idx, :] = 0
        return emb

    def forward(self, input, **kwargs):
        bsz, seq_len, _ = input.size()
        max_pos = seq_len
        if self.weights is None or max_pos > self.weights.size(0):
            # recompute/expand embeddings if needed
            self.weights = SinusoidalPositionalEmbedding.get_embedding(
                max_pos,
                self.embedding_dim,
                self.padding_idx,
            )
        self.weights = self.weights.cuda(input.device)[:max_pos]
        return self.weights.unsqueeze(0)

    def max_positions(self):
        """Maximum number of supported positions."""
        return int(1e5)  # an arbitrary large number

In [113]:
class CPL(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dropout = config['dropout']
        self.vocab_size = config['vocab_size']
        self.sigma = config["sigma"]
        self.use_negative = config['use_negative']
        self.num_props = config['num_props']
        self.max_epoch = config['max_epoch']
        self.gamma = config['gamma']

        self.frame_fc = nn.Linear(config['frames_input_size'], config['hidden_size'])
        self.word_fc = nn.Linear(config['words_input_size'], config['hidden_size'])
        self.mask_vec = nn.Parameter(torch.zeros(config['words_input_size']).float(), requires_grad=True)
        self.start_vec = nn.Parameter(torch.zeros(config['words_input_size']).float(), requires_grad=True)
        self.pred_vec = nn.Parameter(torch.zeros(config['frames_input_size']).float(), requires_grad=True)

        self.trans = DualTransformer(**config['DualTransformer'])
        self.fc_comp = nn.Linear(config['hidden_size'], self.vocab_size)
        self.fc_gauss = nn.Linear(config['hidden_size'], self.num_props*2)
 
        self.word_pos_encoder = SinusoidalPositionalEmbedding(config['hidden_size'], 0, 20)

    def forward(self, frames_feat, frames_len, words_id, words_feat, words_len, weights, **kwargs):
        bsz, n_frames, _ = frames_feat.shape
        pred_vec = self.pred_vec.view(1, 1, -1).expand(bsz, 1, -1)
        frames_feat = torch.cat([frames_feat, pred_vec], dim=1)
        frames_feat = F.dropout(frames_feat, self.dropout, self.training)
        frames_feat = self.frame_fc(frames_feat)
        frames_mask = self._generate_mask(frames_feat, frames_len)

        words_feat[:, 0] = self.start_vec.cuda()
        words_pos = self.word_pos_encoder(words_feat)
        words_feat = F.dropout(words_feat, self.dropout, self.training)
        words_feat = self.word_fc(words_feat)
        words_mask = self._generate_mask(words_feat, words_len + 1)

        # generate Gaussian masks
        enc_out, h = self.trans(frames_feat, frames_mask, words_feat + words_pos, words_mask, decoding=1)
        gauss_param = torch.sigmoid(self.fc_gauss(h[:, -1])).view(bsz*self.num_props, 2)
        gauss_center = gauss_param[:, 0]
        gauss_width = gauss_param[:, 1]

        # downsample for effeciency
        props_len = n_frames//4
        keep_idx = torch.linspace(0, n_frames-1, steps=props_len).long()
        frames_feat = frames_feat[:, keep_idx]
        frames_mask = frames_mask[:, keep_idx]
        props_feat = frames_feat.unsqueeze(1).expand(bsz, self.num_props, -1, -1).contiguous().view(bsz*self.num_props, props_len, -1)
        props_mask = frames_mask.unsqueeze(1).expand(bsz, self.num_props, -1).contiguous().view(bsz*self.num_props, -1)

        gauss_weight = self.generate_gauss_weight(props_len, gauss_center, gauss_width)
        
        # semantic completion
        words_feat, masked_words = self._mask_words(words_feat, words_len, weights=weights)
        words_feat = words_feat + words_pos
        words_feat = words_feat[:, :-1]
        words_mask = words_mask[:, :-1]

        words_mask1 = words_mask.unsqueeze(1).expand(bsz, self.num_props, -1).contiguous().view(bsz*self.num_props, -1)
        words_id1 = words_id.unsqueeze(1).expand(bsz, self.num_props, -1).contiguous().view(bsz*self.num_props, -1)
        words_feat1 = words_feat.unsqueeze(1).expand(bsz, self.num_props, -1, -1).contiguous().view(bsz*self.num_props, words_mask1.size(1), -1)

        pos_weight = gauss_weight/gauss_weight.max(dim=-1, keepdim=True)[0]
        _, h, attn_weight = self.trans(props_feat, props_mask, words_feat1, words_mask1, decoding=2, gauss_weight=pos_weight, need_weight=True)
        words_logit = self.fc_comp(h)

        if self.use_negative:
            neg_1_weight, neg_2_weight = self.negative_proposal_mining(props_len, gauss_center, gauss_width, kwargs['epoch'])
            
            _, neg_h_1 = self.trans(props_feat, props_mask, words_feat1, words_mask1, decoding=2, gauss_weight=neg_1_weight)
            neg_words_logit_1 = self.fc_comp(neg_h_1)
  
            _, neg_h_2 = self.trans(props_feat, props_mask, words_feat1, words_mask1, decoding=2, gauss_weight=neg_2_weight)
            neg_words_logit_2 = self.fc_comp(neg_h_2)

            _, ref_h = self.trans(frames_feat, frames_mask, words_feat, words_mask, decoding=2)
            ref_words_logit = self.fc_comp(ref_h)
        else:
            neg_words_logit_1 = None
            neg_words_logit_2 = None
            ref_words_logit = None

        return {
            'neg_words_logit_1': neg_words_logit_1,
            'neg_words_logit_2': neg_words_logit_2,
            'ref_words_logit': ref_words_logit,
            'words_logit': words_logit,
            'words_id': words_id,
            'words_mask': words_mask,
            'width': gauss_width,
            'center': gauss_center,
            'gauss_weight': gauss_weight,
        }
    
    
    def generate_gauss_weight(self, props_len, center, width):
        # pdb.set_trace()
        weight = torch.linspace(0, 1, props_len)
        weight = weight.view(1, -1).expand(center.size(0), -1).to(center.device)
        center = center.unsqueeze(-1)
        width = width.unsqueeze(-1).clamp(1e-2) / self.sigma

        w = 0.3989422804014327
        weight = w/width*torch.exp(-(weight-center)**2/(2*width**2))

        return weight/weight.max(dim=-1, keepdim=True)[0]


    def negative_proposal_mining(self, props_len, center, width, epoch):
        def Gauss(pos, w1, c):
            w1 = w1.unsqueeze(-1).clamp(1e-2) / (self.sigma/2)
            c = c.unsqueeze(-1)
            w = 0.3989422804014327
            y1 = w/w1*torch.exp(-(pos-c)**2/(2*w1**2))
            return y1/y1.max(dim=-1, keepdim=True)[0]

        weight = torch.linspace(0, 1, props_len)
        weight = weight.view(1, -1).expand(center.size(0), -1).to(center.device)

        left_width = torch.clamp(center-width/2, min=0)
        left_center = left_width * min(epoch/self.max_epoch, 1)**self.gamma * 0.5
        right_width = torch.clamp(1-center-width/2, min=0)
        right_center = 1 - right_width * min(epoch/self.max_epoch, 1)**self.gamma * 0.5

        left_neg_weight = Gauss(weight, left_center, left_center)
        right_neg_weight = Gauss(weight, 1-right_center, right_center)

        return left_neg_weight, right_neg_weight

    def _mask_words(self, words_feat, words_len, weights=None):
        token = self.mask_vec.cuda().unsqueeze(0).unsqueeze(0)
        token = self.word_fc(token)

        masked_words = []
        for i, l in enumerate(words_len):
            l = int(l)
            num_masked_words = max(l // 3, 1) 
            masked_words.append(torch.zeros([words_feat.size(1)]).byte().cuda())
            if l < 1:
                continue
            p = weights[i, :l].cpu().numpy() if weights is not None else None
            choices = np.random.choice(np.arange(1, l + 1), num_masked_words, replace=False, p=p)
            masked_words[-1][choices] = 1
        
        masked_words = torch.stack(masked_words, 0).unsqueeze(-1)
        masked_words_vec = words_feat.new_zeros(*words_feat.size()) + token
        masked_words_vec = masked_words_vec.masked_fill_(masked_words == 0, 0)
        words_feat1 = words_feat.masked_fill(masked_words == 1, 0) + masked_words_vec
        return words_feat1, masked_words
    
    def _generate_mask(self, x, x_len):
        if False and int(x_len.min()) == x.size(1):
            mask = None
        else:
            mask = []
            for l in x_len:
                mask.append(torch.zeros([x.size(1)]).byte().cuda())
                mask[-1][:l] = 1
            mask = torch.stack(mask, 0)
        return mask

In [114]:
# Build Model
model_config = args['model']['config']
model = CPL(model_config)
model = model.cuda()

total_num = sum(p.numel() for p in model.parameters())
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Total parameters:', total_num, 'Trainable parameters:', trainable_num)

Total parameters: 5375680 Trainable parameters: 5375680


# Optimizer & Learning rate scheduler

In [115]:
# Optimizer & Learning rate scheduler

parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
args_optim = args['train']["optimizer"]
optimizer = AdamOptimizer(args_optim, parameters)
lr_scheduler = InverseSquareRootSchedule(args['train']['optimizer'], optimizer)

In [116]:
def _train_one_epoch(model, train_loader, epoch, optimizer, lr_scheduler, num_updates, args, **kwargs):
    model.train()

    def print_log():
        msg = 'Epoch {}, Batch {}, lr = {:.5f}, '.format(epoch, bid, curr_lr)
        for k, v in loss_meter.items():
            msg += '{} = {:.4f}, '.format(k, v.avg)
            v.reset()
        msg += '{:.3f} seconds/batch'.format(1.0 / time_meter.avg)
        print(msg)

    display_n_batches, bid = 50, 0
    time_meter = TimeMeter()
    loss_meter = collections.defaultdict(lambda: AverageMeter())

    for bid, batch in enumerate(train_loader, 1):
        optimizer.zero_grad()
        net_input = move_to_cuda(batch['net_input'])
        output = model(epoch=epoch, **net_input)

        loss, loss_dict = rec_loss(**output, num_props=model.num_props, **args['loss'])
        rnk_loss, rnk_loss_dict = ivc_loss(**output, num_props=model.num_props, **args['loss'])
        loss_dict.update(rnk_loss_dict)
        loss = loss + rnk_loss
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()

        num_updates += 1
        curr_lr = lr_scheduler.step_update(num_updates)
        time_meter.update()
        for k, v in loss_dict.items():
            loss_meter[k].update(v)

        if bid % display_n_batches == 0:
            print_log()

    if bid % display_n_batches != 0:
        print_log()
        
    return num_updates

In [117]:
def eval_cpl(model, test_loader, save=None, epoch=0):
    model.eval()
    with torch.no_grad():
        metrics_logger = collections.defaultdict(lambda: AverageMeter())

        with torch.no_grad():
            for bid, batch in enumerate(test_loader, 1):
                durations = np.asarray([i[1] for i in batch['raw']])
                gt = np.asarray([i[2] for i in batch['raw']])

                net_input = move_to_cuda(batch['net_input'])
                output = model(epoch=epoch, **net_input)
                bsz = len(durations)
                num_props = model.num_props
                k = min(num_props, 5)
                
                words_mask = output['words_mask'].unsqueeze(1).expand(bsz, num_props, -1).contiguous().view(bsz*num_props, -1)
                words_id = output['words_id'].unsqueeze(1).expand(bsz, num_props, -1).contiguous().view(bsz*num_props, -1)

                nll_loss, _ = cal_nll_loss(output['words_logit'], words_id, words_mask)
                idx = nll_loss.view(bsz, num_props).argsort(dim=-1)

                width = output['width'].view(bsz, num_props).gather(index=idx, dim=-1)
                center = output['center'].view(bsz, num_props).gather(index=idx, dim=-1)
                selected_props = torch.stack([torch.clamp(center-width/2, min=0), torch.clamp(center+width/2, max=1)], dim=-1)
                selected_props = selected_props.cpu().numpy()
                gt = gt / durations[:, np.newaxis]
                
                res = top_1_metric(selected_props[:, 0], gt)
                
                for key, v in res.items():
                    metrics_logger['R@1,'+key].update(v, bsz)
                res = top_n_metric(selected_props[:, :k].transpose(1, 0, 2), gt)
                for key, v in res.items():
                    metrics_logger['R@%d,'%(k)+key].update(v, bsz)

        msg = '|'.join([' {} {:.4f} '.format(k, v.avg) for k, v in metrics_logger.items()])
        print('|'+msg+'|')
        return metrics_logger


# Train and Evaluate

In [118]:
# Train & Eval
best_results = None
for epoch in range(1, args['train']['max_num_epochs']+1):
    print('Start Epoch {}'.format(epoch))
    model_saved_path = args['train']['model_saved_path']
    os.makedirs(model_saved_path, mode=0o755, exist_ok=True)
    save_path = os.path.join(model_saved_path, 'model-{}.pt'.format(epoch))

    num_updates = _train_one_epoch(model, train_loader, epoch, optimizer, lr_scheduler, num_updates, args)
    
    # model save
    state_dict = {
        'num_updates': num_updates,
        'config': args,
        'model_parameters': model.state_dict(),
    }
    torch.save(state_dict, save_path)
    print('save model to {}, num_updates {}.'.format(save_path, num_updates))
    
    results = eval_cpl(model, test_loader)
    if best_results is None or results['R@1,mIoU'].avg > best_results['R@1,mIoU'].avg:
        best_results = results
        os.system('cp %s %s'%(save_path, os.path.join(model_saved_path, 'model-best.pt')))
        print('Best results have been updated.')
    print('=' * 60)

msg = '|'.join([' {} {:.4f} '.format(k, v.avg) for k, v in best_results.items()])
print('Best results:')
print('|'+msg+'|')

Start Epoch 1
Epoch 1, Batch 50, lr = 0.00005, final_loss = 6.2458, nll_loss = 6.2001, ref_nll_loss = 6.2914, ivc_loss = 0.4884, neg_loss_1 = 0.0478, neg_loss_2 = 0.0487, ref_loss = 0.0390, div_loss = 0.0706, 0.212 seconds/batch
Epoch 1, Batch 100, lr = 0.00010, final_loss = 4.7244, nll_loss = 4.6914, ref_nll_loss = 4.7575, ivc_loss = 0.2237, neg_loss_1 = 0.0452, neg_loss_2 = 0.0464, ref_loss = 0.0471, div_loss = 0.0170, 0.189 seconds/batch
Epoch 1, Batch 150, lr = 0.00015, final_loss = 4.0648, nll_loss = 4.0231, ref_nll_loss = 4.1065, ivc_loss = 0.1804, neg_loss_1 = 0.0415, neg_loss_2 = 0.0395, ref_loss = 0.0390, div_loss = 0.0121, 0.180 seconds/batch
Epoch 1, Batch 200, lr = 0.00020, final_loss = 3.6682, nll_loss = 3.6203, ref_nll_loss = 3.7162, ivc_loss = 0.1531, neg_loss_1 = 0.0367, neg_loss_2 = 0.0377, ref_loss = 0.0357, div_loss = 0.0086, 0.176 seconds/batch
Epoch 1, Batch 250, lr = 0.00025, final_loss = 3.4320, nll_loss = 3.3764, ref_nll_loss = 3.4876, ivc_loss = 0.1423, neg_los

KeyboardInterrupt: 

---

# Evaluation

In [91]:
model_path = '/dataset/cpl/cpl/checkpoints/charades/default_charades/model-best.pt'

In [92]:
# model load
state_dict = torch.load(model_path)
parameters = state_dict['model_parameters']
model.load_state_dict(parameters)

<All keys matched successfully>

In [93]:
results = eval_cpl(model, test_loader)

| R@1,mIoU 0.4382 | R@1,IoU@0.1 0.7606 | R@1,IoU@0.3 0.6700 | R@1,IoU@0.5 0.4927 | R@1,IoU@0.7 0.2289 | R@1,IoU@0.9 0.0361 | R@5,mIoU 0.6701 | R@5,IoU@0.1 0.9883 | R@5,IoU@0.3 0.9674 | R@5,IoU@0.5 0.8319 | R@5,IoU@0.7 0.5032 | R@5,IoU@0.9 0.0836 |
