In [1]:
from main import load_subvolume_pairs_with_positions, reconstruct_full_volume_from_subvolumes
from dqn import PerPixelCNNWithHistory
import torch
import os

data_dir = "../../data/rapids-p/32x32x32"
val_vols, val_masks, val_positions = load_subvolume_pairs_with_positions(
        os.path.join(data_dir, "val"),
        min_fg_ratio=0, 
        verbose = False
    )

policy_net = PerPixelCNNWithHistory(
        input_channels=1,
        history_len=5,
        future_len = 5,
        height=32,
        width=32,
        small_input=True
    ).to("cuda")

policy_net.load_state_dict(torch.load("results/rapids-p/subvolumes/32x32x32/cont_0.0_grad_0.0_manh_0.5/best_model.pth"))

val_pred = reconstruct_full_volume_from_subvolumes(
                    policy_net,
                    val_vols,
                    val_positions,
                    device="cuda",
                    continuity_coef=0,
                    continuity_decay_factor=0.5,
                    dice_coef=1.
                )




Summary: Loaded 4900 subvolumes


In [2]:
import nibabel as nib

SUBVOL_SHAPE = (32, 32, 32)

mask = nib.load("/home/sysadmin/thesis/data/rapids-p/week3-joint-root-class.nii.gz")

D, H, W = mask.shape
crop_d = (D // SUBVOL_SHAPE[0]) * SUBVOL_SHAPE[0]
crop_h = (H // SUBVOL_SHAPE[1]) * SUBVOL_SHAPE[1]
crop_w = (W // SUBVOL_SHAPE[2]) * SUBVOL_SHAPE[2]
mask_cropped = mask.get_fdata()[:crop_d, :crop_h, :crop_w]

In [3]:
from main import compute_global_metrics

res = compute_global_metrics(val_pred, mask_cropped)
print(res)

{'dice': 0.5217702638978412, 'iou': 0.3529696711917465, 'precision': 0.43768404705741254, 'recall': 0.6458481959486515, 'accuracy': 0.999317851163903, 'tp': 59750, 'fp': 76764, 'fn': 32764}


In [4]:
import numpy as np

full_pred_padded = np.zeros(mask.get_fdata().shape, dtype=np.uint8)
d, h, w = val_pred.shape
full_pred_padded[:d, :h, :w] = val_pred
best_pred = full_pred_padded

nib.save(nib.Nifti1Image(best_pred.astype(np.uint8), affine=mask.affine), 
         os.path.join("results/rapids-p/subvolumes/32x32x32/cont_0.0_grad_0.0_manh_0.5", "val_prediction_on_whole_vol.nii.gz"))