In [1]:
val_files = [{"image": ["Data_BRATS/BraTS2021_00000/BraTS2021_00000_t1ce.nii.gz", 
                        "Data_BRATS/BraTS2021_00000/BraTS2021_00000_t1.nii.gz", 
                        "Data_BRATS/BraTS2021_00000/BraTS2021_00000_t2.nii.gz", 
                        "Data_BRATS/BraTS2021_00000/BraTS2021_00000_flair.nii.gz"], 
            "label": "Data_BRATS/BraTS2021_00000/BraTS2021_00000_seg.nii.gz"}]

In [2]:
from monai.utils import set_determinism
from monai.transforms import (
    Compose,
    LoadImaged,
    ConvertToMultiChannelBasedOnBratsClassesd,
    RandSpatialCropd,
    RandFlipd,
    MapTransform,
    NormalizeIntensityd, 
    RandScaleIntensityd,
    RandShiftIntensityd,
    ToTensord,
    CenterSpatialCropd,
)
from monai.data import DataLoader, Dataset
import numpy as np
import json
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes:
    label 1 is the necrotic and non-enhancing tumor core
    label 2 is the peritumoral edema
    label 4 is the GD-enhancing tumor
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).

    """
 
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            # merge label 1 and label 4 to construct TC
            result.append(np.logical_or(d[key] == 1, d[key] == 4))
            # merge labels 1, 2 and 4 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 1, d[key] == 4), d[key] == 2
                )
            )
            # label 4 is ET
            result.append(d[key] == 4)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d
val_transform = Compose(
        [
            LoadImaged(keys=["image", "label"]),
            ConvertToMultiChannelBasedOnBratsClassesd(
                keys = ['label']),
            CenterSpatialCropd(keys=["image", "label"],
                            roi_size = [128,128,128], 
                            ),
            NormalizeIntensityd(keys = "image",
                               nonzero = True,
                               channel_wise = True),
            ToTensord(keys=["image", "label"]),
        ]
    )
val_ds = Dataset(data=val_files, transform=val_transform)


In [3]:
val_input = val_ds[0]["image"].unsqueeze(0)
val_input.shape

torch.Size([1, 4, 128, 128, 128])

In [4]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
import pytorch_lightning as pl
from trainer import BRATS
import os 
from pytorch_lightning.loggers import TensorBoardLogger
import argparse
import torch

model = BRATS(return_attn = True)
weights = torch.load('Epoch 89-MeanDiceScore0.8897.ckpt', map_location='cpu')
model.load_state_dict(weights['state_dict'], strict = True)
out, attn = model(val_input)
print(out.shape, len(attn))


4 1 2


`threshold_values=True/False` is deprecated, please use `threshold=value` instead.


torch.Size([1, 262144, 64]) 64 64 64
torch.Size([1, 64, 64, 64, 64])
torch.Size([1, 512, 64])
torch.Size([1, 1, 262144, 64]) torch.Size([1, 1, 512, 64]) torch.Size([1, 1, 512, 64])
torch.Size([1, 64, 64, 64, 64])
torch.Size([1, 512, 64])
torch.Size([1, 1, 262144, 64]) torch.Size([1, 1, 512, 64]) torch.Size([1, 1, 512, 64])
torch.Size([1, 32768, 128]) 32 32 32
torch.Size([1, 128, 32, 32, 32])
torch.Size([1, 512, 128])
torch.Size([1, 2, 32768, 64]) torch.Size([1, 2, 512, 64]) torch.Size([1, 2, 512, 64])
torch.Size([1, 128, 32, 32, 32])
torch.Size([1, 512, 128])
torch.Size([1, 2, 32768, 64]) torch.Size([1, 2, 512, 64]) torch.Size([1, 2, 512, 64])
torch.Size([1, 4096, 320]) 16 16 16
torch.Size([1, 320, 16, 16, 16])
torch.Size([1, 512, 320])
torch.Size([1, 5, 4096, 64]) torch.Size([1, 5, 512, 64]) torch.Size([1, 5, 512, 64])
torch.Size([1, 320, 16, 16, 16])
torch.Size([1, 512, 320])
torch.Size([1, 5, 4096, 64]) torch.Size([1, 5, 512, 64]) torch.Size([1, 5, 512, 64])
torch.Size([1, 512, 512]

In [28]:
for x in attn:
    print(x.shape)

torch.Size([1, 1, 262144, 512])
torch.Size([1, 2, 32768, 512])
torch.Size([1, 5, 4096, 512])
torch.Size([1, 8, 512, 512])


In [31]:
attentions = attn[0]
nh = attentions.shape[1] # number of head
print(attentions.shape)
# we keep only the output patch attention
attentions = attentions[0, :, :, :].reshape(nh, -1)
print(attentions.shape)

torch.Size([1, 1, 262144, 512])
torch.Size([1, 134217728])


In [33]:
import torch.nn as nn
attentions = attentions.reshape(nh, 64, 64, 64, 512)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=2, mode="nearest")[0].cpu().numpy()

# save attentions heatmaps
print(attentions.shape)

NotImplementedError: Input Error: Only 3D, 4D and 5D input Tensors supported (got 6D) for the modes: nearest | linear | bilinear | bicubic | trilinear | area | nearest-exact (got nearest)