In [1]:
import torch
import yaml
import numpy as np
import cv2
import os
import sys

PROJECT_ROOT = "/blue/weishao/chojnowski.h/project_i"
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from model.utils import (
    reconstruct_removed_hw,
    reconstruct_angle_linear,
    reconstruct_angle_sr,
)
from model.model import ProjectI
from model.reconstruction import extract_slices, reconstruct_volume, save_volume

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"--------------------------------")
print(f"Loading config file...")

with open("../configs/model/model_config.yaml") as f:
    model_cfg = yaml.safe_load(f)

with open(f"../configs/model/{model_cfg['encoder_type']}_config.yaml") as f:
    encoder_cfg = yaml.safe_load(f)

with open("../configs/model/decoder_config.yaml") as f:
    decoder_cfg = yaml.safe_load(f)

with open("../configs/inference/inference_config.yaml") as f:
    inference_cfg = yaml.safe_load(f)

# paths
model_path = inference_cfg["model_path"]
image_path = inference_cfg["image_path"]
save_path = inference_cfg["save_path"]
patient_name = image_path.split("/")[-1]

# data params
patch_size = inference_cfg["patch_size"]
stride = inference_cfg["stride"]

# reconstruction params
target_step_deg = inference_cfg["target_step_deg"]
arc_deg = inference_cfg["arc_deg"]
stride_deg = inference_cfg["stride_deg"]

# augmentation params
gaussian_noise_mean = inference_cfg.get("gaussian_noise_mean", 0.0)
gaussian_noise_std = inference_cfg.get("gaussian_noise_std", 0.0)
gaussian_noise_scale = inference_cfg.get("gaussian_noise_scale", 1.0)

zero_one = inference_cfg.get("zero_one", False)
angle_norm = inference_cfg.get("angle_norm", 3)

just_sr = inference_cfg.get("just_sr", False)

print(f"Config file loaded successfully")
print(f"--------------------------------")
print(f"Building model...")

model = ProjectI(
    embd_dim=model_cfg["embd_dim"],
    encoder_type=model_cfg["encoder_type"],
    encoder_config=encoder_cfg,
    decoder_config=decoder_cfg,
).to(device)

model.load_state_dict(torch.load("../" + model_path)["model_state"])

print(f"Model built successfully")
print(f"--------------------------------")
print(f"Loading data...")

Using device: cuda
--------------------------------
Loading config file...
Config file loaded successfully
--------------------------------
Building model...
Model built successfully
--------------------------------
Loading data...


In [44]:
data_test_path = "/home/chojnowski.h/weishao/chojnowski.h/project_i/data/datasets/val/UF015_patch_009_rev.npz"

from data.data import _rand_crop2d_content_aware

with np.load(data_test_path, allow_pickle=False) as f:
    slices = f["slices"]
    angles = f["angles"]

min_angle = angles.min()
max_angle = min_angle + arc_deg

r_idx = torch.arange(start=0,end=len(angles), step=3)
other_idx = torch.tensor([i for i in range(len(angles)) if i not in r_idx])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

slices = torch.from_numpy(slices).to(device)
angles = torch.from_numpy(angles).to(device)

slices = (slices - slices.min()) / (slices.max() - slices.min()) * 2 - 1  # normalize to [-1,1]

angles_removed = angles[r_idx]
slices_removed = slices[r_idx]

angles_removed = angles_removed.to(device)
slices_removed = slices_removed.to(device)

h,w = _rand_crop2d_content_aware(slices_removed, 64, 64, is_train=False)

slices_removed = slices_removed[:, h : h + 64, w : w + 64]
slices = slices[:, h : h + 64, w : w + 64]

In [45]:
from model.utils import PSNR, SSIM_slicewise, angle_aware_resample


angles = (angles - angles.min()) / (20 - angles.min()) * (np.pi / 1.5) - (np.pi / 3)
angles_removed = (angles_removed - angles.min()) / (20 - angles.min()) * (np.pi / 1.5) - (np.pi / 3)
print(angles.shape, slices.shape)
print(angles_removed.shape, slices_removed.shape)

with torch.no_grad():
    out, _ = model(slices_removed.to(device), angles_removed.to(device), angles.to(device), is_train=False)
    psnr_model = PSNR(out.cpu(), slices.cpu()) 
    ssim_model = SSIM_slicewise(out.cpu(), slices.cpu())
    print(f"Model PSNR: {psnr_model}, SSIM: {ssim_model}")

    slices_removed = slices_removed.unsqueeze(1)
    out_linear = angle_aware_resample(angles_removed, slices_removed, angles, radians=False).squeeze(1)
    print(out_linear.shape)
    psnr_linear = PSNR(out_linear.cpu(), slices.cpu())
    ssim_linear = SSIM_slicewise(out_linear.cpu(), slices.cpu())
    print(f"Linear PSNR: {psnr_linear}, SSIM: {ssim_linear}")

torch.Size([30]) torch.Size([30, 64, 64])
torch.Size([10]) torch.Size([10, 64, 64])
Model PSNR: 13.82371711730957, SSIM: 0.09849309176206589
torch.Size([30, 64, 64])
Linear PSNR: 13.68072509765625, SSIM: 0.0979984775185585


In [None]:
pred_uniform, grid_deg = reconstruct_angle_sr(
    model,
    slices_THW=slices[idx],  # (T,H,W)
    angles_deg_T=angles[idx],  # (T,)
    arc_deg=20,
    stride_deg=10,
    target_step_deg=0.1,
    patch_size=patch_size,
    stride_hw=stride,
    reconstruct_removed_hw_fn=reconstruct_removed_hw,
)
import SimpleITK as sitk
pred_uniform = (pred_uniform.cpu()*255).numpy().astype(np.uint8)
pred_vol = sitk.GetImageFromArray(pred_uniform)
pred_vol.SetSpacing((1.0, 1.0, 1.0))
sitk.WriteImage(pred_vol, "/home/chojnowski.h/weishao/chojnowski.h/project_i/data/inference/patches/test.nii.gz")