In [None]:
%cd ..

In [None]:
import os
import torch
import yaml
from models.sceneRepresentation import Scene
from dataset.dataset import ImageDataset_LagrangianVAE
import matplotlib.pyplot as plt
from util.util import compute_psnr, compute_iou

In [None]:
path_experiments = os.path.join(
    os.path.abspath(''),
    'experiments',
    '2023-01-25',
    'LagrangianVAE_offset'
)

In [None]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


ious = []
path_file = os.path.join(path_experiments, 'results.txt')

# Delete previous result files
if os.path.exists(path_file):
    os.remove(path_file)

# Get all the experiments
for path_experiment in os.scandir(path_experiments):
    # Load Config
    path_conf = os.path.join(path_experiment, '.hydra','config.yaml')
    with open(path_conf) as f:
        cfg = yaml.safe_load(f)

    print("Doing idx: ", cfg['data']['batch_idx'])

    # Load Model
    model = Scene()
    model.add_pendulum(**cfg['ode'], **cfg['scene']['local_representation'])

    path_ckpt = os.path.join(path_experiment, 'ckpt.pth')
    model.load_state_dict(torch.load(path_ckpt))

    model.to(device)

    # Load Data
    path_data = os.path.join(os.path.abspath(''), 'data',cfg['data']['path_data'])
    data = ImageDataset_LagrangianVAE(
        path_data,
        T_pred=cfg['data']['T_pred'],
        batch_idx=cfg['data']['batch_idx'],
        use_high_res=False,
        offset_x=cfg['data']['offset_x'],
        offset_y=cfg['data']['offset_y']
    )
    H, W = data.get_image_dim()

    # Compute PSNR and IoU
    tspan = data.t_steps_eval.to(device)
    model.update_trafo(tspan)
    output = model.render_image(W, H)
    iou = compute_iou(output['Mask'].cpu(), data.get_full_mask())
    ious.append(iou)
    print(f"IoU: {iou}")

    # Write summary file
    with open(path_file, 'a') as f:
        f.write(f"Index: {cfg['data']['batch_idx']}\n")
        f.write(f"IoU: {iou}\n")
        f.write("=============================\n\n")

    print("Done")
    print("====================================================")

avg_iou = torch.mean(torch.tensor(ious))
print("Results:")
print(f"Avg IoU: {avg_iou}")