In [1]:
import os
import torch
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.utils.data import DataLoader
from modules.lifter_2d_3d.dataset.simple_keypoint_dataset import SimpleKeypointDataset
from modules.lifter_2d_3d.model.semgcn.lit_semgcn import LitSemGCN
from modules.utils.visualization import (
    plot_samples
)
from IPython.display import display
from pathlib import Path
pl.seed_everything(1234)

# ------------
# dataset path
# ------------
dataset_root_path = Path('/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver')
keypoint_2d_path = dataset_root_path / 'annotations'
keypoint_3d_path = dataset_root_path / 'annotations'
# ------------
# model
# ------------
image_width = 1280
image_height = 1024
batch_size = 64
max_epoch = 200
val_check_period = 5
early_stopping_patience = 5
lit_model = LitSemGCN(exclude_ankle =True, exclude_knee=True)
# ------------
# saved model path
# ------------
saved_model_path = './saved_lifter_2d_3d_model/synthetic_cabin_ir/A_Pillar_Codriver/prediction/semgcn/'


train_dataset = SimpleKeypointDataset(
    prediction_file=(keypoint_2d_path / 'keypoint_detection_train.json').as_posix(),
    annotation_file=(keypoint_3d_path / 'person_keypoints_train.json').as_posix(),
    image_width=image_width,
    image_height=image_height,
    exclude_ankle=True,
    exclude_knee=True
)
val_dataset = SimpleKeypointDataset(
    prediction_file=(keypoint_2d_path / 'keypoint_detection_val.json').as_posix(),
    annotation_file=(keypoint_3d_path / 'person_keypoints_val.json').as_posix(),
    image_width=image_width,
    image_height=image_height,
    exclude_ankle=True,
    exclude_knee=True
)
test_dataset = SimpleKeypointDataset(
    prediction_file=(keypoint_2d_path / 'keypoint_detection_test.json').as_posix(),
    annotation_file=(keypoint_3d_path / 'person_keypoints_test.json').as_posix(),
    image_width=image_width,
    image_height=image_height,
    exclude_ankle=True,
    exclude_knee=True
)

print(
    'train_dataset', len(train_dataset),
    'val_dataset', len(val_dataset),
    'test_dataset', len(test_dataset)
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, drop_last=True, shuffle=True, num_workers=24)
val_loader = DataLoader(val_dataset, batch_size=batch_size, drop_last=True, num_workers=24)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=24)

model_checkpoint = ModelCheckpoint(monitor='val_loss',mode='min', save_top_k=1)
early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=early_stopping_patience)


if not os.path.exists(saved_model_path):
    os.makedirs(saved_model_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=max_epoch,
    callbacks=[model_checkpoint, early_stopping],
    accelerator=device,
    check_val_every_n_epoch=val_check_period,
    default_root_dir=saved_model_path,
    gradient_clip_val=1.0
)
trainer.fit(lit_model, train_loader, val_loader)

Global seed set to 1234
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: saved_lifter_2d_3d_model/synthetic_cabin_ir/A_Pillar_Codriver/prediction/semgcn/lightning_logs


train_dataset 37499 val_dataset 6250 test_dataset 6251


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | SemGCN | 434 K 
---------------------------------
434 K     Trainable params
0         Non-trainable params
434 K     Total params
1.739     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

check #0
val MPJPE from: 0 batches : 938.8236999511719


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

check #1
training loss from 2925 batches: 87.86310912070111
val MPJPE from: 0 batches : 64.6202564239502


Validation: 0it [00:00, ?it/s]

check #2
training loss from 2925 batches: 53.569887809009636
val MPJPE from: 0 batches : 54.31749299168587


Validation: 0it [00:00, ?it/s]

check #3
training loss from 2925 batches: 47.49058394095836
val MPJPE from: 0 batches : 43.39056834578514


Validation: 0it [00:00, ?it/s]

check #4
training loss from 2925 batches: 43.6569387306515
val MPJPE from: 0 batches : 48.559218645095825


Validation: 0it [00:00, ?it/s]

check #5
training loss from 2925 batches: 40.53578013793016
val MPJPE from: 0 batches : 50.772007554769516


Validation: 0it [00:00, ?it/s]

check #6
training loss from 2925 batches: 38.166464230953125
val MPJPE from: 0 batches : 45.187074691057205


Validation: 0it [00:00, ?it/s]

check #7
training loss from 2925 batches: 36.38353891670704
val MPJPE from: 0 batches : 38.1944514811039


Validation: 0it [00:00, ?it/s]

check #8
training loss from 2925 batches: 35.12085741147017
val MPJPE from: 0 batches : 43.46007853746414


Validation: 0it [00:00, ?it/s]

check #9
training loss from 2925 batches: 33.990961892100486
val MPJPE from: 0 batches : 45.972853899002075


Validation: 0it [00:00, ?it/s]

check #10
training loss from 2925 batches: 33.24730817897197
val MPJPE from: 0 batches : 36.53201460838318


Validation: 0it [00:00, ?it/s]

check #11
training loss from 2925 batches: 32.50929945172408
val MPJPE from: 0 batches : 37.14746981859207


Validation: 0it [00:00, ?it/s]

check #12
training loss from 2925 batches: 31.8829694409401
val MPJPE from: 0 batches : 36.2519733607769


Validation: 0it [00:00, ?it/s]

check #13
training loss from 2925 batches: 31.45809023044048
val MPJPE from: 0 batches : 37.13878244161606


Validation: 0it [00:00, ?it/s]

check #14
training loss from 2925 batches: 31.06511875286571
val MPJPE from: 0 batches : 37.55078464746475


Validation: 0it [00:00, ?it/s]

check #15
training loss from 2925 batches: 30.704856551865227
val MPJPE from: 0 batches : 36.75764054059982


Validation: 0it [00:00, ?it/s]

check #16
training loss from 2925 batches: 30.41950090063943
val MPJPE from: 0 batches : 37.63733059167862


Validation: 0it [00:00, ?it/s]

check #17
training loss from 2925 batches: 30.167243080134067
val MPJPE from: 0 batches : 37.205301225185394


In [2]:
with open(f'{saved_model_path}/best_model_path.txt', 'w') as f:
    f.writelines(model_checkpoint.best_model_path)

In [3]:
best_checkpoint_path = model_checkpoint.best_model_path
trainer.test(ckpt_path=best_checkpoint_path, dataloaders=test_loader)

Restoring states from the checkpoint path at saved_lifter_2d_3d_model/synthetic_cabin_ir/A_Pillar_Codriver/prediction/semgcn/lightning_logs/version_0/checkpoints/epoch=59-step=35100.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at saved_lifter_2d_3d_model/synthetic_cabin_ir/A_Pillar_Codriver/prediction/semgcn/lightning_logs/version_0/checkpoints/epoch=59-step=35100.ckpt


Testing: 0it [00:00, ?it/s]

MPJPE: 36.30107268691063
PJPE
                     PJPE
nose            28.903124
left_eye        29.337566
right_eye       28.913559
left_ear        27.954557
right_ear       26.414228
left_shoulder   25.260942
right_shoulder  34.278622
left_elbow      84.799065
right_elbow     54.234325
left_wrist      46.713318
right_wrist     40.981144
left_hip        22.427483
right_hip       21.695965
activities_mpjpe:
{}
test mpjpe: 36.30107268691063


[{'mpjpe': 36.30107268691063}]

In [4]:
trainer.model.test_history[0]['pjpe']

AttributeError: 'LitSemGCN' object has no attribute 'test_history'

In [None]:
trainer.model.test_history[0]['mpjpe']

In [None]:
plot_samples(
    dataset_root_path,
    trainer.model,
    test_loader,
    'test',
    img_figsize=(20, 10),
    img_width=image_width,
    img_height=image_height,
    plot_figsize=(20.5, 10),
    sample_idices=[1, 1000, 5000]
)