<a href="https://colab.research.google.com/github/leopomme/Dose-Prediction-Segmentation/blob/main/RadImageNet_V2_DLMI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [67]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import h5py

import cv2
import albumentations as A
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import torch.optim as optim
import matplotlib.pyplot as plt
from albumentations.pytorch import ToTensorV2
from torch.optim.lr_scheduler import StepLR
from torchvision.models.resnet import resnet50, Bottleneck

import time 
from tqdm import tqdm
from glob import glob
import zipfile
import shutil


In [68]:
!git clone https://github.com/soniamartinot/MVA-Dose-Prediction.git

fatal: destination path 'MVA-Dose-Prediction' already exists and is not an empty directory.


In [69]:
transform_img = A.Compose([
    A.OneOf([
        A.GaussNoise(var_limit=(0, 25), p=1),
        A.Blur(blur_limit=(3, 10), p=1),
    ], p=0.5),
])

transform_img_mask = A.Compose([
    A.OneOf([
        A.ElasticTransform(alpha=14, sigma=5, alpha_affine=5, p=1),
        A.Rotate(limit=10, p=1),
        A.VerticalFlip(p=1),
        A.GridDistortion(num_steps=10, distort_limit=0.5, p=1),
    ], p=0.5),
])


train_transforms = {
    'img': transform_img,
    'img_mask': transform_img_mask
}

val_transforms = None


In [70]:
class DoseDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.samples = os.listdir(data_path)
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_path = self.data_path + os.sep + self.samples[idx]
        ct_scan = np.load(sample_path + os.sep + 'ct.npy')
        possible_dose_mask = np.load(sample_path + os.sep + 'possible_dose_mask.npy')
        dose = np.load(sample_path + os.sep + 'dose.npy')
        structure_masks = np.load(sample_path + os.sep + 'structure_masks.npy')

        combined_masks = [possible_dose_mask] + [structure_masks[i] for i in range(structure_masks.shape[0])] + [dose]

        if self.transform:
            if 'img_mask' in self.transform:
                transform_func = self.transform['img_mask']
                augmentations = transform_func(image=ct_scan, masks=combined_masks)
                ct_scan = augmentations['image']
                combined_masks = augmentations['masks']
            if 'img' in self.transform:
                transform_func = self.transform['img']
                augmentations = transform_func(image=ct_scan)
                ct_scan = augmentations['image']

        # Convert NumPy arrays to PyTorch tensors
        ct_scan_tensor = torch.from_numpy(ct_scan).float().unsqueeze(0)
        combined_masks_tensor = torch.from_numpy(np.stack(combined_masks[:-1], axis=0)).float()
        dose_tensor = torch.from_numpy(combined_masks[-1]).float().unsqueeze(0)

        # Normalize the ct_scan tensor
        mean = ct_scan_tensor.mean()
        std = ct_scan_tensor.std()
        normalize = transforms.Normalize(mean=[mean], std=[std+1e-6])
        ct_scan_tensor = normalize(ct_scan_tensor)

        # Combine ct_scan_tensor and combined_masks_tensor
        input_image = torch.cat([ct_scan_tensor, combined_masks_tensor], dim=0)

        return {'input_image': input_image, 'dose': dose_tensor}


In [71]:
class test_DoseDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.samples = os.listdir(data_path)
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample_path = self.data_path + os.sep + self.samples[idx]
        ct_scan = np.load(sample_path + os.sep + 'ct.npy')
        possible_dose_mask = np.load(sample_path + os.sep + 'possible_dose_mask.npy')
        structure_masks = np.load(sample_path + os.sep + 'structure_masks.npy')
        combined_masks = [possible_dose_mask] + [structure_masks[i] for i in range(structure_masks.shape[0])]

        # Convert NumPy arrays to PyTorch tensors
        ct_scan_tensor = torch.from_numpy(ct_scan).float().unsqueeze(0)
        combined_masks_tensor = torch.from_numpy(np.stack(combined_masks, axis=0)).float()

        # Normalize the ct_scan tensor
        mean = ct_scan_tensor.mean()
        std = ct_scan_tensor.std()
        normalize = transforms.Normalize(mean=[mean], std=[std+1e-6])
        ct_scan_tensor = normalize(ct_scan_tensor)

        # Combine ct_scan_tensor and combined_masks_tensor
        input_image = torch.cat([ct_scan_tensor, combined_masks_tensor], dim=0)

        return {'sample_name': self.samples[idx], 'input_image': input_image}


In [72]:
from torch.utils.data import ConcatDataset

train_dir = "./MVA-Dose-Prediction/train/"

train_dataset_original = DoseDataset(train_dir, transform=None)
train_dataset_augmented = DoseDataset(train_dir, transform=train_transforms)

train_dataset = ConcatDataset([train_dataset_original, train_dataset_augmented])

train_dataloader = DataLoader(train_dataset_augmented, batch_size=32, num_workers=2, shuffle=True, pin_memory=True)
print(len(train_dataloader))


244


In [73]:
val_dir = "./MVA-Dose-Prediction/validation/"

val_dataset = DoseDataset(val_dir, transform=val_transforms)
val_dataloader = DataLoader(val_dataset, batch_size=32, num_workers=2, shuffle=True, pin_memory=True)
print(len(val_dataloader))

38


In [74]:
test_dir = "./MVA-Dose-Prediction/test"

test_dataset = test_DoseDataset(test_dir)
test_dataloader = DataLoader(test_dataset, batch_size=32, num_workers=2, shuffle=True, pin_memory=True)
print(len(test_dataloader))

38


In [75]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50


class BigBasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, intermediate_channels):
        super(BigBasicBlock, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, intermediate_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(intermediate_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class BigUNet(nn.Module):
    def __init__(self, in_channels=12, out_channels=1, pretrained_weights=None):
        super(BigUNet, self).__init__()

        # Load pre-trained ResNet50 model
        resnet = resnet50(weights=pretrained_weights)

        if pretrained_weights is not None:
            # Remove the fully connected layer weights from the state_dict
            pretrained_weights = {k: v for k, v in pretrained_weights.items() if not k.startswith('fc')}
            # Load the pre-trained ResNet50 weights into the resnet model
            resnet.load_state_dict(pretrained_weights, strict=False)

        # Change the number of input channels for the first convolutional layer
        resnet.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Use the pre-trained ResNet50 layers as the encoder
        self.encoder_stage1 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool
        )
        self.encoder_stage2 = resnet.layer1
        self.encoder_stage3 = resnet.layer2
        self.encoder_stage4 = resnet.layer3

        # Freeze the encoder layers
        for param in self.encoder_stage1.parameters():
            param.requires_grad = False
        for param in self.encoder_stage2.parameters():
            param.requires_grad = False
        for param in self.encoder_stage3.parameters():
            param.requires_grad = False
        for param in self.encoder_stage4.parameters():
            param.requires_grad = False

        # Add convolutional layers for fine-tuning
        self.conv_finetune1 = nn.Conv2d(1024, 512, kernel_size=1)
        self.conv_finetune2 = nn.Conv2d(512, 1024, kernel_size=1)
        self.upsample_finetuned = nn.ConvTranspose2d(1024, 1024, kernel_size=2, stride=2)

        # Bottom block
        self.bottom = BigBasicBlock(1024, 1024, 1024)

        # Upward pass
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(1024 + 512, 512, kernel_size=2, stride=2),
            BigBasicBlock(512, 512, 512),
        )
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(512 + 256, 256, kernel_size=2, stride=2),
            BigBasicBlock(256, 256, 256),
        )
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),
            BigBasicBlock(128, 128, 128),
        )
        # self.up1 = nn.Sequential(
        #     nn.ConvTranspose2d(128 + in_channels, 64, kernel_size=2, stride=2),
        #     nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
        #     nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2, padding=0, output_padding=0)
        # )

        self.last = nn.Sequential(
            nn.Conv2d(128 + in_channels, 64, kernel_size=3, stride=1, padding=1),
            # nn.Conv2d(32 + in_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(32, 16, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(16, 8, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(8, 4, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(4, 2, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(2, 1, kernel_size=3, stride=1, padding=1)  
        )

    def forward(self, x):
        # Encoder pass
        # print('x',x.shape)
        x1 = self.encoder_stage1(x)
        # print('x1',x1.shape)
        x2 = self.encoder_stage2(x1)
        # print('x2',x2.shape)
        x3 = self.encoder_stage3(x2)
        # print('x3',x3.shape)
        x4 = self.encoder_stage4(x3)
        # print('x4',x4.shape)

        # Fine-tune the encoder output with convolutional layers
        x4_finetuned = F.relu(self.conv_finetune1(x4))
        # print('x4_finetuned',x4_finetuned.shape)
        x4_finetuned = F.relu(self.conv_finetune2(x4_finetuned))
        # print('x4_finetuned',x4_finetuned.shape)
        x4_finetuned_upsampled = self.upsample_finetuned(x4_finetuned)
        # print('x4_finetuned_upsampled',x4_finetuned_upsampled.shape)

        # Bottom block
        bottom = self.bottom(x4_finetuned_upsampled)
        # print('bottom',bottom.shape)

        # Upward pass
        up4 = self.up4(torch.cat((bottom, x3), dim=1))
        # print('up4',up4.shape)
        up3 = self.up3(torch.cat((up4, x2), dim=1))
        # print('up3',up3.shape)
        up2 = self.up2(up3)
        # print('up2',up2.shape)
        # up1 = self.up1(torch.cat((up2, x), dim=1))
        # print('up1',up1.shape)

        # Concatenate input layer
        y_concat = torch.cat((up2, x), dim=1)
        # print('y_concat',y_concat.shape)
        y = F.relu(self.last(y_concat))
        # print('y',y.shape)

        y = y * x[:, 1:2, :, :]

        return y



In [77]:
# Create a training loop
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(dataloader, 0):
        inputs, labels = data['input_image'].to(device), data['dose'].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / (i + 1)

# Create a validation loop
def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for i, data in enumerate(dataloader, 0):
            inputs, labels = data['input_image'].to(device), data['dose'].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
    return running_loss / (i + 1)


In [78]:
class EarlyStopping:
  def __init__(self, patience = 5):
    self.patience = patience
    self.best_val_loss = None
    self.counter = 0
    self.early_stop = False

  def __call__(self, val_loss):
    if self.best_val_loss is None or val_loss < self.best_val_loss:
        self.best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_trained_model.pth')
        print('model saved')
        self.counter = 0
    else:
        self.counter += 1

In [79]:
# Train the model and evaluate its performance on the validation set
num_epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Load the pre-trained weights
resnet_weights = torch.load('/content/drive/MyDrive/RadImageNet-ResNet50_notop_torch.pth')
model = BigUNet(pretrained_weights=resnet_weights).to(device)


criterion = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = StepLR(optimizer, step_size=15, gamma=0.1)
early_stopping = EarlyStopping(patience = 16)
real_num_epoch = 0

train_loss_list = []
val_loss_list = []

start_time = time.time()

for epoch in range(num_epochs):
    epoch_time = time.time()

    train_loss = train(model, train_dataloader, criterion, optimizer, device)
    val_loss = validate(model, val_dataloader, criterion, device)

    elapsed_time = time.time() - epoch_time
    total_time = time.time() - start_time
    print(f'Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Epoch time: {elapsed_time/60:.2f}min, Total time: {total_time/60:.2f}min')
    real_num_epoch +=1

    # Update the learning rate using the scheduler
    scheduler.step()

    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)
    
    early_stopping(val_loss)
    if early_stopping.counter >= early_stopping.patience:
      print(f'Early stopping at epoch {epoch + 1}. Best epoch: {epoch + 1 - early_stopping.counter} with Val Loss: {early_stopping.best_val_loss:.4f}')
      break

cuda


KeyboardInterrupt: ignored

In [None]:
plt.plot(range(real_num_epoch), train_loss_list, label='Training Loss')
plt.plot(range(real_num_epoch), val_loss_list, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show

In [None]:
model_save_path = "/content/drive/MyDrive/best_trained_model_rad.pth"
torch.save(model.state_dict(), model_save_path)


In [None]:
model.load_state_dict(torch.load("/content/drive/MyDrive/best_trained_model_rad.pth"))
model.eval()

# Create a directory to store the predicted dose files
output_dir = "predictions"
os.makedirs(output_dir, exist_ok=True)

# Loop through the test dataset and make predictions
with torch.no_grad():
    for idx, data in enumerate(test_dataloader):
        sample_names = data['sample_name']
        inputs = data['input_image'].to(device)
        outputs = model(inputs)
        
        batch_outputs = outputs.cpu().numpy()
        
        # Save each sample's prediction separately
        for i, sample_name in enumerate(sample_names):
            output_file = os.path.join(output_dir, f'{sample_name}.npy')
            np.save(output_file, batch_outputs[i])

# Create a zip file with all the predicted dose files
with zipfile.ZipFile("submission.zip", "w") as zf:
    for file in os.listdir(output_dir):
        file_path = os.path.join(output_dir, file)
        zf.write(file_path, arcname=file)
