In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
# !rm -rf saved_lifter_2d_3d_model/synthetic_cabin_bw/A_Pillar_Codriver/prediction/linear_model/lightning_logs

In [3]:
import pandas as pd
import numpy as np
import torch
import lightning.pytorch as pl
import matplotlib.pyplot as plt
# import plotly
import plotly.express as px

In [4]:
import os
import torch
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from src.modules.lifter_2d_3d.dataset.simple_keypoint_dataset import SimpleKeypointDataset
from src.modules.lifter_2d_3d.model.graph_mlp.lit_graphmlp import LitGraphMLP
from src.modules.utils.visualization import generate_connection_line, get_sample_from_loader, visualize_pose
from IPython.display import display

pl.seed_everything(1234)

train_dataset = SimpleKeypointDataset(
    prediction_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_train.json",
    annotation_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/annotations/person_keypoints_train.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_knee=True
)
val_dataset = SimpleKeypointDataset(
    prediction_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_val.json",
    annotation_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/annotations/person_keypoints_val.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_knee=True
)
test_dataset = SimpleKeypointDataset(
    prediction_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_test.json",
    annotation_file="/root/data/processed/synthetic_cabin_ir/A_Pillar_Codriver/annotations/person_keypoints_test.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_knee=True
)

print(
    'train_dataset', len(train_dataset),
    'val_dataset', len(val_dataset),
    'test_dataset', len(test_dataset)
)
train_loader = DataLoader(train_dataset, batch_size=64, drop_last=True, shuffle=True, num_workers=24)
val_loader = DataLoader(val_dataset, batch_size=64, drop_last=True, num_workers=24)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=24)

model_checkpoint = ModelCheckpoint(monitor='val_loss',mode='min', save_top_k=1)
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=5)

# ------------
# model
# ------------
lit_model = LitGraphMLP(exclude_ankle=True)
# ------------
# training
# ------------
saved_model_path = './saved_lifter_2d_3d_model/synthetic_cabin_ir/A_Pillar_Codriver/prediction/graphmlp/'
if not os.path.exists(saved_model_path):
    os.makedirs(saved_model_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=20,
    callbacks=[model_checkpoint, early_stopping],
    accelerator=device,
    check_val_every_n_epoch=1,
    default_root_dir=saved_model_path,
    # gradient_clip_val=1.0
)
trainer.fit(lit_model, train_loader, val_loader)

Global seed set to 1234
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


train_dataset 37499 val_dataset 6250 test_dataset 6251


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type  | Params
--------------------------------
0 | model | Model | 9.5 M 
--------------------------------
9.5 M     Trainable params
676       Non-trainable params
9.5 M     Total params
37.931    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

check #0
val MPJPE from: 0 batches : 1128.348708152771


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

check #1
training loss from 585 batches: 185.70408563838046
val MPJPE from: 0 batches : 81.15095645189285


Validation: 0it [00:00, ?it/s]

check #2
training loss from 585 batches: 97.27416986074202
val MPJPE from: 0 batches : 65.13888388872147


Validation: 0it [00:00, ?it/s]

check #3
training loss from 585 batches: 78.35966767663629
val MPJPE from: 0 batches : 57.38060921430588


Validation: 0it [00:00, ?it/s]

check #4
training loss from 585 batches: 68.80844153909601
val MPJPE from: 0 batches : 49.659036099910736


Validation: 0it [00:00, ?it/s]

check #5
training loss from 585 batches: 63.8435456233147
val MPJPE from: 0 batches : 50.511471927165985


Validation: 0it [00:00, ?it/s]

check #6
training loss from 585 batches: 60.588290185755135
val MPJPE from: 0 batches : 61.20399758219719


Validation: 0it [00:00, ?it/s]

check #7
training loss from 585 batches: 58.11559373879025
val MPJPE from: 0 batches : 41.65687784552574


Validation: 0it [00:00, ?it/s]

check #8
training loss from 585 batches: 56.58782654338413
val MPJPE from: 0 batches : 42.77541860938072


Validation: 0it [00:00, ?it/s]

check #9
training loss from 585 batches: 55.09637942553586
val MPJPE from: 0 batches : 41.19153693318367


Validation: 0it [00:00, ?it/s]

check #10
training loss from 585 batches: 53.19402443292814
val MPJPE from: 0 batches : 45.014068484306335


Validation: 0it [00:00, ?it/s]

check #11
training loss from 585 batches: 52.425435898650406
val MPJPE from: 0 batches : 41.00299999117851


Validation: 0it [00:00, ?it/s]

check #12
training loss from 585 batches: 50.617546180629326
val MPJPE from: 0 batches : 41.61962866783142


Validation: 0it [00:00, ?it/s]

check #13
training loss from 585 batches: 49.140128302268494
val MPJPE from: 0 batches : 42.61445626616478


Validation: 0it [00:00, ?it/s]

check #14
training loss from 585 batches: 48.08381495949549
val MPJPE from: 0 batches : 38.864754140377045


Validation: 0it [00:00, ?it/s]

check #15
training loss from 585 batches: 47.04911973741319
val MPJPE from: 0 batches : 39.18778896331787


Validation: 0it [00:00, ?it/s]

check #16
training loss from 585 batches: 46.51329765717189
val MPJPE from: 0 batches : 35.87224707007408


Validation: 0it [00:00, ?it/s]

check #17
training loss from 585 batches: 44.4184940149132
val MPJPE from: 0 batches : 41.67758300900459


Validation: 0it [00:00, ?it/s]

check #18
training loss from 585 batches: 43.55921644685615
val MPJPE from: 0 batches : 40.421899408102036


Validation: 0it [00:00, ?it/s]

check #19
training loss from 585 batches: 43.12234080245352
val MPJPE from: 0 batches : 42.28004068136215


Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=20` reached.


check #20
training loss from 585 batches: 42.17202971633683
val MPJPE from: 0 batches : 36.97454184293747


In [5]:
# torch.nn.functional.mse_loss(torch.tensor([[1,4,2], [2,4,5]]).float(), torch.tensor([[1,2,3], [2,5,4]]).float(), reduction='none').mean(axis=1)

In [6]:
best_checkpoint_path = model_checkpoint.best_model_path
trainer.test(ckpt_path=best_checkpoint_path, dataloaders=test_loader)

Restoring states from the checkpoint path at saved_lifter_2d_3d_model/synthetic_cabin_ir/A_Pillar_Codriver/prediction/graphmlp/lightning_logs/version_17/checkpoints/epoch=15-step=9360.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at saved_lifter_2d_3d_model/synthetic_cabin_ir/A_Pillar_Codriver/prediction/graphmlp/lightning_logs/version_17/checkpoints/epoch=15-step=9360.ckpt


Testing: 0it [00:00, ?it/s]

EinopsError:  Error while processing rearrange-reduction pattern "b f j c -> b j (c f)".
 Input tensor shape: torch.Size([64, 13, 2]). Additional info: {}.
 Wrong shape: expected 4 dims. Received 3-dim tensor.

In [None]:
sample = get_sample_from_loader(val_loader)

In [None]:
results = generate_connection_line(sample[1])
pose_df = pd.DataFrame(results)
visualize_pose(pose_df)

In [None]:
model = trainer.model.to(device)
model.eval()
estimated_pose = model(torch.flatten(torch.tensor(sample[0])).unsqueeze(0).float().to(device), 0)
estimated_pose_df = pd.DataFrame(generate_connection_line(estimated_pose[0].cpu().reshape([-1, 3]).detach().numpy()))
visualize_pose(estimated_pose_df)