In [None]:
import os
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import argparse
import sys
import json
import math
import numpy as np
import copy
from tqdm import tqdm
import wandb

from config import config
from model import siMLPe as Model
from datasets.epfl_sk30 import EPFLSK30Dataset
from utils.logger import get_logger, print_and_log_info
from utils.pyt_utils import link_file, ensure_dir
from datasets.epfl_sk30_eval import EPFLSK30Eval

from test import test

import torch
from torch.utils.data import DataLoader

In [2]:
# Joint names for EPFL-SK30 dataset
joint_names = [
    'Root',
    'LeftHead',
    'RightHead',
    'LeftBody',
    'RightBody',
    'LeftShoulder',
    'RightShoulder',
    'LeftArm',
    'RightArm',
    'LeftForearm',
    'RightForearm',
    'LeftHip',
    'RightHip',
    'LeftKnee',
    'RightKnee',
    'LeftFoot',
    'RightFoot'
]


In [3]:
exp_name = "epfl_sk30_baseline"

torch.use_deterministic_algorithms(True)
acc_log = open(exp_name, 'a')
torch.manual_seed(config.seed)

# Initialize wandb
wandb.init(
    project="siMLPe-EPFL-SK30",
    name=exp_name,
    config={
        "exp_name": exp_name,
        "seed": config.seed,
        "with_normalization": config.motion_mlp.with_normalization,
        "spatial_fc": config.motion_mlp.spatial_fc_only,
        "num_layers": config.motion_mlp.num_layers,
        # Model config
        "motion_input_length": config.motion.epfl_input_length,
        "motion_target_length_train": config.motion.epfl_target_length_train,
        "motion_target_length_eval": config.motion.epfl_target_length_eval,
        "motion_dim": config.motion.dim,
        "data_aug": config.data_aug,
        "deriv_input": config.deriv_input,
        "deriv_output": config.deriv_output,
        "use_relative_loss": config.use_relative_loss,
        # Training config
        "batch_size": config.batch_size,
        "num_workers": config.num_workers,
        "cos_lr_max": config.cos_lr_max,
        "cos_lr_min": config.cos_lr_min,
        "cos_lr_total_iters": config.cos_lr_total_iters,
        "weight_decay": config.weight_decay,
        "print_every": config.print_every,
        "save_steps": config.save_steps,
        "num_epochs": config.num_epochs
    }
)


acc_log.write(''.join('Seed : ' + str(config.seed) + '\n'))

[34m[1mwandb[0m: Currently logged in as: [33memredmrcx[0m ([33memredmrcx-itu-edu-tr[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


11

In [4]:
def get_dct_matrix(N):
    dct_m = np.eye(N)
    for k in np.arange(N):
        for i in np.arange(N):
            w = np.sqrt(2 / N)
            if k == 0:
                w = np.sqrt(1 / N)
            dct_m[k, i] = w * np.cos(np.pi * (i + 1 / 2) * k / N)
    idct_m = np.linalg.inv(dct_m)
    return dct_m, idct_m

dct_m,idct_m = get_dct_matrix(config.motion.epfl_input_length_dct)
dct_m = torch.tensor(dct_m).float().cuda().unsqueeze(0)
idct_m = torch.tensor(idct_m).float().cuda().unsqueeze(0)
 

In [5]:
def update_lr_multistep(nb_iter, total_iter, max_lr, min_lr, optimizer) :
    if nb_iter > 30000:
        current_lr = 1e-5
    else:
        current_lr = 3e-4

    for param_group in optimizer.param_groups:
        param_group["lr"] = current_lr

    return optimizer, current_lr

def gen_velocity(m):
    dm = m[:, 1:] - m[:, :-1]
    return dm

In [6]:
def train_step(epfl_motion_input, epfl_motion_target, model, optimizer, nb_iter, total_iter, max_lr, min_lr) :
    
    # Input shape: (b, n, 17, 3) - reshape to (b, n, 51) for model
    b, n, num_joints, _ = epfl_motion_input.shape
    epfl_motion_input_flat = epfl_motion_input.reshape(b, n, -1)  # (b, n, 51)
    
    if config.deriv_input:
        epfl_motion_input_ = epfl_motion_input_flat.clone()
        epfl_motion_input_ = torch.matmul(dct_m[:, :, :config.motion.epfl_input_length], epfl_motion_input_.cuda())
    else:
        epfl_motion_input_ = epfl_motion_input_flat.clone()

    motion_pred = model(epfl_motion_input_.cuda())
    motion_pred = torch.matmul(idct_m[:, :config.motion.epfl_input_length, :], motion_pred)

    if config.deriv_output:
        offset = epfl_motion_input_flat[:, -1:].cuda()
        motion_pred = motion_pred[:, :config.motion.epfl_target_length_train] + offset
    else:
        motion_pred = motion_pred[:, :config.motion.epfl_target_length_train]

    b_target, n_target, _, _ = epfl_motion_target.shape
    motion_pred = motion_pred.reshape(b_target, n_target, 17, 3)
    

    position_loss = torch.mean(torch.norm(motion_pred - epfl_motion_target.cuda(), 2, 3))

    velocity_loss = 0.0
    if config.use_relative_loss:
        dmotion_pred = gen_velocity(motion_pred)
        dmotion_gt = gen_velocity(epfl_motion_target.cuda())
        velocity_loss = torch.mean(torch.norm(dmotion_pred - dmotion_gt, 2, 3))
        loss = position_loss + velocity_loss
    else:
        loss = position_loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    optimizer, current_lr = update_lr_multistep(nb_iter, total_iter, max_lr, min_lr, optimizer)
    
    # Log metrics to wandb
    wandb.log({
        "train/loss": loss.detach().cpu().item(),
        "train/position_loss": position_loss.detach().cpu().item(),
        "train/velocity_loss": velocity_loss.detach().cpu().item() if config.use_relative_loss else 0.0,
        "train/learning_rate": current_lr,
    })

    return loss.item(), optimizer, current_lr

In [7]:
model = Model(config)
model.train()
model.cuda()

siMLPe(
  (arr0): Rearrange('b n d -> b d n')
  (arr1): Rearrange('b d n -> b n d')
  (motion_mlp): TransMLP(
    (mlps): Sequential(
      (0): MLPblock(
        (fc0): Temporal_FC(
          (fc): Linear(in_features=50, out_features=50, bias=True)
        )
        (norm0): LN()
      )
      (1): MLPblock(
        (fc0): Temporal_FC(
          (fc): Linear(in_features=50, out_features=50, bias=True)
        )
        (norm0): LN()
      )
      (2): MLPblock(
        (fc0): Temporal_FC(
          (fc): Linear(in_features=50, out_features=50, bias=True)
        )
        (norm0): LN()
      )
      (3): MLPblock(
        (fc0): Temporal_FC(
          (fc): Linear(in_features=50, out_features=50, bias=True)
        )
        (norm0): LN()
      )
      (4): MLPblock(
        (fc0): Temporal_FC(
          (fc): Linear(in_features=50, out_features=50, bias=True)
        )
        (norm0): LN()
      )
      (5): MLPblock(
        (fc0): Temporal_FC(
          (fc): Linear(in_features=50

In [8]:
dataset = EPFLSK30Dataset(config, 'train', config.data_aug)
dataloader = DataLoader(dataset, batch_size=config.batch_size,
                        num_workers=config.num_workers, drop_last=True,
                        sampler=None, shuffle=True, pin_memory=True)



In [9]:
print(f"Length of training dataloader: {len(dataloader)}")

Length of training dataloader: 4353


In [15]:
eval_config = copy.deepcopy(config)
eval_config.data_aug = False
# Set both train and eval to use the same target length (eval length)
# This ensures consistent evaluation between train and test splits
eval_config.motion.epfl_target_length_train = eval_config.motion.epfl_target_length_eval

eval_dataset = EPFLSK30Eval(eval_config, 'test')
eval_dataloader = DataLoader(eval_dataset, batch_size=128,
                        num_workers=eval_config.num_workers, drop_last=False,
                        sampler=None, shuffle=False, pin_memory=True)

In [16]:
# Use EPFLSK30Eval (not EPFLSK30Dataset) so data format matches test set
# This ensures both train and test evaluation use the same data pipeline
train_eval_dataset = EPFLSK30Eval(eval_config, 'train')
train_eval_dataloader = DataLoader(train_eval_dataset, batch_size=128,
                        num_workers=eval_config.num_workers, drop_last=False,
                        sampler=None, shuffle=False, pin_memory=True)

In [17]:
print(f"Length of test evaluation dataloader: {len(eval_dataloader)}")
print(f"Length of train evaluation dataloader: {len(train_eval_dataloader)}")

Length of test evaluation dataloader: 4199
Length of train evaluation dataloader: 8704


In [None]:
# initialize optimizer
optimizer = torch.optim.Adam(model.parameters(),
                             lr=config.cos_lr_max,
                             weight_decay=config.weight_decay)

ensure_dir(config.snapshot_dir)
logger = get_logger(config.log_file, 'train')
link_file(config.log_file, config.link_log_file)

print_and_log_info(logger, json.dumps(config, indent=4, sort_keys=True))

if config.model_pth is not None :
    state_dict = torch.load(config.model_pth)
    model.load_state_dict(state_dict, strict=True)
    print_and_log_info(logger, "Loading model path from {} ".format(config.model_pth))
    
# Log model architecture to wandb
wandb.watch(model)

##### ------ training ------- #####
nb_epoch = 0
nb_iter = 0

total_iters = len(dataloader) * config.num_epochs

pbar = tqdm(total=total_iters, desc=f"Epoch 0/{config.num_epochs}", unit="iter")

while nb_epoch < config.num_epochs:
    for (epfl_motion_input, epfl_motion_target) in dataloader:

        loss, optimizer, current_lr = train_step(epfl_motion_input, epfl_motion_target, model, optimizer, nb_iter, total_iters, config.cos_lr_max, config.cos_lr_min)

        if (nb_iter + 1) in config.save_steps:
        
            # Save model checkpoint
            model_path = config.snapshot_dir + '/model-iter-' + str(nb_iter + 1) + '.pth'
            torch.save(model.state_dict(), model_path)
            
            # Save model to wandb
            wandb.save(model_path)
            
            # Evaluate model
            model.eval()
            results_keys = ['#2', '#4', '#8', '#10', '#14', '#18', '#22', '#25']
            
            # Evaluate on test set
            ret_test, ret_per_joint_test = test(eval_config, model, eval_dataloader, return_per_joint=True)
            pbar.write(f"Test MPJPE: {ret_test}")
            
            # Evaluate on train set
            ret_train, ret_per_joint_train = test(eval_config, model, train_eval_dataloader, return_per_joint=True)
            pbar.write(f"Train MPJPE: {ret_train}")
            
            # Log evaluation metrics to wandb
            # Structure metrics so train/test appear on same graph in WandB
            eval_metrics = {"iteration": nb_iter + 1,
                            "epoch": nb_epoch}
            
            # Log overall MPJPE for each frame (train and test together)
            for i, key in enumerate(results_keys):
                frame_num = key.replace('#', 'frame_')
                # Overall metrics - both train and test under same prefix
                eval_metrics[f"{frame_num}/overall/test"] = ret_test[i]
                eval_metrics[f"{frame_num}/overall/train"] = ret_train[i]
                eval_metrics[f"{frame_num}/overall/gap"] = ret_train[i] - ret_test[i]  # Overfitting gap
            
            # Log per-joint MPJPE for each frame
            for key in results_keys:
                frame_num = key.replace('#', 'frame_')
                per_joint_test = ret_per_joint_test[key]
                per_joint_train = ret_per_joint_train[key]
                
                for joint_idx, joint_name in enumerate(joint_names):
                    # Per-joint metrics - both train and test under same prefix
                    eval_metrics[f"{frame_num}/{joint_name}/test"] = per_joint_test[joint_idx]
                    eval_metrics[f"{frame_num}/{joint_name}/train"] = per_joint_train[joint_idx]
                    eval_metrics[f"{frame_num}/{joint_name}/gap"] = per_joint_train[joint_idx] - per_joint_test[joint_idx]
            
            wandb.log(eval_metrics)
            
            
            
            # Write to log file
            acc_log.write(''.join(str(nb_iter + 1) + '\n'))
            # Test set results
            line_test = 'Test: '
            for ii in ret_test:
                line_test += str(ii) + ' '
            line_test += '\n'
            # Train set results
            line_train = 'Train: '
            for ii in ret_train:
                line_train += str(ii) + ' '
            line_train += '\n'
            acc_log.write(''.join(line_test))
            acc_log.write(''.join(line_train))
            model.train()

        # Update progress bar
        pbar.update(1)
        pbar.set_postfix({
            'epoch': nb_epoch,
            'loss': f'{loss:.4f}',
            'lr': f'{current_lr:.2e}',
            'iter': nb_iter + 1
        })

        # if (nb_iter + 1) == config.cos_lr_total_iters :
        #     break
        nb_iter += 1
   
    nb_epoch += 1
    pbar.set_description(f"Epoch {nb_epoch}/{config.num_epochs}")
    
pbar.close()
wandb.finish()


4353


Epoch 0/10:   0%|          | 0/43530 [12:06<?, ?iter/s]




Test MPJPE: [18.8, 35.4, 64.5, 77.5, 101.1, 122.2, 141.3, 154.4]




Train MPJPE: [20.3, 38.4, 70.4, 84.8, 111.1, 134.8, 156.1, 170.9]




Test MPJPE: [18.9, 35.4, 64.5, 77.5, 101.2, 122.3, 141.3, 154.5]




Train MPJPE: [20.4, 38.4, 70.4, 84.8, 111.1, 134.8, 156.1, 170.9]




Test MPJPE: [12.9, 25.2, 49.8, 61.9, 85.2, 106.8, 126.9, 141.0]




Train MPJPE: [13.2, 25.9, 52.0, 65.1, 90.7, 114.9, 137.5, 153.4]




Test MPJPE: [12.0, 22.8, 46.0, 57.6, 80.9, 102.8, 123.4, 137.7]




Train MPJPE: [12.6, 23.9, 48.3, 60.9, 86.6, 111.0, 134.2, 150.4]




Test MPJPE: [9.6, 20.0, 42.0, 53.3, 76.1, 98.2, 118.8, 133.6]




Train MPJPE: [9.9, 20.8, 44.1, 56.1, 80.7, 105.0, 127.9, 144.6]




Test MPJPE: [8.9, 18.6, 39.3, 50.0, 71.8, 93.1, 113.2, 127.4]




Train MPJPE: [9.1, 19.2, 41.0, 52.3, 75.6, 99.0, 121.5, 137.5]




Test MPJPE: [8.5, 18.0, 38.3, 48.8, 70.4, 91.7, 111.9, 126.3]




Train MPJPE: [8.7, 18.6, 39.8, 50.9, 74.0, 97.3, 119.8, 135.9]




Test MPJPE: [8.3, 17.7, 37.6, 48.0, 69.5, 90.4, 110.4, 124.7]




Train MPJPE: [8.5, 18.3, 38.9, 49.7, 72.4, 95.2, 117.3, 133.3]




Test MPJPE: [8.2, 17.5, 37.3, 47.5, 68.3, 89.3, 109.4, 123.7]




Train MPJPE: [8.4, 18.1, 38.5, 49.1, 71.2, 94.2, 116.7, 132.9]




Test MPJPE: [7.9, 17.0, 36.1, 46.1, 66.7, 87.1, 106.8, 120.8]




Train MPJPE: [8.1, 17.4, 37.0, 47.3, 69.0, 91.0, 112.6, 128.3]




Test MPJPE: [7.9, 16.9, 35.9, 45.9, 66.5, 86.9, 106.6, 120.6]




Train MPJPE: [8.1, 17.3, 36.8, 47.1, 68.6, 90.5, 112.1, 127.6]


Epoch 10/10: 100%|██████████| 43530/43530 [1:19:55<00:00,  9.08iter/s, epoch=9, loss=0.0295, lr=1.00e-05, iter=43530]


0,1
epoch,▁▁▁▂▃▃▄▅▆▇█
frame_10/LeftArm/gap,██▃▄▃▂▂▂▁▁▁
frame_10/LeftArm/test,██▅▄▃▂▂▂▁▁▁
frame_10/LeftArm/train,██▄▄▃▂▂▂▁▁▁
frame_10/LeftBody/gap,██▃▃▃▂▂▂▁▁▁
frame_10/LeftBody/test,██▄▃▂▂▁▁▁▁▁
frame_10/LeftBody/train,██▄▃▂▂▂▁▁▁▁
frame_10/LeftFoot/gap,██▆▆▅▄▄▃▂▁▁
frame_10/LeftFoot/test,██▆▅▄▃▂▂▂▁▁
frame_10/LeftFoot/train,██▆▅▄▃▃▂▂▁▁

0,1
epoch,9
frame_10/LeftArm/gap,3.64989
frame_10/LeftArm/test,53.83376
frame_10/LeftArm/train,57.48365
frame_10/LeftBody/gap,1.91156
frame_10/LeftBody/test,39.84282
frame_10/LeftBody/train,41.75439
frame_10/LeftFoot/gap,-0.52858
frame_10/LeftFoot/test,33.40074
frame_10/LeftFoot/train,32.87217
