- Handling the case of video vs a single image for transformer input 

In [None]:
# !wget https://github.com/Arthur151/ROMP/releases/download/v1.1/model_data.zip
# !unzip model_data.zip -d .

In [None]:
print(,project_dir,source_dir,root_dir,model_dir,trained_model_dir)


In [None]:
from __future__ import absolute_import, division, print_function
import argparse
import os
import sys
import os.path as op
import code
import json
import time
import datetime
import torch
import torchvision.models as models
from torchvision.utils import make_grid
import numpy as np
import cv2
from torch.utils.data import Dataset, DataLoader, ConcatDataset

In [None]:
from lib.models.backbone import build_backbone
from dataset.mixed_dataset import MixedDataset, SingleDataset

In [None]:
parser = argparse.ArgumentParser()
#########################################################
# Data related arguments
#########################################################
parser.add_argument("--data_dir", default='datasets', type=str, required=False,
                    help="Directory with all datasets, each in one subfolder")
parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False,
                    help="Yaml file with all data for training.")
parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False,
                    help="Yaml file with all data for validation.")
parser.add_argument("--num_workers", default=4, type=int, 
                    help="Workers in dataloader.")       
parser.add_argument("--img_scale_factor", default=1, type=int, 
                    help="adjust image resolution.")  
#########################################################
# Loading/saving checkpoints
#########################################################
parser.add_argument("--model_name_or_path", default='metro/modeling/bert/bert-base-uncased/', type=str, required=False,
                    help="Path to pre-trained transformer model or model type.")
parser.add_argument("--resume_checkpoint", default=None, type=str, required=False,
                    help="Path to specific checkpoint for resume training.")
parser.add_argument("--output_dir", default='output/', type=str, required=False,
                    help="The output directory to save checkpoint and test results.")
parser.add_argument("--config_name", default="", type=str, 
                    help="Pretrained config name or path if not the same as model_name.")
parser.add_argument('--backbone', default='resnet101',
                help='CNN backbone architecture: hrnet-w64, hrnet, resnet50')
#########################################################
# Training parameters
#########################################################
parser.add_argument("--per_gpu_train_batch_size", default=64, type=int, 
                    help="Batch size per GPU/CPU for training.")
parser.add_argument("--per_gpu_eval_batch_size", default=64, type=int, 
                    help="Batch size per GPU/CPU for evaluation.")
parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float, 
                    help="The initial lr.")
parser.add_argument("--num_train_epochs", default=200, type=int, 
                    help="Total number of training epochs to perform.")
parser.add_argument("--vertices_loss_weight", default=1.0, type=float)          
parser.add_argument("--joints_loss_weight", default=1.0, type=float)
parser.add_argument("--drop_out", default=0.1, type=float, 
                    help="Drop out ratio in BERT.")
#########################################################
# Model architectures
#########################################################
parser.add_argument("--num_hidden_layers", default=-1, type=int, required=False, 
                    help="Update model config if given")
parser.add_argument("--hidden_size", default=-1, type=int, required=False, 
                    help="Update model config if given")
parser.add_argument("--num_attention_heads", default=-1, type=int, required=False, 
                    help="Update model config if given. Note that the division of "
                    "hidden_size / num_attention_heads should be in integer.")
 
#########################################################
# Others
#########################################################
parser.add_argument("--run_eval_only", default=False, action='store_true',) 

parser.add_argument('--logging_steps', type=int, default=100, 
                    help="Log every X steps.")
parser.add_argument("--device", type=str, default='cuda', 
                    help="cuda or cpu")
parser.add_argument('--seed', type=int, default=88, 
                    help="random seed for initialization.")
parser.add_argument("--local_rank", type=int, default=0, 
                    help="For distributed training.")
parser.add_argument("--hidden_dim", type=int, default=256, 
                    help="Size of the embeddings (dimension of the transformer.")
parser.add_argument("--position_embedding", type=str, default='sine', 
                    help="Type of positional embedding to use on top \
                    of the image features. (sine, learned).")
parser.add_argument("--lr_backbone", type=float, default=0.00001, 
                    help="Learning rate for backbone.")
parser.add_argument("--masks", type=bool, default=False, 
                    help="Segmentation")
parser.add_argument("--num_feature_levels", type=int, default=1, 
                    help="Number of feature levels the encoder processes from the backbone")
parser.add_argument("--dilation", type=bool, default=False, 
                    help="If true, we replace stride with dilation in the last convolutional block (DC5)")

In [None]:
args, unknown = parser.parse_known_args()

In [None]:
def _create_single_data_loader(datasets, shuffle=False, **kwargs):
    datasets = SingleDataset(dataset=datasets, **kwargs)
    return DataLoader(dataset = datasets, shuffle=shuffle,batch_size = args.per_gpu_train_batch_size,\
            drop_last = False, pin_memory = True, num_workers = 2)

_create_single_data_loader('coco')

In [None]:
backbone = build_backbone(args)

In [None]:
class Trainer(Base):
    def __init__(self):
        super(Trainer, self).__init__()
        self._build_model_()
        self._build_optimizer()
        self.set_up_val_loader()
        self._calc_loss = Loss()
        self.loader = self._create_data_loader(train_flag=True)
        self.mutli_task_uncertainty_weighted_loss = Learnable_Loss(self.loader.dataset._get_ID_num_()).cuda()
        self.optimizer.add_param_group({'params': self.mutli_task_uncertainty_weighted_loss.parameters()})
        
        self.train_cfg = {'mode':'matching_gts', 'is_training':True, 'update_data': True, 'calc_loss': True if self.model_return_loss else False, \
                           'new_training': args().new_training}
        self.val_best_PAMPJPE = {'pw3d': 60, 'mpiinf':80}
        logging.info('Initialization of Trainer finished!')

    def train(self):
        #init_seeds(self.local_rank, cuda_deterministic=False)
        logging.info('start training')
        self.model.train()
        if self.fix_backbone_training_scratch:
            fix_backbone(self.model, exclude_key=['backbone.'])
        else:
            train_entire_model(self.model)
        for epoch in range(self.epoch):
            if epoch==1:
                train_entire_model(self.model)
            self.train_epoch(epoch)
        self.summary_writer.close()

    def train_step(self, meta_data):
        self.optimizer.zero_grad()
        outputs = self.network_forward(self.model, meta_data, self.train_cfg)
        
        if not self.model_return_loss:
            outputs.update(self._calc_loss(outputs))
        loss, outputs = self.mutli_task_uncertainty_weighted_loss(outputs)

        if self.model_precision=='fp16':
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            loss.backward()
            self.optimizer.step()
        return outputs, loss

    def train_log_visualization(self, outputs, loss, run_time, data_time, losses, losses_dict, epoch, iter_index):
        losses.update(loss.item())
        losses_dict.update(outputs['loss_dict'])
        if self.global_count%self.print_freq==0:
            message = 'Epoch: [{0}][{1}/{2}] Time {data_time.avg:.2f} RUN {run_time.avg:.2f} Lr {lr} Loss {loss.avg:.2f} | Losses {3}'.format(
                      epoch, iter_index + 1,  len(self.loader), losses_dict.avg(), #Acc {3} | accuracies.avg(), 
                      data_time=data_time, run_time=run_time, loss=losses, lr = self.optimizer.param_groups[0]['lr'])
            print(message)
            write2log(self.log_file,'%s\n' % message)
            self.summary_writer.add_scalar('loss', losses.avg, self.global_count)
            self.summary_writer.add_scalars('loss_items', losses_dict.avg(), self.global_count)
            
            losses.reset(); losses_dict.reset(); data_time.reset() #accuracies.reset(); 
            self.summary_writer.flush()

        if self.global_count%(6*self.print_freq)==0 or self.global_count==50:
            vis_ids, vis_errors = determ_worst_best(outputs['kp_error'], top_n=3)
            save_name = '{}'.format(self.global_count)
            for ds_name in set(outputs['meta_data']['data_set']):
                save_name += '_{}'.format(ds_name)
            train_vis_dict = self.visualizer.visulize_result(outputs, outputs['meta_data'], show_items=['org_img', 'mesh', 'pj2d', 'centermap'],\
                vis_cfg={'settings': ['save_img'], 'vids': vis_ids, 'save_dir':self.train_img_dir, 'save_name':save_name, 'verrors': [vis_errors], 'error_names':['E']})

    def train_epoch(self,epoch):
        run_time, data_time, losses = [AverageMeter() for i in range(3)]
        losses_dict= AverageMeter_Dict()
        batch_start_time = time.time()
        for iter_index, meta_data in enumerate(self.loader):
            #torch.cuda.reset_peak_memory_stats(device=0)
            self.global_count += 1
            if args().new_training:
                if self.global_count==args().new_training_iters:
                    self.train_cfg['new_training'],self.val_cfg['new_training'],self.eval_cfg['new_training'] = False, False, False

            data_time.update(time.time() - batch_start_time)
            run_start_time = time.time()

            outputs, loss = self.train_step(meta_data)

            if self.local_rank in [-1, 0]:
                run_time.update(time.time() - run_start_time)
                self.train_log_visualization(outputs, loss, run_time, data_time, losses, losses_dict, epoch, iter_index)
            
            if self.global_count%self.test_interval==0 or self.global_count==self.fast_eval_iter: #self.print_freq*2
                save_model(self.model,'{}_val_cache.pkl'.format(self.tab),parent_folder=self.model_save_dir)
                self.validation(epoch)
            
            if self.distributed_training:
                # wait for rank 0 process finish the job
                torch.distributed.barrier()
            batch_start_time = time.time()
            
        title  = '{}_epoch_{}.pkl'.format(self.tab,epoch)
        save_model(self.model,title,parent_folder=self.model_save_dir)
        self.e_sche.step()

    def validation(self,epoch):
        logging.info('evaluation result on {} iters: '.format(epoch))
        for ds_name, val_loader in self.dataset_val_list.items():
            logging.info('Evaluation on {}'.format(ds_name))
            MPJPE, PA_MPJPE, eval_results = val_result(self,loader_val=val_loader, evaluation=False)
            if ds_name=='pw3d' and PA_MPJPE<self.val_best_PAMPJPE['pw3d']:
                self.val_best_PAMPJPE['pw3d'] = PA_MPJPE
                _, _, eval_results = val_result(self,loader_val=self.dataset_test_list['pw3d'], evaluation=True)
                self.summary_writer.add_scalars('pw3d-vibe-test', eval_results, self.global_count)
            if ds_name=='mpiinf':
                _, _, eval_results = val_result(self,loader_val=self.dataset_test_list['mpiinf'], evaluation=True)
                self.summary_writer.add_scalars('mpiinf-test', eval_results, self.global_count)
   
            self.evaluation_results_dict[ds_name]['MPJPE'].append(MPJPE)
            self.evaluation_results_dict[ds_name]['PAMPJPE'].append(PA_MPJPE)

            logging.info('Running evaluation results:')
            ds_running_results = self.get_running_results(ds_name)
            print('Running MPJPE:{}|{}; Running PAMPJPE:{}|{}'.format(*ds_running_results))

        title = '{}_{:.4f}_{:.4f}_{}.pkl'.format(epoch, MPJPE, PA_MPJPE, self.tab)
        logging.info('Model saved as {}'.format(title))
        save_model(self.model,title,parent_folder=self.model_save_dir)

        self.model.train()
        self.summary_writer.flush()

    def get_running_results(self, ds):
        mpjpe = np.array(self.evaluation_results_dict[ds]['MPJPE'])
        pampjpe = np.array(self.evaluation_results_dict[ds]['PAMPJPE'])
        mpjpe_mean, mpjpe_var, pampjpe_mean, pampjpe_var = np.mean(mpjpe), np.var(mpjpe), np.mean(pampjpe), np.var(pampjpe)
        return mpjpe_mean, mpjpe_var, pampjpe_mean, pampjpe_var

In [None]:
def run(args, train_dataloader, val_dataloader, METRO_model, smpl, mesh_sampler, renderer):

    max_iter = len(train_dataloader)
    iters_per_epoch = max_iter // args.num_train_epochs
    if iters_per_epoch<1000:
        args.logging_steps = 500

    optimizer = torch.optim.Adam(params=list(METRO_model.parameters()),
                                           lr=args.lr,
                                           betas=(0.9, 0.999),
                                           weight_decay=0)

    # define loss function (criterion) and optimizer
    criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').cuda(args.device)
    criterion_keypoints = torch.nn.MSELoss(reduction='none').cuda(args.device)
    criterion_vertices = torch.nn.L1Loss().cuda(args.device)


    start_training_time = time.time()
    end = time.time()
    METRO_model.train()
    batch_time = AverageMeter()
    data_time = AverageMeter()
    log_losses = AverageMeter()
    log_loss_2djoints = AverageMeter()
    log_loss_3djoints = AverageMeter()
    log_loss_vertices = AverageMeter()
    log_eval_metrics = EvalMetricsLogger()

    for iteration, (img_keys, images, annotations) in enumerate(train_dataloader):

        METRO_model.train()
        iteration += 1
        epoch = iteration // iters_per_epoch
        batch_size = images.size(0)
        adjust_learning_rate(optimizer, epoch, args)
        data_time.update(time.time() - end)

        images = images.cuda(args.device)
                
        
        # forward-pass
        pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = METRO_model(images, smpl, mesh_sampler, meta_masks=meta_masks, is_train=True)

        
        # compute 3d joint loss  (where the joints are directly output from transformer)
        loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints, has_3d_joints, args.device)
        # compute 3d vertex loss
        loss_vertices = ( args.vloss_w_sub2 * vertices_loss(criterion_vertices, pred_vertices_sub2, gt_vertices_sub2, has_smpl, args.device) + \
                            args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_smpl, args.device) + \
                            args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, args.device) )
        # compute 3d joint loss (where the joints are regressed from full mesh)
        loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints, args.device)
        # compute 2d joint loss
        loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints)  + \
                         keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_smpl, gt_2d_joints, has_2d_joints)
        
        loss_3d_joints = loss_3d_joints + loss_reg_3d_joints
    
        # we empirically use hyperparameters to balance difference losses
        loss = args.joints_loss_weight*loss_3d_joints + \
                args.vertices_loss_weight*loss_vertices  + args.vertices_loss_weight*loss_2d_joints

        # update logs
        log_loss_2djoints.update(loss_2d_joints.item(), batch_size)
        log_loss_3djoints.update(loss_3d_joints.item(), batch_size)
        log_loss_vertices.update(loss_vertices.item(), batch_size)
        log_losses.update(loss.item(), batch_size)

        # back prop
        optimizer.zero_grad()
        loss.backward() 
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        if iteration % args.logging_steps == 0 or iteration == max_iter:
            eta_seconds = batch_time.avg * (max_iter - iteration)
            eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
            logger.info(
                ' '.join(
                ['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',]
                ).format(eta=eta_string, ep=epoch, iter=iteration, 
                    memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) 
                + '  loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format(
                    log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg, 
                    optimizer.param_groups[0]['lr'])
            )

            visual_imgs = visualize_mesh(   renderer,
                                            annotations['ori_img'].detach(),
                                            annotations['joints_2d'].detach(),
                                            pred_vertices.detach(), 
                                            pred_camera.detach(),
                                            pred_2d_joints_from_smpl.detach())
            visual_imgs = visual_imgs.transpose(0,1)
            visual_imgs = visual_imgs.transpose(1,2)
            visual_imgs = np.asarray(visual_imgs)

            if is_main_process()==True:
                stamp = str(epoch) + '_' + str(iteration)
                temp_fname = args.output_dir + 'visual_' + stamp + '.jpg'
                cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255))

        if iteration % iters_per_epoch == 0:
            val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader, 
                                                METRO_model, 
                                                criterion_keypoints, 
                                                criterion_vertices, 
                                                epoch, 
                                                smpl,
                                                mesh_sampler)

            logger.info(
                ' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch) 
                + '  mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, Data Count: {:6.2f}'.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE, val_count)
            )

            if val_PAmPJPE<log_eval_metrics.PAmPJPE:
                checkpoint_dir = save_checkpoint(METRO_model, args, epoch, iteration)
                log_eval_metrics.update(val_mPVE, val_mPJPE, val_PAmPJPE, epoch)
                
        
    total_training_time = time.time() - start_training_time
    total_time_str = str(datetime.timedelta(seconds=total_training_time))
    logger.info('Total training time: {} ({:.4f} s / iter)'.format(
        total_time_str, total_training_time / max_iter)
    )
    checkpoint_dir = save_checkpoint(METRO_model, args, epoch, iteration)

    logger.info(
        ' Best Results:'
        + '  mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, at epoch {:6.2f}'.format(1000*log_eval_metrics.mPVE, 1000*log_eval_metrics.mPJPE, 1000*log_eval_metrics.PAmPJPE, log_eval_metrics.epoch)
    )

In [None]:
trans_encoder = []

config_class, model_class = BertConfig, METRO
config = config_class.from_pretrained(args.config_name if args.config_name \
        else args.model_name_or_path)

config.output_attentions = False
config.hidden_dropout_prob = args.drop_out
config.img_feature_dim = input_feat_dim
config.output_feature_dim = output_feat_dim
args.hidden_size = hidden_feat_dim

if args.legacy_setting==True:
    # During our paper submission, we were using the original intermediate size, which is 3072 fixed
    # We keep our legacy setting here 
    args.intermediate_size = -1
else:
    # We have recently tried to use an updated intermediate size, which is 4*hidden-size.
    # But we didn't find significant performance changes on Human3.6M (~36.7 PA-MPJPE)
    args.intermediate_size = int(args.hidden_size*4)

# update model structure if specified in arguments
update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size']

for idx, param in enumerate(update_params):
    arg_param = getattr(args, param)
    config_param = getattr(config, param)
    if arg_param > 0 and arg_param != config_param:
        print("Update config parameter {}: {} -> {}".format(param, config_param, arg_param))
        setattr(config, param, arg_param)

# init a transformer encoder and append it to a list
assert config.hidden_size % config.num_attention_heads == 0
model = model_class(config=config) 
print("Init model from scratch.")
trans_encoder.append(model)

trans_encoder = torch.nn.Sequential(*trans_encoder)
total_params = sum(p.numel() for p in trans_encoder.parameters())
print('Transformers total parameters: {}'.format(total_params))
backbone_total_params = sum(p.numel() for p in backbone.parameters())
print('Backbone total parameters: {}'.format(backbone_total_params))

# build end-to-end network (CNN backbone + transformer encoder + transformer decoder)
_metro_network = Network(args, config, backbone, trans_encoder, mesh_sampler)

_metro_network.to(args.device)
print("Training parameters %s", args)

train_dataloader = make_data_loader(args, args.train_yaml, args.distributed,
                                    is_train=True, scale_factor=args.img_scale_factor)
val_dataloader = make_data_loader(args, args.val_yaml, args.distributed,
                                    is_train=False, scale_factor=args.img_scale_factor)
run(args, train_dataloader, val_dataloader, _metro_network, smpl, mesh_sampler, renderer)
