In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# !rm -rf saved_lifter_2d_3d_model/synthetic_cabin_bw/A_Pillar_Codriver/prediction/linear_model/lightning_logs

In [3]:
import pandas as pd
import numpy as np
import torch
import lightning.pytorch as pl
import matplotlib.pyplot as plt
# import plotly
import plotly.express as px

KeyboardInterrupt: 

In [None]:
import os
import torch
import pandas as pd
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from src.modules.lifter_2d_3d.model.linear_model.linear_model import BaselineModel
from src.modules.lifter_2d_3d.dataset.simple_keypoint_dataset import SimpleKeypointDataset
from src.modules.lifter_2d_3d.model.repnet.lit_repnet import LitRepNet
from src.modules.utils.visualization import generate_connection_line, get_sample_from_loader, visualize_pose
from IPython.display import display

pl.seed_everything(1234)

train_dataset = SimpleKeypointDataset(
    # prediction_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_train.json",
    # annotation_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/annotations/person_keypoints_train.json",
    prediction_file="/root/data/processed/drive_and_act_train_with_vp1/keypoint_detection_results/keypoint_detection_train.json",
    annotation_file="/root/data/processed/drive_and_act_train_with_vp1/annotations/person_keypoints_train.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_knee=True
)
val_dataset = SimpleKeypointDataset(
    # prediction_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_val.json",
    # annotation_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/annotations/person_keypoints_val.json",
    prediction_file="/root/data/processed/drive_and_act_train_with_vp1/keypoint_detection_results/keypoint_detection_val.json",
    annotation_file="//root/data/processed/drive_and_act_train_with_vp1/annotations/person_keypoints_val.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_knee=True
)
test_dataset = SimpleKeypointDataset(
    # prediction_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/keypoint_detection_results/keypoint_detection_test.json",
    # annotation_file="/root/data/processed/synthetic_cabin_bw/A_Pillar_Codriver/annotations/person_keypoints_test.json",
    prediction_file="/root/data/processed/drive_and_act_train_with_vp1/keypoint_detection_results/keypoint_detection_test.json",
    annotation_file="/root/data/processed/drive_and_act_train_with_vp1/annotations/person_keypoints_test.json",
    image_width=1280,
    image_height=1024,
    exclude_ankle=True,
    exclude_knee=True
)

print(
    'train_dataset', len(train_dataset),
    'val_dataset', len(val_dataset),
    'test_dataset', len(test_dataset)
)
train_loader = DataLoader(train_dataset, batch_size=64, drop_last=True, shuffle=True, num_workers=24)
val_loader = DataLoader(val_dataset, batch_size=64, drop_last=True, num_workers=24)
test_loader = DataLoader(test_dataset, batch_size=64, num_workers=24)

model_checkpoint = ModelCheckpoint(monitor='val_loss',mode='min', save_top_k=1)
early_stopping = EarlyStopping(monitor="val_loss", mode="min", patience=5)

# ------------
# model
# ------------
lifter_2D_3D = BaselineModel(exclude_ankle=True, exclude_knee=True)
lit_model = LitRepNet(lifter_2D_3D=lifter_2D_3D)
# ------------
# training
# ------------
saved_model_path = './saved_lifter_2d_3d_model/drive_and_act_train_with_vp1/prediction/repnet_linear_model/'
if not os.path.exists(saved_model_path):
    os.makedirs(saved_model_path)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device', device)
# device = 'cpu'
trainer = pl.Trainer(
    # max_steps=10,
    max_epochs=500,
    callbacks=[model_checkpoint, early_stopping],
    accelerator=device,
    check_val_every_n_epoch=10,
    default_root_dir=saved_model_path,
    # gradient_clip_val=1.0
)
trainer.fit(lit_model, train_loader, val_loader)

In [None]:
# best_checkpoint_path = model_checkpoint.best_model_path
# trainer.test(ckpt_path=best_checkpoint_path, dataloaders=test_loader)

In [None]:
sample = get_sample_from_loader(val_loader)

In [None]:
for item in val_loader.dataset.samples:
    if item['id'] == sample[0]:
        print(item)
        break

In [None]:
img = plt.imread(f'/root/data/processed/drive_and_act_train_with_vp1/images/val/{item["filenames"]}')
fig, ax = plt.subplots()
ax.imshow(img)
ax.scatter(item['keypoints2D'][:,0] * img.shape[0], item['keypoints2D'][:,1] * img.shape[1])

In [None]:
valid_keypoints = (sample[1].sum(axis=1) != 0)

In [None]:
results = generate_connection_line(sample[2],
        np.argwhere(valid_keypoints).reshape(-1))
pose_df = pd.DataFrame(results)
visualize_pose(pose_df)

In [None]:
model = trainer.model.to(device)
model.eval()
estimated_pose = model.generator(torch.flatten(torch.tensor(sample[1])).unsqueeze(0).float().to(device))
estimated_pose_df = pd.DataFrame(
    generate_connection_line(
        estimated_pose[1].cpu().reshape([-1, 3]).detach().numpy(),
        # np.argwhere(valid_keypoints).reshape(-1)
    )
)
visualize_pose(estimated_pose_df)