<a href="https://colab.research.google.com/github/bhnunes/AIModels/blob/UNET/IA_Medica_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

1. Set Up Environment

In [None]:
# Install required packages
!pip install torch torchvision torchsummary numpy matplotlib

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-

2. Import Libraries

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import functional as TF
import matplotlib.pyplot as plt
from PIL import Image
import gc
from torch.optim.lr_scheduler import ReduceLROnPlateau

Mounted at /content/drive


In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Define hyperparameters
LEARNING_RATE = 1e-5
BATCH_SIZE = 8
NUM_EPOCHS = 100
WORKERS = 12
WEIGHT_DECAY = 1e-5

cuda


# ***Define METRICS FOR EVALUATION***

In [None]:
import numpy as np

def dice_coefficient(pred, target, smooth=1e-6):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred & target).float().sum((1, 2))

    dice = (2. * intersection + smooth) / (pred.float().sum((1, 2)) + target.float().sum((1, 2)) + smooth)

    return dice.mean().item()

def iou(pred, target, smooth=1e-6):
    pred = pred.contiguous()
    target = target.contiguous()

    intersection = (pred & target).float().sum((1, 2))
    union = (pred | target).float().sum((1, 2))

    iou = (intersection + smooth) / (union + smooth)

    return iou.mean().item()

def pixel_accuracy(preds, masks):
    correct = (preds == masks).sum().item()
    total = masks.numel()
    return correct / total


3. Define Model

# ***UNET***

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.encoder = nn.ModuleList([
            self.conv_block(in_channels, 64),
            self.conv_block(64, 128),
            self.conv_block(128, 256),
            self.conv_block(256, 512),
        ])
        self.middle = self.conv_block(512, 1024)
        self.decoder = nn.ModuleList([
            self.up_conv_block(1024, 512),
            self.up_conv_block(512, 256),
            self.up_conv_block(256, 128),
            self.up_conv_block(128, 64),
        ])
        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def up_conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            self.conv_block(out_channels * 2, out_channels)  # *2 because of concatenation
        )

    def forward(self, x):
        encoder_outputs = []
        for encoder_layer in self.encoder:
            x = encoder_layer(x)
            encoder_outputs.append(x)
            x = nn.functional.max_pool2d(x, kernel_size=2, stride=2)

        x = self.middle(x)

        for decoder_layer, encoder_output in zip(self.decoder, reversed(encoder_outputs)):
            x = decoder_layer[0](x)  # Upconvolution
            x = torch.cat([x, encoder_output], dim=1)
            x = decoder_layer[1](x)  # Conv block after concatenation

        return self.final_conv(x)

4. Create Custom Dataset

In [None]:
class ProstateCancerDataset(Dataset):
    def __init__(self, cancer_image_dir, cancer_mask_dir, not_cancer_image_dir, not_cancer_mask_dir, image_transform=None, mask_transform=None):
        self.cancer_image_dir = cancer_image_dir
        self.cancer_mask_dir = cancer_mask_dir
        self.not_cancer_image_dir = not_cancer_image_dir
        self.not_cancer_mask_dir = not_cancer_mask_dir
        self.image_transform = image_transform
        self.mask_transform = mask_transform

        # Combine image and mask file lists
        self.cancer_images = os.listdir(cancer_image_dir)
        self.not_cancer_images = os.listdir(not_cancer_image_dir)
        self.images = self.cancer_images + self.not_cancer_images
        self.labels = [1] * len(self.cancer_images) + [0] * len(self.not_cancer_images)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.labels[idx] == 1:
            img_path = os.path.join(self.cancer_image_dir, self.images[idx])
            mask_path = os.path.join(self.cancer_mask_dir, self.images[idx])
        else:
            img_path = os.path.join(self.not_cancer_image_dir, self.images[idx])
            mask_path = os.path.join(self.not_cancer_mask_dir, self.images[idx])

        image = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        # Ensure image and mask have the same size
        image = image.resize((224, 224), Image.BILINEAR)
        mask = mask.resize((224, 224), Image.NEAREST)

        if self.image_transform:
            image = self.image_transform(image)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return image, mask

In [None]:
class ToTensor:
    def __call__(self, image):
        image = TF.to_tensor(image)
        return image

class ToTensorMask:
    def __call__(self, mask):
        mask = torch.tensor(np.array(mask), dtype=torch.long)
        return mask

class Normalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, image):
        image = TF.normalize(image, self.mean, self.std)
        return image

image_transform = transforms.Compose([
    ToTensor(),
    Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

mask_transform = transforms.Compose([
    ToTensorMask(),
])

5. Load data

In [None]:
# Directories
train_cancer_image_dir = '/content/drive/MyDrive/SUB_SET_SMALL/TRAIN/CANCER'
train_cancer_mask_dir = '/content/drive/MyDrive/SUB_SET_SMALL/TRAIN/CANCER_MASK'
train_not_cancer_image_dir = '/content/drive/MyDrive/SUB_SET_SMALL/TRAIN/NOT_CANCER'
train_not_cancer_mask_dir = '/content/drive/MyDrive/SUB_SET_SMALL/TRAIN/NOT_CANCER_MASK'

val_cancer_image_dir = '/content/drive/MyDrive/SUB_SET_SMALL/VALIDATION/CANCER'
val_cancer_mask_dir = '/content/drive/MyDrive/SUB_SET_SMALL/VALIDATION/CANCER_MASK'
val_not_cancer_image_dir = '/content/drive/MyDrive/SUB_SET_SMALL/VALIDATION/NOT_CANCER'
val_not_cancer_mask_dir = '/content/drive/MyDrive/SUB_SET_SMALL/VALIDATION/NOT_CANCER_MASK'

test_cancer_image_dir = '/content/drive/MyDrive/SUB_SET_SMALL/TEST/CANCER'
test_cancer_mask_dir = '/content/drive/MyDrive/SUB_SET_SMALL/TEST/CANCER_MASK'
test_not_cancer_image_dir = '/content/drive/MyDrive/SUB_SET_SMALL/TEST/NOT_CANCER'
test_not_cancer_mask_dir = '/content/drive/MyDrive/SUB_SET_SMALL/TEST/NOT_CANCER_MASK'

# Datasets
train_dataset = ProstateCancerDataset(train_cancer_image_dir, train_cancer_mask_dir, train_not_cancer_image_dir, train_not_cancer_mask_dir, image_transform=image_transform, mask_transform=mask_transform)
val_dataset = ProstateCancerDataset(val_cancer_image_dir, val_cancer_mask_dir, val_not_cancer_image_dir, val_not_cancer_mask_dir, image_transform=image_transform, mask_transform=mask_transform)
test_dataset = ProstateCancerDataset(test_cancer_image_dir, test_cancer_mask_dir, test_not_cancer_image_dir, test_not_cancer_mask_dir, image_transform=image_transform, mask_transform=mask_transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=WORKERS, pin_memory=True)



# ***Early Stop Class***

In [None]:
class EarlyStopping:
    def __init__(self, patience=5, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.delta = delta

    def __call__(self, val_loss, model, path):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model, path)
        elif val_loss > self.best_loss + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, path):
        """Saves model when validation loss decreases."""
        if self.verbose:
            print(f'Validation loss decreased ({self.best_loss:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), path)

6. Train the model

In [None]:
def train_model(model, criterion, optimizer, dataloader, device):
    model.train()
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.squeeze(1).long())  # Adjust target shape
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

        # Calculate accuracy
        preds = torch.argmax(outputs, dim=1)
        correct_pixels += (preds == masks.squeeze(1)).sum().item()
        total_pixels += masks.numel()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_accuracy = correct_pixels / total_pixels


    return epoch_loss, epoch_accuracy

In [None]:
def validate_model(model, criterion, dataloader, device):
    model.eval()
    running_loss = 0.0
    correct_pixels = 0
    total_pixels = 0
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            #loss = criterion(outputs, masks)
            loss = criterion(outputs, masks.squeeze(1).long())

            running_loss += loss.item() * images.size(0)

            # Calculate accuracy
            preds = torch.argmax(outputs, dim=1)
            correct_pixels += (preds == masks.squeeze(1)).sum().item()
            total_pixels += masks.numel()

    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_accuracy = correct_pixels / total_pixels

    return epoch_loss, epoch_accuracy

7. Initialize and Train

In [None]:
# Initialize and Train Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=3, out_channels=2).to(device)  # 2 classes: CANCER and NOT_CANCER
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

num_epochs = NUM_EPOCHS
early_stopping = EarlyStopping(patience=5, verbose=True)

for epoch in range(num_epochs):
    train_loss, train_accuracy = train_model(model, criterion, optimizer, train_loader, device)
    val_loss, val_accuracy = validate_model(model, criterion, val_loader, device)
    scheduler.step(val_loss)

    current_lr = optimizer.param_groups[0]['lr']

    # Print training and validation metrics
    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Learning Rate: {current_lr:.6f}")

    early_stopping(val_loss, model, 'best_model.pth')

    if early_stopping.early_stop:
        print("Early stopping")
        break
    gc.collect()  # Trigger garbage collection
    torch.cuda.empty_cache()  # Free up any cached CUDA allocations

Epoch 1/100, Train Loss: 0.6927, Train Accuracy: 0.5040, Validation Loss: 0.6905, Validation Accuracy: 0.4921, Learning Rate: 0.000010
Validation loss decreased (0.690532 --> 0.690532).  Saving model ...
Epoch 2/100, Train Loss: 0.6908, Train Accuracy: 0.5184, Validation Loss: 0.6865, Validation Accuracy: 0.5433, Learning Rate: 0.000010
Validation loss decreased (0.686458 --> 0.686458).  Saving model ...
Epoch 3/100, Train Loss: 0.6853, Train Accuracy: 0.5433, Validation Loss: 0.6821, Validation Accuracy: 0.5508, Learning Rate: 0.000010
Validation loss decreased (0.682113 --> 0.682113).  Saving model ...
Epoch 4/100, Train Loss: 0.6755, Train Accuracy: 0.5758, Validation Loss: 0.6690, Validation Accuracy: 0.5963, Learning Rate: 0.000010
Validation loss decreased (0.668989 --> 0.668989).  Saving model ...
Epoch 5/100, Train Loss: 0.6631, Train Accuracy: 0.6017, Validation Loss: 0.6626, Validation Accuracy: 0.6062, Learning Rate: 0.000010
Validation loss decreased (0.662596 --> 0.662596)

8. Evaluate the Model

In [None]:
def evaluate_model(model, dataloader, device):
    model.eval()
    dice_score = 0.0
    iou_score = 0.0
    num_batches = 0

    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)

            dice_score += dice_coefficient(preds, masks)
            iou_score += iou(preds, masks)
            num_batches += 1

    dice_score /= num_batches
    iou_score /= num_batches

    print(f"Dice Coefficient: {dice_score:.4f}, IoU: {iou_score:.4f}")
    return dice_score, iou_score

9. Visualize Results

In [None]:
# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

# Test the model
dice, iou = evaluate_model(model, test_loader, device)

def visualize_predictions(model, dataloader, device, num_images=5):
    model.eval()
    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if i >= num_images:
                break
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1)
            images = images.cpu().numpy()
            masks = masks.cpu().numpy()
            preds = preds.cpu().numpy()

            for j in range(images.shape[0]):
                plt.figure(figsize=(10, 3))
                plt.subplot(1, 3, 1)
                plt.imshow(np.transpose(images[j], (1, 2, 0)))
                plt.title('Image')
                plt.subplot(1, 3, 2)
                plt.imshow(masks[j])
                plt.title('Ground Truth')
                plt.subplot(1, 3, 3)
                plt.imshow(preds[j])
                plt.title('Prediction')
                plt.show()

# Visualize predictions
visualize_predictions(model, test_loader, device)

In [None]:
print(f'Dice Value = {dice} ===== IoU Value = {iou}')