In [None]:
import os
import glob
import sys
import json
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [None]:
sys.path.insert(0, '../')
import utils.misc as workspace
from SkelPointNet import SkelPointNet 
from DataUtil import LeafletData
import FileRW as rw
import DistFunc as DF
import EvalUtil

In [None]:
EXP_NAME = "gt-full5000-pskel100-finetune_leaf"
experiment_dir = os.path.join("../experiments/", EXP_NAME)
checkpoint = 'latest'
with open(os.path.join(experiment_dir, "specs.json"), "r") as f:
    specs = json.load(f)

In [None]:
point_num = specs["InputPointNum"]
skelpoint_num = specs["SkelPointNum"]
to_normalize = specs["Normalize"]
gpu = "0"
model_skel = SkelPointNet(
    num_skel_points=skelpoint_num, input_channels=0, use_xyz=True
)

if torch.cuda.is_available():
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    print("GPU Number:", torch.cuda.device_count(), "GPUs!")
    model_skel.cuda()
    model_skel.eval()

# Load the saved model
model_epoch = workspace.load_model_checkpoint(
    experiment_dir, checkpoint, model_skel
)
print(f"Evaluating model on using checkpoint={checkpoint} and epoch={model_epoch}.")

In [None]:
# load data and evaluate
# Assume Training/Test split file (given as cmd line arg) will be present in the experiment dir
data_dir = "../data/leaflet_sreps/"

# For leaflets
case_dirs = sorted(os.listdir(data_dir))
data_list = [os.path.join(data_dir, case, "warped_template.vtp") for case in case_dirs]
label_list = [os.path.join(data_dir, case, "up_proc.vtp") for case in case_dirs]

idx_end = int(len(data_list) * 0.9)
data_list_eval = data_list[idx_end:]
label_list_eval = label_list[idx_end:]

eval_data = LeafletData(
    data_list_eval, label_list_eval, point_num, load_in_ram=True
)

EvalUtil.test_results(experiment_dir, eval_data, model_skel)
