In [None]:
import os
import glob
import sys
import json
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [None]:
sys.path.insert(0, '../')
import utils.misc as workspace
from SkelPointNet import SkelPointNet 
from DataUtil import HippocampiProcessedData, LeafletData
import FileRW as rw
import DistFunc as DF

In [None]:
def save_results(log_path, batch_id, input_xyz, skel_xyz, skel_r, label_xyz):
    batch_size = skel_xyz.size()[0]
    batch_id = batch_id.numpy()
    input_xyz_save = input_xyz.detach().cpu().numpy()
    skel_xyz_save = skel_xyz.detach().cpu().numpy()
    skel_r_save = skel_r.detach().cpu().numpy()
    label_xyz_save = label_xyz.detach().cpu().numpy()
    for i in range(batch_size):
        save_name_input = os.path.join(log_path, f"val_{batch_id[i]}_input.ply")
        save_name_sphere = os.path.join(log_path, f"val_{batch_id[i]}_sphere.obj")
        save_name_center = os.path.join(log_path, f"val_{batch_id[i]}_center.ply")
        save_name_label = os.path.join(log_path, f"val_{batch_id[i]}_label.ply")
        rw.save_ply_points(input_xyz_save[i], save_name_input)
        rw.save_spheres(skel_xyz_save[i], skel_r_save[i], save_name_sphere)
        rw.save_ply_points(skel_xyz_save[i], save_name_center)
        rw.save_ply_points(label_xyz_save[i], save_name_label)


In [None]:
EXP_NAME = "gt-full5000-pskel100-finetune_leaf"
experiment_dir = os.path.join("../experiments/", EXP_NAME)
# split_file = 'val_split.txt'
checkpoint = 'latest'

with open(os.path.join(experiment_dir, "specs.json"), "r") as f:
    specs = json.load(f)

In [None]:
point_num = specs["InputPointNum"]
skelpoint_num = specs["SkelPointNum"]
to_normalize = specs["Normalize"]
gpu = "0"
model_skel = SkelPointNet(
    num_skel_points=skelpoint_num, input_channels=0, use_xyz=True
)

if torch.cuda.is_available():
    os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    print("GPU Number:", torch.cuda.device_count(), "GPUs!")
    model_skel.cuda()
    model_skel.eval()

# Load the saved model
model_epoch = workspace.load_model_checkpoint(
    experiment_dir, checkpoint, model_skel
)
print(f"Evaluating model on using checkpoint={checkpoint} and epoch={model_epoch}.")

In [None]:
# load data and evaluate
# Assume Training/Test split file (given as cmd line arg) will be present in the experiment dir
data_dir = "../data/leaflet_sreps/"

# For leaflets
case_dirs = sorted(os.listdir(data_dir))
data_list = [os.path.join(data_dir, case, "warped_template.vtp") for case in case_dirs]
label_list = [os.path.join(data_dir, case, "up_proc.vtp") for case in case_dirs]

idx_end = int(len(data_list) * 0.9)
data_list_eval = data_list[idx_end:]
label_list_eval = label_list[idx_end:]

eval_data = LeafletData(
    data_list_eval, label_list_eval, point_num, load_in_ram=True
)

# eval_data = PCDataset(pc_list, data_dir, point_num, to_normalisze)
data_loader = DataLoader(
    dataset=eval_data, batch_size=1, shuffle=False, drop_last=False
)

eval_save_dir = os.path.join(experiment_dir, workspace.evaluation_subdir, "leaf")
rw.check_and_create_dirs([eval_save_dir])

overall_loss = 0
for _, batch_data in enumerate(tqdm(data_loader)):
    batch_id, batch_pc, batch_label = batch_data
    batch_id = batch_id
    batch_pc = batch_pc.cuda().float()
    with torch.no_grad():
        skel_xyz, skel_r, _ = model_skel(batch_pc, compute_graph=False)
        loss = model_skel.get_sampling_loss(batch_pc, skel_xyz, skel_r)        
        overall_loss += loss.item()
    save_results(eval_save_dir, batch_id, batch_pc, skel_xyz, skel_r, batch_label)
overall_loss /= len(data_loader)

print(overall_loss)
