In [None]:
import pytorch_lightning as pl
from modules.lifter_2d_3d.model.linear_model.lit_linear_model import LitSimpleBaselineLinear as LitModel
from modules.experiments.dataset import (
    construct_synthetic_cabin_ir, synthetic_cabin_ir_dataset_root_path
)
from modules.utils.convention import get_saved_model_path
from modules.experiments.experiment import Experiment

pl.seed_everything(1234)

viewpoint = 'A_Pillar_Codriver'

constructed_loader = construct_synthetic_cabin_ir(
    dataset_root_path=synthetic_cabin_ir_dataset_root_path,
    viewpoint=viewpoint
)
saved_model_path = get_saved_model_path(
    model_name=LitModel.__name__,
    trained_dataset_name=constructed_loader['dataset_name'],
    trained_datasubset_name=constructed_loader['datasubset_name'],
)
experiment = Experiment(
    LitModel=LitModel,
    constructed_loader=constructed_loader,
    saved_model_path=saved_model_path,
    model_parameters=dict(
        exclude_ankle=True,
        exclude_knee=True
    )
)

experiment.setup()
experiment.train()

[rank: 0] Received SIGTERM: 15


In [None]:
experiment.test()
experiment.print_result()

## Train Samples

In [None]:
from modules.utils.visualization import (
    plot_samples
)
from pathlib import Path

dataset_length = len(experiment.train_loader.dataset)
plot_samples(
    Path(synthetic_cabin_ir_dataset_root_path)/ viewpoint,
    experiment.lit_model,
    experiment.train_loader,
    'train',
    img_figsize=(20, 10),
    plot_figsize=(20.5, 10),
    sample_indices=[
        int(dataset_length * 0.1),
        int(dataset_length * 0.5),
        int(dataset_length * 0.9),
    ],
    is_plot_gt_skeleton=False
)

## Test Samples

In [None]:
from modules.utils.visualization import (
    plot_samples
)
dataset_length = len(experiment.test_loader.dataset)
plot_samples(
    synthetic_cabin_ir_dataset_root_path / viewpoint,
    experiment.lit_model,
    experiment.test_loader,
    'test',
    img_figsize=(20, 10),
    plot_figsize=(20.5, 10),
    sample_indices=[
        int(dataset_length * 0.1),
        int(dataset_length * 0.5),
        int(dataset_length * 0.9),
    ],
    is_plot_gt_skeleton=False
)