In [1]:
!nvidia-smi

Wed Apr 12 11:00:04 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.48.07    Driver Version: 515.48.07    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:0D:00.0 Off |                  N/A |
| 23%   25C    P8     8W / 250W |   2625MiB / 11264MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
workspace_dir = '/nfs/Workspace/CardiacSeg'

import os
import sys
from pathlib import PurePath
from functools import partial

sys.path.append(workspace_dir)

from torchsummaryX import summary

import torch

import pandas as pd


from monai.transforms import (
    Compose,
    AddChanneld,
    EnsureChannelFirstd,
    LoadImaged,
    Orientationd,
    SqueezeDimd,
    LoadImage,
    Spacingd,
    ScaleIntensityRanged,
    RandCropByPosNegLabeld,
    ToNumpyd,
)
from monai.inferers import sliding_window_inference
from monai.visualize import GradCAM
from monailabel.transform.post import Restored


from datasets.chgh_dataset import get_data_dicts

from data_utils.data_loader_utils import load_data_dict_json
from data_utils.dataset import get_infer_data
from data_utils.utils import get_pid_by_file
from data_utils.io import load_json
from data_utils.visualization import show_img_lbl, show_img_lbl_pred, show_img_lbl_preds, show_img_lbl_preds_overlap

from runners.inferer import run_infering
from expers.infer_utils import get_tune_model_dir, get_data_path, get_pred_path
from expers.args import get_parser

from networks.network import network
from networks.networkx.blocks.cbam import CBAM
from networks.networkx.blocks.ham import HAM
from networks.ssl_head import SSLHead

# sync python module
%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model_name = 'unetcnx_x3_2_2_a5' #'unetcnx_x3_2_2'
data_name = 'chgh'
sub_data_dir_name = 'dataset_2'
exp_name = 't_4'
data_dict_file_name = 'exp_2_2.json' #'exp_2_2.json' exp_b7_9

root_exp_dir = os.path.join(
    workspace_dir, 
    'exps',
    'exps',
    model_name,
    data_name,
    'tune_results'
)

root_data_dir = os.path.join(
    workspace_dir, 
    'dataset',
    data_name
)

data_dir = os.path.join(root_data_dir, sub_data_dir_name)

model_dir = get_tune_model_dir(root_exp_dir, exp_name)

best_checkpoint = os.path.join(model_dir, 'best_model.pth')
final_checkpoint = os.path.join(model_dir, 'final_model.pth')

infer_dir = os.path.join(
    workspace_dir, 
    'exps',
    'exps',
    model_name,
    data_name,
    'infers',
    exp_name,
)

print('\nbest model:',best_checkpoint)
print('infer dir:',infer_dir)

pid = 'pid_1000'
data_dict = get_data_path(data_dir, pid)
data_dict['pred'] = get_pred_path(root_exp_dir, exp_name, data_dict['image'])

img_pth = data_dict['image']
lbl_pth = data_dict['label'] 
print(img_pth)
print(lbl_pth)
print(data_dict['pred'])

Loading results from /nfs/Workspace/CardiacSeg/exps/exps/unetcnx_x3_2_2_a5/chgh/tune_results/t_4...
2023-04-12 11:00:11,277 - No `self.trials`. Drawing logdirs from checkpoint file. This may result in some information that is out of sync, as checkpointing is periodic.





Best trial 6d320_00000: 
config: {'exp': {'exp': 'exp_b7_9_x3_2_2_a5'}}
tt_dice: 0.8903141
tt_hd95: 4.954921004130534
best log dir: /nfs/Workspace/CardiacSeg/exps/exps/unetcnx_x3_2_2_a5/chgh/tune_results/t_4/main_6d320_00000_0_exp=exp_exp_b7_9_x3_2_2_a5_2023-03-29_03-54-38

best model: /nfs/Workspace/CardiacSeg/exps/exps/unetcnx_x3_2_2_a5/chgh/tune_results/t_4/main_6d320_00000_0_exp=exp_exp_b7_9_x3_2_2_a5_2023-03-29_03-54-38/models/best_model.pth
infer dir: /nfs/Workspace/CardiacSeg/exps/exps/unetcnx_x3_2_2_a5/chgh/infers/t_4
/nfs/Workspace/CardiacSeg/dataset/chgh/dataset_2/pid_1000/pid_1000.nii.gz
/nfs/Workspace/CardiacSeg/dataset/chgh/dataset_2/pid_1000/pid_1000_gt.nii.gz
/nfs/Workspace/CardiacSeg/exps/exps/unetcnx_x3_2_2_a5/chgh/tune_results/infers/t_4/pid_1000.nii.gz


In [5]:
args = get_parser([])
args.model_name=model_name
args.data_name=data_name
args.data_dir=data_dir
args.model_dir=model_dir
args.infer_dir=infer_dir
args.checkpoint=best_checkpoint

args.out_channels=2 
args.patch_size=4 
args.drop_rate=0.4 
args.depths=[4, 4, 8, 4] 
args.a_min = -42
args.a_max = 423
args.space_x = 0.7
args.space_y = 0.7
args.space_z = 1.0
args.roi_x = 128
args.roi_y = 128
args.roi_z = 128
args.infer_overlap = 0.25
args.sw_batch_size = 2
data_dicts = [{
    'image': img_pth,
    'label': lbl_pth
}]
data_dicts

# device
if torch.cuda.is_available():
    print("cuda is available")
    args.device = torch.device("cuda")
else:
    print("cuda is not available")
    args.device = torch.device("cpu")

# model
model = network(args.model_name, args)

# check point
if args.checkpoint is not None:
    checkpoint = torch.load(args.checkpoint, map_location="cpu")
    # load model
    model.load_state_dict(checkpoint["state_dict"])
    # load check point epoch and best acc

    print(
      "=> loaded checkpoint '{}')"\
      .format(args.checkpoint)
    )

# post transforom
post_transform = Compose([
    Orientationd(keys=['pred'], axcodes="LPS"),
    ToNumpyd(keys=['pred']),
    Restored(keys=['pred'], ref_image="image")
])

# inferer
model_inferer = partial(
    sliding_window_inference,
    roi_size=[args.roi_x, args.roi_y, args.roi_z],
    sw_batch_size=args.sw_batch_size,
    predictor=model,
    overlap=args.infer_overlap,
)

cuda is available
model: unetcnx_x3_2_2_a5
[4, 4, 8, 4]
=> loaded checkpoint '/nfs/Workspace/CardiacSeg/exps/exps/unetcnx_x3_2_2_a5/chgh/tune_results/t_4/main_6d320_00000_0_exp=exp_exp_b7_9_x3_2_2_a5_2023-03-29_03-54-38/models/best_model.pth')


In [5]:
# transforms
transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        AddChanneld(keys=["image", "label"]),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(
            keys=["image", "label"],
            pixdim=(args.space_x, args.space_y, args.space_z),
            mode=("bilinear", "nearest"),
        ),
        ScaleIntensityRanged(
            keys=["image"], a_min=args.a_min, a_max=args.a_max, b_min=args.b_min, b_max=args.b_max, clip=True
        ),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(args.roi_x, args.roi_y, args.roi_z),
            pos=1,
            neg=1,
            num_samples=1,
            image_key="image",
            image_threshold=0,
        )
    ]
)

train_ds = transforms(data_dicts)

<class 'monai.transforms.utility.array.AddChannel'>: Class `AddChannel` has been deprecated since version 0.8. please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.


In [6]:
cam = GradCAM(nn_module=model, target_layers="encoder1")
result = cam(x=torch.rand((1, 1, 96, 96, 96)).cuda())

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 3 for tensor number 1 in the list.

In [7]:
device = args.device
win_size = (args.roi_x, args.roi_y, args.roi_z)

# cam = monai.visualize.CAM(nn_module=model_3d, target_layers="class_layers.relu", fc_layers="class_layers.out")
cam = GradCAM(nn_module=model, target_layers="encoder1")
# cam = monai.visualize.GradCAMpp(nn_module=model_3d, target_layers="class_layers.relu")
print(
    "original feature shape",
    cam.feature_map_size([1, 1] + list(win_size), device),
)
print("upsampled feature shape", [1, 1] + list(win_size))

occ_sens = monai.visualize.OcclusionSensitivity(nn_module=model, mask_size=96, n_batch=1, stride=48)

# For occlusion sensitivity, inference must be run many times. Hence, we can use a
# bounding box to limit it to a 2D plane of interest (z=the_slice) where each of
# the arguments are the min and max for each of the spatial dimensions (in this case HWD).

the_slice = train_ds[0]["image"].shape[-1] // 2
occ_sens_b_box = [-1, -1, -1, -1, the_slice - 1, the_slice]

RuntimeError: CUDA out of memory. Tried to allocate 16384.00 GiB (GPU 0; 10.92 GiB total capacity; 5.70 GiB already allocated; 1.85 GiB free; 5.76 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
train_transforms.set_random_state(42)
n_examples = 5
subplot_shape = [3, n_examples]
fig, axes = plt.subplots(*subplot_shape, figsize=(25, 15), facecolor="white")
items = np.random.choice(len(train_ds), size=len(train_ds), replace=False)

example = 0
for item in items:
    data = train_ds[item]  # this fetches training data with random augmentations
    image, label = data["image"].to(device).unsqueeze(0), data["label"][1]
    y_pred = model_3d(image)
    pred_label = y_pred.argmax(1).item()
    # Only display tumours images
    if label != 1 or label != pred_label:
        continue

    img = image.detach().cpu().numpy()[..., the_slice]

    name = "actual: "
    name += "lesion" if label == 1 else "non-lesion"
    name += "\npred: "
    name += "lesion" if pred_label == 1 else "non-lesion"
    name += f"\nlesion: {y_pred[0,1]:.3}"
    name += f"\nnon-lesion: {y_pred[0,0]:.3}"

    # run CAM
    cam_result = cam(x=image, class_idx=None)
    cam_result = cam_result[..., the_slice]

    # run occlusion
    occ_result, _ = occ_sens(x=image, b_box=occ_sens_b_box)
    occ_result = occ_result[0, pred_label][None, None, ..., -1]

    for row, (im, title) in enumerate(
        zip(
            [img, cam_result, occ_result],
            [name, "CAM", "Occ. sens."],
        )
    ):
        cmap = "gray" if row == 0 else "jet"
        ax = axes[row, example]
        if isinstance(im, torch.Tensor):
            im = im.cpu().detach()
        im_show = ax.imshow(im[0][0], cmap=cmap)

        ax.set_title(title, fontsize=25)
        ax.axis("off")
        fig.colorbar(im_show, ax=ax)

    example += 1
    if example == n_examples:
        break