# Import Library

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
from glob import glob
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device: {torch.cuda.get_device_name(0)}")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

PyTorch version: 1.13.1+cu116
CUDA Available: True
CUDA Device: NVIDIA GeForce GTX 1660 SUPER
Using device: cuda


__Configuration__

In [3]:
class Config:
    BATCH_SIZE = 8  # Adjust for memory
    IMG_HEIGHT = 256
    IMG_WIDTH = 256
    EPOCHS = 50
    LEARNING_RATE = 1e-4
    DATASET_PATH = './datasets'  # <- ensure correct root
    DEVICE = device
    NUM_WORKERS = 2  # Set to 0 on Kaggle if hitting RAM issues

config = Config()

__Extract red mask from OVERLAY path__

In [4]:
def extract_red_mask_from_path(mask_path, width, height):
    """Read a color overlay image and extract the red-filled lesion as a binary mask {0,1}       Returns None if path missing/unreadable.
    """
    if mask_path is None or (isinstance(mask_path, str) and not os.path.exists(mask_path)):
        return None
    if not isinstance(mask_path, str):
        return None
    overlay = cv2.imread(mask_path)  # BGR
    if overlay is None:
        return None
    hsv = cv2.cvtColor(overlay, cv2.COLOR_BGR2HSV)
    lower1, upper1 = np.array([0, 50, 50]),  np.array([10, 255, 255])
    lower2, upper2 = np.array([170, 50, 50]), np.array([180, 255, 255])
    mask = cv2.inRange(hsv, lower1, upper1) | cv2.inRange(hsv, lower2, upper2)
    mask = (mask > 0).astype(np.float32)
    mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
    mask = (mask > 0.5).astype(np.float32)
    return mask

__Dataset__

In [5]:
class StrokeDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths  # may contain None for "Normal" images
        self.transforms = transforms

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # --- image ---
        img = cv2.imread(self.image_paths[idx])
        if img is None:
            img = np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH, 3), dtype=np.uint8)
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (config.IMG_WIDTH, config.IMG_HEIGHT), interpolation=cv2.INTER_LINEAR)

        # --- overlay -> binary mask (0/1); for Normal (no OVERLAY) use all-zeros ---
        mask_path = self.mask_paths[idx]
        msk = extract_red_mask_from_path(mask_path, config.IMG_WIDTH, config.IMG_HEIGHT)
        if msk is None:
            msk = np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH), dtype=np.float32)

        if self.transforms:
            transformed = self.transforms(image=img, mask=msk)
            img, msk = transformed['image'], transformed['mask']

        return img, msk

__Transforms__

In [6]:
def get_transforms(is_training=True):
    if is_training:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=0.3),
            A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=5, p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

__Data discovery (includes Normal with empty masks)__

In [7]:
def load_and_preprocess_data():
    """Scan dataset folders and pair PNG with OVERLAY having the same filename.
       If a PNG has no matching OVERLAY (e.g., in Normal), include it with an empty mask.
    """
    images_paths, masks_paths = [], []

    # Expect subfolders: Bleeding/, Ischemia/, Normal/, maybe External_Test/
    for class_dir in ["Bleeding", "Ischemia", "Normal"]:
        png_dir = os.path.join(config.DATASET_PATH, class_dir, "PNG")
        overlay_dir = os.path.join(config.DATASET_PATH, class_dir, "OVERLAY")
        if not os.path.exists(png_dir):
            continue
        png_files = sorted(glob(os.path.join(png_dir, "*.png")))

        for png_file in png_files:
            filename = os.path.basename(png_file)
            overlay_file = os.path.join(overlay_dir, filename)
            if os.path.exists(overlay_dir) and os.path.exists(overlay_file):
                images_paths.append(png_file)
                masks_paths.append(overlay_file)
            else:
                # No overlay -> treat as Normal/negative (all-zero mask)
                images_paths.append(png_file)
                masks_paths.append(None)

    print(f"Found {len(images_paths)} images; of which {sum(1 for m in masks_paths if m is not None)} have overlays.")
    if len(images_paths) > 0:
        print(f"Sample image: {images_paths[0]}")
        print(f"Sample mask : {masks_paths[0]}")
    return images_paths, masks_paths

__Model__

In [8]:
def build_model():
    # Use logits (activation=None) + BCEWithLogitsLoss for numerical stability
    model = smp.Unet(
        encoder_name='efficientnet-b4',
        encoder_weights='imagenet',
        in_channels=3,
        classes=1,
        activation=None,
    )
    return model

__Losses & Metrics__

In [9]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    def forward(self, y_pred_prob, y_true):
        y_pred = y_pred_prob.view(-1)
        y_true = y_true.view(-1)
        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)
        return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
    def forward(self, y_pred_logits, y_true):
        prob = torch.sigmoid(y_pred_logits)
        return 0.5 * self.bce(y_pred_logits, y_true) + 0.5 * self.dice(prob, y_true)

def calculate_dice_score(y_pred_prob, y_true, smooth=1e-6):
    y_pred = (y_pred_prob > 0.5).float()
    intersection = (y_pred * y_true).sum()
    dice = (2. * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)
    return dice.item()

def calculate_iou_score(y_pred_prob, y_true, smooth=1e-6):
    y_pred = (y_pred_prob > 0.5).float()
    intersection = (y_pred * y_true).sum()
    union = y_pred.sum() + y_true.sum() - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.item()

def calculate_sensitivity(y_pred_prob, y_true):
    y_pred = (y_pred_prob > 0.5).float()
    true_positives = (y_pred * y_true).sum()
    actual_positives = y_true.sum()
    if actual_positives == 0:
        return 0.0
    return (true_positives / actual_positives).item()

def calculate_specificity(y_pred_prob, y_true):
    y_pred = (y_pred_prob > 0.5).float()
    true_negatives = ((1 - y_pred) * (1 - y_true)).sum()
    actual_negatives = (1 - y_true).sum()
    if actual_negatives == 0:
        return 0.0
    return (true_negatives / actual_negatives).item()

__Train / Validate loops__

In [10]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = running_dice = running_iou = 0.0
    progress_bar = tqdm(dataloader, desc='Training')
    for batch_idx, (images, masks) in enumerate(progress_bar):
        images = images.to(device)
        masks = masks.unsqueeze(1).to(device)

        logits = model(images)
        loss = criterion(logits, masks)

        optimizer.zero_grad(); loss.backward(); optimizer.step()

        probs = torch.sigmoid(logits)
        dice_score = calculate_dice_score(probs, masks)
        iou_score = calculate_iou_score(probs, masks)

        running_loss += loss.item()
        running_dice += dice_score
        running_iou += iou_score

        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}', 'Dice': f'{dice_score:.4f}', 'IoU': f'{iou_score:.4f}'})

    N = len(dataloader)
    return running_loss / N, running_dice / N, running_iou / N


def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = running_dice = running_iou = 0.0
    running_sensitivity = running_specificity = running_accuracy = 0.0
    progress_bar = tqdm(dataloader, desc='Validation')
    with torch.no_grad():
        for images, masks in progress_bar:
            images = images.to(device)
            masks = masks.unsqueeze(1).to(device)

            logits = model(images)
            loss = criterion(logits, masks)

            probs = torch.sigmoid(logits)
            dice_score = calculate_dice_score(probs, masks)
            iou_score = calculate_iou_score(probs, masks)
            sensitivity = calculate_sensitivity(probs, masks)
            specificity = calculate_specificity(probs, masks)

            pred_binary = (probs > 0.5).float()
            accuracy = (pred_binary == masks).float().mean().item()

            running_loss += loss.item()
            running_dice += dice_score
            running_iou += iou_score
            running_sensitivity += sensitivity
            running_specificity += specificity
            running_accuracy += accuracy

            progress_bar.set_postfix({'Loss': f'{loss.item():.4f}', 'Dice': f'{dice_score:.4f}', 'IoU': f'{iou_score:.4f}'})

    N = len(dataloader)
    return (
        running_loss / N,
        running_dice / N,
        running_iou / N,
        running_sensitivity / N,
        running_specificity / N,
        running_accuracy / N,
    )

__Training wrapper__

In [11]:
def train_model():
    print("Loading dataset...")
    image_paths, mask_paths = load_and_preprocess_data()
    if len(image_paths) == 0:
        print("No images found! Please check the dataset path.")
        return None, None

    train_images, val_images, train_masks, val_masks = train_test_split(
        image_paths, mask_paths, test_size=0.15, random_state=42
    )

    print(f"Training samples: {len(train_images)}")
    print(f"Validation samples: {len(val_images)}")

    train_dataset = StrokeDataset(train_images, train_masks, get_transforms(True))
    val_dataset   = StrokeDataset(val_images,   val_masks,   get_transforms(False))

    train_dataloader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True,
                                  num_workers=config.NUM_WORKERS, pin_memory=True)
    val_dataloader   = DataLoader(val_dataset,   batch_size=config.BATCH_SIZE, shuffle=False,
                                  num_workers=config.NUM_WORKERS, pin_memory=True)

    print("Building model...")
    model = build_model().to(device)
    print("\nModel built successfully!")
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    criterion = CombinedLoss()
    optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=8,
                                                     min_lr=1e-7, verbose=True)

    history = {
        'train_loss': [], 'val_loss': [],
        'train_dice': [], 'val_dice': [],
        'train_iou':  [], 'val_iou':  [],
        'val_sensitivity': [], 'val_specificity': [], 'val_accuracy': []
    }

    best_dice = 0.0
    patience_counter = 0
    patience = 15

    print("Starting training...")
    print("="*50)

    for epoch in range(config.EPOCHS):
        print(f"\nEpoch {epoch+1}/{config.EPOCHS}")
        print("-" * 30)

        train_loss, train_dice, train_iou = train_epoch(model, train_dataloader, criterion, optimizer, device)
        val_loss, val_dice, val_iou, val_sensitivity, val_specificity, val_accuracy = validate_epoch(
            model, val_dataloader, criterion, device
        )

        scheduler.step(val_dice)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['train_dice'].append(train_dice)
        history['val_dice'].append(val_dice)
        history['train_iou'].append(train_iou)
        history['val_iou'].append(val_iou)
        history['val_sensitivity'].append(val_sensitivity)
        history['val_specificity'].append(val_specificity)
        history['val_accuracy'].append(val_accuracy)

        print(f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, Train IoU: {train_iou:.4f}")
        print(f"Val   Loss: {val_loss:.4f}, Val   Dice: {val_dice:.4f}, Val   IoU: {val_iou:.4f}")
        print(f"Val Sensitivity: {val_sensitivity:.4f}, Val Specificity: {val_specificity:.4f}")
        print(f"Val Accuracy: {val_accuracy:.4f}")

        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), 'best_stroke_model_with_normals.pth')
            print(f"New best model saved! Dice: {best_dice:.4f}")
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

    model.load_state_dict(torch.load('best_stroke_model_with_normals.pth'))
    return model, history

__Plots__

In [12]:
def plot_training_history(history):
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    axes[0, 0].plot(history['train_loss'], label='Training Loss', linewidth=2)
    axes[0, 0].plot(history['val_loss'], label='Validation Loss', linewidth=2)
    axes[0, 0].set_title('Model Loss'); axes[0, 0].set_xlabel('Epoch'); axes[0, 0].set_ylabel('Loss'); axes[0, 0].legend(); axes[0, 0].grid(True)

    axes[0, 1].plot(history['train_dice'], label='Training Dice', linewidth=2)
    axes[0, 1].plot(history['val_dice'], label='Validation Dice', linewidth=2)
    axes[0, 1].set_title('Dice Coefficient'); axes[0, 1].set_xlabel('Epoch'); axes[0, 1].set_ylabel('Dice'); axes[0, 1].legend(); axes[0, 1].grid(True)

    axes[0, 2].plot(history['train_iou'], label='Training IoU', linewidth=2)
    axes[0, 2].plot(history['val_iou'], label='Validation IoU', linewidth=2)
    axes[0, 2].set_title('IoU Score'); axes[0, 2].set_xlabel('Epoch'); axes[0, 2].set_ylabel('IoU'); axes[0, 2].legend(); axes[0, 2].grid(True)

    axes[1, 0].plot(history['val_sensitivity'], label='Validation Sensitivity', linewidth=2)
    axes[1, 0].set_title('Sensitivity (Recall)'); axes[1, 0].set_xlabel('Epoch'); axes[1, 0].set_ylabel('Sensitivity'); axes[1, 0].legend(); axes[1, 0].grid(True)

    axes[1, 1].plot(history['val_specificity'], label='Validation Specificity', linewidth=2)
    axes[1, 1].set_title('Specificity'); axes[1, 1].set_xlabel('Epoch'); axes[1, 1].set_ylabel('Specificity'); axes[1, 1].legend(); axes[1, 1].grid(True)

    axes[1, 2].plot(history['val_accuracy'], label='Validation Accuracy', linewidth=2)
    axes[1, 2].set_title('Binary Accuracy'); axes[1, 2].set_xlabel('Epoch'); axes[1, 2].set_ylabel('Accuracy'); axes[1, 2].legend(); axes[1, 2].grid(True)

    plt.tight_layout(); plt.show()

__Visualization & Evaluation (same mask logic; skip metrics if no GT)__

In [13]:
def predict_and_visualize(model, image_paths, mask_paths, num_samples=5):
    model.eval()
    transforms = get_transforms(False)
    rows = min(num_samples, len(image_paths))
    fig, axes = plt.subplots(rows, 4, figsize=(20, 5*rows))

    with torch.no_grad():
        for i in range(rows):
            image = cv2.imread(image_paths[i])
            if image is None:
                print(f"Could not load image: {image_paths[i]}")
                continue
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_resized = cv2.resize(image_rgb, (config.IMG_WIDTH, config.IMG_HEIGHT))

            mask_binary = extract_red_mask_from_path(mask_paths[i], config.IMG_WIDTH, config.IMG_HEIGHT)
            has_gt = mask_binary is not None
            if not has_gt:
                mask_binary = np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH), dtype=np.float32)

            transformed = transforms(image=image_resized, mask=mask_binary)
            image_tensor = transformed['image'].unsqueeze(0).to(device)

            pred_logits = model(image_tensor)[0, 0].cpu().numpy()
            pred_prob = 1 / (1 + np.exp(-pred_logits))
            pred_mask_binary = (pred_prob > 0.5).astype(np.uint8) * 255

            if has_gt:
                pred_binary = (pred_prob > 0.5).astype(np.float32)
                intersection = np.sum(mask_binary * pred_binary)
                union = np.sum(mask_binary) + np.sum(pred_binary) - intersection
                iou = intersection / (union + 1e-6) if union > 0 else 0
                dice = (2 * intersection) / (np.sum(mask_binary) + np.sum(pred_binary) + 1e-6)
                title3 = f'Prediction Binary\nDice: {dice:.3f}, IoU: {iou:.3f}'
            else:
                title3 = 'Prediction Binary (no GT)'

            axes[i, 0].imshow(image_resized); axes[i, 0].set_title(f'Original Image {i+1}'); axes[i, 0].axis('off')
            axes[i, 1].imshow(mask_binary, cmap='gray'); axes[i, 1].set_title('Ground Truth Mask' if has_gt else 'No GT (Normal)'); axes[i, 1].axis('off')
            axes[i, 2].imshow(pred_prob, cmap='gray'); axes[i, 2].set_title('Prediction (Prob)'); axes[i, 2].axis('off')
            axes[i, 3].imshow(pred_mask_binary, cmap='gray'); axes[i, 3].set_title(title3); axes[i, 3].axis('off')

    plt.tight_layout(); plt.show()


def evaluate_model(model, image_paths, mask_paths):
    print("Evaluating model on test set...")
    model.eval()
    transforms = get_transforms(False)
    dice_scores, iou_scores = [], []

    with torch.no_grad():
        for i, (img_path, mask_path) in enumerate(zip(image_paths, mask_paths)):
            if i % 100 == 0:
                print(f"Evaluating {i}/{len(image_paths)}")

            image = cv2.imread(img_path)
            if image is None:
                continue
            image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image_resized = cv2.resize(image_rgb, (config.IMG_WIDTH, config.IMG_HEIGHT))

            mask_binary = extract_red_mask_from_path(mask_path, config.IMG_WIDTH, config.IMG_HEIGHT)
            if mask_binary is None:
                # skip normals for quantitative metrics
                continue

            x = get_transforms(False)(image=image_resized, mask=mask_binary)['image'].unsqueeze(0).to(device)
            pred_logits = model(x)[0, 0].cpu().numpy()
            pred_prob = 1 / (1 + np.exp(-pred_logits))
            pred_binary = (pred_prob > 0.5).astype(np.float32)

            intersection = np.sum(mask_binary * pred_binary)
            union = np.sum(mask_binary) + np.sum(pred_binary) - intersection
            if union > 0:
                iou = intersection / union
                dice = (2 * intersection) / (np.sum(mask_binary) + np.sum(pred_binary))
                dice_scores.append(dice)
                iou_scores.append(iou)

    print("\nEvaluation Results (lesion slices only):")
    if len(dice_scores) == 0:
        print("No GT masks found for evaluation.")
    else:
        print(f"Mean Dice Score: {np.mean(dice_scores):.4f} ± {np.std(dice_scores):.4f}")
        print(f"Mean IoU  Score: {np.mean(iou_scores):.4f} ± {np.std(iou_scores):.4f}")
        print(f"Median Dice Score: {np.median(dice_scores):.4f}")
        print(f"Median IoU  Score: {np.median(iou_scores):.4f}")
    return dice_scores, iou_scores

__Inference for a single image (no GT required)__

In [14]:
def predict_single_image(model, image_path, threshold=0.5, device=config.DEVICE):
    model.eval()
    transforms = get_transforms(False)

    bgr = cv2.imread(image_path)
    if bgr is None:
        raise FileNotFoundError(f"Could not read image: {image_path}")
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    rgb = cv2.resize(rgb, (config.IMG_WIDTH, config.IMG_HEIGHT), interpolation=cv2.INTER_LINEAR)

    sample = transforms(image=rgb, mask=np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH), np.float32))
    x = sample['image'].unsqueeze(0).to(device)

    with torch.no_grad():
        logits = model(x)[0, 0].cpu().numpy()
    prob = 1 / (1 + np.exp(-logits))
    pred_bin = (prob > threshold).astype(np.uint8) * 255

    overlay = rgb.copy()
    overlay[pred_bin > 0] = (255, 64, 64)
    return rgb, prob, pred_bin, overlay

__Main__

In [None]:
if __name__ == "__main__":
    print("Starting Brain Stroke Detection Training - PyTorch Version...")
    print("="*60)

    model, history = train_model()

    if model is not None and history is not None:
        plot_training_history(history)
        image_paths, mask_paths = load_and_preprocess_data()

        print("\nVisualizing predictions...")
        predict_and_visualize(model, image_paths[-10:], mask_paths[-10:], num_samples=5)

        dice_scores, iou_scores = evaluate_model(model, image_paths[-100:], mask_paths[-100:])

        print("\nTraining completed successfully!")
        print("Model saved as 'best_stroke_model_with_normals.pth'")
    else:
        print("Training failed. Please check the dataset path and try again.")

Starting Brain Stroke Detection Training - PyTorch Version...
Loading dataset...
Found 6650 images; of which 2223 have overlays.
Sample image: ./datasets\Bleeding\PNG\10002.png
Sample mask : ./datasets\Bleeding\OVERLAY\10002.png
Training samples: 5652
Validation samples: 998
Building model...


Downloading: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth" to C:\Users\Ishigami/.cache\torch\hub\checkpoints\efficientnet-b4-6ed6700e.pth
100%|██████████| 74.4M/74.4M [00:14<00:00, 5.25MB/s]



Model built successfully!
Total parameters: 20,225,689
Trainable parameters: 20,225,689
Starting training...

Epoch 1/50
------------------------------


Training:   0%|          | 0/707 [00:00<?, ?it/s]