<a href="https://colab.research.google.com/github/faizaslam11/Brain_tumor_segmentation/blob/main/Brain_Tumor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
!pip install kaggle



In [None]:
from google.colab import files
files.upload()  # Upload kaggle.json
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d awsaf49/brats20-dataset-training-validation -p /content
!unzip -q brats20-dataset-training-validation.zip -d /content/brats2020

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/content/brats2020'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
# Brain Tumor Segmentation using ResNet-120 + LSTM
# PhD Research Project - BraTS 2020 Dataset

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Deep Learning Libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import resnet101

# Medical Imaging
import nibabel as nib
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score

# Visualization
import cv2
from tqdm import tqdm

print("Libraries imported successfully!")

In [None]:
# =============================================================================
# 1. DATASET EXPLORATION AND SETUP
# =============================================================================

# In Kaggle, the dataset path is typically:
# DATASET_PATH = "/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
DATASET_PATH = "/content/brats2020/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"

# Dataset structure exploration
def explore_dataset_structure():
    """Explore the actual structure of BraTS dataset"""
    print("=== EXPLORING DATASET STRUCTURE ===")

    if not os.path.exists(DATASET_PATH):
        print("Dataset not found!")
        return None, None

    print(f"Dataset root: {DATASET_PATH}")
    root_contents = os.listdir(DATASET_PATH)
    print(f"Root contents: {root_contents}")

    # BraTS dataset typically has this structure:
    # /BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/
    # Let's search for the actual patient folders

    train_path = None
    patient_folders = []

    # Search recursively for patient folders containing .nii.gz files
    print("Searching for patient folders with .nii.gz files...")

    for root, dirs, files in os.walk(DATASET_PATH):
        # Check if this directory contains .nii.gz files
        # nii_files = [f for f in files if f.endswith('.nii.gz')]
        nii_files = [f for f in files if f.endswith('.nii.gz') or f.endswith('.nii')]

        if nii_files:
            # This is likely a patient folder
            patient_folder_name = os.path.basename(root)
            parent_dir = os.path.dirname(root)

            if patient_folder_name.startswith('BraTS'):
                if train_path is None:
                    train_path = parent_dir
                    print(f"Found training data directory: {train_path}")

                patient_folders.append(patient_folder_name)
                print(f"Found patient folder: {patient_folder_name} with {len(nii_files)} .nii.gz files")

                # Show files in first patient folder
                if len(patient_folders) == 1:
                    print(f"Sample files in {patient_folder_name}:")
                    for f in nii_files[:10]:  # Show first 10 files
                        print(f"  {f}")

    if train_path and patient_folders:
        print(f"\nSUMMARY:")
        print(f"Training data path: {train_path}")
        print(f"Total patient folders found: {len(patient_folders)}")
        print(f"Sample patients: {patient_folders[:5]}")
        return train_path, patient_folders
    else:
        print("Could not find patient folders with .nii.gz files!")

        # Let's explore the structure manually
        print("\nManual exploration:")
        base_train_path = os.path.join(DATASET_PATH, "BraTS2020_TrainingData")
        if os.path.exists(base_train_path):
            print(f"Contents of {base_train_path}:")
            for item in os.listdir(base_train_path):
                item_path = os.path.join(base_train_path, item)
                if os.path.isdir(item_path):
                    print(f"  Directory: {item}")
                    sub_contents = os.listdir(item_path)[:10]  # First 10 items
                    for sub_item in sub_contents:
                        sub_path = os.path.join(item_path, sub_item)
                        if os.path.isdir(sub_path):
                            print(f"    Subdirectory: {sub_item}")
                            # Check if this contains .nii.gz files
                            try:
                                files_in_sub = os.listdir(sub_path)
                                nii_count = len([f for f in files_in_sub if f.endswith('.nii.gz')])
                                if nii_count > 0:
                                    print(f"      Contains {nii_count} .nii.gz files")
                                    return sub_path if train_path is None else train_path, [sub_item]
                            except:
                                pass

        return None, None

# Run exploration
train_path, patient_folders = explore_dataset_structure()

In [None]:

class BraTSDataLoader:
    def __init__(self, data_path, subset_size=10):
        """
        Initialize BraTS data loader
        Args:
            data_path: Path to BraTS training data
            subset_size: Number of patients to use (for memory constraints)
        """

        self.data_path = data_path
        self.subset_size = subset_size

        # Get actual patient folders
        all_folders = os.listdir(data_path) if os.path.exists(data_path) else []
        self.patient_folders = [f for f in all_folders if 'BraTS' in f][:subset_size]
        print("Patient folders being used:", self.patient_folders)
        print(f"Using {len(self.patient_folders)} patients for training")

        if not self.patient_folders:
            print("No patient folders found!")
            return

        # Auto-detect file naming pattern
        self.file_pattern = self._detect_file_pattern()

    def _detect_file_pattern(self):
        """Auto-detect the file naming pattern in BraTS dataset"""
        if not self.patient_folders:
            return None
        sample_patient = self.patient_folders[0]
        sample_path = os.path.join(self.data_path, sample_patient)

        if not os.path.exists(sample_path):
            return None
        files = os.listdir(sample_path)
        nii_files = [f for f in files if f.endswith('.nii') or f.endswith('.nii.gz')]
        print(f"Sample patient files: {nii_files}")
        # Check for standard naming like BraTS20_Training_001_flair.nii.gz
        test_file_nii = f"{sample_patient}_flair.nii"
        test_file_nizz = f"{sample_patient}_flair.nii.gz"
        if test_file_nii in nii_files or test_file_niigz in nii_files:
            print("✅ Detected standard BraTS naming convention.")
            return lambda patient_id, modality: f"{patient_id}_{modality}.nii.gz"
		# Otherwise fallback to flexible pattern matching
        else:
            print("⚠️ Using flexible file pattern matching.")
            return lambda patient_id, modality: self._find_modality_file(
            os.path.join(self.data_path, patient_id), modality
    )
    def _find_modality_file(self, patient_path, modality):
        """Find file by modality name in the filename"""
        files = os.listdir(patient_path)
        for file in files:
            if modality.lower() in file.lower() and file.endswith('.nii.gz'):
                return file
        return None

    def _flexible_pattern(self, patient_id, modality):
        """Flexible pattern matching"""
        patient_path = os.path.join(self.data_path, patient_id)
        return self._find_modality_file(patient_path, modality)

    def load_patient_data(self, patient_id):
        """Load all modalities for a single patient"""
        patient_path = os.path.join(self.data_path, patient_id)

        if not os.path.exists(patient_path):
            print(f"Patient path not found: {patient_path}")
            return {}

        # Try to find files for each modality
        modalities = ['flair', 't1', 't1ce', 't2', 'seg']
        files = {}

        for modality in modalities:
            filename = None
            if self.file_pattern:
                filename = self.file_pattern(patient_id, modality)
                if filename:
                    # filepath = os.path.join(patient_path, filename)
                    nii_gz_path = os.path.join(patient_path, filename)
                    nii_path = nii_gz_path.replace(".nii.gz", ".nii")
                    if os.path.exists(nii_gz_path):
                        files[modality] = nii_gz_path
                    elif os.path.exists(nii_path):
                        files[modality] = nii_path
                    else:
                        print(f"[Missing] {modality.upper()} not found at {nii_gz_path} or {nii_path}")
                else:
                    print(f"[Pattern Fail] Could not generate filename for {modality} using pattern.")

        # Load the data
        data = {}
        for modality, filepath in files.items():
            try:
                nii_img = nib.load(filepath)
                data[modality] = nii_img.get_fdata()
                print(f"Loaded {modality}: {data[modality].shape}")
            except Exception as e:
                print(f"Error loading {filepath}: {e}")

        return data

    def preprocess_volume(self, volume, target_size=(128, 128), is_mask=False):
        """Preprocess 3D volume to 2D slices"""
        # Normalize to 0-1 range
        if not is_mask:
            volume = (volume - volume.min()) / (volume.max() - volume.min() + 1e-8)

        processed_slices = []
        for slice_idx in range(volume.shape[2]):
            slice_2d = volume[:, :, slice_idx]

            # Resize slice
            if is_mask:
                slice_resized = cv2.resize(slice_2d, target_size, interpolation=cv2.INTER_NEAREST)
            else:
                slice_resized = cv2.resize(slice_2d, target_size, interpolation=cv2.INTER_LINEAR)

            processed_slices.append(slice_resized)

        return np.array(processed_slices)

    def create_dataset(self):
        """Create dataset with all patients"""
        all_sequences = []
        all_masks = []

        for patient_id in tqdm(self.patient_folders, desc="Processing patients"):
            try:
                patient_data = self.load_patient_data(patient_id)

                if all(mod in patient_data for mod in ['flair', 't1ce', 'seg']):
                    # Process each modality
                    flair_slices = self.preprocess_volume(patient_data['flair'])
                    t1ce_slices = self.preprocess_volume(patient_data['t1ce'])
                    seg_slices = self.preprocess_volume(patient_data['seg'], is_mask=True)
                    seg_slices[seg_slices == 4] = 3
                    # Combine modalities (using FLAIR and T1CE as example)
                    combined_slices = np.stack([flair_slices, t1ce_slices], axis=1)

                    # Filter out empty slices (slices with no tumor)
                    for i in range(len(seg_slices)):
                        print(f"[{patient_id}] Slice {i} sum: {np.sum(seg_slices[i])}")
                        # if np.sum(seg_slices[i]) > 0:  # Has tumor pixels
                        all_sequences.append(combined_slices[i])
                        all_masks.append(seg_slices[i])

            except Exception as e:
                print(f"Error processing {patient_id}: {e}")
                continue

        return np.array(all_sequences), np.array(all_masks)

In [None]:
# Initialize with 50 patients
data_loader = BraTSDataLoader(data_path=train_path, subset_size=50)

# Actually load and preprocess the data
sequences, masks = data_loader.create_dataset()

# Print summarys
print(f"\n✅ Final dataset shapes:")
print(f"Sequences shape: {sequences.shape}")
print(f"Masks shape: {masks.shape}")


In [None]:
class BraTSDataset(Dataset):
    def __init__(self, sequences, masks, sequence_length=8):
        """
        Args:
            sequences: 4D array (num_slices, channels, height, width)
            masks: 3D array (num_slices, height, width)
            sequence_length: Number of consecutive slices for LSTM
        """
        self.sequences = sequences
        self.masks = masks
        self.sequence_length = sequence_length

        # Create sequence indices
        self.valid_indices = []
        current_patient_slices = 0

        for i in range(len(sequences) - sequence_length + 1):
            # Check if we have enough consecutive slices
            if current_patient_slices >= sequence_length - 1:
                self.valid_indices.append(i - sequence_length + 1)
            current_patient_slices += 1

    def __len__(self):
        return len(self.valid_indices)

    def __getitem__(self, idx):
        start_idx = self.valid_indices[idx]

        # Get sequence of slices
        sequence = self.sequences[start_idx:start_idx + self.sequence_length]
        mask = self.masks[start_idx + self.sequence_length - 1]  # Predict last slice

        # Convert to tensors
        sequence = torch.FloatTensor(sequence)
        mask = torch.LongTensor(mask)

        return sequence, mask

In [None]:
# train_transform = A.Compose([
#     A.HorizontalFlip(p=0.5),
#     A.VerticalFlip(p=0.5),
#     A.RandomRotate90(p=0.5),
#     A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=15, p=0.5),
#     A.GaussianBlur(p=0.2),
#     A.RandomBrightnessContrast(p=0.2),
#     A.Normalize(mean=(0.5, 0.5), std=(0.5, 0.5)),
#     ToTensorV2()
# ])

# val_transform = A.Compose([
#     A.Normalize(mean=(0.5, 0.5), std=(0.5, 0.5)),
#     ToTensorV2()
# ])


In [None]:
class ResNet120Backbone(nn.Module):
    def __init__(self, input_channels=2, pretrained=True):
        super(ResNet120Backbone, self).__init__()

        # Start with ResNet-101 and modify
        resnet = resnet101(pretrained=pretrained)

        # Modify first conv layer for multi-channel input
        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        if pretrained:
            # Initialize new conv layer with pretrained weights (average across channels)
            with torch.no_grad():
                self.conv1.weight = nn.Parameter(
                    resnet.conv1.weight.mean(dim=1, keepdim=True).repeat(1, input_channels, 1, 1)
                )

        # Copy other layers from ResNet-101
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3
        self.layer4 = resnet.layer4

        # Add additional layers to make it ResNet-120
        self.layer5 = nn.Sequential(
            nn.Conv2d(2048, 2048, kernel_size=3, padding=1),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True),
            nn.Conv2d(2048, 2048, kernel_size=3, padding=1),
            nn.BatchNorm2d(2048),
            nn.ReLU(inplace=True)
        )

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)

        return x

class ResNet120LSTM(nn.Module):
    def __init__(self, input_channels=2, num_classes=4, hidden_size=512, num_layers=2,pretrained=False):
        super(ResNet120LSTM, self).__init__()

        self.num_classes = num_classes
        self.hidden_size = hidden_size

        # ResNet backbone for feature extraction
        self.backbone = ResNet120Backbone(input_channels, pretrained=pretrained)
        backbone_output_size = 2048

        # Bidirectional LSTM
        self.lstm = nn.LSTM(
            input_size=backbone_output_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            bidirectional=True,
            dropout=0.3
        )

        # Classification head for segmentation
        lstm_output_size = hidden_size * 2  # bidirectional
        self.classifier = nn.Sequential(
            nn.Linear(lstm_output_size, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128 * 128 * num_classes)  # Output for 128x128 segmentation
        )

    def forward(self, x):
        batch_size, seq_len, channels, height, width = x.shape

        # Process each slice through ResNet backbone
        features = []
        for i in range(seq_len):
            slice_features = self.backbone(x[:, i])
            features.append(slice_features)

        # Stack features for LSTM input
        features = torch.stack(features, dim=1)  # (batch, seq_len, feature_dim)

        # LSTM processing
        lstm_out, _ = self.lstm(features)

        # Use the last output for segmentation
        last_output = lstm_out[:, -1, :]  # (batch, hidden_size*2)

        # Generate segmentation map
        segmentation = self.classifier(last_output)
        segmentation = segmentation.view(batch_size, self.num_classes, 128, 128)

        return segmentation

In [None]:
def dice_coefficient(pred, target, smooth=1e-6):
    """Calculate Dice coefficient"""
    pred = torch.softmax(pred, dim=1)
    pred = torch.argmax(pred, dim=1)

    dice_scores = []
    for class_idx in range(1, 4):  # Classes 1, 2, 3 (excluding background)
        pred_class = (pred == class_idx).float()
        target_class = (target == class_idx).float()

        intersection = (pred_class * target_class).sum()
        union = pred_class.sum() + target_class.sum()

        dice = (2. * intersection + smooth) / (union + smooth)
        dice_scores.append(dice.item())

    return np.mean(dice_scores)

def combined_loss(pred, target, alpha=0.5):
    """Combine Cross Entropy and Dice Loss"""
    ce_loss = nn.CrossEntropyLoss()(pred, target)

    # Dice loss
    pred_soft = torch.softmax(pred, dim=1)
    dice_loss = 0
    for class_idx in range(1, 4):
        pred_class = pred_soft[:, class_idx]
        target_class = (target == class_idx).float()

        intersection = (pred_class * target_class).sum()
        union = pred_class.sum() + target_class.sum()
        dice = (2. * intersection + 1e-6) / (union + 1e-6)
        dice_loss += (1 - dice)

    dice_loss /= 3  # Average over classes

    return alpha * ce_loss + (1 - alpha) * dice_loss


In [None]:
from IPython.display import FileLink
checkpoint_path='checkpoint_epoch_last.pth'

def train_model(model, train_loader, val_loader, num_epochs=12, learning_rate=0.001, checkpoint_path='checkpoint_epoch_last.pth'):
    """Train the ResNet-120 + LSTM model"""

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on device: {device}")

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
# Resume training if checkpoint exists
    resume_epoch = 0
    total_epochs = num_epochs
    train_losses = []
    val_losses = []
    train_dice_scores = []
    val_dice_scores = []

    if os.path.exists(checkpoint_path):
        print(f"🔁 Resuming from checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        resume_epoch = checkpoint['epoch'] + 1
        print(f"Resuming from epoch {resume_epoch}")
      # 🔁 Load previous history if available
        train_losses = checkpoint.get('train_losses', [])
        val_losses = checkpoint.get('val_losses', [])
        train_dice_scores = checkpoint.get('train_dice_scores', [])
        val_dice_scores = checkpoint.get('val_dice_scores', [])
        print(f"✅ Resuming training from epoch {resume_epoch}")
    else:
        print("🆕 No checkpoint found. Starting fresh training.")

    for epoch in range(resume_epoch, num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_dice = 0.0

        for batch_idx, (sequences, masks) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            sequences, masks = sequences.to(device), masks.to(device)

            optimizer.zero_grad()
            outputs = model(sequences)
            loss = combined_loss(outputs, masks)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_dice += dice_coefficient(outputs, masks)

            if batch_idx % 10 == 0:
                print(f'Batch {batch_idx}, Loss: {loss.item():.4f}')

        # Validation phase
        model.eval()
        val_loss = 0.0
        val_dice = 0.0

        with torch.no_grad():
            for sequences, masks in val_loader:
                sequences, masks = sequences.to(device), masks.to(device)
                outputs = model(sequences)
                loss = combined_loss(outputs, masks)

                val_loss += loss.item()
                val_dice += dice_coefficient(outputs, masks)

        # Calculate averages
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        avg_train_dice = train_dice / len(train_loader)
        avg_val_dice = val_dice / len(val_loader)

        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        train_dice_scores.append(avg_train_dice)
        val_dice_scores.append(avg_val_dice)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {avg_train_loss:.4f}, Train Dice: {avg_train_dice:.4f}')
        print(f'  Val Loss: {avg_val_loss:.4f}, Val Dice: {avg_val_dice:.4f}')

        scheduler.step()
    # Save checkpoint
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_dice_scores': train_dice_scores,
            'val_dice_scores': val_dice_scores
        }, checkpoint_path)
        print(f"✅ Checkpoint saved at epoch {epoch+1}")
        try:
            display(FileLink(checkpoint_path))
            print("Click below to download the checkpoint manually (before session ends):")
        except:
            print(" Could not create FileLink. Run this cell in an interactive environment.")

    print("✅ Training complete.")
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_dice_scores': train_dice_scores,
        'val_dice_scores': val_dice_scores,
        'resume_epoch': resume_epoch
        }

In [None]:
def visualize_results(model, dataset, device, num_samples=5):
    """Visualize segmentation results"""
    model.eval()

    fig, axes = plt.subplots(3, num_samples, figsize=(15, 9))

    with torch.no_grad():
        for i in range(num_samples):
            sequences, true_mask = dataset[i]
            print(f"[Sample {i}] True mask unique values:", np.unique(true_mask.numpy()))
            sequences = sequences.unsqueeze(0).to(device)

            # Get prediction
            pred = model(sequences)
            pred_mask = torch.softmax(pred, dim=1)
            pred_mask = torch.argmax(pred_mask, dim=1).cpu().numpy()[0]

            # Original image (FLAIR channel)
            original = sequences[0, -1, 0].cpu().numpy()
            true_mask = true_mask.numpy()

            # Plot
            axes[0, i].imshow(original, cmap='gray')
            axes[0, i].set_title(f'Original {i+1}')
            axes[0, i].axis('off')

            axes[1, i].imshow(true_mask, cmap='jet')
            axes[1, i].set_title(f'True Mask {i+1}')
            axes[1, i].axis('off')

            axes[2, i].imshow(pred_mask, cmap='jet')
            axes[2, i].set_title(f'Predicted {i+1}')
            axes[2, i].axis('off')

    plt.tight_layout()
    plt.show()

def plot_training_curves(history):
    """Plot training curves"""
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))

    # Loss curves
    ax1.plot(history['train_losses'], label='Train Loss')
    ax1.plot(history['val_losses'], label='Validation Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)

    # Dice score curves
    ax2.plot(history['train_dice_scores'], label='Train Dice')
    ax2.plot(history['val_dice_scores'], label='Validation Dice')
    ax2.set_title('Training and Validation Dice Score')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Dice Score')
    ax2.legend()
    ax2.grid(True)

    # Loss difference
    loss_diff = np.array(history['val_losses']) - np.array(history['train_losses'])
    ax3.plot(loss_diff)
    ax3.set_title('Validation - Training Loss (Overfitting Check)')
    ax3.set_xlabel('Epoch')
    ax3.set_ylabel('Loss Difference')
    ax3.grid(True)

    # Dice improvement
    dice_improvement = np.array(history['val_dice_scores']) - history['val_dice_scores'][0]
    ax4.plot(dice_improvement)
    ax4.set_title('Validation Dice Score Improvement')
    ax4.set_xlabel('Epoch')
    ax4.set_ylabel('Dice Improvement')
    ax4.grid(True)

    plt.tight_layout()
    plt.show()

In [None]:

def main():
    """Main execution function"""
    print("Starting Brain Tumor Segmentation Training...")

    # First, explore dataset structure
    train_path, patient_folders = explore_dataset_structure()
#     train_path = "/kaggle/input/brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData"
# patient_folders = os.listdir(train_path)


    if not train_path or not patient_folders:
        print("ERROR: Could not find proper dataset structure!")
        print("Please check that the BraTS dataset is properly added to your Kaggle notebook.")
        return

    # Load and preprocess data
    print("\nLoading and preprocessing data...")
    data_loader = BraTSDataLoader(
        data_path=train_path,
        subset_size=50
    )

    if not data_loader.patient_folders:
        print("ERROR: No patient folders found!")
        return

    sequences, masks = data_loader.create_dataset()
    print(f"Dataset created: {sequences.shape} sequences, {masks.shape} masks")

    if len(sequences) == 0:
        print("ERROR: No data loaded!")
        print("This might be due to:")
        print("1. Incorrect file naming pattern")
        print("2. Missing modality files")
        print("3. Empty/corrupt data files")
        return

    print("SUCCESS: Data loaded successfully!")
    print(f"Total sequences: {len(sequences)}")
    print(f"Sequence shape: {sequences[0].shape}")
    print(f"Mask shape: {masks[0].shape}")

    # If we reach here, data loading works - continue with training
    # For now, let's just test with a tiny dataset
    if len(sequences) < 10 or sequences.shape[0] < 8:
        print("❗ Not enough sequences for training. Increase subset_size or reduce sequence_length.")
        return
        # print(f"Warning: Only {len(sequences)} sequences available. Need more data for proper training.")
        # print("Consider increasing subset_size or checking data quality.")
        # return

    # Create train/validation split
    train_sequences, val_sequences, train_masks, val_masks = train_test_split(
        sequences, masks, test_size=0.2, random_state=42
    )

    # Create datasets and dataloaders
    train_dataset = BraTSDataset(train_sequences, train_masks, sequence_length=8)
    val_dataset = BraTSDataset(val_sequences, val_masks, sequence_length=8)

    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=2)

    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")

    print("🔍 Checking unique values in validation masks:")
    for i in range(5):
        _, true_mask = val_dataset[i]
        print(f"[Sample {i}] Unique labels in mask:", np.unique(true_mask.numpy()))

    # Initialize model
    model = ResNet120LSTM(input_channels=2, num_classes=4, pretrained=False)

    # Train model
    print("Starting training...")
    history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=12
        ,
        learning_rate=0.001,
        checkpoint_path='checkpoint.pth'
    )
    print("Training resumed from epoch:", history.get('resume_epoch', 0))
    # Plot results
    plot_training_curves(history)

    # Visualize some results
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    visualize_results(model, val_dataset, device, num_samples=5)

    # Save model
    torch.save(model.state_dict(), 'resnet120_lstm_brain_tumor.pth')
    print("Model saved as 'resnet120_lstm_brain_tumor.pth'")

    print("Training completed successfully!")

# Run the main function
if __name__ == "__main__":
    main()