In [None]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


https://www.kaggle.com/code/huseyincavus/deepstroke-se-resnext50/notebook

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

huseyincavus_deepstroke_path = kagglehub.dataset_download('huseyincavus/deepstroke')

print('Data source import complete.')


In [None]:
# Cell 1: Display random images from the dataset
import os
import random
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Define the data directories
base_dir = "/kaggle/input/deepstroke/DeepStroke1_Data"
ischaemic_dir = os.path.join(base_dir, "Ischaemic")
non_ischaemic_dir = os.path.join(base_dir, "Non-Ischaemic")
hemoraj_dir = os.path.join(base_dir, "Hemoraj")

def display_random_images(directory, num_images=4):
    if not os.path.exists(directory):
        print(f"Directory not found: {directory}")
        return

    all_files = os.listdir(directory)
    # Filter out non-image files
    image_files = [f for f in all_files if f.lower().endswith(('.png', '.jpg', '.jpeg'))]

    if not image_files:
        print(f"No image files found in {directory}")
        return

    random_files = random.sample(image_files, min(num_images, len(image_files)))

    # Create a figure and axes for the subplots
    fig, axes = plt.subplots(1, len(random_files), figsize=(15, 5))
    if len(random_files) == 1:
        axes = [axes]

    # Display each image in a subplot
    for i, file in enumerate(random_files):
        image_path = os.path.join(directory, file)
        img = mpimg.imread(image_path)
        axes[i].imshow(img)
        axes[i].axis('off')
        axes[i].set_title(os.path.basename(file))

    plt.tight_layout()
    plt.show()

# Display Ischaemic images
if os.path.exists(ischaemic_dir):
    print("Ischaemic Images:")
    print(f"Number of Ischaemic samples: {len(os.listdir(ischaemic_dir))}")
    display_random_images(ischaemic_dir)
else:
    print(f"Directory not found: {ischaemic_dir}")

# Display Non-Ischaemic images
if os.path.exists(non_ischaemic_dir):
    print("\nNon-Ischaemic Images:")
    print(f"Number of Non-Ischemic samples: {len(os.listdir(non_ischaemic_dir))}")
    display_random_images(non_ischaemic_dir)
else:
    print(f"Directory not found: {non_ischaemic_dir}")

# Display Hemoraj images
if os.path.exists(hemoraj_dir):
    print("\nHemoraj Images:")

    # Check if Hemoraj has subdirectories (class folders)
    subdirs = [d for d in os.listdir(hemoraj_dir) if os.path.isdir(os.path.join(hemoraj_dir, d))]

    if subdirs:  # If there are subdirectories (class folders)
        for subdir in subdirs:
            class_dir = os.path.join(hemoraj_dir, subdir)
            print(f"-- {subdir} Images:")
            print(f"   Number of {subdir} samples: {len(os.listdir(class_dir))}")
            display_random_images(class_dir)
    else:  # If images are directly in the base directory
        print(f"Number of samples: {len([f for f in os.listdir(hemoraj_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])}")
        display_random_images(hemoraj_dir)
else:
    print(f"Directory not found: {hemoraj_dir}")

In [None]:
# Cell 2: Prepare the dataset for training (Optimized for Speed)
import os
import random
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

# --- OPTIMIZATION: Get the number of available CPU cores ---
# This will be used to parallelize data loading.
try:
    NUM_CPUS = os.cpu_count()
except:
    NUM_CPUS = 4 # A reasonable fallback for platforms where os.cpu_count() might fail

# Constants for ResNet50
IMG_WIDTH = 224
IMG_HEIGHT = 224
BATCH_SIZE = 128 # Increased for better GPU utilization

# ImageNet normalization values for ResNet50
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

# Setup device for GPU usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Using {NUM_CPUS} CPU workers for data loading.")

# DeepStroke dataset paths
base_dir = "/kaggle/input/deepstroke/DeepStroke1_Data"
normal_dir = os.path.join(base_dir, "Non-Ischaemic")  # Normal images (label 0)
abnormal_dirs = [
    os.path.join(base_dir, "Ischaemic"),  # Abnormal images (label 1)
    os.path.join(base_dir, "Hemoraj")     # Abnormal images (label 1)
]

class ImageDataset(Dataset):
    def __init__(self, normal_dir, abnormal_dirs, transform=None):
        self.image_paths = []
        self.labels = []

        # Add normal images (label 0)
        if os.path.exists(normal_dir):
            for f in os.listdir(normal_dir):
                if f.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(normal_dir, f))
                    self.labels.append(0)  # Normal

        # Add abnormal images (label 1)
        for abnormal_dir in abnormal_dirs:
            if os.path.exists(abnormal_dir):
                for f in os.listdir(abnormal_dir):
                    if f.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.image_paths.append(os.path.join(abnormal_dir, f))
                        self.labels.append(1)  # Abnormal

        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, torch.tensor(label, dtype=torch.long)

# Define the transformations for training and for validation/testing
train_transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

val_test_transform = transforms.Compose([
    transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
    transforms.ToTensor(),
    transforms.Normalize(mean=MEAN, std=STD)
])

# Create dataset with augmented images for training
dataset = ImageDataset(normal_dir, abnormal_dirs, transform=train_transform)

# Split dataset into training (70%), validation (15%), and test (15%)
generator = torch.Generator().manual_seed(42)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size], generator=generator)

# --- Oversampling for Training Set ---
train_indices = train_dataset.indices
normal_indices = [i for i in train_indices if dataset[i][1].item() == 0]
abnormal_indices = [i for i in train_indices if dataset[i][1].item() == 1]

num_normal = len(normal_indices)
num_abnormal = len(abnormal_indices)
print(f"Training set before balancing: Normal={num_normal}, Abnormal={num_abnormal}")

if num_normal < num_abnormal:
    oversampled_indices = random.choices(normal_indices, k=num_abnormal - num_normal)
    new_train_indices = train_indices + oversampled_indices
    print(f"Oversampling Normal class: added {len(oversampled_indices)} samples")
elif num_abnormal < num_normal:
    oversampled_indices = random.choices(abnormal_indices, k=num_normal - num_abnormal)
    new_train_indices = train_indices + oversampled_indices
    print(f"Oversampling Abnormal class: added {len(oversampled_indices)} samples")
else:
    new_train_indices = train_indices
    print("Classes are already balanced")

random.shuffle(new_train_indices)
train_dataset = torch.utils.data.Subset(dataset, new_train_indices)

# Create an "original" dataset for displaying purposes
original_dataset = ImageDataset(normal_dir, abnormal_dirs, transform=None)
original_train_dataset = torch.utils.data.Subset(original_dataset, new_train_indices)

# Create datasets with validation/test transforms
val_dataset_with_transform = torch.utils.data.Subset(
    ImageDataset(normal_dir, abnormal_dirs, transform=val_test_transform),
    val_dataset.indices
)
test_dataset_with_transform = torch.utils.data.Subset(
    ImageDataset(normal_dir, abnormal_dirs, transform=val_test_transform),
    test_dataset.indices
)

# --- OPTIMIZATION: Create DataLoaders using multiple CPU cores ---
# num_workers > 0 enables multi-process data loading.
# pin_memory=True speeds up CPU to GPU data transfer.
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=NUM_CPUS)
val_loader = DataLoader(val_dataset_with_transform, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=NUM_CPUS)
test_loader = DataLoader(test_dataset_with_transform, batch_size=BATCH_SIZE, shuffle=False, pin_memory=True, num_workers=NUM_CPUS)


# The rest of the cell remains the same...
# Count images per class in the full dataset
normal_count = sum(1 for _, label in dataset if label.item() == 0)
abnormal_count = sum(1 for _, label in dataset if label.item() == 1)

print(f"\nFull dataset:")
print(f"Number of normal images: {normal_count}")
print(f"Number of abnormal images: {abnormal_count}")
print(f"Total images: {len(dataset)}")

class_names = ["Normal", "Abnormal"]

def display_random_images_from_dataset(dataset, save=False, filename="random_images.png", num_images=4):
    indices = random.sample(range(len(dataset)), min(num_images, len(dataset)))
    fig, axes = plt.subplots(1, len(indices), figsize=(15, 5))
    if len(indices) == 1:
        axes = [axes]
    for i, idx in enumerate(indices):
        image, label = dataset[idx]
        image = image.cpu() if isinstance(image, torch.Tensor) and image.device.type != 'cpu' else image
        if isinstance(image, torch.Tensor):
            image = image.cpu().numpy().transpose((1, 2, 0))
            image = image * np.array(STD) + np.array(MEAN)
            image = np.clip(image, 0, 1)
        axes[i].imshow(image)
        axes[i].axis('off')
        label_idx = label.item() if torch.is_tensor(label) else label
        axes[i].set_title(f"Label: {class_names[label_idx]}")
    plt.tight_layout()
    if save:
        plt.savefig(filename)
        print(f"Saved random images to {filename}")
    plt.show()

print("\nRandom Images from Training Set:")
display_random_images_from_dataset(train_dataset, save=False, filename="random_train_images.png")

def display_augmented_images(original_dataset, augmented_dataset, save=False, base_filename="augmented_", num_images=4):
    indices = random.sample(range(len(original_dataset)), min(num_images, len(original_dataset)))

    for i, idx in enumerate(indices):
        orig_image, label = original_dataset[idx]
        label_idx = label.item() if torch.is_tensor(label) else label
        file_path = original_dataset.dataset.image_paths[original_dataset.indices[idx]]

        augmented_images = []
        for _ in range(10):
            augmented_image, _ = augmented_dataset[idx]
            augmented_images.append(augmented_image)

        fig_aug, axes_aug = plt.subplots(2, 5, figsize=(15, 6))
        fig_aug.suptitle(f"Class: {class_names[label_idx]}\nFilename: {os.path.basename(file_path)}", fontsize=14)
        axes_aug = axes_aug.flatten()

        for j, aug_img in enumerate(augmented_images):
            aug_disp = aug_img.numpy().transpose((1, 2, 0))
            aug_disp = aug_disp * np.array(STD) + np.array(MEAN)
            aug_disp = np.clip(aug_disp, 0, 1)
            axes_aug[j].imshow(aug_disp)
            axes_aug[j].axis('off')
            axes_aug[j].set_title(f"Augmented {j+1}")

        plt.tight_layout()
        if save:
            out_filename = f"{base_filename}{class_names[label_idx]}_{i+1}.png"
            plt.savefig(out_filename)
        plt.show()

print("\nSome Augmented Images:")
display_augmented_images(original_train_dataset, train_dataset, save=False)

In [None]:
import torch

torch.cuda.empty_cache()
torch.cuda.synchronize()  # optional: waits for all kernels to finish

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torchvision.models.resnet import Bottleneck
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import os
import pickle
# --- OPTIMIZATION: Import for Automatic Mixed Precision (AMP) ---
from torch.cuda.amp import GradScaler, autocast

# --- OPTIMIZATION: Configuration updated for new batch size ---
num_epochs = 30
patience = 5
# Batch size was 32, now 128 (4x). Scale learning rate by 4x.
learning_rate = 4e-4  # Previously 1e-4
weight_decay = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"Learning rate adjusted to {learning_rate} for larger batch size.")

# Define Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction
        self.criterion = nn.BCEWithLogitsLoss(reduction='none')

    def forward(self, inputs, targets):
        BCE_loss = self.criterion(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt)**self.gamma * BCE_loss
        if self.reduction == 'mean':
            return torch.mean(F_loss)
        elif self.reduction == 'sum':
            return torch.sum(F_loss)
        else:
            return F_loss

# --- Model Definition (No changes needed here) ---
class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(nn.Linear(channel, channel // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(channel // reduction, channel, bias=False), nn.Sigmoid())
    def forward(self, x):
        b, c, _, _ = x.size(); y = self.avg_pool(x).view(b, c); y = self.fc(y).view(b, c, 1, 1); return x * y.expand_as(x)

class SEBottleneck(Bottleneck):
    expansion = 4
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None, se_reduction=16):
        super(SEBottleneck, self).__init__(inplanes, planes, stride, downsample, groups, base_width, dilation, norm_layer)
        self.se = SELayer(planes * self.expansion, reduction=se_reduction)
    def forward(self, x):
        identity = x; out = self.conv1(x); out = self.bn1(out); out = self.relu(out); out = self.conv2(out); out = self.bn2(out); out = self.relu(out); out = self.conv3(out); out = self.bn3(out); out = self.se(out)
        if self.downsample is not None: identity = self.downsample(x)
        out += identity; out = self.relu(out); return out

def get_seresnext50(num_classes=1, se_reduction=16):
    model = models.resnext50_32x4d(pretrained=True); base_width = model.base_width
    def replace_bottlenecks(module, se_reduction_ratio, base_width):
        for name, child_module in module.named_children():
            if isinstance(child_module, Bottleneck):
                inplanes = child_module.conv1.in_channels; planes = child_module.conv3.out_channels // child_module.expansion; stride = child_module.stride; downsample = child_module.downsample; groups = child_module.conv2.groups; dilation = child_module.conv2.dilation[0]
                new_bottleneck = SEBottleneck(inplanes=inplanes, planes=planes, stride=stride, downsample=downsample, groups=groups, base_width=base_width, dilation=dilation, se_reduction=se_reduction_ratio)
                new_bottleneck.load_state_dict(child_module.state_dict(), strict=False); setattr(module, name, new_bottleneck)
            else: replace_bottlenecks(child_module, se_reduction_ratio, base_width)
    replace_bottlenecks(model, se_reduction, base_width)
    in_features = model.fc.in_features; model.fc = nn.Linear(in_features, num_classes); return model

# --- Initialization (No changes needed here) ---
se_reduction_ratio = 16
model = get_seresnext50(num_classes=1, se_reduction=se_reduction_ratio)
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}")
    model = nn.DataParallel(model)
model = model.to(device)
criterion = FocalLoss(alpha=0.25, gamma=2.0)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

# --- OPTIMIZATION: Training function with Mixed Precision ---
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, patience, se_reduction=16):
    best_val_loss = float('inf')
    patience_counter = 0
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': [], 'lr': []}

    # Initialize the gradient scaler for mixed precision
    scaler = GradScaler()
    print("Training with Automatic Mixed Precision (AMP) enabled.")

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        pbar_train = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for inputs, labels in pbar_train:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True).float().view(-1, 1)
            optimizer.zero_grad()

            # Use autocast for the forward pass
            with autocast():
                outputs = model(inputs)
                loss = criterion(outputs, labels)

            # Scale the loss and call backward()
            scaler.scale(loss).backward()
            # Unscale the gradients and call optimizer.step()
            scaler.step(optimizer)
            # Update the scaler for the next iteration
            scaler.update()

            train_loss += loss.item() * inputs.size(0)
            predicted = (torch.sigmoid(outputs) > 0.5).float()
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            pbar_train.set_postfix({'loss': f'{loss.item():.4f}'})

        # ... (rest of the training loop is the same) ...

        epoch_train_loss = train_loss / len(train_loader.dataset)
        epoch_train_acc = train_correct / train_total
        history['train_loss'].append(epoch_train_loss)
        history['train_acc'].append(epoch_train_acc)
        history['lr'].append(optimizer.param_groups[0]['lr'])

        # Validation phase (autocast is recommended here too for consistency and speed)
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            pbar_val = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Val]')
            for inputs, labels in pbar_val:
                inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True).float().view(-1, 1)
                with autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                predicted = (torch.sigmoid(outputs) > 0.5).float()
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                pbar_val.set_postfix({'loss': f'{loss.item():.4f}'})

        epoch_val_loss = val_loss / len(val_loader.dataset)
        epoch_val_acc = val_correct / val_total
        history['val_loss'].append(epoch_val_loss)
        history['val_acc'].append(epoch_val_acc)

        print(f"Epoch {epoch+1}/{num_epochs}: Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f}, Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")
        scheduler.step(epoch_val_loss)

        if epoch_val_loss < best_val_loss:
            best_val_loss = epoch_val_loss
            patience_counter = 0
            checkpoint = {'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'epoch': epoch, 'val_loss': best_val_loss, 'history': history, 'config': {'learning_rate': learning_rate, 'weight_decay': weight_decay, 'batch_size': train_loader.batch_size, 'se_reduction_ratio': se_reduction}}
            torch.save(checkpoint, 'best_seresnext50_model.pth')
            print(f"Model improved! Saved checkpoint at epoch {epoch+1}")
        else:
            patience_counter += 1
            print(f"Model didn't improve for {patience_counter}/{patience} epochs")
            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break

    # Save final model regardless of performance
    final_checkpoint = {'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'epoch': epoch, 'val_loss': epoch_val_loss, 'history': history, 'config': {'learning_rate': learning_rate, 'weight_decay': weight_decay, 'batch_size': train_loader.batch_size, 'se_reduction_ratio': se_reduction}}
    torch.save(final_checkpoint, 'final_seresnext50_model.pth')
    with open('seresnext50_training_history.pkl', 'wb') as f:
        pickle.dump(history, f)
    if os.path.exists('best_seresnext50_model.pth'):
        checkpoint = torch.load('best_seresnext50_model.pth')
        model.load_state_dict(checkpoint['model_state_dict'])
    return model, history

# Start training
model, history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=num_epochs,
    patience=patience,
    se_reduction=se_reduction_ratio
)

# Plot training history
plt.figure(figsize=(15, 10))
plt.subplot(2, 2, 1); plt.plot(history['train_loss'], label='Training Loss'); plt.plot(history['val_loss'], label='Validation Loss'); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.title('Loss Curves')
plt.subplot(2, 2, 2); plt.plot(history['train_acc'], label='Training Accuracy'); plt.plot(history['val_acc'], label='Validation Accuracy'); plt.xlabel('Epoch'); plt.ylabel('Accuracy'); plt.legend(); plt.title('Accuracy Curves')
plt.subplot(2, 2, 3); plt.plot(history['lr']); plt.xlabel('Epoch'); plt.ylabel('Learning Rate'); plt.title('Learning Rate Schedule')
plt.subplot(2, 2, 4); plt.axis('off'); info_text = (f"Model: SE-ResNeXt50\nSE Reduction Ratio: {se_reduction_ratio}\nOptimizer: AdamW\nInitial LR: {learning_rate}\nWeight Decay: {weight_decay}\nLoss: Focal Loss (α={criterion.alpha}, γ={criterion.gamma})"); plt.text(0.1, 0.5, info_text, fontsize=12)
plt.tight_layout(); plt.savefig(f'seresnext50_r{se_reduction_ratio}_training_plots.png', dpi=300); plt.show()
print(f"Training completed! SE-ResNeXt50 with reduction ratio {se_reduction_ratio}")