In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader

import os

from dice_loss import DiceLoss
import torch.optim as optim
from brain_mri_dataset import BrainMRIDatasetBuilder,BrainMRIDataset

from transforms import BrainMRITransforms

from calculate_iou import calculate_iou

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# batch
batch_size = 64

learning_rate = 0.0003

In [4]:
data_dir = "../datasets/lgg-mri-segmentation/kaggle_3m"

builder = BrainMRIDatasetBuilder(data_dir)
df = builder.create_df()
train_df, val_df, test_df = builder.split_df(df)

transform_ = BrainMRITransforms()

train_data = BrainMRIDataset(train_df, transform = transform_ ,  mask_transform= transform_)
val_data = BrainMRIDataset(val_df, transform = transform_ ,  mask_transform= transform_)
test_data = BrainMRIDataset(test_df, transform = transform_ ,  mask_transform= transform_)

train_dataloader = DataLoader(train_data, batch_size = batch_size , shuffle = True)
val_dataloader = DataLoader(val_data, batch_size = batch_size , shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = batch_size , shuffle = True)


In [5]:
# ResNet18

class ResNet18Seg(nn.Module):
    def __init__(self, num_classes=1):
        super(ResNet18Seg, self).__init__()
        resnet = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet.children())[:-2])  # Remove the final fully connected layer and avgpool layer
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),  # Upsample to 8x8 spatial size
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Upsample to 16x16 spatial size
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # Upsample to 32x32 spatial size
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, num_classes, kernel_size=4, stride=2, padding=1),  # Upsample to 64x64 spatial size (original input size)
            nn.Upsample(size=(256, 256), mode='bilinear', align_corners=False)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [6]:
# Example usage
model = ResNet18Seg()
input_tensor = torch.randn(1, 3, 256, 256)  # Example input tensor with size 256x256
output = model(input_tensor)
print(output.shape)  # Check the shape of the output



torch.Size([1, 1, 256, 256])


In [7]:
model = nn.DataParallel(ResNet18Seg()).to(device)

In [8]:
print(model)

DataParallel(
  (module): ResNet18Seg(
    (encoder): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, e

In [9]:
criterion = DiceLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=learning_rate)

In [10]:
checkpoint_dir = "../checkpoints/resnet18_checkpoints/"

In [11]:
epochs = 100

train_loss = []
val_loss = []

for epoch in range(epochs):
    total_train_loss = 0
    total_val_loss = 0

    # Training mode
    model.train()
    total_train_iou = 0

    for imgs, labels in train_dataloader:
        imgs, labels = imgs.to(device).float(), labels.to(device).float()
        optimizer.zero_grad()

        pred = model(imgs)

        loss = criterion(pred, labels)
        total_train_loss += loss.item()

        loss.backward()
        optimizer.step()


    train_loss.append(total_train_loss / len(train_dataloader))

    # Validation mode 
    model.eval()
    total_val_iou = 0
    with torch.no_grad():
        for imgs, labels in val_dataloader:
            imgs, labels = imgs.to(device).float(), labels.to(device).float()
            
            pred = model(imgs)

            loss = criterion(pred, labels)
            total_val_loss += loss.item()

    total_val_loss = total_val_loss / len(val_dataloader)
    val_loss.append(total_val_loss)
        
    # Print
    print('Epoch: {}/{}, Train Loss: {:.4f}, Val Loss: {:.4f}'.format(epoch + 1, epochs, train_loss[-1], total_val_loss))

    # Save checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'resnet18_checkpoint_epoch_{epoch+1}.pt')
    torch.save(model.state_dict(), checkpoint_path)

# Assuming your model is named 'model' and you want to save its state_dict
model_state_dict = model.state_dict()

# Specify the file path where you want to save the weights
file_path = 'resnet18_weights.pth'

# Save the model state_dict to the specified file
torch.save(model_state_dict, file_path)



Epoch: 1/100, Train Loss: 0.6233, Val Loss: 0.4284
Epoch: 2/100, Train Loss: 0.2935, Val Loss: 0.2177
Epoch: 3/100, Train Loss: 0.2099, Val Loss: 0.1997
Epoch: 4/100, Train Loss: 0.1835, Val Loss: 0.1771
Epoch: 5/100, Train Loss: 0.1674, Val Loss: 0.1780
Epoch: 6/100, Train Loss: 0.1506, Val Loss: 0.1672
Epoch: 7/100, Train Loss: 0.1476, Val Loss: 0.1617
Epoch: 8/100, Train Loss: 0.1621, Val Loss: 0.1701
Epoch: 9/100, Train Loss: 0.1400, Val Loss: 0.1502
Epoch: 10/100, Train Loss: 0.1298, Val Loss: 0.1558
Epoch: 11/100, Train Loss: 0.1275, Val Loss: 0.1461
Epoch: 12/100, Train Loss: 0.1203, Val Loss: 0.1421
Epoch: 13/100, Train Loss: 0.1153, Val Loss: 0.1427
Epoch: 14/100, Train Loss: 0.1132, Val Loss: 0.1910
Epoch: 15/100, Train Loss: 0.1094, Val Loss: 0.1418
Epoch: 16/100, Train Loss: 0.1076, Val Loss: 0.1199
Epoch: 17/100, Train Loss: 0.1058, Val Loss: 0.1546
Epoch: 18/100, Train Loss: 0.1058, Val Loss: 0.1486
Epoch: 19/100, Train Loss: 0.1038, Val Loss: 0.1283
Epoch: 20/100, Train 

: 