In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !rm -rf saved_lifter_2d_3d_model/synthetic_cabin_bw/A_Pillar_Codriver/prediction/linear_model/lightning_logs

In [3]:
import pandas as pd
import numpy as np
import torch
import lightning.pytorch as pl
import matplotlib.pyplot as plt
# import plotly
import plotly.express as px

In [4]:
import os
import torch
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from src.modules.lifter_2d_3d.model.linear_model.linear_model import BaselineModel
from src.modules.lifter_2d_3d.dataset.simple_keypoint_dataset import SimpleKeypointDataset
from src.modules.lifter_2d_3d.model.linear_model.lit_linear_model import LitSimpleBaselineLinear
from src.modules.utils.visualization import generate_connection_line, get_sample_from_loader, visualize_pose
from IPython.display import display

pl.seed_everything(1234)

train_dataset = SimpleKeypointDataset(
    prediction_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_train.json",
    annotation_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/annotations/person_keypoints_train.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_hip=True
)
val_dataset = SimpleKeypointDataset(
    prediction_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_val.json",
    annotation_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/annotations/person_keypoints_val.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_hip=True
)
test_dataset = SimpleKeypointDataset(
    prediction_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_test.json",
    annotation_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/annotations/person_keypoints_test.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_hip=True
)

print(
    'train_dataset', len(train_dataset),
    'val_dataset', len(val_dataset),
    'test_dataset', len(test_dataset)
)
train_loader = DataLoader(train_dataset, batch_size=64, drop_last=True, shuffle=True, num_workers=24)
val_loader = DataLoader(val_dataset, batch_size=64, drop_last=True, num_workers=24)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=24)

model_checkpoint = ModelCheckpoint(monitor='val_loss',mode='min', save_top_k=1)
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=5)

# ------------
# model
# ------------
lit_model = LitSimpleBaselineLinear(exclude_ankle=True, exclude_hip=True)
# ------------
# training
# ------------
saved_model_path = './saved_lifter_2d_3d_model/synthetic_cabin_bw/A_Pillar_Codriver/prediction/linear_model/'
if not os.path.exists(saved_model_path):
    os.makedirs(saved_model_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=200,
    callbacks=[model_checkpoint, early_stopping],
    accelerator=device,
    check_val_every_n_epoch=5,
    default_root_dir=saved_model_path,
    gradient_clip_val=1.0
)
trainer.fit(lit_model, train_loader, val_loader)

Global seed set to 1234
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


train_dataset 37499 val_dataset 6250 test_dataset 6251


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type          | Params
----------------------------------------
0 | model | BaselineModel | 4.3 M 
----------------------------------------
4.3 M     Trainable params
0         Non-trainable params
4.3 M     Total params
17.105    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

check #0
val loss from: 2 batches : 2473.1621742248535


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

check #1
training loss from 2925 batches: 180.98934262990952
val loss from: 97 batches : 72.91328880129402


Validation: 0it [00:00, ?it/s]

check #2
training loss from 2925 batches: 54.90470664750816
val loss from: 97 batches : 31.602366049725987


Validation: 0it [00:00, ?it/s]

check #3
training loss from 2925 batches: 43.09635831504805
val loss from: 97 batches : 29.755979027483882


Validation: 0it [00:00, ?it/s]

check #4
training loss from 2925 batches: 37.87763009277674
val loss from: 97 batches : 31.7369069251203


Validation: 0it [00:00, ?it/s]

check #5
training loss from 2925 batches: 35.170156350120514
val loss from: 97 batches : 28.226330837945348


Validation: 0it [00:00, ?it/s]

check #6
training loss from 2925 batches: 33.59210160425585
val loss from: 97 batches : 25.825075171503823


Validation: 0it [00:00, ?it/s]

check #7
training loss from 2925 batches: 32.06165970454359
val loss from: 97 batches : 24.976600499190006


Validation: 0it [00:00, ?it/s]

check #8
training loss from 2925 batches: 31.208241760221302
val loss from: 97 batches : 26.101118952189523


Validation: 0it [00:00, ?it/s]

check #9
training loss from 2925 batches: 30.39775961994106
val loss from: 97 batches : 25.92446098161727


Validation: 0it [00:00, ?it/s]

check #10
training loss from 2925 batches: 29.755190899993618
val loss from: 97 batches : 24.51116288292039


Validation: 0it [00:00, ?it/s]

check #11
training loss from 2925 batches: 29.202459964614647
val loss from: 97 batches : 24.42070701610796


Validation: 0it [00:00, ?it/s]

check #12
training loss from 2925 batches: 28.77777078658597
val loss from: 97 batches : 24.874687367646963


Validation: 0it [00:00, ?it/s]

check #13
training loss from 2925 batches: 28.384585248736233
val loss from: 97 batches : 24.50087737560887


Validation: 0it [00:00, ?it/s]

check #14
training loss from 2925 batches: 28.046939074993134
val loss from: 97 batches : 24.273265921270724


Validation: 0it [00:00, ?it/s]

check #15
training loss from 2925 batches: 27.83787143981864
val loss from: 97 batches : 24.536979163891264


Validation: 0it [00:00, ?it/s]

check #16
training loss from 2925 batches: 27.57782029481525
val loss from: 97 batches : 24.22161782448439


Validation: 0it [00:00, ?it/s]

check #17
training loss from 2925 batches: 27.478279824185574
val loss from: 97 batches : 24.05559477041063


Validation: 0it [00:00, ?it/s]

check #18
training loss from 2925 batches: 27.31660449797781
val loss from: 97 batches : 24.188177103234324


Validation: 0it [00:00, ?it/s]

check #19
training loss from 2925 batches: 27.15028436520161
val loss from: 97 batches : 24.178747279742332


Validation: 0it [00:00, ?it/s]

check #20
training loss from 2925 batches: 27.110926906904606
val loss from: 97 batches : 24.265140794294396


Validation: 0it [00:00, ?it/s]

check #21
training loss from 2925 batches: 27.00235250604968
val loss from: 97 batches : 23.965793844197215


Validation: 0it [00:00, ?it/s]

check #22
training loss from 2925 batches: 26.938850914693287
val loss from: 97 batches : 23.936087880091566


Validation: 0it [00:00, ?it/s]

check #23
training loss from 2925 batches: 26.83005308939351
val loss from: 97 batches : 24.360516790262203


Validation: 0it [00:00, ?it/s]

check #24
training loss from 2925 batches: 26.87311377280798
val loss from: 97 batches : 24.065314532862494


Validation: 0it [00:00, ?it/s]

check #25
training loss from 2925 batches: 26.825345099991203
val loss from: 97 batches : 23.999132484812097


Validation: 0it [00:00, ?it/s]

check #26
training loss from 2925 batches: 26.79223894945577
val loss from: 97 batches : 23.897567072633617


Validation: 0it [00:00, ?it/s]

check #27
training loss from 2925 batches: 26.694721035086193
val loss from: 97 batches : 24.089094654647344


Validation: 0it [00:00, ?it/s]

check #28
training loss from 2925 batches: 26.711304113268852
val loss from: 97 batches : 24.060401028579044


Validation: 0it [00:00, ?it/s]

check #29
training loss from 2925 batches: 26.716927537678654
val loss from: 97 batches : 23.95628319726777


Validation: 0it [00:00, ?it/s]

check #30
training loss from 2925 batches: 26.666756433426826
val loss from: 97 batches : 24.26979807925593


Validation: 0it [00:00, ?it/s]

check #31
training loss from 2925 batches: 26.722374536160732
val loss from: 97 batches : 23.988008537550563


In [11]:
train_dataset = SimpleKeypointDataset(
    # prediction_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_train.json",
    # annotation_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/annotations/person_keypoints_train.json",
    prediction_file="/root/data/processed/drive_and_act_train_with_vp1/keypoint_detection_results/keypoint_detection_train.json",
    annotation_file="/root/data/processed/drive_and_act_train_with_vp1/annotations/person_keypoints_train.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_hip=True
)
val_dataset = SimpleKeypointDataset(
    # prediction_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_val.json",
    # annotation_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/annotations/person_keypoints_val.json",
    prediction_file="/root/data/processed/drive_and_act_train_with_vp1/keypoint_detection_results/keypoint_detection_val.json",
    annotation_file="//root/data/processed/drive_and_act_train_with_vp1/annotations/person_keypoints_val.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_hip=True
)
test_dataset = SimpleKeypointDataset(
    # prediction_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_test.json",
    # annotation_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/annotations/person_keypoints_test.json",
    prediction_file="/root/data/processed/drive_and_act_train_with_vp1/keypoint_detection_results/keypoint_detection_test.json",
    annotation_file="/root/data/processed/drive_and_act_train_with_vp1/annotations/person_keypoints_test.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_hip=True
)

print(
    'train_dataset', len(train_dataset),
    'val_dataset', len(val_dataset),
    'test_dataset', len(test_dataset)
)
train_loader = DataLoader(train_dataset, batch_size=64, drop_last=True, shuffle=True, num_workers=24)
val_loader = DataLoader(val_dataset, batch_size=64, drop_last=True, num_workers=24)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=24)

model_checkpoint = ModelCheckpoint(monitor='val_loss',mode='min', save_top_k=1)
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=5)

# ------------
# model
# ------------
# lit_model = LitSimpleBaselineLinear(exclude_ankle=True, exclude_hip=True)
# loaded_lit_model = lit_model.load_from_checkpoint(
#     # checkpoint_path="saved_lifter_2d_3d_model/drive_and_act/prediction/linear_model/lightning_logs/version_15/checkpoints/epoch=94-step=11780.ckpt"
#     checkpoint_path="saved_lifter_2d_3d_model/synthetic_cabin_bw/A_Pillar_Codriver/prediction/linear_model/lightning_logs/version_1/checkpoints/epoch=69-step=40950.ckpt"
# )
# ------------
# training
# ------------
# saved_model_path = './saved_lifter_2d_3d_model/synthetic_cabin_bw/A_Pillar_Codriver/prediction/linear_model/'
saved_model_path = './saved_lifter_2d_3d_model/drive_and_act_train_with_vp1/prediction/linear_model'
if not os.path.exists(saved_model_path):
    os.makedirs(saved_model_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=100,
    callbacks=[model_checkpoint, early_stopping],
    accelerator=device,
    check_val_every_n_epoch=5,
    default_root_dir=saved_model_path,
    gradient_clip_val=1.0
)
trainer.fit(lit_model, train_loader, val_loader)

skipping problematic image 8853
skipping problematic image 8854


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type          | Params
----------------------------------------
0 | model | BaselineModel | 4.3 M 
----------------------------------------
4.3 M     Trainable params
0         Non-trainable params
4.3 M     Total params
17.105    Total estimated model params size (MB)


skipping problematic image 3283
skipping problematic image 1497
train_dataset 1586 val_dataset 3059 test_dataset 5047


Sanity Checking: 0it [00:00, ?it/s]

check #53
val loss from: 2 batches : 38.06672804057598


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

check #54
training loss from 120 batches: 49.073664812992014
val loss from: 47 batches : 60.66555688355831


Validation: 0it [00:00, ?it/s]

check #55
training loss from 120 batches: 45.0997336457173
val loss from: 47 batches : 57.58864900216143


Validation: 0it [00:00, ?it/s]

check #56
training loss from 120 batches: 41.18062431613604
val loss from: 47 batches : 54.6596108123343


Validation: 0it [00:00, ?it/s]

check #57
training loss from 120 batches: 38.886423874646425
val loss from: 47 batches : 54.52471793173475


Validation: 0it [00:00, ?it/s]

check #58
training loss from 120 batches: 37.223534580941
val loss from: 47 batches : 53.4187440780249


Validation: 0it [00:00, ?it/s]

check #59
training loss from 120 batches: 35.945805053537086
val loss from: 47 batches : 53.53083144477073


Validation: 0it [00:00, ?it/s]

check #60
training loss from 120 batches: 35.338657100995384
val loss from: 47 batches : 52.08613576882697


Validation: 0it [00:00, ?it/s]

check #61
training loss from 120 batches: 34.59578316348295
val loss from: 47 batches : 51.67369948739701


Validation: 0it [00:00, ?it/s]

check #62
training loss from 120 batches: 34.363112319260836
val loss from: 47 batches : 52.605462835190146


Validation: 0it [00:00, ?it/s]

check #63
training loss from 120 batches: 34.10479724407196
val loss from: 47 batches : 51.74999981325992


Validation: 0it [00:00, ?it/s]

check #64
training loss from 120 batches: 33.514402170355126
val loss from: 47 batches : 52.10474743805033


Validation: 0it [00:00, ?it/s]

check #65
training loss from 120 batches: 33.431764567891754
val loss from: 47 batches : 52.06040106713772


Validation: 0it [00:00, ?it/s]

check #66
training loss from 120 batches: 33.243605525543295
val loss from: 47 batches : 52.06221397569839


In [6]:
# import json
# predictions = {}
# with open("/root/data/processed/drive_and_act/keypoint_detection_results/keypoint_detection_train.json") as f:
#     data = json.loads(f.read())
#     for item in data:
#         predictions[item['image_id']] = item

In [7]:
# best_checkpoint_path = model_checkpoint.best_model_path
# trainer.test(ckpt_path=best_checkpoint_path, dataloaders=test_loader)

In [52]:
# sample = get_sample_from_loader(val_loader)
item = val_loader.dataset[100]
sample = item[1]

In [53]:
valid_keypoints = (item[2].sum(axis=1) != 0)

In [54]:
results = generate_connection_line(item[2],
        np.argwhere(valid_keypoints).reshape(-1))
pose_df = pd.DataFrame(results)
visualize_pose(pose_df)

In [55]:
model = trainer.model.to(device)
model.eval()
estimated_pose = model(torch.flatten(torch.tensor(sample)).unsqueeze(0).float().to(device), 0)
estimated_pose_df = pd.DataFrame(
    generate_connection_line(
        estimated_pose[0].cpu().reshape([-1, 3]).detach().numpy(),
        # np.argwhere(valid_keypoints).reshape(-1)
    )
)
visualize_pose(estimated_pose_df)