In [None]:
# default_exp train

# train.py

> training model

In [None]:
# hide
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
# export
from baseline_3d_pose.utils import *
from baseline_3d_pose.model import *
from baseline_3d_pose.dataset import *
from baseline_3d_pose.viz import *
from fastai.vision import *
from fastprogress.fastprogress import master_bar, progress_bar
import json
import torch
import torch.optim as optim
from torch.utils.data import DataLoader

In [None]:
class AverageMeter():
    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def lr_decay(optimizer, step, lr, decay_step, gamma):
    lr = lr * gamma ** (step/decay_step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr

In [None]:
class Options():
    def __init__(self):
        # paths
        self.data_path = Path('data')
        self.model_path = Path('model')
        
        # train options
        self.actions = 'Directions'
        self.attempt_id = '01'
        self.attempt_path = Path('model')/self.attempt_id
        
        self.load_ckpt = False
        self.resume = False
        
        # train hyper-params
        self.bs = 64
        self.start_epoch = 0
        self.epochs = 10
        self.lr = 1e-3
        self.lr_decay = 100000
        self.lr_gamma = 0.96
        
        # model hyper-params
        self.size = 1024
        self.stages = 2
        self.dropout = 0.5
        
        # metrics
        self.lr_list = []
        self.loss_list = []

In [None]:
def save_options(options):
    options.attempt_path.mkdir(exist_ok=True)
    torch.save(options, options.attempt_path/'options.pt')

In [None]:
def save_ckpt(state, options, is_best=True):
    options.attempt_path.mkdir(exist_ok=True)
    fname = 'best_ckpt.pt' if is_best else 'last_ckpt.pt'
    torch.save(state, options.attempt_path/fname)

In [None]:
def save_optimizer(optimizer, options):
    options.attempt_path.mkdir(exist_ok=True)
    torch.save(state, options.attempt_path/'optimizer.pt')

In [None]:
options = Options()
device = torch.device('cuda')

In [None]:
model = Model()
model = model.cuda()
model.apply(init_kaiming)
print(f'total params: {sum(p.numel() for p in model.parameters())}')

total params: 4291632


In [None]:
criterion = nn.MSELoss(reduction='none').cuda()
optimizer = optim.Adam(model.parameters(), lr=options.lr)

In [None]:
vars(options)

{'data_path': PosixPath('data'),
 'model_path': PosixPath('model'),
 'actions': 'Directions',
 'attempt_id': '01',
 'attempt_path': PosixPath('model/01'),
 'load': False,
 'resume': False,
 'size': 1024,
 'stages': 2,
 'dropout': 0.5,
 'bs': 64,
 'epochs': 10,
 'lr': 0.001,
 'lr_decay': 100000,
 'lr_gamma': 0.96,
 'lr_list': [],
 'loss_list': []}

In [None]:
if options.load:
    pass

if options.resume:
    pass
else:
    pass

In [None]:
stat_3d = torch.load(data_path/'stat_3d.pt')
stat_2d = torch.load(data_path/'stat_2d.pt')
train_set_3d = torch.load(data_path/'train_3d.pt')
test_set_3d = torch.load(data_path/'test_3d.pt')
train_set_2d = torch.load(data_path/'train_2d.pt')
test_set_2d = torch.load(data_path/'test_2d.pt')
rcams = torch.load(data_path/'rcams.pt')

mean_2d = stat_2d['mean']
std_2d = stat_2d['std']
dim_use_2d = stat_2d['dim_use']
dim_ignore_2d = stat_2d['dim_ignore']

mean_3d = stat_3d['mean']
std_3d = stat_3d['std']
dim_use_3d = stat_3d['dim_use']
dim_ignore_3d = stat_3d['dim_ignore']

In [None]:
train_ds = Human36Dataset(get_actions(options.actions),
                          options.data_path, is_train=True)

In [None]:
train_dl = DataLoader(train_ds, batch_size=options.bs, shuffle=True)

In [None]:
mb = master_bar(range(options.start_epoch, 1))
for epoch in mb:
    for b, (xb, yb) in enumerate(progress_bar(train_dl, parent=mb)):
        print(b , x.shape, y.shape)
        break
#         mb.child.comment = f'second bar stat'
#     mb.main_bar.comment = f'first bar stat'

0 torch.Size([64, 32]) torch.Size([64, 48])
