In [1]:
import pickle
import torch
import pytorch_lightning as pl
import pandas as pd
import numpy as np
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from modules.action_recognizer.dataset.pose_dataset import PoseDataset, all_activities
from modules.action_recognizer.models.action_transformer.lit_action_transformer import LitActionTransformer
import shutil

pl.seed_everything(1234)

Seed set to 1234


1234

In [2]:
with open('output/inner_mirror/train_pose_info.pkl', 'rb') as f:
    train_pose_info_list = pickle.load(f)
train_pose_info_list = sorted(train_pose_info_list, key=lambda x: x['index'])

with open('output/inner_mirror/train_annotation.pkl', 'rb') as f:
    train_annotation = pickle.load(f)

In [3]:
model_data, test_data = train_test_split(train_annotation, test_size=0.15)
train_data, val_data = train_test_split(model_data, test_size=0.15)
print(len(train_data), len(val_data), len(test_data))

1551 274 323


In [4]:
train_dataset = PoseDataset(train_data, train_pose_info_list, max_len=30)
val_dataset = PoseDataset(val_data, train_pose_info_list, max_len=30)
test_dataset = PoseDataset(test_data, train_pose_info_list, max_len=30)

class DataModule(pl.LightningDataModule):
    def __init__(self, train_dataset, val_dataset, test_dataset):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset

    def train_dataloader(self):
        self.train_dataset.shuffle()
        return DataLoader(self.train_dataset, batch_size=64, drop_last=True, shuffle=True, num_workers=23)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=64, num_workers=23)

    def test_dataloader(self):
        return DataLoader(test_dataset, batch_size=64, num_workers=23)
dm = DataModule(train_dataset, val_dataset, test_dataset)

In [5]:
lit_model = LitActionTransformer(
    dict(
        embed_dim=128,
        hidden_dim=128,
        num_heads=2,
        num_layers=5,
        num_classes=len(all_activities),
        num_joints=13,
        num_frames=30,
        dropout=0.3,
        is_pre_norm=False
    ),
    lr=1e-3,
    is_pose_3d=True
)

model_checkpoint_callback = ModelCheckpoint(
    monitor='val_acc', mode='max', save_top_k=1
)
early_stopping = EarlyStopping(
    monitor='val_acc',  mode="max", patience=3
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_epoch = 300
saved_model_path = 'saved_models/action_recognizer/action_transformer/pose_3d'
if Path(saved_model_path).exists():
    shutil.rmtree(saved_model_path)

enable_progress_bar = True
num_sanity_val_steps = 10
val_check_period = 10

trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=max_epoch,
    callbacks=[
        model_checkpoint_callback,
        early_stopping
    ],
    accelerator=device,
    check_val_every_n_epoch=val_check_period,
    default_root_dir=saved_model_path,
    gradient_clip_val=1.0,
    logger=enable_progress_bar,
    enable_progress_bar=enable_progress_bar,
    num_sanity_val_steps=num_sanity_val_steps,
    log_every_n_steps=1,
    reload_dataloaders_every_n_epochs=1
)

trainer.fit(lit_model, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/action_recognizer/action_transformer/pose_3d/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type              | Params
--------------------------------------------
0 | model | ActionTransformer | 511 K 
--------------------------------------------
511 K     Trainable params
0         Non-trainable params
511 K     Total params
2.047     Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.0328 [correct=9/274]


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.4380 [correct=120/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5000 [correct=137/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5474 [correct=150/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5073 [correct=139/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5438 [correct=149/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6095 [correct=167/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5839 [correct=160/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5803 [correct=159/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6204 [correct=170/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6168 [correct=169/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6423 [correct=176/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6533 [correct=179/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6496 [correct=178/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6752 [correct=185/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6496 [correct=178/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6715 [correct=184/274]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.6679 [correct=183/274]


In [6]:
trainer.test(lit_model, datamodule=dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                | 0/? [00:00<?, ?it/s]

Test set accuracy = 0.6718 [correct=217/323]
Test set confusion_matrix =
[[ 2  0  0  0  1  2  0  0  0  0  1  0  0  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  1  0  0  0  1  0  0  0  0  0]
 [ 0  0  0  0  4  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  2  0  0]
 [ 2  0  0  0  0  2  0  0  0  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  1  0  0  0  0  0]
 [ 0  0  0  0  0  0  4  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0 13  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  

[{'test_acc': 0.6718266010284424}]