In [None]:
from main import load_subvolume_pairs_with_positions, reconstruct_full_volume_from_subvolumes
from dqn import PerPixelCNNWithHistory
import torch
import os

data_dir = "../../data/rapids-p/32x32x32"
val_vols, val_masks, val_positions = load_subvolume_pairs_with_positions(
        os.path.join(data_dir, "val"),
        min_fg_ratio=0, 
        verbose = False
    )

policy_net = PerPixelCNNWithHistory(
        input_channels=1,
        history_len=5,
        future_len = 5,
        height=32,
        width=32,
        small_input=True
    ).to("cuda")

policy_net.load_state_dict(torch.load("results/rapids-p/subvolumes/32x32x32/cont_0.0_grad_0.0_manh_0.5/best_model.pth"))

val_pred = reconstruct_full_volume_from_subvolumes(
                    policy_net,
                    val_vols,
                    val_positions,
                    device="cuda",
                    continuity_coef=0,
                    continuity_decay_factor=0.5,
                    dice_coef=1.
                )



Loaded: week3_d00_h00_w00.npy at position (0, 0, 0) (fg=0.0000%)
Loaded: week3_d00_h00_w01.npy at position (0, 0, 1) (fg=0.0000%)
Loaded: week3_d00_h00_w02.npy at position (0, 0, 2) (fg=0.0000%)
Loaded: week3_d00_h00_w03.npy at position (0, 0, 3) (fg=0.0000%)
Loaded: week3_d00_h00_w04.npy at position (0, 0, 4) (fg=0.0000%)
Loaded: week3_d00_h00_w05.npy at position (0, 0, 5) (fg=0.0000%)
Loaded: week3_d00_h00_w06.npy at position (0, 0, 6) (fg=0.0000%)
Loaded: week3_d00_h00_w07.npy at position (0, 0, 7) (fg=0.0000%)
Loaded: week3_d00_h00_w08.npy at position (0, 0, 8) (fg=0.1251%)
Loaded: week3_d00_h00_w09.npy at position (0, 0, 9) (fg=0.0000%)
Loaded: week3_d00_h00_w10.npy at position (0, 0, 10) (fg=0.0000%)
Loaded: week3_d00_h00_w11.npy at position (0, 0, 11) (fg=0.0000%)
Loaded: week3_d00_h00_w12.npy at position (0, 0, 12) (fg=0.0000%)
Loaded: week3_d00_h00_w13.npy at position (0, 0, 13) (fg=0.0000%)
Loaded: week3_d00_h01_w00.npy at position (0, 1, 0) (fg=0.0000%)
Loaded: week3_d00_h01

KeyboardInterrupt: 

In [3]:
import nibabel as nib

SUBVOL_SHAPE = (32, 32, 32)

mask = nib.load("/home/sysadmin/thesis/data/rapids-p/week3-joint-root-class.nii.gz")

D, H, W = mask.shape
crop_d = (D // SUBVOL_SHAPE[0]) * SUBVOL_SHAPE[0]
crop_h = (H // SUBVOL_SHAPE[1]) * SUBVOL_SHAPE[1]
crop_w = (W // SUBVOL_SHAPE[2]) * SUBVOL_SHAPE[2]
mask_cropped = mask.get_fdata()[:crop_d, :crop_h, :crop_w]

In [4]:
from main_global_dice import compute_global_metrics

res = compute_global_metrics(val_pred, mask_cropped)
print(res)

{'dice': 0.49176934450466186, 'iou': 0.32605712044962604, 'precision': 0.3965661142243686, 'recall': 0.6471236785783074, 'accuracy': 0.9992293128188775, 'tp': 59868, 'fp': 91098, 'fn': 32646}


In [5]:
import numpy as np

full_pred_padded = np.zeros(mask.get_fdata().shape, dtype=np.uint8)
d, h, w = val_pred.shape
full_pred_padded[:d, :h, :w] = val_pred
best_pred = full_pred_padded

nib.save(nib.Nifti1Image(best_pred.astype(np.uint8), affine=mask.affine), 
         os.path.join("results/rapids-p/subvolumes/32x32x32/week2+3_dice/0_all_with_future","val_prediction_on_whole_vol.nii.gz"))