In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !rm -rf saved_lifter_2d_3d_model/synthetic_cabin_bw/A_Pillar_Codriver/prediction/sem_gcn/lightning_logs

In [3]:
import os
import torch
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
# from src.modules.lifter_2d_3d.model.semgcn.sem_gcn import SemGCN
from src.modules.lifter_2d_3d.dataset.simple_keypoint_dataset import SimpleKeypointDataset
from src.modules.lifter_2d_3d.model.semgcn.lit_semgcn import LitSemGCN
from IPython.display import display
from src.modules.utils.visualization import generate_connection_line, get_sample_from_loader, visualize_pose
from pathlib import Path

pl.seed_everything(1234)

dataset_root = Path('/root/synthetic_cabin_1m/syntheticcabin_1mil/processed_syntheticCabin_1m/A_Pillar_Codriver/')

train_dataset = SimpleKeypointDataset(
    prediction_file=(
        dataset_root / "keypoint_detection_results/keypoint_detection_with_ground_truth_bbox_train.json"
        ).as_posix(),
    annotation_file=(dataset_root / "annotations/person_keypoints_train.json").as_posix(),
    image_width=1280,
    image_height=1024,
    exclude_ankle=True
)
val_dataset = SimpleKeypointDataset(
    prediction_file=(
        dataset_root / "keypoint_detection_results/keypoint_detection_with_ground_truth_bbox_val.json"
        ).as_posix(),
    annotation_file=(dataset_root / "annotations/person_keypoints_val.json").as_posix(),
    image_width=1280,
    image_height=1024,
    exclude_ankle=True
)

test_dataset = SimpleKeypointDataset(
    prediction_file=(
        dataset_root / "keypoint_detection_results/keypoint_detection_with_ground_truth_bbox_test.json"
        ).as_posix(),
    annotation_file=(dataset_root / "annotations/person_keypoints_test.json").as_posix(),
    image_width=1280,
    image_height=1024,
    exclude_ankle=True
)
print(
    'train_dataset', len(train_dataset),
    'val_dataset', len(val_dataset),
    'test_dataset', len(test_dataset)
)
train_loader = DataLoader(train_dataset, batch_size=64, drop_last=True, shuffle=True, num_workers=24)
val_loader = DataLoader(val_dataset, batch_size=64, drop_last=True, num_workers=24)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=24)

model_checkpoint = ModelCheckpoint(monitor='val_loss',mode='min', save_top_k=1)
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=5)
# ------------
# model
# ------------
lit_model = LitSemGCN(exclude_ankle=True)
# ------------
# training
# ------------
saved_model_path = './saved_lifter_2d_3d_model/synthetic_cabin_ir_1m/A_Pillar_Codriver/prediction/sem_gcn/'
if not os.path.exists(saved_model_path):
    os.makedirs(saved_model_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=200,
    callbacks=[model_checkpoint, early_stopping],
    accelerator=device,
    check_val_every_n_epoch=5,
    default_root_dir=saved_model_path,
    # gradient_clip_val=1.0
)
trainer.fit(lit_model, train_loader, val_loader)

Global seed set to 1234


skipping problematic image 262501
skipping problematic image 278126
skipping problematic image 293751
skipping problematic image 309376
skipping problematic image 325001
skipping problematic image 340626
skipping problematic image 356251
skipping problematic image 371876
skipping problematic image 512501
skipping problematic image 528126
skipping problematic image 543751
skipping problematic image 559376
skipping problematic image 575001
skipping problematic image 590626
skipping problematic image 606251
skipping problematic image 621876
skipping problematic image 762501
skipping problematic image 778126
skipping problematic image 793751
skipping problematic image 809376
skipping problematic image 825001
skipping problematic image 840626
skipping problematic image 856251
skipping problematic image 871876
skipping problematic image 387501
skipping problematic image 403126
skipping problematic image 418751
skipping problematic image 434376
skipping problematic image 637501
skipping probl

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


skipping problematic image 964725
train_dataset 74976 val_dataset 37488 test_dataset 30339


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | SemGCN | 267 K 
---------------------------------
267 K     Trainable params
0         Non-trainable params
267 K     Total params
1.070     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

check #0
val MPJPE from: 0 batches : 1147.0415592193604


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

check #1
training loss from 5855 batches: 183.4572306613857
val MPJPE from: 0 batches : 278.4784138202667


Validation: 0it [00:00, ?it/s]

check #2
training loss from 5855 batches: 174.57055397312627
val MPJPE from: 0 batches : 282.1851670742035


Validation: 0it [00:00, ?it/s]

check #3
training loss from 5855 batches: 172.49103666726185
val MPJPE from: 0 batches : 274.9946415424347


Validation: 0it [00:00, ?it/s]

check #4
training loss from 5855 batches: 171.58208785914235
val MPJPE from: 0 batches : 273.4110951423645


Validation: 0it [00:00, ?it/s]

check #5
training loss from 5855 batches: 171.08285230871553
val MPJPE from: 0 batches : 273.1604278087616


Validation: 0it [00:00, ?it/s]

check #6
training loss from 5855 batches: 170.7869995643093
val MPJPE from: 0 batches : 271.972119808197


Validation: 0it [00:00, ?it/s]

check #7
training loss from 5855 batches: 170.4877248057741
val MPJPE from: 0 batches : 271.85434103012085


Validation: 0it [00:00, ?it/s]

check #8
training loss from 5855 batches: 170.28042645814784
val MPJPE from: 0 batches : 272.007018327713


Validation: 0it [00:00, ?it/s]

check #9
training loss from 5855 batches: 170.12298327546972
val MPJPE from: 0 batches : 271.42560482025146


Validation: 0it [00:00, ?it/s]

check #10
training loss from 5855 batches: 169.98692601282917
val MPJPE from: 0 batches : 271.7480659484863


Validation: 0it [00:00, ?it/s]

check #11
training loss from 5855 batches: 169.88371411110987
val MPJPE from: 0 batches : 272.17772603034973


Validation: 0it [00:00, ?it/s]

check #12
training loss from 5855 batches: 169.7433067248684
val MPJPE from: 0 batches : 271.6076076030731


In [None]:
best_checkpoint_path = model_checkpoint.best_model_path
trainer.test(ckpt_path=best_checkpoint_path, dataloaders=test_loader)

In [None]:
sample = get_sample_from_loader(val_loader)

In [None]:
results = generate_connection_line(sample['keypoints_3d'])
pose_df = pd.DataFrame(results)
visualize_pose(pose_df)

In [None]:
model = trainer.model.to(device)
model.eval()
estimated_pose = model(torch.flatten(torch.tensor(sample['keypoints_2d'])).unsqueeze(0).float().to(device), 0)
estimated_pose_df = pd.DataFrame(generate_connection_line(estimated_pose[0].cpu().reshape([-1, 3]).detach().numpy()))
visualize_pose(estimated_pose_df)