<a href="https://colab.research.google.com/github/fifaak/BrainStroke_Segmentation/blob/main/SP_proj_all_in_one_UNET%2B%2BandRESNET_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount("/content/drive")
!cp /content/drive/MyDrive/spdataset.zip /content/dataset.zip
!unzip /content/dataset.zip

Mounted at /content/drive
Archive:  /content/dataset.zip
   creating: spdataset/
   creating: spdataset/original/
  inflating: __MACOSX/spdataset/._original  
   creating: spdataset/mask/
  inflating: __MACOSX/spdataset/._mask  
  inflating: spdataset/original/63.jpg  
  inflating: __MACOSX/spdataset/original/._63.jpg  
  inflating: spdataset/original/189.jpg  
  inflating: __MACOSX/spdataset/original/._189.jpg  
  inflating: spdataset/original/77.jpg  
  inflating: __MACOSX/spdataset/original/._77.jpg  
  inflating: spdataset/original/162.jpg  
  inflating: __MACOSX/spdataset/original/._162.jpg  
  inflating: spdataset/original/176.jpg  
  inflating: __MACOSX/spdataset/original/._176.jpg  
  inflating: spdataset/original/88.jpg  
  inflating: __MACOSX/spdataset/original/._88.jpg  
  inflating: spdataset/original/348.jpg  
  inflating: __MACOSX/spdataset/original/._348.jpg  
  inflating: spdataset/original/228.jpg  
  inflating: __MACOSX/spdataset/original/._228.jpg  
  inflating: spda

In [None]:
!pip install torch torchvision segmentation-models-pytorch

Collecting segmentation-models-pytorch
  Downloading segmentation_models_pytorch-0.3.3-py3-none-any.whl (106 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)


In [None]:
import os
from torch.utils.data import random_split, DataLoader, Dataset
import torch.nn.functional as F
import segmentation_models_pytorch as smp
import torch.optim as optim
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import jaccard_score
import torchvision

# Define paths
image_dir = '/content/spdataset/original'
mask_dir = '/content/spdataset/mask'

# Dataset class
class BrainStrokeDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = self.images[idx]
        image = Image.open(os.path.join(self.image_dir, img_name)).convert("L")
        mask = Image.open(os.path.join(self.mask_dir, img_name.split('.')[0] + '_HGE_Seg.jpg')).convert("L")

        # Resize images and masks to 256x256
        image = image.resize((256, 256))
        mask = mask.resize((256, 256))

        # Convert images and masks to tensors
        image = transforms.ToTensor()(image)
        mask = transforms.ToTensor()(mask)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        # Convert mask to binary
        mask = mask > 0.5
        mask = mask.float()

        return image, mask

# Transformations
transform = transforms.Compose([
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Dataset
dataset = BrainStrokeDataset(image_dir=image_dir, mask_dir=mask_dir, transform=transform)

# Dataset sizes
dataset_size = len(dataset)
train_size = int(0.7 * dataset_size)
valid_size = int(0.1 * dataset_size)
test_size = dataset_size - train_size - valid_size

# Splits
train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# Model with DenseNet121 backbone and attention
model = smp.UnetPlusPlus(
    encoder_name="resnet50",  # You can replace this with "se_resnext50_32x4d", "resnet50", "densenet201", etc.
    encoder_weights="imagenet",
    in_channels=1,
    classes=1,
    activation=None,
    decoder_attention_type="scse"  # Using attention mechanism
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Wrap model with DataParallel
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)

model = model.to(device)

# Loss and optimizer with learning rate scheduler
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

images, _ = next(iter(train_loader))  # Get a batch of images to visualize the model graph
images = images.to(device)

def train_one_epoch(model, criterion, optimizer, data_loader, device):
    model.train()
    total_loss = 0
    total_iou = 0
    total_dice = 0

    for images, masks in data_loader:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)

        if torch.isnan(loss) or torch.isinf(loss):
            print(f"NaN or Inf detected in loss")
            continue

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        preds = torch.sigmoid(outputs)
        preds = (preds > 0.5).float()

        iou = jaccard_score(masks.flatten().cpu(), preds.flatten().cpu())
        dice = (2 * (preds * masks).sum()) / ((preds + masks).sum() + 1e-10)

        total_iou += iou
        total_dice += dice.item()

    avg_loss = total_loss / len(data_loader)
    avg_iou = total_iou / len(data_loader)
    avg_dice = total_dice / len(data_loader)

    return avg_loss, avg_iou, avg_dice

def validate_one_epoch(model, criterion, data_loader, device):
    model.eval()
    total_loss = 0
    total_iou = 0
    total_dice = 0

    with torch.no_grad():
        for images, masks in data_loader:
            images = images.to(device)
            masks = masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            total_loss += loss.item()

            preds = torch.sigmoid(outputs)
            preds = (preds > 0.5).float()

            iou = jaccard_score(masks.flatten().cpu(), preds.flatten().cpu())
            dice = (2 * (preds * masks).sum()) / ((preds + masks).sum() + 1e-10)

            total_iou += iou
            total_dice += dice.item()

    avg_loss = total_loss / len(data_loader)
    avg_iou = total_iou / len(data_loader)
    avg_dice = total_dice / len(data_loader)

    return avg_loss, avg_iou, avg_dice

num_epochs = 200
patience = 5
best_val_loss = float('inf')
epochs_no_improve = 0

for epoch in range(num_epochs):
    train_loss, train_iou, train_dice = train_one_epoch(model, criterion, optimizer, train_loader, device)
    valid_loss, valid_iou, valid_dice = validate_one_epoch(model, criterion, valid_loader, device)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {valid_loss:.4f}, Train IoU: {train_iou:.4f}, Valid IoU: {valid_iou:.4f}, Train Dice: {train_dice:.4f}, Valid Dice: {valid_dice:.4f}")

#     scheduler.step(valid_loss)

    if valid_loss < best_val_loss:
        best_val_loss = valid_loss
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1

    if epochs_no_improve == patience:
        print(f"Early stopping at epoch {epoch+1}")
        break

Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to /root/.cache/torch/hub/checkpoints/efficientnet-b4-6ed6700e.pth
100%|██████████| 74.4M/74.4M [00:00<00:00, 275MB/s]
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 1/200, Train Loss: 0.4594, Valid Loss: 0.5263, Train IoU: 0.0252, Valid IoU: 0.0051, Train Dice: 0.0488, Valid Dice: 0.0100
Epoch 2/200, Train Loss: 0.3085, Valid Loss: 0.3194, Train IoU: 0.0676, Valid IoU: 0.0326, Train Dice: 0.1225, Valid Dice: 0.0628
Epoch 3/200, Train Loss: 0.2273, Valid Loss: 0.2231, Train IoU: 0.1283, Valid IoU: 0.0001, Train Dice: 0.2155, Valid Dice: 0.0003
Epoch 4/200, Train Loss: 0.1799, Valid Loss: 0.1786, Train IoU: 0.1524, Valid IoU: 0.0000, Train Dice: 0.2395, Valid Dice: 0.0000
Epoch 5/200, Train Loss: 0.1532, Valid Loss: 0.1534, Train IoU: 0.1615, Valid IoU: 0.0000, Train Dice: 0.2564, Valid Dice: 0.0000
Epoch 6/200, Train Loss: 0.1334, Valid Loss: 0.1389, Train IoU: 0.1334, Valid IoU: 0.0000, Train Dice: 0.2195, Valid Dice: 0.0000
Epoch 7/200, Train Loss: 0.1188, Valid Loss: 0.1246, Train IoU: 0.1266, Valid IoU: 0.0003, Train Dice: 0.2117, Valid Dice: 0.0006
Epoch 8/200, Train Loss: 0.1057, Valid Loss: 0.1134, Train IoU: 0.1715, Valid IoU: 0.0021,

In [None]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import precision_recall_curve, average_precision_score

criterion = nn.BCEWithLogitsLoss()

def calculate_iou(pred, target, threshold=0.5):
    pred = pred > threshold
    target = target > threshold

    intersection = (pred & target).sum((1, 2))
    union = (pred | target).sum((1, 2))

    if torch.any(union == 0):
        print("Union is zero for some images")

    iou = torch.where(union == 0, torch.tensor(float('nan')), intersection / union)
    return iou[~torch.isnan(iou)].mean()

def calculate_dice(pred, target, threshold=0.5):
    pred = pred > threshold
    target = target > threshold

    intersection = (pred & target).sum((1, 2))
    sum_pred_target = pred.sum((1, 2)) + target.sum((1, 2))

    dice = 2 * intersection / sum_pred_target
    return dice[~torch.isnan(dice)].mean()

def calculate_map(probs, targets, thresholds=np.linspace(0, 1, 101)):
    aps = []
    for threshold in thresholds:
        binarized_probs = probs > threshold
        ap = average_precision_score(targets.flatten(), binarized_probs.flatten())
        aps.append(ap)
    return np.mean(aps)

# Model evaluation with IoU, Dice, and mAP metrics
model.eval()
test_loss = 0
ious = []
dices = []
maps = []

with torch.no_grad():
    for images, masks in test_loader:
        images = images.to(device)
        masks = masks.to(device)
        outputs = model(images)

        loss = criterion(outputs, masks)
        test_loss += loss.item()

        probs = torch.sigmoid(outputs)

        # Calculate IoU
        iou = calculate_iou(probs.cpu(), masks.cpu())
        if not torch.isnan(iou):
            ious.append(iou.item())

        # Calculate Dice
        dice = calculate_dice(probs.cpu(), masks.cpu())
        if not torch.isnan(dice):
            dices.append(dice.item())

        # Calculate mAP
        map_score = calculate_map(probs.cpu().numpy(), masks.cpu().numpy())
        maps.append(map_score)

test_loss /= len(test_loader)
mean_iou = np.nanmean(ious)
mean_dice = np.nanmean(dices)
mean_map = np.nanmean(maps)

print(f"Test Loss: {test_loss:.4f}")
print(f"Mean IoU: {mean_iou:.4f}")
print(f"Mean Dice: {mean_dice:.4f}")
print(f"Mean mAP: {mean_map:.4f}")

Union is zero for some images
Union is zero for some images
Union is zero for some images
Union is zero for some images
Union is zero for some images
Union is zero for some images
Union is zero for some images
Union is zero for some images
Test Loss: 0.0113
Mean IoU: 0.4469
Mean Dice: 0.5132
Mean mAP: 0.6089


In [None]:
# Save the entire model
torch.save(model, '/content/drive/MyDrive/SP_project/unetplusplus_model_resnet50andimagenet_best.pth')
torch.save(model.state_dict(), '/content/drive/MyDrive/SP_project/unetplusplus_model_statedict_resnet50andimagenet_best.pth')
