In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import os

from dice_loss import DiceLoss
import torch.optim as optim
from brain_mri_dataset import BrainMRIDatasetBuilder,BrainMRIDataset

from transforms import BrainMRITransforms

from calculate_iou import calculate_iou


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# batch
batch_size = 64

learning_rate = 0.0003

In [4]:
data_dir = "../datasets/lgg-mri-segmentation/kaggle_3m"

builder = BrainMRIDatasetBuilder(data_dir)
df = builder.create_df()
train_df, val_df, test_df = builder.split_df(df)

transform_ = BrainMRITransforms()

train_data = BrainMRIDataset(train_df, transform = transform_ ,  mask_transform= transform_)
val_data = BrainMRIDataset(val_df, transform = transform_ ,  mask_transform= transform_)
test_data = BrainMRIDataset(test_df, transform = transform_ ,  mask_transform= transform_)

train_dataloader = DataLoader(train_data, batch_size = batch_size , shuffle = True)
val_dataloader = DataLoader(val_data, batch_size = batch_size , shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = batch_size , shuffle = True)


In [5]:
train_df.iloc[1]

images_paths    ../datasets/lgg-mri-segmentation/kaggle_3m/TCG...
masks_paths     ../datasets/lgg-mri-segmentation/kaggle_3m/TCG...
Name: 653, dtype: object

In [6]:
print('Train\t', train_df.shape, '\nVal\t', val_df.shape, '\nTest\t', test_df.shape)

Train	 (3143, 2) 
Val	 (393, 2) 
Test	 (393, 2)


In [7]:
class UNet(nn.Module):
    # def __init__(self, threshold=0.5):
    def __init__(self):
        super(UNet, self).__init__()
        # self.threshold = threshold

        # Encoder
        self.encoder1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Bottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder4 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.upconv3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder2 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.upconv1 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder1 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        # Final conv layer
        self.conv = nn.Conv2d(32, 1, kernel_size=1)  # Output 1 channel for the mask

    def forward(self, x):
        # Encoder
        encoder1 = self.encoder1(x)
        encoder2 = self.encoder2(self.pool1(encoder1))
        encoder3 = self.encoder3(self.pool2(encoder2))
        encoder4 = self.encoder4(self.pool3(encoder3))

        # Bottleneck
        bottleneck = self.bottleneck(self.pool4(encoder4))

        # Decoder & Connections
        x = self.upconv4(bottleneck)
        x = torch.cat([x, encoder4], dim=1)
        x = self.decoder4(x)

        x = self.upconv3(x)
        x = torch.cat([x, encoder3], dim=1)
        x = self.decoder3(x)

        x = self.upconv2(x)
        x = torch.cat([x, encoder2], dim=1)
        x = self.decoder2(x)

        x = self.upconv1(x)
        x = torch.cat([x, encoder1], dim=1)
        x = self.decoder1(x)

        x = self.conv(x)

        # # Thresholding
        # if self.training:
        #     # During training, apply sigmoid activation
        #     x = torch.sigmoid(x)
        # else:
        #     # During inference, apply thresholding
        #     x = torch.where(x >= self.threshold, torch.tensor(1.0).to(x.device), torch.tensor(0.0).to(x.device))


        return x


In [8]:
model = UNet()
model = nn.DataParallel(model)
model = model.to(device)

In [9]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total parameters: {total_params}')
print(f'Trainable parameters: {trainable_params}')

Total parameters: 7765985
Trainable parameters: 7765985


In [10]:
criterion = DiceLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=learning_rate)

In [11]:
checkpoint_dir = "../checkpoints/unet_checkpoints/"

In [12]:
epochs = 100

train_loss = []
val_loss = []
trainIOU = []
valIOU = []

for epoch in range(epochs):
    total_train_loss = 0
    total_val_loss = 0

    # Training mode
    model.train()
    total_train_iou = 0

    for imgs, labels in train_dataloader:
        imgs, labels = imgs.to(device).float(), labels.to(device).float()
        optimizer.zero_grad()

        pred = model(imgs)

        loss = criterion(pred, labels)
        total_train_loss += loss.item()

        iou = calculate_iou(pred, labels)
        total_train_iou += iou.item()

        loss.backward()
        optimizer.step()

    train_iou = total_train_iou / len(train_dataloader)
    trainIOU.append(train_iou)

    train_loss.append(total_train_loss / len(train_dataloader))

    # Validation mode 
    model.eval()
    total_val_iou = 0
    with torch.no_grad():
        for imgs, labels in val_dataloader:
            imgs, labels = imgs.to(device).float(), labels.to(device).float()
            
            pred = model(imgs)
            
            loss = criterion(pred, labels)
            total_val_loss += loss.item()

            iou = calculate_iou(pred, labels)
            total_val_iou += iou.item()

    val_iou = total_val_iou / len(val_dataloader)
    valIOU.append(val_iou)
    total_val_loss = total_val_loss / len(val_dataloader)
    val_loss.append(total_val_loss)
        
    # Print
    print('Epoch: {}/{}, Train Loss: {:.4f}, Val Loss: {:.4f}'.format(epoch + 1, epochs, train_loss[-1], total_val_loss))

    # Save checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'unet_checkpoint_epoch_{epoch+1}.pt')
    torch.save(model.state_dict(), checkpoint_path)

# Assuming your model is named 'model' and you want to save its state_dict
model_state_dict = model.state_dict()

# Specify the file path where you want to save the weights
file_path = 'unet_weights.pth'

# Save the model state_dict to the specified file
torch.save(model_state_dict, file_path)





Epoch: 1/100, Train Loss: 0.9603, Val Loss: 0.9581
Epoch: 2/100, Train Loss: 0.9516, Val Loss: 0.9518
Epoch: 3/100, Train Loss: 0.9462, Val Loss: 0.9417
Epoch: 4/100, Train Loss: 0.9401, Val Loss: 0.9346
Epoch: 5/100, Train Loss: 0.9316, Val Loss: 0.9322
Epoch: 6/100, Train Loss: 0.9199, Val Loss: 0.9304
Epoch: 7/100, Train Loss: 0.9021, Val Loss: 0.8812
Epoch: 8/100, Train Loss: 0.8811, Val Loss: 0.8544
Epoch: 9/100, Train Loss: 0.8435, Val Loss: 0.8308
Epoch: 10/100, Train Loss: 0.7909, Val Loss: 0.7906
Epoch: 11/100, Train Loss: 0.7347, Val Loss: 0.7657
Epoch: 12/100, Train Loss: 0.6672, Val Loss: 0.6180
Epoch: 13/100, Train Loss: 0.5604, Val Loss: 0.4970
Epoch: 14/100, Train Loss: 0.4656, Val Loss: 0.4472
Epoch: 15/100, Train Loss: 0.3899, Val Loss: 0.3567
Epoch: 16/100, Train Loss: 0.3380, Val Loss: 0.3092
Epoch: 17/100, Train Loss: 0.2941, Val Loss: 0.2869
Epoch: 18/100, Train Loss: 0.2701, Val Loss: 0.2699
Epoch: 19/100, Train Loss: 0.2310, Val Loss: 0.2098
Epoch: 20/100, Train 