In [2]:
import argparse
from functools import partial
from typing import Dict, Optional, Union

import torch
from monai.data import decollate_batch
from monai.inferers import sliding_window_inference
from monai.losses.dice import DiceLoss
from monai.metrics import DiceMetric
from monai.transforms import Activations, AsDiscrete
from torch.utils.data import DataLoader

from bts.common import logutils, miscutils
from bts.common.miscutils import DotConfig
from bts.data.dataset import get_train_dataset, get_val_dataset
from bts.swinunetr import model as smodel

In [3]:
logger = logutils.get_logger(__name__)

In [4]:
hyperparams = miscutils.load_hyperparameters("../bts/swinunetr/hyperparameters.yaml")
hyperparams

[:home:vedatb:senior-project:bbm47980_bts:.venv:lib:python3.9:site-packages:miscutils.py:load_hyperparameters:55] Hyperparameters loaded: {'BATCH_SIZE': 1, 'SHUFFLE': True, 'ROI': [128, 128, 128], 'IN_CHANNELS': 1, 'OUT_CHANNELS': 2, 'FEATURE_SIZE': 48, 'GRADIENT_CHECKPOINT': True, 'EPOCHS': 100, 'FOLD': 1, 'SW_BATCH_SIZE': 2, 'LEARNING_RATE': 0.0001, 'INFER_OVERLAP': 0.6, 'WEIGHT_DECAY': 1e-05, 'DEVICE': 'cuda', 'LABELS': {'GROUND': 0, 'BRAIN': 1, 'TUMOR': 2}, 'DATA_DIR': '../../data'}


DotConfig({'BATCH_SIZE': 1, 'SHUFFLE': True, 'ROI': [128, 128, 128], 'IN_CHANNELS': 1, 'OUT_CHANNELS': 2, 'FEATURE_SIZE': 48, 'GRADIENT_CHECKPOINT': True, 'EPOCHS': 100, 'FOLD': 1, 'SW_BATCH_SIZE': 2, 'LEARNING_RATE': 0.0001, 'INFER_OVERLAP': 0.6, 'WEIGHT_DECAY': 1e-05, 'DEVICE': 'cuda', 'LABELS': DotConfig({'GROUND': 0, 'BRAIN': 1, 'TUMOR': 2}), 'DATA_DIR': '../../data'})

In [5]:
hyperparams.BATCH_SIZE = 2

In [6]:
smodel.set_cudnn_benchmark()

[:home:vedatb:senior-project:bbm47980_bts:.venv:lib:python3.9:site-packages:model.py:set_cudnn_benchmark:45] Enabling cuDNN benchmark.


In [7]:
model = smodel.get_model(
    img_size=hyperparams.ROI,
    in_channels=hyperparams.IN_CHANNELS,
    out_channels=hyperparams.OUT_CHANNELS,
    feature_size=hyperparams.FEATURE_SIZE,
    use_checkpoint=hyperparams.GRADIENT_CHECKPOINT,
)

[:home:vedatb:senior-project:bbm47980_bts:.venv:lib:python3.9:site-packages:model.py:get_model:38] SwinUNETR model created.


In [8]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"

In [9]:
model = torch.nn.DataParallel(model)

In [10]:
dice_loss = DiceLoss(to_onehot_y=False, sigmoid=True)

In [11]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=hyperparams.LEARNING_RATE,
    weight_decay=hyperparams.WEIGHT_DECAY,
)

In [12]:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=hyperparams.EPOCHS
)

In [13]:
train_dataset = get_train_dataset("../../data/btsed_dataset")
train_loader = DataLoader(train_dataset, batch_size=hyperparams.BATCH_SIZE)

val_dataset = get_val_dataset("../../data/btsed_dataset")
val_loader = DataLoader(val_dataset, batch_size=hyperparams.BATCH_SIZE)

In [14]:
data_paths = [
    {
        'img': '/home/vedatb/senior-project/data/btsed_dataset/0/image.nrrd',
        'label': '/home/vedatb/senior-project/data/btsed_dataset/0/label.nrrd',
        'info': '/home/vedatb/senior-project/data/btsed_dataset/0/info.json',
    },
    {
        'img': '/home/vedatb/senior-project/data/btsed_dataset/1/image.nrrd',
        'label': '/home/vedatb/senior-project/data/btsed_dataset/1/label.nrrd',
        'info': '/home/vedatb/senior-project/data/btsed_dataset/1/info.json',
    }
]

In [15]:
from monai.transforms import (
    Compose,
    LoadImaged,
    MapTransform,
    NormalizeIntensityd,
    Transform,
    RandSpatialCropd,
    RandFlipd,
)

from bts.data import ConvertToMultiChannelBasedOnEchidnaClassesd, JsonTransform
from bts.data.utils import UnsqueezeDatad

from torch.utils.data import DataLoader

import monai

In [16]:
def collate_fn(batch):
    elem = batch[0]
    coll = {}
    
    for key in elem:
        if key in ["img", "label"]:
            data_for_batch = tuple(torch.as_tensor(d[key]) for d in batch)
            coll[key] = torch.stack(data_for_batch, dim=0)
        else:
            coll[key] = torch.stack(data_for_batch, dim=0)
    
    return coll

In [17]:
transforms = Compose([
                LoadImaged(["img", "label"], reader="NrrdReader"),
                UnsqueezeDatad(["img"]),
                ConvertToMultiChannelBasedOnEchidnaClassesd(["label"]),
                JsonTransform(["info"]),
                RandSpatialCropd(
                     keys=["img", "label"],
                     roi_size=[128, 128, 128],
                     random_size=False,
                ),
                RandFlipd(keys=["img", "label"], prob=0.5, spatial_axis=0),
                RandFlipd(keys=["img", "label"], prob=0.5, spatial_axis=1),
                RandFlipd(keys=["img", "label"], prob=0.5, spatial_axis=2),
                NormalizeIntensityd(keys="img", nonzero=True, channel_wise=True),
            ])

data_set = monai.data.Dataset(data=data_paths, transform=transforms)
data_loader = DataLoader(data_set, batch_size=2,collate_fn=collate_fn)

In [18]:
sample = next(iter(data_loader))

RuntimeError: applying transform <monai.transforms.compose.Compose object at 0x7f3a1c741ac0>

In [None]:
sample["img"].shape

In [None]:
sample["label"].shape

In [None]:
import matplotlib.pyplot as plt

slice_nums = [50, 100]

img = sample["img"][0][0]

label = (
    (sample["label"][0][0] == 1) * hyperparams.LABELS.BRAIN +
    (sample["label"][0][1] == 1) * hyperparams.LABELS.TUMOR
)

print(f"image shape: {img.shape}, label shape: {label.shape}")
fig, axs = plt.subplots(2, 2, figsize=(9, 9))

slice_num = slice_nums[0]
axs[0, 0].set_title(f"image: {slice_num}")
axs[0, 0].imshow(img[:, :, slice_num], cmap="gray")

axs[0, 1].set_title(f"label: {slice_num}")
axs[0, 1].imshow(label[:, :, slice_num])

slice_num = slice_nums[1]
axs[1, 0].set_title(f"image: {slice_num}")
axs[1, 0].imshow(img[:, :, slice_num], cmap="gray")


axs[1, 1].set_title(f"label: {slice_num}")
axs[1, 1].imshow(label[:, :, slice_num])

fig.tight_layout()
plt.show()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

In [None]:
model = model.to(device)
model = model.train()

In [None]:
train_loss = miscutils.AverageMeter()

In [None]:
image = sample["img"].to("cuda")
print(image.shape)

label = sample["label"].to("cuda")
print(label.shape)

In [None]:
optimizer.zero_grad()

In [None]:
logits = model(image)

In [None]:
logits.shape

In [None]:
sample["label"].shape

In [None]:
loss: torch.Tensor = dice_loss(logits, label)

In [None]:
loss

In [None]:
post_pred = AsDiscrete(threshold=0.5, dtype="bool")
post_sigmoid = Activations(sigmoid=True)

In [None]:
with torch.no_grad():
    sig = post_pred(post_sigmoid(logits))

In [None]:
from monai.utils.enums import MetricReduction

dice_acc = DiceMetric(include_background=True, reduction=MetricReduction.MEAN_BATCH, get_not_nans=True)

In [None]:
dice_acc.reset()
acc_out = dice_acc(y=sample["label"].to(device), y_pred=sig)

In [None]:
acc_out

In [None]:
from monai.data import decollate_batch

In [None]:
acc, not_nans = dice_acc.aggregate()

In [None]:
acc, not_nans

In [None]:
run_acc = miscutils.AverageMeter()

In [None]:
run_acc.update(acc.cpu().numpy(), n=not_nans.cpu().numpy())

In [None]:
run_acc.avg[0]

In [None]:
run_acc.avg[1]

In [None]:
run_acc.avg[2]

In [None]:
!nvidia-smi