In [None]:
import ipyvolume as ipv
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

from XrayTo3DShape import (
    AttentionUnet,
    VolumeAsInputExperiment,
    get_dataset,
    get_model,
    get_transform_from_model_name,
    modify_checkpoint_keys,
)

In [None]:
# load model checkpoint
model_path = "../runs/2d-3d-benchmark/yiw2kgep/checkpoints/epoch=60-step=2867-val_dice=0.85.ckpt"
checkpoint = torch.load(model_path)
checkpoint = modify_checkpoint_keys(checkpoint)
model_architecture = get_model(model_name=AttentionUnet.__name__, image_size=128)
model_architecture.load_state_dict(checkpoint["state_dict"], strict=False)

In [None]:
# load sample data
test_datapath = "../test_data/totalsegmentation_hips_test.csv"
test_transform = get_transform_from_model_name(
    AttentionUnet.__name__, image_size=128, resolution=2.25
)

test_loader = DataLoader(
    get_dataset(test_datapath, transforms=test_transform),
    batch_size=1,
    num_workers=1,
    shuffle=False,
    drop_last=False,
)

In [None]:
expt = VolumeAsInputExperiment(model_architecture)
for idx, item in enumerate(test_loader):
    input, seg = expt.get_input_output_from_batch(item)
    out = expt.predict_step(item, idx)
    pred, gt = out["pred"], out["gt"]

    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(input[0][0, 0, 0, :, :], cmap="gray")
    ax[1].imshow(input[0][0, 1, :, :, 0], cmap="gray")

In [None]:
fig = ipv.figure()
vol = ipv.volshow(pred[0, 0])
ipv.show()