In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !rm -rf saved_lifter_2d_3d_model/synthetic_cabin_bw/A_Pillar_Codriver/ground_truth/sem_gcn/lightning_logs

In [3]:
import os
import torch
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
# from modules.lifter_2d_3d.model.semgcn.sem_gcn import SemGCN
from modules.lifter_2d_3d.dataset.simple_keypoint_dataset import SimpleKeypointDataset
from modules.lifter_2d_3d.model.semgcn.lit_semgcn import LitSemGCN
from IPython.display import display
from modules.utils.visualization import generate_connection_line, get_sample_from_loader, visualize_pose

pl.seed_everything(1234)

train_dataset = SimpleKeypointDataset(
    prediction_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_train.json",
    annotation_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/annotations/person_keypoints_train.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_knee=True
)
val_dataset = SimpleKeypointDataset(
    prediction_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_val.json",
    annotation_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/annotations/person_keypoints_val.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_knee=True
)
test_dataset = SimpleKeypointDataset(
    prediction_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_test.json",
    annotation_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/annotations/person_keypoints_test.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_knee=True
)

print(
    'train_dataset', len(train_dataset),
    'val_dataset', len(val_dataset),
    'test_dataset', len(test_dataset)
)
train_loader = DataLoader(train_dataset, batch_size=64, drop_last=True, shuffle=True, num_workers=24)
val_loader = DataLoader(val_dataset, batch_size=64, drop_last=True, num_workers=24)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=24)

model_checkpoint = ModelCheckpoint(monitor='val_loss',mode='min', save_top_k=1)
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=5)

# ------------
# model
# ------------
lit_model = LitSemGCN(exclude_ankle =True, exclude_knee=True)
# ------------
# training
# ------------
saved_model_path = './saved_lifter_2d_3d_model/synthetic_cabin_bw/A_Pillar_Codriver/prediction/sem_gcn/'
if not os.path.exists(saved_model_path):
    os.makedirs(saved_model_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=200,
    callbacks=[model_checkpoint, early_stopping],
    accelerator=device,
    check_val_every_n_epoch=5,
    default_root_dir=saved_model_path,
    gradient_clip_val=1.0
)
trainer.fit(lit_model, train_loader, val_loader)

Global seed set to 1234
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


train_dataset 37499 val_dataset 6250 test_dataset 6251


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | SemGCN | 434 K 
---------------------------------
434 K     Trainable params
0         Non-trainable params
434 K     Total params
1.739     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

check #0
val MPJPE from: 0 batches : 938.8236999511719


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

check #1
training loss from 2925 batches: 88.6540410699498
val MPJPE from: 0 batches : 73.7123042345047


Validation: 0it [00:00, ?it/s]

check #2
training loss from 2925 batches: 53.900968680779144
val MPJPE from: 0 batches : 58.63027647137642


Validation: 0it [00:00, ?it/s]

check #3
training loss from 2925 batches: 48.07467365876222
val MPJPE from: 0 batches : 52.12249979376793


Validation: 0it [00:00, ?it/s]

check #4
training loss from 2925 batches: 43.952022821475296
val MPJPE from: 0 batches : 48.47831279039383


Validation: 0it [00:00, ?it/s]

check #5
training loss from 2925 batches: 40.827630491465584
val MPJPE from: 0 batches : 49.76576939225197


Validation: 0it [00:00, ?it/s]

check #6
training loss from 2925 batches: 38.44643257048904
val MPJPE from: 0 batches : 40.86746647953987


Validation: 0it [00:00, ?it/s]

check #7
training loss from 2925 batches: 36.61186954213513
val MPJPE from: 0 batches : 37.68044710159302


Validation: 0it [00:00, ?it/s]

check #8
training loss from 2925 batches: 35.337200764025376
val MPJPE from: 0 batches : 47.313615679740906


Validation: 0it [00:00, ?it/s]

check #9
training loss from 2925 batches: 34.21315474515287
val MPJPE from: 0 batches : 42.98614710569382


Validation: 0it [00:00, ?it/s]

check #10
training loss from 2925 batches: 33.413248223244636
val MPJPE from: 0 batches : 37.051379680633545


Validation: 0it [00:00, ?it/s]

check #11
training loss from 2925 batches: 32.604265800143914
val MPJPE from: 0 batches : 36.34168207645416


Validation: 0it [00:00, ?it/s]

check #12
training loss from 2925 batches: 31.961027124626007
val MPJPE from: 0 batches : 36.74640506505966


Validation: 0it [00:00, ?it/s]

check #13
training loss from 2925 batches: 31.515041314001778
val MPJPE from: 0 batches : 36.74386814236641


Validation: 0it [00:00, ?it/s]

check #14
training loss from 2925 batches: 31.11668507449138
val MPJPE from: 0 batches : 37.30987757444382


Validation: 0it [00:00, ?it/s]

check #15
training loss from 2925 batches: 30.736833161905277
val MPJPE from: 0 batches : 35.67042946815491


Validation: 0it [00:00, ?it/s]

check #16
training loss from 2925 batches: 30.440757264438858
val MPJPE from: 0 batches : 36.82127967476845


Validation: 0it [00:00, ?it/s]

check #17
training loss from 2925 batches: 30.18991338327909
val MPJPE from: 0 batches : 36.00655123591423


Validation: 0it [00:00, ?it/s]

check #18
training loss from 2925 batches: 29.99830356329425
val MPJPE from: 0 batches : 36.13624721765518


Validation: 0it [00:00, ?it/s]

check #19
training loss from 2925 batches: 29.782907229203445
val MPJPE from: 0 batches : 36.093153059482574


Validation: 0it [00:00, ?it/s]

check #20
training loss from 2925 batches: 29.649225959283676
val MPJPE from: 0 batches : 35.961952060461044


In [None]:
# 27.17979177448553

In [None]:
best_checkpoint_path = model_checkpoint.best_model_path
trainer.test(ckpt_path=best_checkpoint_path, dataloaders=test_loader)

In [None]:
sample = get_sample_from_loader(val_loader)

In [None]:
results = generate_connection_line(sample[1])
pose_df = pd.DataFrame(results)
visualize_pose(pose_df)

In [None]:
model = trainer.model.to(device)
model.eval()
estimated_pose = model(torch.tensor(sample[0]).float().to(device), 0)
estimated_pose_df = pd.DataFrame(generate_connection_line(estimated_pose[0].cpu().reshape([-1, 3]).detach().numpy()))
visualize_pose(estimated_pose_df)