In [1]:
import pickle
import torch
import pytorch_lightning as pl
import pandas as pd
import numpy as np
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from modules.action_recognizer.dataset.pose_dataset import PoseDataset, all_activities
from modules.action_recognizer.models.transformer.lit_action_transformer import LitActionTransformer
import shutil

%load_ext autoreload
%autoreload 2

In [2]:
with open('output/inner_mirror/train_pose_info.pkl', 'rb') as f:
    train_pose_info_list = pickle.load(f)
train_pose_info_list = sorted(train_pose_info_list, key=lambda x: x['index'])

with open('output/inner_mirror/train_annotation.pkl', 'rb') as f:
    train_annotation = pickle.load(f)

In [3]:
model_data, test_data = train_test_split(train_annotation, test_size=0.15)
train_data, val_data = train_test_split(model_data, test_size=0.15)
print(len(train_data), len(val_data), len(test_data))

1551 274 323


In [4]:
pose_train_dataset = PoseDataset(train_data, train_pose_info_list, max_len=30)
pose_val_dataset = PoseDataset(val_data, train_pose_info_list, max_len=30)
pose_test_dataset = PoseDataset(test_data, train_pose_info_list, max_len=30)

In [5]:
pl.seed_everything(1234)

train_loader = DataLoader(pose_train_dataset, batch_size=64, shuffle=True, num_workers=23)
val_loader = DataLoader(pose_val_dataset, batch_size=64, shuffle=False, num_workers=23)
test_loader = DataLoader(pose_test_dataset, batch_size=64, shuffle=False, num_workers=23)

lit_model = LitActionTransformer(
    dict(
        embed_dim=64,
        hidden_dim=64,
        num_heads=4,
        num_layers=4,
        num_classes=len(all_activities),
        num_joints=13,
        num_frames=30
    ),
    lr=1e-3,
    is_pose_3d=False
)

model_checkpoint_callback = ModelCheckpoint(
    monitor='val_acc', mode='max', save_top_k=1
)
early_stopping = EarlyStopping(
    monitor='val_acc',  mode="max", patience=3
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_epoch = 150
saved_model_path = 'saved_models/action_recognizer/action_transformer'
if Path(saved_model_path).exists():
    shutil.rmtree(saved_model_path)

enable_progress_bar = True
num_sanity_val_steps = 10
val_check_period = 10

trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=max_epoch,
    callbacks=[
        model_checkpoint_callback,
        # early_stopping
    ],
    accelerator=device,
    check_val_every_n_epoch=val_check_period,
    default_root_dir=saved_model_path,
    gradient_clip_val=1.0,
    logger=enable_progress_bar,
    enable_progress_bar=enable_progress_bar,
    num_sanity_val_steps=num_sanity_val_steps,
    # log_every_n_steps=10
    reload_dataloaders_every_n_epochs=1,
)

trainer.fit(lit_model, train_loader, val_loader)

Seed set to 1234
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/action_recognizer/action_transformer/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type              | Params
--------------------------------------------
0 | model | ActionTransformer | 106 K 
--------------------------------------------
106 K     Trainable params
0         Non-trainable params
106 K     Total params
0.428     Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.0109 [correct=3/274]


/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (25) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.4891 [correct=134/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5255 [correct=144/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5876 [correct=161/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5912 [correct=162/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6168 [correct=169/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6204 [correct=170/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6204 [correct=170/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6350 [correct=174/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6460 [correct=177/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6204 [correct=170/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6679 [correct=183/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6496 [correct=178/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6387 [correct=175/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6423 [correct=176/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=150` reached.


validation set accuracy = 0.6496 [correct=178/274]


In [6]:
trainer.test(lit_model, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                | 0/? [00:00<?, ?it/s]

Test set accuracy = 0.6316 [correct=204/323]
Test set confusion_matrix =
[[ 2  0  0  0  0  1  0  2  0  0  0  0  3  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  2  0  0  0  0  0  0  1  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  1  0  0  0  0  0  0  0  0  1  0  0  0]
 [ 0  0  0  0  2  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0]
 [ 1  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  2  1  0  0  0  1  0  0  2  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  2  0  0  0  0  0  0  0  1  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  4  0  0  0  0  1  0  0  0  1  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0]
 [ 0  1  0  0  0  0  0  1 32  0  0  0  0  0  0  0  6  0  0  0  1  0  0  2  0  0  1  0  0  2  3  0  0]
 [ 0  0  

[{'test_acc': 0.6315789222717285}]