In [1]:
import pickle
import torch
import pytorch_lightning as pl
import pandas as pd
import numpy as np
from pathlib import Path
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from modules.action_recognizer.dataset.pose_dataset import PoseDataset, all_activities
from modules.action_recognizer.models.STTFormer.lit_sttformer import LitSTTFormer
import shutil

%load_ext autoreload
%autoreload 2

In [2]:
with open('output/inner_mirror/train_pose_info.pkl', 'rb') as f:
    train_pose_info_list = pickle.load(f)
train_pose_info_list = sorted(train_pose_info_list, key=lambda x: x['index'])

with open('output/inner_mirror/train_annotation.pkl', 'rb') as f:
    train_annotation = pickle.load(f)

In [3]:
model_data, test_data = train_test_split(train_annotation, test_size=0.15)
train_data, val_data = train_test_split(model_data, test_size=0.15)
print(len(train_data), len(val_data), len(test_data))

1551 274 323


In [4]:
pose_train_dataset = PoseDataset(train_data, train_pose_info_list, max_len=30)
pose_val_dataset = PoseDataset(val_data, train_pose_info_list, max_len=30)
pose_test_dataset = PoseDataset(test_data, train_pose_info_list, max_len=30)

In [5]:
pl.seed_everything(1234)

train_loader = DataLoader(pose_train_dataset, batch_size=64, shuffle=True, num_workers=23)
val_loader = DataLoader(pose_val_dataset, batch_size=64, shuffle=False, num_workers=23)
test_loader = DataLoader(pose_test_dataset, batch_size=64, shuffle=False, num_workers=23)

lit_model = LitSTTFormer(
    dict(
        len_parts=6,
        num_classes=len(all_activities),
        num_joints=13,
        num_frames=30,
        num_heads=3,
        num_persons=1,
        num_channels=3,
        kernel_size=[3, 5],
        config=[
            [64,  64,  16], [64,  64,  16], 
            [64,  128, 32], [128, 128, 32],
            [128, 256, 64], [256, 256, 64], 
            [256, 256, 64], [256, 256, 64]
        ]
    ),                  
    lr=1e-3,
    is_pose_3d=False
)

model_checkpoint_callback = ModelCheckpoint(
    monitor='val_acc', mode='max', save_top_k=1
)
early_stopping = EarlyStopping(
    monitor='val_acc',  mode="max", patience=3
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
max_epoch = 150
saved_model_path = 'saved_models/action_recognizer/sttformer'
if Path(saved_model_path).exists():
    shutil.rmtree(saved_model_path)

enable_progress_bar = True
num_sanity_val_steps = 10
val_check_period = 10

trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=max_epoch,
    callbacks=[
        model_checkpoint_callback,
        # early_stopping
    ],
    accelerator=device,
    check_val_every_n_epoch=val_check_period,
    default_root_dir=saved_model_path,
    gradient_clip_val=1.0,
    logger=enable_progress_bar,
    enable_progress_bar=enable_progress_bar,
    num_sanity_val_steps=num_sanity_val_steps,
    # log_every_n_steps=10
    reload_dataloaders_every_n_epochs=1,
)

trainer.fit(lit_model, train_loader, val_loader)

Seed set to 1234
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_models/action_recognizer/sttformer/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params
--------------------------------
0 | model | Model | 5.8 M 
--------------------------------
5.8 M     Trainable params
0         Non-trainable params
5.8 M     Total params
23.331    Total estimated model params size (MB)


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.0372 [correct=7/188]


/opt/conda/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (17) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.4574 [correct=86/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5000 [correct=94/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5266 [correct=99/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5266 [correct=99/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5000 [correct=94/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5479 [correct=103/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5053 [correct=95/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5160 [correct=97/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5106 [correct=96/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5266 [correct=99/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5106 [correct=96/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5160 [correct=97/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5532 [correct=104/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5798 [correct=109/188]


Validation: |                                             | 0/? [00:00<?, ?it/s]

validation set accuracy = 0.5585 [correct=105/188]


`Trainer.fit` stopped: `max_epochs=150` reached.


In [6]:
trainer.test(lit_model, dataloaders=test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |                                                | 0/? [00:00<?, ?it/s]

Test set accuracy = 0.6842 [correct=143/209]
Test set confusion_matrix =
[[ 3  0  0  0  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  6  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 3  0  1  2  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  0  0  2  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  3  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  1  0  0  0  5  0  0  0  0  0  0  1  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  1]
 [ 1  1  0  1  0  0  0 25  0  0  0  0  0  0  5  0  0  0  0  0  1  1  0  0  0  0  0  0  0  2  0]
 [ 0  0  0  0  0  0  0  0  5  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0  0  0  0  0  0  0  1  0  0  0  0  0  

[{'test_acc': 0.6842105388641357}]