In [1]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader

import os

from dice_loss import DiceLoss
import torch.optim as optim
from brain_mri_dataset import BrainMRIDatasetBuilder,BrainMRIDataset

from transforms import BrainMRITransforms

from calculate_iou import calculate_iou

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# batch
batch_size = 64

learning_rate = 0.0003

In [4]:
data_dir = "../datasets/lgg-mri-segmentation/kaggle_3m"

builder = BrainMRIDatasetBuilder(data_dir)
df = builder.create_df()
train_df, val_df, test_df = builder.split_df(df)

transform_ = BrainMRITransforms()

train_data = BrainMRIDataset(train_df, transform = transform_ ,  mask_transform= transform_)
val_data = BrainMRIDataset(val_df, transform = transform_ ,  mask_transform= transform_)
test_data = BrainMRIDataset(test_df, transform = transform_ ,  mask_transform= transform_)

train_dataloader = DataLoader(train_data, batch_size = batch_size , shuffle = True)
val_dataloader = DataLoader(val_data, batch_size = batch_size , shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = batch_size , shuffle = True)


In [5]:
# ResNet50

class ResNet50Seg(nn.Module):
    def __init__(self, num_classes=1):
        super(ResNet50Seg, self).__init__()
        resnet = models.resnet50(pretrained=True)
        self.encoder = nn.Sequential(*list(resnet.children())[:-2])  # Remove the final fully connected layer and avgpool layer
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(2048, 512, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, num_classes, kernel_size=4, stride=2, padding=1),
            nn.Upsample(size=(256, 256), mode='bilinear', align_corners=False)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


In [6]:
# Example usage
model = ResNet50Seg()
input_tensor = torch.randn(1, 3, 256, 256)  # Example input tensor with size 256x256
output = model(input_tensor)
print(output.shape)  # Check the shape of the output



torch.Size([1, 1, 256, 256])


In [7]:
model = nn.DataParallel(ResNet50Seg()).to(device)

In [8]:
print(model)

DataParallel(
  (module): ResNet50Seg(
    (encoder): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
  

In [9]:
criterion = DiceLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=learning_rate)

In [10]:
checkpoint_dir = "../checkpoints/resnet50_checkpoints/"

In [11]:
epochs = 100

train_loss = []
val_loss = []

for epoch in range(epochs):
    total_train_loss = 0
    total_val_loss = 0

    # Training mode
    model.train()
    total_train_iou = 0

    for imgs, labels in train_dataloader:
        imgs, labels = imgs.to(device).float(), labels.to(device).float()
        optimizer.zero_grad()

        pred = model(imgs)

        loss = criterion(pred, labels)
        total_train_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_loss.append(total_train_loss / len(train_dataloader))

    # Validation mode 
    model.eval()
    total_val_iou = 0
    with torch.no_grad():
        for imgs, labels in val_dataloader:
            imgs, labels = imgs.to(device).float(), labels.to(device).float()
            
            pred = model(imgs)

            loss = criterion(pred, labels)
            total_val_loss += loss.item()

    total_val_loss = total_val_loss / len(val_dataloader)
    val_loss.append(total_val_loss)
        
    # Print
    print('Epoch: {}/{}, Train Loss: {:.4f}, Val Loss: {:.4f}'.format(epoch + 1, epochs, train_loss[-1], total_val_loss))

    # Save checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'resnet50_checkpoint_epoch_{epoch+1}.pt')
    torch.save(model.state_dict(), checkpoint_path)

# Assuming your model is named 'model' and you want to save its state_dict
model_state_dict = model.state_dict()

# Specify the file path where you want to save the weights
file_path = 'resnet50_weights.pth'

# Save the model state_dict to the specified file
torch.save(model_state_dict, file_path)



Epoch: 1/100, Train Loss: 0.5409, Val Loss: 0.2866
Epoch: 2/100, Train Loss: 0.2377, Val Loss: 0.2017
Epoch: 3/100, Train Loss: 0.1841, Val Loss: 0.1949
Epoch: 4/100, Train Loss: 0.1594, Val Loss: 0.1622
Epoch: 5/100, Train Loss: 0.1419, Val Loss: 0.1503
Epoch: 6/100, Train Loss: 0.1313, Val Loss: 0.1514
Epoch: 7/100, Train Loss: 0.1361, Val Loss: 0.2813
Epoch: 8/100, Train Loss: 0.1514, Val Loss: 0.2004
Epoch: 9/100, Train Loss: 0.1420, Val Loss: 0.1436
Epoch: 10/100, Train Loss: 0.1138, Val Loss: 0.1194
Epoch: 11/100, Train Loss: 0.1056, Val Loss: 0.1284
Epoch: 12/100, Train Loss: 0.0991, Val Loss: 0.1385
Epoch: 13/100, Train Loss: 0.0993, Val Loss: 0.1365
Epoch: 14/100, Train Loss: 0.0954, Val Loss: 0.1560
Epoch: 15/100, Train Loss: 0.0910, Val Loss: 0.1344
Epoch: 16/100, Train Loss: 0.0900, Val Loss: 0.1367
Epoch: 17/100, Train Loss: 0.0904, Val Loss: 0.1377
Epoch: 18/100, Train Loss: 0.0973, Val Loss: 0.1519
Epoch: 19/100, Train Loss: 0.0893, Val Loss: 0.1243
Epoch: 20/100, Train 

: 