In [1]:

import torch
import os

import numpy as np

from monai.losses import DiceCELoss

from monai.data import Dataset, DataLoader
from monai.transforms import LoadImage, Resized, Compose, Lambda
from monai.networks.nets import UNETR
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
dataset_list = []

# read all folders in Images

base_path = "./Task01_BrainTumour"


for file in os.listdir(os.path.join(base_path, "imagesTr")):

    data_dict = {
        "img": os.path.join(base_path, "imagesTr", file), 
        "seg": os.path.join(base_path, "labelsTr", file)
    }

    dataset_list.append(data_dict)

In [4]:
class BrainTumor(Dataset):

    def __init__(self, data, transform=None):
        super().__init__(data, transform)
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        element_dict = self.data[index]
        img_name = element_dict["img"]
        seg_name = element_dict["seg"]
        img = LoadImage(image_only=True, ensure_channel_first=True, simple_keys=True)(img_name)
        seg = LoadImage(image_only=True, ensure_channel_first=True, simple_keys=True)(seg_name)
        if self.transform:
            transformed = self.transform({"img": img, "seg": seg} )
            img = transformed["img"]
            seg = transformed["seg"]
            
        seg = torch.tensor(seg, dtype=torch.long)
            
        return img, seg

In [5]:
test_img = LoadImage(image_only=True, ensure_channel_first=True, simple_keys=True)("Task01_BrainTumour/imagesTr/BRATS_001.nii.gz") 

test_seg = LoadImage(image_only=True, ensure_channel_first=True, simple_keys=True)("Task01_BrainTumour/labelsTr/BRATS_001.nii.gz")

print(test_img.shape, test_seg.shape)

torch.Size([4, 240, 240, 155]) torch.Size([1, 240, 240, 155])


In [6]:
# get the number of each value in test_seg
bins, hist = np.unique(test_seg.flatten(), return_counts=True) 
hist

array([8816276,   53050,   27189,   31485])

In [7]:
train_transforms = Compose([
    Lambda(lambda x: {"img": x["img"].permute(0, 3, 1, 2), "seg": x["seg"].permute(0, 3, 1, 2)}),
    Resized(keys=["img", "seg"], spatial_size=(32, 64, 64), mode=('nearest')),
    #ToTensorD(keys=["img", "seg"])
])

In [8]:
dataset = BrainTumor(dataset_list, transform=train_transforms)

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [390, 94])

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

In [9]:
model = UNETR(
    in_channels=4,
    out_channels=4,
    img_size=(32, 64, 64),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
    conv_block=False,
).to(device)

In [10]:
def calculate_iou(predicted, target):
    """
    Calculate Intersection over Union (IoU) for binary segmentation.

    Args:
    - predicted (torch.Tensor): Predicted binary mask (0 or 1).
    - target (torch.Tensor): Ground truth binary mask (0 or 1).

    Returns:
    - float: IoU score.
    """
    intersection = torch.logical_and(predicted, target).sum().item()
    union = torch.logical_or(predicted, target).sum().item()

    iou = intersection / union if union != 0 else 0.0
    return iou

In [14]:
loss_function = DiceCELoss(to_onehot_y=True, softmax=True) #torch.nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), 0.1)

In [15]:
num_epochs = 3
losses = []
test_losses = []
ious = []
test_ious = []

for epoch in range(num_epochs):
    
    epoch_loss = 0
    epoch_test_loss = 0
    epoch_iou = 0
    epoch_test_iou = 0
    
    model.train() 
    
    for i, (img, seg) in enumerate(tqdm(train_dataloader)):
        
        img = img.to(device)
        seg = seg.to(device)
        
        optimizer.zero_grad()
        
        output = model(img)
        
        iou = calculate_iou(output, seg)
        ious.append(iou)
        
        loss = loss_function(output, seg)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        
    losses.append(epoch_loss)
    
    model.eval()
    with torch.no_grad():
        for i, (img, seg) in enumerate(test_dataloader):
            
            img = img.to(device)
            seg = seg.to(device)
             
            output = model(img)
            
            iou = calculate_iou(output, seg)
            test_ious.append(iou)
            
            loss = loss_function(output, seg)
            
            epoch_test_loss += loss.item()
            
        test_losses.append(epoch_test_loss)
        
        print(f"Epoch: {epoch} | Train Loss: {epoch_loss / len(train_dataloader)} | Test Loss: {epoch_test_loss / len(test_dataloader)} | Train IoU: {ious[-1]} | Test IoU: {test_ious[-1]}")

  0%|          | 0/13 [00:00<?, ?it/s]

To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


Epoch: 0 | Train Loss: 1.0310037319476788 | Test Loss: 0.7589441339174906 | Train IoU: 0.008731842041015625 | Test IoU: 0.013091023763020833


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch: 1 | Train Loss: 0.7075036122248723 | Test Loss: 0.6853947838147482 | Train IoU: 0.012828826904296875 | Test IoU: 0.013091023763020833


  0%|          | 0/13 [00:00<?, ?it/s]

Epoch: 2 | Train Loss: 0.6648460397353539 | Test Loss: 0.6475573778152466 | Train IoU: 0.008544921875 | Test IoU: 0.013091023763020833


In [None]:
data = next(iter(train_dataloader))

data[0].shape

In [None]:
np.unique(data[1][1], return_counts=True)

In [None]:
plt.plot(losses)