<a href="https://colab.research.google.com/github/chjayarajesh/Brain-Stroke-detection-and-Segmentation/blob/main/Stroke_detection_and_segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **Brain Stroke classification and segmentation**

# **Using Kaggle API to import dataset**

In [None]:
!pip install -q kaggle timm

# Upload kaggle.json from your Kaggle account (Account → API → Create New Token)
from google.colab import files
files.upload()

# Move kaggle.json to correct path
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


Downloading Dataset

In [None]:
# Download the Brain Stroke CT Dataset
!kaggle datasets download -d ozguraslank/brain-stroke-ct-dataset -p /content/dataset

# Unzip the dataset
!unzip -q /content/dataset/brain-stroke-ct-dataset.zip -d /content/dataset

# **Data Segregation and splitting into (Train, validate and test)**

In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split
import cv2  # Added import for OpenCV

# Step 1: Segregate images and overlays into train, val, test
def segregate_dataset(base_dir, classes, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15):
    for cls in classes:
        png_dir = os.path.join(base_dir, cls, 'PNG')
        overlay_dir = os.path.join(base_dir, cls, 'OVERLAY')
        if not os.path.exists(png_dir):
            print(f"PNG folder not found for {cls}. Skipping...")
            continue

        # Get list of PNG images
        images = [os.path.join(png_dir, f) for f in os.listdir(png_dir) if f.lower().endswith('.png')]
        if len(images) == 0:
            print(f"No PNG images found for {cls}. Skipping...")
            continue

        # Split into train/val/test
        train_images, test_images = train_test_split(images, test_size=val_ratio + test_ratio, random_state=42)
        val_images, test_images = train_test_split(test_images, test_size=test_ratio / (val_ratio + test_ratio), random_state=42)

        # Create directories for PNG and OVERLAY
        for split, split_images in [('train', train_images), ('val', val_images), ('test', test_images)]:
            split_png_dir = os.path.join(base_dir, split, 'PNG', cls)
            split_overlay_dir = os.path.join(base_dir, split, 'OVERLAY', cls)
            os.makedirs(split_png_dir, exist_ok=True)
            os.makedirs(split_overlay_dir, exist_ok=True)

            # Copy PNG images
            for img in split_images:
                filename = os.path.basename(img)
                shutil.copy(img, os.path.join(split_png_dir, filename))

            # Copy corresponding OVERLAY images or create zero mask for Normal
            for img in split_images:
                filename = os.path.basename(img)
                overlay_path = os.path.join(overlay_dir, filename)
                if os.path.exists(overlay_path):
                    shutil.copy(overlay_path, os.path.join(split_overlay_dir, filename))
                elif cls == 'Normal' and not os.path.exists(overlay_dir):
                    # Create a zero mask for Normal if no overlay exists
                    zero_mask = np.zeros((256, 256), dtype=np.uint8)  # Adjust size if needed
                    cv2.imwrite(os.path.join(split_overlay_dir, filename), zero_mask)
                    print(f"Created zero mask for {filename} in {cls}")
                else:
                    print(f"Overlay not found for {filename}, skipping copy.")

        print(f"{cls}: Train {len(train_images)}, Val {len(val_images)}, Test {len(test_images)}")

# Base directory (adjust to your Colab path)
base_dir = '/content/dataset/Brain_Stroke_CT_Dataset'  # Update to your actual path
classes = ['Bleeding', 'Ischemia', 'Normal']  # Adjust based on your dataset

# Run segregation
segregate_dataset(base_dir, classes)

# **1. Classification Model training**

---


Preproccesing and augmentation of dataset

---

ConvNext model -(Convolutional Next)


In [None]:
import os
import shutil
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from torchvision import transforms
import timm
from tqdm import tqdm  # For progress bar

base_dir = '/content/dataset/Brain_Stroke_CT_Dataset'  # Update to your actual path in Colab

# Step 2: Preprocessing and Training
# Custom transform to replicate grayscale to 3 channels
class GrayscaleToRGB:
    def __call__(self, image):
        if image.shape[0] == 1:
            return torch.cat([image] * 3, dim=0)
        return image

# Custom Dataset with 3-channel replication
class CTDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = sorted(os.listdir(data_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.class_names)}

        for cls_name in self.class_names:
            cls_dir = os.path.join(data_dir, cls_name)
            if os.path.isdir(cls_dir):
                for img_name in os.listdir(cls_dir):
                    img_path = os.path.join(cls_dir, img_name)
                    if os.path.isfile(img_path):
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[cls_name])
        if not self.images:
            raise ValueError(f"No images found in {data_dir}. Check your dataset path and ensure it contains image subfolders.")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        try:
            image = Image.open(img_path).convert('L')  # Load as grayscale
            if self.transform:
                image = self.transform(image)  # Apply full transform pipeline
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            return torch.zeros(3, 224, 224), label  # Return 3-channel placeholder

# Load datasets (after segregation, use the train/val/test directories)
train_dir = os.path.join(base_dir, 'train')
val_dir = os.path.join(base_dir, 'val')
test_dir = os.path.join(base_dir, 'test')

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    GrayscaleToRGB(),  # Replicate to 3 channels after ToTensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 3-channel normalization
])

val_test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    GrayscaleToRGB(),  # Replicate to 3 channels after ToTensor
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

train_dataset = CTDataset(train_dir, transform=train_transform)
val_dataset = CTDataset(val_dir, transform=val_test_transform)
test_dataset = CTDataset(test_dir, transform=val_test_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
print(f"Classes: {train_dataset.class_names}")

# Set device
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Load pre-trained ConvNeXt-Base model
model = timm.create_model('convnext_base', pretrained=True)
model.reset_classifier(num_classes=3)  # 3 classes: hemorrhagic, ischemic, normal
model = model.to(device)


# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Training function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10):
    best_acc = 0.0
    val_accuracies = []
    val_precisions = []
    val_recalls = []
    val_f1_scores = []
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        print(f"Starting epoch {epoch + 1}/{num_epochs}")
        with tqdm(total=len(train_loader), desc=f"Epoch {epoch + 1}/{num_epochs}") as pbar:
            for i, (inputs, labels) in enumerate(train_loader):
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})
                pbar.update(1)

        # Validation
        model.eval()
        val_loss = 0.0
        all_preds = []
        all_labels = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, preds = torch.max(outputs, 1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        epoch_val_loss = val_loss / len(val_loader)
        epoch_acc = accuracy_score(all_labels, all_preds)
        epoch_precision = precision_score(all_labels, all_preds, average='weighted')
        epoch_recall = recall_score(all_labels, all_preds, average='weighted')
        epoch_f1 = f1_score(all_labels, all_preds, average='weighted')
        val_accuracies.append(epoch_acc)
        val_precisions.append(epoch_precision)
        val_recalls.append(epoch_recall)
        val_f1_scores.append(epoch_f1)

        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_acc:.4f}, Precision: {epoch_precision:.4f}, Recall: {epoch_recall:.4f}, F1: {epoch_f1:.4f}')

        if epoch_acc > best_acc:
            best_acc = epoch_acc
            torch.save(model.state_dict(), 'best_model_3class.pth')

    # Plot graphs for 4 parameters
    epochs = range(1, num_epochs + 1)
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))
    fig.suptitle('Validation Metrics Over Epochs')

    # Accuracy graph
    axs[0, 0].plot(epochs, val_accuracies, label='Accuracy', color='b')
    axs[0, 0].set_title('Accuracy')
    axs[0, 0].set_xlabel('Epoch')
    axs[0, 0].set_ylabel('Score')
    axs[0, 0].legend()

    # Precision graph
    axs[0, 1].plot(epochs, val_precisions, label='Precision', color='r')
    axs[0, 1].set_title('Precision')
    axs[0, 1].set_xlabel('Epoch')
    axs[0, 1].set_ylabel('Score')
    axs[0, 1].legend()

    # Recall graph
    axs[1, 0].plot(epochs, val_recalls, label='Recall', color='g')
    axs[1, 0].set_title('Recall')
    axs[1, 0].set_xlabel('Epoch')
    axs[1, 0].set_ylabel('Score')
    axs[1, 0].legend()

    # F1 Score graph
    axs[1, 1].plot(epochs, val_f1_scores, label='F1 Score', color='m')
    axs[1, 1].set_title('F1 Score')
    axs[1, 1].set_xlabel('Epoch')
    axs[1, 1].set_ylabel('Score')
    axs[1, 1].legend()

    plt.tight_layout()
    plt.savefig('validation_metrics_graphs.png')  # Save the figure
    plt.show()
    return val_accuracies, val_precisions, val_recalls, val_f1_scores

# Evaluate function
def evaluate_model(model, test_loader):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    cm = confusion_matrix(all_labels, all_preds)

    print(f'Accuracy: {accuracy:.4f}')
    print(f'Precision: {precision:.4f}')
    print(f'Recall: {recall:.4f}')
    print(f'F1 Score: {f1:.4f}')

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Hemorrhagic', 'Ischemic', 'Normal'], yticklabels=['Hemorrhagic', 'Ischemic', 'Normal'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.savefig('confusion_matrix.png')  # Save the confusion matrix
    plt.show()

# Train the model and get metrics
val_accuracies, val_precisions, val_recalls, val_f1_scores = train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=10)

# Load best model and evaluate
model.load_state_dict(torch.load('best_model_3class.pth'))
evaluate_model(model, test_loader)

# **Segmentation Model training**
---


Preproccesing and augmentation of dataset

---

U-Net with efficientnet-b4 encoder

In [None]:
# Step 1: Install required packages
!pip install segmentation-models-pytorch -q

# Step 3: Preprocessing and Training
import os
import cv2
from glob import glob
import numpy as np
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torch
import torch.optim as optim  # Added import
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp
from tqdm import tqdm

# Configuration
class Config:
    BATCH_SIZE = 8
    IMG_HEIGHT = 256
    IMG_WIDTH = 256
    EPOCHS = 50
    LEARNING_RATE = 1e-4
    DATASET_PATH = '/content/brain-stroke-images'  # Root path
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_WORKERS = 2

config = Config()

# Preprocessing: Generate pseudo-masks for stroke images
def generate_pseudo_mask(image, threshold=0.1):  # Simple thresholding
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    _, mask = cv2.threshold(gray, threshold * 255, 255, cv2.THRESH_BINARY)
    mask = mask / 255.0
    return mask.astype(np.float32)

# Load data from cropped folders
def load_segmentation_data(base_dir):
    image_paths = []
    mask_paths = []  # Pseudo-masks generated on-the-fly
    labels = []  # 1 for STROKE, 0 for NORMAL

    for split in ['TRAIN_CROP', 'VAL_CROP', 'TEST_CROP']:
        for category in ['NORMAL', 'STROKE']:
            category_dir = os.path.join(base_dir, 'stroke_cropped/CROPPED', split, category)
            if os.path.exists(category_dir):
                images = glob(os.path.join(category_dir, '*.jpg'))
                for img in images:
                    image_paths.append(img)
                    mask_paths.append(None)  # Will generate pseudo-mask
                    labels.append(1 if category == 'STROKE' else 0)

    print(f"Loaded {len(image_paths)} images: {np.sum(labels)} stroke, {len(labels) - np.sum(labels)} normal.")
    return image_paths, mask_paths, labels

image_paths, mask_paths, labels = load_segmentation_data(config.DATASET_PATH)
if len(image_paths) == 0:
    raise ValueError("No images found! Check dataset path.")

# Split into train/val/test
train_images, test_images, train_labels, test_labels = train_test_split(image_paths, labels, test_size=0.2, random_state=42, stratify=labels)
val_images, test_images, val_labels, test_labels = train_test_split(test_images, test_labels, test_size=0.5, random_state=42, stratify=test_labels)

print(f"Training samples: {len(train_images)}")
print(f"Validation samples: {len(val_images)}")
print(f"Test samples: {len(test_images)}")

# Dataset class
class StrokeDataset(Dataset):
    def __init__(self, image_paths, labels, transforms=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = cv2.imread(self.image_paths[idx])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (config.IMG_WIDTH, config.IMG_HEIGHT))

        label = self.labels[idx]
        if label == 1:  # Stroke: generate pseudo-mask
            msk = generate_pseudo_mask(img, threshold=0.1)
        else:  # Normal: zero mask
            msk = np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH), dtype=np.float32)

        if self.transforms:
            transformed = self.transforms(image=img, mask=msk)
            img, msk = transformed['image'], transformed['mask']

        return img, msk

# Transforms
def get_transforms(is_training=True):
    if is_training:
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(brightness_limit=0.05, contrast_limit=0.05, p=0.3),
            A.Affine(translate_percent={"x": (-0.05, 0.05), "y": (-0.05, 0.05)}, scale=(0.95, 1.05), rotate=(-5, 5), p=0.3),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

train_dataset = StrokeDataset(train_images, train_labels, get_transforms(True))
val_dataset = StrokeDataset(val_images, val_labels, get_transforms(False))
test_dataset = StrokeDataset(test_images, test_labels, get_transforms(False))

train_loader = DataLoader(train_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=config.NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config.BATCH_SIZE, shuffle=False, num_workers=config.NUM_WORKERS, pin_memory=True)

# Segmentation Model
model = smp.Unet(
    encoder_name='efficientnet-b4',
    encoder_weights='imagenet',
    in_channels=3,
    classes=1,
    activation=None,
).to(config.DEVICE)

# Loss and Optimizer
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth
    def forward(self, y_pred_prob, y_true):
        y_pred = y_pred_prob.view(-1)
        y_true = y_true.view(-1)
        intersection = (y_pred * y_true).sum()
        dice = (2. * intersection + self.smooth) / (y_pred.sum() + y_true.sum() + self.smooth)
        return 1 - dice

class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.bce = nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
    def forward(self, y_pred_logits, y_true):
        prob = torch.sigmoid(y_pred_logits)
        return 0.5 * self.bce(y_pred_logits, y_true) + 0.5 * self.dice(prob, y_true)

criterion = CombinedLoss()
optimizer = optim.Adam(model.parameters(), lr=config.LEARNING_RATE)  # Now should work

# Training
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = running_dice = running_iou = 0.0
    progress_bar = tqdm(dataloader, desc='Training')
    for batch_idx, (images, masks) in enumerate(progress_bar):
        images = images.to(device)
        masks = masks.unsqueeze(1).to(device)

        logits = model(images)
        loss = criterion(logits, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        probs = torch.sigmoid(logits)
        dice_score = calculate_dice_score(probs, masks)
        iou_score = calculate_iou_score(probs, masks)

        running_loss += loss.item()
        running_dice += dice_score
        running_iou += iou_score

        progress_bar.set_postfix({'Loss': f'{loss.item():.4f}', 'Dice': f'{dice_score:.4f}', 'IoU': f'{iou_score:.4f}'})

    N = len(dataloader)
    return running_loss / N, running_dice / N, running_iou / N

def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = running_dice = running_iou = 0.0
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.unsqueeze(1).to(device)

            logits = model(images)
            loss = criterion(logits, masks)

            probs = torch.sigmoid(logits)
            dice_score = calculate_dice_score(probs, masks)
            iou_score = calculate_iou_score(probs, masks)

            running_loss += loss.item()
            running_dice += dice_score
            running_iou += iou_score

    N = len(dataloader)
    return running_loss / N, running_dice / N, running_iou / N

def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=50):
    best_dice = 0.0
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        train_loss, train_dice, train_iou = train_epoch(model, train_loader, criterion, optimizer, config.DEVICE)
        val_loss, val_dice, val_iou = validate_epoch(model, val_loader, criterion, config.DEVICE)

        print(f"Train Loss: {train_loss:.4f}, Train Dice: {train_dice:.4f}, Train IoU: {train_iou:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Dice: {val_dice:.4f}, Val IoU: {val_iou:.4f}")

        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), 'best_stroke_segmentation_model.pth')
            print(f"New best model saved! Dice: {best_dice:.4f}")

    model.load_state_dict(torch.load('best_stroke_segmentation_model.pth'))
    return model

# Metrics
def calculate_dice_score(y_pred_prob, y_true, smooth=1e-6):
    y_pred = (y_pred_prob > 0.5).float()
    intersection = (y_pred * y_true).sum()
    dice = (2. * intersection + smooth) / (y_pred.sum() + y_true.sum() + smooth)
    return dice.item()

def calculate_iou_score(y_pred_prob, y_true, smooth=1e-6):
    y_pred = (y_pred_prob > 0.5).float()
    intersection = (y_pred * y_true).sum()
    union = y_pred.sum() + y_true.sum() - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.item()

# Train the model
model = train_model(model, train_loader, val_loader, criterion, optimizer)

# Step 3: Show 2 Images After Segmentation
def predict_and_show(model, image_paths, num_samples=2):
    model.eval()
    transforms = get_transforms(False)
    for i in range(num_samples):
        image = cv2.imread(image_paths[i])
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_resized = cv2.resize(image_rgb, (config.IMG_WIDTH, config.IMG_HEIGHT))

        transformed = transforms(image=image_resized, mask=np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH), np.float32))
        image_tensor = transformed['image'].unsqueeze(0).to(config.DEVICE)

        with torch.no_grad():
            pred_logits = model(image_tensor)[0, 0].cpu().numpy()
            pred_prob = 1 / (1 + np.exp(-pred_logits))
            pred_mask_binary = (pred_prob > 0.5).astype(np.uint8) * 255

        plt.figure(figsize=(8, 6))
        plt.imshow(pred_mask_binary, cmap='gray')
        plt.title(f"Segmented Image {i+1}")
        plt.axis('off')
        plt.show()

# Use 2 images from test set
predict_and_show(model, test_images, num_samples=2)

## **validation code for the classification model**

In [None]:
# Install dependencies
!pip install timm scikit-learn matplotlib seaborn -q

import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
import timm
from datetime import datetime
import json
import warnings
warnings.filterwarnings('ignore')

# ====================================================
# CLASSIFICATION CONFIGURATION
# ====================================================

class ClassificationConfig:
    BASE_DIR = '/content/dataset/Brain_Stroke_CT_Dataset'
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_WORKERS = 2
    CLASSES = ['Bleeding', 'Ischemia', 'Normal']
    MODEL_PATH = 'best_model_3class.pth'
    BATCH_SIZE = 32

config = ClassificationConfig()
print(f"🔧 Classification validation - Using device: {config.DEVICE}")

# ====================================================
# CLASSIFICATION DATASET
# ====================================================

class CTDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.images = []
        self.labels = []
        self.class_names = sorted(os.listdir(data_dir))
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.class_names)}

        for cls_name in self.class_names:
            cls_dir = os.path.join(data_dir, cls_name)
            if os.path.isdir(cls_dir):
                for img_name in os.listdir(cls_dir):
                    img_path = os.path.join(cls_dir, img_name)
                    if os.path.isfile(img_path) and img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                        self.images.append(img_path)
                        self.labels.append(self.class_to_idx[cls_name])

        if not self.images:
            raise ValueError(f"No images found in {data_dir}. Check your dataset path.")

        print(f"📊 Dataset loaded: {len(self.images)} images across {len(self.class_names)} classes")
        print(f"🏷️  Classes: {self.class_names}")

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = self.labels[idx]
        try:
            image = Image.open(img_path).convert('L')  # Load as grayscale
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"⚠️  Error loading image {img_path}: {e}")
            # Return placeholder
            return torch.zeros(3, 224, 224), label

# Custom transform to replicate grayscale to 3 channels
class GrayscaleToRGB:
    def __call__(self, image):
        if image.shape[0] == 1:
            return torch.cat([image] * 3, dim=0)
        return image

# Classification transforms
classification_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    GrayscaleToRGB(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load test dataset
test_dir = os.path.join(config.BASE_DIR, 'test/PNG')
print(f"🔍 Loading classification test data from: {test_dir}")

classification_dataset = CTDataset(test_dir, classification_transform)
classification_loader = DataLoader(
    classification_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS
)

print(f"✅ Classification test samples: {len(classification_dataset)}")

# ====================================================
# CLASSIFICATION MODEL VALIDATION (FIXED)
# ====================================================

class ClassificationValidator:
    def __init__(self, model_path, class_names):
        """Initialize classification validator"""
        self.device = config.DEVICE
        self.class_names = class_names
        self.class_to_idx = {name: idx for idx, name in enumerate(class_names)}

        # Load model
        self.model = timm.create_model('convnext_base', pretrained=False, num_classes=len(class_names))
        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            print(f"✅ Classification model loaded from {model_path}")
        else:
            raise FileNotFoundError(f"❌ Classification model not found: {model_path}")

        self.model = self.model.to(self.device)
        self.model.eval()
        print(f"🔧 Classification validator initialized on {self.device}")

    def validate(self):
        """Run comprehensive classification validation - FIXED"""
        print("\n" + "="*80)
        print("🔍 CLASSIFICATION MODEL VALIDATION")
        print("="*80)

        all_preds = []
        all_labels = []
        all_probs = []
        all_filenames = []

        print(f"📊 Processing {len(classification_dataset)} test samples...")

        with torch.no_grad():
            progress_bar = tqdm(classification_loader, desc="Validating", unit="batch")
            for batch_idx, (images, labels) in enumerate(progress_bar):
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = self.model(images)
                probs = torch.softmax(outputs, dim=1)
                _, preds = torch.max(outputs, 1)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

                # Get filenames for first few batches
                if batch_idx < 3:
                    start_idx = batch_idx * config.BATCH_SIZE
                    end_idx = min((batch_idx + 1) * config.BATCH_SIZE, len(classification_dataset.images))
                    batch_filenames = [os.path.basename(img_path) for img_path in classification_dataset.images[start_idx:end_idx]]
                    all_filenames.extend(batch_filenames)

                # Progress update
                if len(all_preds) % 100 == 0:
                    batch_labels = np.array(all_labels[-100:])
                    batch_preds = np.array(all_preds[-100:])
                    batch_acc = accuracy_score(batch_labels, batch_preds)
                    progress_bar.set_postfix({'Acc': f"{batch_acc:.3f}"})

        # Convert to numpy arrays
        all_labels_np = np.array(all_labels)
        all_preds_np = np.array(all_preds)

        # Calculate metrics
        accuracy = accuracy_score(all_labels_np, all_preds_np)
        precision = precision_score(all_labels_np, all_preds_np, average='weighted', zero_division=0)
        recall = recall_score(all_labels_np, all_preds_np, average='weighted', zero_division=0)
        f1 = f1_score(all_labels_np, all_preds_np, average='weighted', zero_division=0)
        cm = confusion_matrix(all_labels_np, all_preds_np)

        # Per-class metrics
        class_accuracy = {}
        class_support = {}
        for i, class_name in enumerate(self.class_names):
            class_mask = all_labels_np == i
            class_support[class_name] = np.sum(class_mask)
            if class_support[class_name] > 0:
                class_acc = accuracy_score(all_labels_np[class_mask], all_preds_np[class_mask])
                class_accuracy[class_name] = class_acc
            else:
                class_accuracy[class_name] = 0.0

        # Print results
        print(f"\n📊 OVERALL PERFORMANCE:")
        print(f"{'Metric':<20} {'Value':<10}")
        print("-" * 30)
        print(f"{'Overall Accuracy':<20} {accuracy:<10.4f}")
        print(f"{'Weighted Precision':<20} {precision:<10.4f}")
        print(f"{'Weighted Recall':<20} {recall:<10.4f}")
        print(f"{'Weighted F1-Score':<20} {f1:<10.4f}")
        print(f"{'Total Samples':<20} {len(all_labels_np):<10}")

        print(f"\n🏆 CLASS-WISE PERFORMANCE:")
        print(f"{'Class':<12} {'Accuracy':<10} {'Support':<8}")
        print("-" * 30)
        for class_name in self.class_names:
            print(f"{class_name:<12} {class_accuracy[class_name]:<10.4f} {class_support[class_name]:<8}")

        # FIXED: Use the imported classification_report function
        print(f"\n📋 DETAILED CLASSIFICATION REPORT:")
        detailed_report = classification_report(all_labels_np, all_preds_np, target_names=self.class_names, zero_division=0)
        print(detailed_report)

        # Create output directory
        os.makedirs('classification_validation', exist_ok=True)

        # 1. Confusion Matrix
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=self.class_names,
                   yticklabels=self.class_names,
                   annot_kws={'size': 12})
        plt.title('Classification Confusion Matrix', fontsize=16, fontweight='bold')
        plt.ylabel('True Label', fontsize=12)
        plt.xlabel('Predicted Label', fontsize=12)
        plt.tight_layout()
        plt.savefig('classification_validation/confusion_matrix.png', dpi=300, bbox_inches='tight')
        plt.show()

        # 2. Class-wise Accuracy Bar Plot
        plt.figure(figsize=(12, 6))
        class_acc_values = [class_accuracy[name] for name in self.class_names]
        bars = plt.bar(self.class_names, class_acc_values,
                      color=['#ff6b6b', '#4ecdc4', '#45b7d1'], alpha=0.8, edgecolor='black')
        plt.title('Class-wise Classification Accuracy', fontsize=16, fontweight='bold')
        plt.ylabel('Accuracy', fontsize=12)
        plt.ylim(0, 1)
        plt.grid(axis='y', alpha=0.3)

        # Add value labels on bars
        for bar, acc in zip(bars, class_acc_values):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{acc:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=11)

        plt.tight_layout()
        plt.savefig('classification_validation/class_accuracy.png', dpi=300, bbox_inches='tight')
        plt.show()

        # 3. Precision-Recall-F1 Plot
        metrics_data = {
            'Precision': precision,
            'Recall': recall,
            'F1-Score': f1
        }
        plt.figure(figsize=(10, 6))
        bars = plt.bar(list(metrics_data.keys()), list(metrics_data.values()),
                      color=['#ff9f43', '#48cae4', '#06d6a0'], alpha=0.8, edgecolor='black')
        plt.title('Classification Metrics Overview', fontsize=16, fontweight='bold')
        plt.ylabel('Score', fontsize=12)
        plt.ylim(0, 1)
        plt.grid(axis='y', alpha=0.3)

        # Add value labels
        for bar, val in zip(bars, list(metrics_data.values())):
            plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                    f'{val:.3f}', ha='center', va='bottom', fontweight='bold', fontsize=11)

        plt.tight_layout()
        plt.savefig('classification_validation/metrics_overview.png', dpi=300, bbox_inches='tight')
        plt.show()

        # ====================================================
        # JSON REPORT (RENAMED TO AVOID CONFLICT)
        # ====================================================
        classification_results_report = {
            'timestamp': datetime.now().isoformat(),
            'model_info': {
                'architecture': 'ConvNeXt-Base',
                'num_classes': len(self.class_names),
                'classes': self.class_names
            },
            'dataset_info': {
                'test_samples': len(all_labels_np),
                'class_distribution': {name: int(support) for name, support in class_support.items()}
            },
            'metrics': {
                'accuracy': float(accuracy),
                'precision_weighted': float(precision),
                'recall_weighted': float(recall),
                'f1_weighted': float(f1),
                'confusion_matrix': cm.tolist()
            },
            'class_metrics': {
                name: {
                    'accuracy': float(class_accuracy[name]),
                    'support': int(class_support[name])
                } for name in self.class_names
            },
            'detailed_report': detailed_report  # Store the string report
        }

        report_filename = f'classification_validation/classification_report_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
        with open(report_filename, 'w') as f:
            json.dump(classification_results_report, f, indent=2)

        print(f"\n📄 Classification report saved: {report_filename}")

        # ====================================================
        # FINAL SUMMARY
        # ====================================================
        print("\n" + "="*80)
        print("🎯 CLASSIFICATION VALIDATION SUMMARY")
        print("="*80)
        print(f"🏆 Overall Performance:")
        print(f"   📊 Accuracy:     {accuracy:.4f} ({accuracy*100:.1f}%)")
        print(f"   📈 F1-Score:     {f1:.4f}")
        print(f"   ⚖️  Precision:   {precision:.4f}")
        print(f"   🔄 Recall:       {recall:.4f}")
        print(f"   📦 Total Samples: {len(all_labels_np)}")

        print(f"\n🏷️  Class Performance:")
        for class_name in self.class_names:
            print(f"   {class_name:<12}: {class_accuracy[class_name]:.4f} ({class_support[class_name]} samples)")

        print(f"\n📁 Generated Files:")
        print(f"   📊 {report_filename}")
        print(f"   📈 classification_validation/confusion_matrix.png")
        print(f"   📈 classification_validation/class_accuracy.png")
        print(f"   📈 classification_validation/metrics_overview.png")

        # Performance assessment
        if accuracy > 0.90:
            print(f"\n🎉 EXCELLENT PERFORMANCE! Model is ready for deployment.")
        elif accuracy > 0.80:
            print(f"\n👍 GOOD PERFORMANCE! Model is suitable for research/clinical validation.")
        else:
            print(f"\n⚠️  CONSIDER IMPROVEMENT: Accuracy {accuracy:.1%} may need additional training.")

        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'confusion_matrix': cm,
            'class_accuracy': class_accuracy,
            'class_support': class_support,
            'predictions': all_preds_np,
            'labels': all_labels_np,
            'probabilities': np.array(all_probs),
            'total_samples': len(all_labels_np),
            'report': classification_results_report
        }

# ====================================================
# MAIN EXECUTION - CLASSIFICATION
# ====================================================

def run_classification_validation():
    """Run complete classification validation"""
    print("\n" + "="*80)
    print("🚀 CLASSIFICATION MODEL VALIDATION PIPELINE")
    print("="*80)

    # Check if model exists
    if not os.path.exists(config.MODEL_PATH):
        print(f"❌ Model file not found: {config.MODEL_PATH}")
        print("Please train the classification model first.")
        return None

    # Initialize and run validation
    validator = ClassificationValidator(config.MODEL_PATH, config.CLASSES)
    results = validator.validate()

    if results:
        print(f"\n🎉 CLASSIFICATION VALIDATION COMPLETED SUCCESSFULLY!")
        print(f"✅ Model performance: {results['accuracy']*100:.1f}% accuracy")
        return results
    else:
        print(f"\n❌ Classification validation failed.")
        return None

# Run classification validation
if __name__ == "__main__":
    classification_results = run_classification_validation()

## **validation code for the segmentation**

In [None]:
# Install dependencies
!pip install segmentation-models-pytorch albumentations scikit-image matplotlib seaborn -q

import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from skimage import measure
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import json
from datetime import datetime
import warnings
from glob import glob
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp

warnings.filterwarnings('ignore')

# ====================================================
# SEGMENTATION CONFIGURATION
# ====================================================

class SegmentationConfig:
    BASE_DIR = '/content/dataset/Brain_Stroke_CT_Dataset'
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    NUM_WORKERS = 2
    IMG_HEIGHT = 256
    IMG_WIDTH = 256
    BATCH_SIZE = 4  # Smaller for detailed analysis
    MODEL_PATH = 'best_stroke_segmentation_model.pth'
    SEGMENTATION_THRESHOLD = 0.5  # Binary segmentation threshold for pixels
    STROKE_DETECTION_THRESHOLD = 0.01  # 1% threshold for image-level stroke detection (OPTIMIZED!)

config = SegmentationConfig()
print(f"🔧 Segmentation validation - Using device: {config.DEVICE}")
print(f"🎯 Default stroke detection threshold: {config.STROKE_DETECTION_THRESHOLD:.0%} (OPTIMIZED!)")

# Function to extract red mask from OVERLAY
def extract_red_mask_from_path(mask_path, width=256, height=256):
    if mask_path is None or not os.path.exists(mask_path):
        return np.zeros((height, width), dtype=np.float32)
    overlay = cv2.imread(mask_path)
    if overlay is None:
        return np.zeros((height, width), dtype=np.float32)
    hsv = cv2.cvtColor(overlay, cv2.COLOR_BGR2HSV)
    lower1, upper1 = np.array([0, 50, 50]), np.array([10, 255, 255])
    lower2, upper2 = np.array([170, 50, 50]), np.array([180, 255, 255])
    mask = cv2.inRange(hsv, lower1, upper1) | cv2.inRange(hsv, lower2, upper2)
    mask = (mask > 0).astype(np.float32)
    mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
    return mask

# Load segmentation data
def load_segmentation_data(base_dir, split):
    image_paths = []
    mask_paths = []
    classes = ['Bleeding', 'Ischemia', 'Normal']

    for cls in classes:
        png_dir = os.path.join(base_dir, split, 'PNG', cls)
        overlay_dir = os.path.join(base_dir, split, 'OVERLAY', cls)

        if not os.path.exists(png_dir):
            print(f"PNG folder not found for {cls} in {split}. Skipping...")
            continue

        png_files = glob(os.path.join(png_dir, '*.png'))
        print(f"Found {len(png_files)} images for class {cls} in {split}")

        for png in png_files:
            filename = os.path.basename(png)
            overlay = os.path.join(overlay_dir, filename)
            image_paths.append(png)
            mask_paths.append(overlay if os.path.exists(overlay) else None)

    return image_paths, mask_paths

# Load test data
print("🔍 Loading segmentation test data...")
test_images, test_masks = load_segmentation_data(config.BASE_DIR, 'test')
print(f"✅ Segmentation test samples: {len(test_images)}")

# ====================================================
# SEGMENTATION DATASET
# ====================================================

class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        print(f"📊 Segmentation dataset initialized with {len(image_paths)} samples")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img = cv2.imread(self.image_paths[idx])
        if img is None:
            img = np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH, 3), dtype=np.uint8)
            print(f"⚠️  Warning: Could not load image {self.image_paths[idx]}")
        else:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = cv2.resize(img, (config.IMG_WIDTH, config.IMG_HEIGHT))

        # Load mask
        msk = extract_red_mask_from_path(self.mask_paths[idx], config.IMG_WIDTH, config.IMG_HEIGHT)

        # Ensure mask has correct shape
        if len(msk.shape) != 2:
            msk = msk.squeeze()
        if msk.shape != (config.IMG_HEIGHT, config.IMG_WIDTH):
            msk = cv2.resize(msk, (config.IMG_WIDTH, config.IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)
            msk = msk.astype(np.float32)

        if self.transforms:
            transformed = self.transforms(image=img, mask=msk)
            img, msk = transformed['image'], transformed['mask']

            # Ensure mask has channel dimension (1, H, W)
            if len(msk.shape) == 2:
                msk = msk.unsqueeze(0)

        return img, msk

# Validation transforms
val_transforms = A.Compose([
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

# Create test dataset and loader
test_dataset = SegmentationDataset(test_images, test_masks, val_transforms)
test_loader = DataLoader(
    test_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=False,
    num_workers=config.NUM_WORKERS,
    pin_memory=True
)

print(f"✅ Segmentation test loader created with {len(test_dataset)} samples")

# ====================================================
# SEGMENTATION MODEL VALIDATOR
# ====================================================

class SegmentationValidator:
    def __init__(self, model_path):
        """Initialize segmentation validator"""
        self.device = config.DEVICE

        # Load model
        self.model = smp.Unet(
            encoder_name='efficientnet-b4',
            encoder_weights=None,
            in_channels=3,
            classes=1,
            activation=None,
        ).to(self.device)

        if os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=self.device))
            print(f"✅ Segmentation model loaded from {model_path}")
        else:
            raise FileNotFoundError(f"❌ Segmentation model not found: {model_path}")

        self.model.eval()
        print(f"🔧 Segmentation validator initialized on {self.device}")
        print(f"🎯 Using {config.STROKE_DETECTION_THRESHOLD:.0%} default threshold for stroke detection")

    def calculate_detailed_metrics(self, pred_prob, true_mask, threshold=config.SEGMENTATION_THRESHOLD):
        """Calculate comprehensive segmentation metrics"""
        # Ensure both are numpy arrays
        if isinstance(pred_prob, torch.Tensor):
            pred_prob = pred_prob.cpu().numpy()
        if isinstance(true_mask, torch.Tensor):
            true_mask = true_mask.cpu().numpy()

        # Remove channel dimension if present
        if len(pred_prob.shape) == 3:
            pred_prob = pred_prob.squeeze(0)
        if len(true_mask.shape) == 3:
            true_mask = true_mask.squeeze(0)

        # Ensure same shape
        if pred_prob.shape != true_mask.shape:
            min_h = min(pred_prob.shape[0], true_mask.shape[0])
            min_w = min(pred_prob.shape[1], true_mask.shape[1])
            pred_prob = pred_prob[:min_h, :min_w]
            true_mask = true_mask[:min_h, :min_w]

        pred_binary = (pred_prob > threshold).astype(np.uint8)
        true_binary = (true_mask > threshold).astype(np.uint8)

        # Flatten for pixel-wise metrics
        pred_flat = pred_binary.flatten()
        true_flat = true_binary.flatten()

        # Check shapes match
        if pred_flat.shape != true_flat.shape:
            print(f"⚠️  Shape mismatch in metrics: pred={pred_flat.shape}, true={true_flat.shape}")
            return {
                'dice': 0.0, 'iou': 0.0, 'precision': 0.0, 'recall': 0.0,
                'sensitivity': 0.0, 'specificity': 0.0, 'accuracy': 0.0, 'f1': 0.0,
                'tp': 0, 'fp': 0, 'fn': 0, 'tn': 0
            }

        # Confusion matrix components
        tp = np.sum((pred_flat == 1) & (true_flat == 1))
        fp = np.sum((pred_flat == 1) & (true_flat == 0))
        fn = np.sum((pred_flat == 0) & (true_flat == 1))
        tn = np.sum((pred_flat == 0) & (true_flat == 0))

        # Dice Score
        dice = (2 * tp + 1e-7) / (2 * tp + fp + fn + 1e-7)

        # IoU (Jaccard)
        iou = (tp + 1e-7) / (tp + fp + fn + 1e-7)

        # Precision and Recall
        precision = (tp + 1e-7) / (tp + fp + 1e-7)
        recall = (tp + 1e-7) / (tp + fn + 1e-7)

        # Sensitivity (Recall) and Specificity
        sensitivity = recall
        specificity = (tn + 1e-7) / (tn + fp + 1e-7)

        # Accuracy
        accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-7)

        # F1 Score
        f1 = 2 * (precision * recall) / (precision + recall + 1e-7)

        return {
            'dice': dice, 'iou': iou, 'precision': precision, 'recall': recall,
            'sensitivity': sensitivity, 'specificity': specificity, 'accuracy': accuracy, 'f1': f1,
            'tp': tp, 'fp': fp, 'fn': fn, 'tn': tn
        }

    def classify_stroke_presence(self, mask, threshold=config.STROKE_DETECTION_THRESHOLD):
        """Classify image as Normal (0) or Stroke (1) based on mask area ratio"""
        if isinstance(mask, torch.Tensor):
            mask = mask.cpu().numpy()

        if len(mask.shape) == 3:
            mask = mask.squeeze(0)

        mask_flat = mask.flatten()
        stroke_pixels = np.sum(mask_flat > config.SEGMENTATION_THRESHOLD)
        total_pixels = len(mask_flat)
        stroke_ratio = stroke_pixels / total_pixels

        return 1 if stroke_ratio > threshold else 0, stroke_ratio

    def analyze_stroke_distribution(self, true_masks, pred_masks):
        """Analyze stroke area distribution for diagnostics"""
        print("\n🔍 STROKE AREA DISTRIBUTION ANALYSIS:")
        print("="*50)

        true_stroke_ratios = []
        pred_stroke_ratios = []

        for true_mask, pred_mask in zip(true_masks, pred_masks):
            true_class, true_ratio = self.classify_stroke_presence(true_mask)
            true_stroke_ratios.append(true_ratio)

            pred_class, pred_ratio = self.classify_stroke_presence(pred_mask)
            pred_stroke_ratios.append(pred_ratio)

        # Statistics
        true_stroke_cases = sum(1 for r in true_stroke_ratios if r >= config.STROKE_DETECTION_THRESHOLD)
        pred_stroke_cases = sum(1 for r in pred_stroke_ratios if r >= config.STROKE_DETECTION_THRESHOLD)

        print(f"True Stroke Cases (≥{config.STROKE_DETECTION_THRESHOLD:.0%} area):  {true_stroke_cases:>4}/{len(true_stroke_ratios)} ({true_stroke_cases/len(true_stroke_ratios)*100:.1f}%)")
        print(f"Predicted Stroke Cases (≥{config.STROKE_DETECTION_THRESHOLD:.0%}): {pred_stroke_cases:>4}/{len(pred_stroke_ratios)} ({pred_stroke_cases/len(pred_stroke_ratios)*100:.1f}%)")

        print(f"\nTrue Stroke Area - Mean: {np.mean(true_stroke_ratios):.4f}, Median: {np.median(true_stroke_ratios):.4f}")
        print(f"Predicted Stroke Area - Mean: {np.mean(pred_stroke_ratios):.4f}, Median: {np.median(pred_stroke_ratios):.4f}")

        # Threshold analysis
        print(f"\n🎯 THRESHOLD SENSITIVITY ANALYSIS:")
        print(f"{'Threshold':<10} {'True Pos':<8} {'Pred Pos':<8} {'Accuracy':<8} {'F1':<8}")
        print("-" * 40)

        thresholds = [0.01, 0.02, 0.05, 0.10, 0.15, 0.20]
        threshold_results = {}

        for threshold in thresholds:
            true_pos = sum(1 for r in true_stroke_ratios if r >= threshold)
            pred_pos = sum(1 for r in pred_stroke_ratios if r >= threshold)

            true_classes = [1 if r >= threshold else 0 for r in true_stroke_ratios]
            pred_classes = [1 if r >= threshold else 0 for r in pred_stroke_ratios]

            acc = accuracy_score(true_classes, pred_classes)
            f1_score_val = f1_score(true_classes, pred_classes, zero_division=0)

            threshold_results[threshold] = {'accuracy': acc, 'f1': f1_score_val}
            print(f"{threshold:<10.2f} {true_pos:<8} {pred_pos:<8} {acc:<8.3f} {f1_score_val:<8.3f}")

        # Find optimal threshold
        best_threshold = max(thresholds, key=lambda k: threshold_results[k]['f1'])
        best_f1 = threshold_results[best_threshold]['f1']
        print(f"\n🎯 OPTIMAL THRESHOLD: {best_threshold} (F1 = {best_f1:.3f})")

        return threshold_results, best_threshold, true_stroke_ratios, pred_stroke_ratios

    def validate(self):
        """Run comprehensive segmentation validation"""
        print("\n" + "="*80)
        print("🔬 SEGMENTATION MODEL VALIDATION")
        print("="*80)

        all_metrics = []
        all_true_masks = []
        all_pred_masks = []
        all_true_classes = []
        all_pred_classes = []
        valid_samples = 0
        error_count = 0

        print(f"📊 Processing {len(test_dataset)} test samples...")

        self.model.eval()
        with torch.no_grad():
            progress_bar = tqdm(test_loader, desc="Validating", unit="batch")
            for batch_idx, (images, masks) in enumerate(progress_bar):
                images = images.to(self.device)

                # Ensure masks have proper shape
                if masks.dim() == 3:
                    masks = masks.unsqueeze(1)
                masks = masks.to(self.device)

                # Get predictions
                logits = self.model(images)
                probs = torch.sigmoid(logits)

                # Process each sample in batch
                for i in range(images.shape[0]):
                    try:
                        pred_prob = probs[i, 0]  # Shape: (H, W)
                        true_mask = masks[i, 0]  # Shape: (H, W)

                        # Store for diagnostics
                        all_true_masks.append(true_mask.clone())
                        all_pred_masks.append(pred_prob.clone())

                        # Only process samples with meaningful content
                        if torch.sum(true_mask) > 0 or torch.sum(pred_prob > config.SEGMENTATION_THRESHOLD) > 0:
                            metrics = self.calculate_detailed_metrics(pred_prob, true_mask)
                            all_metrics.append(metrics)
                            valid_samples += 1

                            # Classify stroke presence (default 1% threshold)
                            true_class, _ = self.classify_stroke_presence(true_mask)
                            pred_class, _ = self.classify_stroke_presence(pred_prob)

                            all_true_classes.append(true_class)
                            all_pred_classes.append(pred_class)

                        # Progress update
                        if valid_samples % 50 == 0 and all_metrics:
                            avg_dice = np.mean([m['dice'] for m in all_metrics[-50:]])
                            progress_bar.set_postfix({
                                'Dice': f"{avg_dice:.3f}",
                                'Samples': valid_samples
                            })
                    except Exception as e:
                        error_count += 1
                        print(f"⚠️  Error processing sample {i} in batch {batch_idx}: {e}")
                        continue

        print(f"\n✅ Validation complete! Processed {valid_samples} valid samples")
        if error_count > 0:
            print(f"⚠️  Encountered {error_count} errors during processing")

        # Calculate classification metrics (default 1% threshold)
        if all_true_classes:
            default_accuracy = accuracy_score(all_true_classes, all_pred_classes)
            default_precision = precision_score(all_true_classes, all_pred_classes, zero_division=0)
            default_recall = recall_score(all_true_classes, all_pred_classes, zero_division=0)
            default_f1 = f1_score(all_true_classes, all_pred_classes, zero_division=0)
            default_cm = confusion_matrix(all_true_classes, all_pred_classes)
        else:
            default_accuracy = default_precision = default_recall = default_f1 = 0.0
            default_cm = np.zeros((2, 2))

        # Calculate segmentation statistics
        if all_metrics:
            dice_scores = [m['dice'] for m in all_metrics]
            iou_scores = [m['iou'] for m in all_metrics]
            precision_scores = [m['precision'] for m in all_metrics]
            recall_scores = [m['recall'] for m in all_metrics]
            sensitivity_scores = [m['sensitivity'] for m in all_metrics]
            specificity_scores = [m['specificity'] for m in all_metrics]
            f1_scores = [m['f1'] for m in all_metrics]
            accuracy_scores = [m['accuracy'] for m in all_metrics]
        else:
            dice_scores = iou_scores = precision_scores = recall_scores = []
            sensitivity_scores = specificity_scores = f1_scores = accuracy_scores = []

        # ====================================================
        # PRINT SEGMENTATION RESULTS TABLE
        # ====================================================
        print("\n" + "="*80)
        print("📊 SEGMENTATION PERFORMANCE METRICS")
        print("="*80)
        print(f"{'Metric':<15} {'Mean':<10} {'Std':<10} {'Best':<10} {'Worst':<10}")
        print("-" * 60)

        metrics_table = [
            ('Dice Score', dice_scores),
            ('IoU Score', iou_scores),
            ('Precision', precision_scores),
            ('Recall', recall_scores),
            ('Sensitivity', sensitivity_scores),
            ('Specificity', specificity_scores),
            ('F1 Score', f1_scores),
            ('Accuracy', accuracy_scores)
        ]

        for metric_name, scores in metrics_table:
            if scores:
                mean_val = np.mean(scores)
                std_val = np.std(scores)
                best_val = np.max(scores)
                worst_val = np.min(scores)
                print(f"{metric_name:<15} {mean_val:<10.4f} {std_val:<10.4f} {best_val:<10.4f} {worst_val:<10.4f}")
            else:
                print(f"{metric_name:<15} {'N/A':<10} {'N/A':<10} {'N/A':<10} {'N/A':<10}")

        # ====================================================
        # STROKE DETECTION RESULTS (DEFAULT 1% THRESHOLD)
        # ====================================================
        print("\n" + "="*80)
        print(f"🏥 STROKE DETECTION CLASSIFICATION ({config.STROKE_DETECTION_THRESHOLD:.0%} Threshold)")
        print("="*80)
        print(f"{'Metric':<15} {'Value':<10}")
        print("-" * 25)
        print(f"{'Accuracy':<15} {default_accuracy:<10.4f}")
        print(f"{'Precision':<15} {default_precision:<10.4f}")
        print(f"{'Recall':<15} {default_recall:<10.4f}")
        print(f"{'F1-Score':<15} {default_f1:<10.4f}")
        print(f"{'Stroke Cases':<15} {sum(all_true_classes):<10}/{valid_samples}")

        print(f"\n📊 Confusion Matrix (Normal/Stroke):")
        print(f"                  Predicted")
        print(f"              Normal  Stroke")
        print(f"True Normal   {default_cm[0,0]:^7}  {default_cm[0,1]:^7}")
        print(f"True Stroke   {default_cm[1,0]:^7}  {default_cm[1,1]:^7}")

        # ====================================================
        # STROKE DISTRIBUTION DIAGNOSTICS
        # ====================================================
        print("\n🔍 RUNNING STROKE DISTRIBUTION ANALYSIS...")
        threshold_results, best_threshold, true_stroke_ratios, pred_stroke_ratios = self.analyze_stroke_distribution(
            all_true_masks, all_pred_masks
        )

        # Recalculate with optimal threshold
        optimal_true_classes = [1 if self.classify_stroke_presence(mask, best_threshold)[1] >= best_threshold else 0
                               for mask in all_true_masks]
        optimal_pred_classes = [1 if self.classify_stroke_presence(mask, best_threshold)[1] >= best_threshold else 0
                               for mask in all_pred_masks]

        optimal_accuracy = accuracy_score(optimal_true_classes, optimal_pred_classes)
        optimal_f1 = f1_score(optimal_true_classes, optimal_pred_classes, zero_division=0)
        optimal_cm = confusion_matrix(optimal_true_classes, optimal_pred_classes)

        print(f"\n🎯 OPTIMAL PERFORMANCE (Threshold = {best_threshold:.0%}):")
        print(f"   Accuracy: {optimal_accuracy:.4f} | F1-Score: {optimal_f1:.4f}")

        # ====================================================
        # VISUALIZATIONS
        # ====================================================
        print("\n📈 Generating visualizations...")
        os.makedirs('segmentation_validation', exist_ok=True)

        # 1. Sample Predictions Visualization
        stroke_samples = []
        for i in range(min(50, len(test_images))):
            try:
                mask = extract_red_mask_from_path(test_masks[i], config.IMG_WIDTH, config.IMG_HEIGHT)
                if np.sum(mask) > 100:  # At least 100 stroke pixels
                    stroke_samples.append(i)
                    if len(stroke_samples) >= 4:
                        break
            except:
                continue

        if stroke_samples:
            fig, axes = plt.subplots(4, len(stroke_samples), figsize=(5*len(stroke_samples), 20))
            if len(stroke_samples) == 1:
                axes = axes.reshape(-1, 1)

            fig.suptitle(f'SEGMENTATION: Sample Predictions (Stroke Cases)', fontsize=16)

            for idx, sample_idx in enumerate(stroke_samples):
                try:
                    # Load image
                    img = cv2.imread(test_images[sample_idx])
                    if img is None:
                        continue
                    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img_resized = cv2.resize(img_rgb, (config.IMG_WIDTH, config.IMG_HEIGHT))

                    # Get prediction
                    transformed = val_transforms(image=img_resized, mask=np.zeros((config.IMG_HEIGHT, config.IMG_WIDTH), dtype=np.float32))
                    img_tensor = transformed['image'].unsqueeze(0).to(self.device)

                    with torch.no_grad():
                        pred_logits = self.model(img_tensor)
                        pred_prob = torch.sigmoid(pred_logits)[0, 0].cpu().numpy()
                        pred_mask = (pred_prob > config.SEGMENTATION_THRESHOLD).astype(np.uint8) * 255

                    true_mask = extract_red_mask_from_path(test_masks[sample_idx], config.IMG_WIDTH, config.IMG_HEIGHT)
                    sample_metrics = self.calculate_detailed_metrics(pred_prob, true_mask)

                    # Stroke ratios
                    true_stroke_ratio = np.sum(true_mask > 0) / (config.IMG_HEIGHT * config.IMG_WIDTH)
                    pred_stroke_ratio = np.sum(pred_prob > config.SEGMENTATION_THRESHOLD) / (config.IMG_HEIGHT * config.IMG_WIDTH)

                    # Optimal classification
                    true_class_opt = 1 if true_stroke_ratio >= best_threshold else 0
                    pred_class_opt = 1 if pred_stroke_ratio >= best_threshold else 0

                    # Plot
                    axes[0, idx].imshow(img_resized)
                    axes[0, idx].set_title(f'Original #{idx+1}\nTrue:{true_class_opt} Pred:{pred_class_opt}', fontsize=10)
                    axes[0, idx].axis('off')

                    axes[1, idx].imshow(true_mask, cmap='Reds')
                    axes[1, idx].set_title(f'True Mask\nArea: {true_stroke_ratio:.1%}', fontsize=10)
                    axes[1, idx].axis('off')

                    axes[2, idx].imshow(pred_mask, cmap='Blues')
                    axes[2, idx].set_title(f'Pred Mask (T={config.SEGMENTATION_THRESHOLD})\nDice: {sample_metrics["dice"]:.3f}', fontsize=10)
                    axes[2, idx].axis('off')

                    # Overlay
                    overlay = img_resized.copy()
                    overlay[pred_mask > 0] = [255, 0, 0]  # Red for prediction
                    overlay[true_mask > 0] = [0, 255, 0]  # Green for ground truth
                    axes[3, idx].imshow(overlay)
                    axes[3, idx].set_title(f'Overlay Comparison\nIoU: {sample_metrics["iou"]:.3f}', fontsize=10)
                    axes[3, idx].axis('off')

                except Exception as e:
                    print(f"Error in visualization for sample {sample_idx}: {e}")
                    continue

            plt.tight_layout()
            plt.savefig('segmentation_validation/sample_predictions.png', dpi=300, bbox_inches='tight')
            plt.show()

        # 2. Metrics Distribution
        if len(dice_scores) > 1:
            fig, axes = plt.subplots(2, 4, figsize=(20, 10))
            fig.suptitle('Segmentation Metrics Distribution', fontsize=16)

            metrics_data = [
                ('Dice Score', dice_scores, 'Dice'),
                ('IoU Score', iou_scores, 'IoU'),
                ('Precision', precision_scores, 'Precision'),
                ('Recall', recall_scores, 'Recall'),
                ('Sensitivity', sensitivity_scores, 'Sensitivity'),
                ('Specificity', specificity_scores, 'Specificity'),
                ('F1 Score', f1_scores, 'F1'),
                ('Accuracy', accuracy_scores, 'Accuracy')
            ]

            for idx, (name, scores, short_name) in enumerate(metrics_data):
                if scores:
                    row, col = idx // 4, idx % 4
                    axes[row, col].hist(scores, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
                    axes[row, col].axvline(np.mean(scores), color='red', linestyle='--', linewidth=2,
                                         label=f'Mean: {np.mean(scores):.3f}')
                    axes[row, col].set_title(f'{name} Distribution')
                    axes[row, col].set_xlabel(short_name)
                    axes[row, col].set_ylabel('Frequency')
                    axes[row, col].legend()
                    axes[row, col].grid(True, alpha=0.3)
                else:
                    axes[row, col].text(0.5, 0.5, 'No Data', ha='center', va='center', transform=axes[row, col].transAxes)
                    axes[row, col].set_title(f'{name} Distribution')

            plt.tight_layout()
            plt.savefig('segmentation_validation/metrics_distribution.png', dpi=300, bbox_inches='tight')
            plt.show()

        # 3. Stroke Detection Confusion Matrix (Default vs Optimal)
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

        # Default threshold confusion matrix
        sns.heatmap(default_cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=['Normal', 'Stroke'],
                   yticklabels=['Normal', 'Stroke'],
                   ax=ax1, annot_kws={'size': 14})
        ax1.set_title(f'Stroke Detection CM\n(Default {config.STROKE_DETECTION_THRESHOLD:.0%} Threshold)\nAcc: {default_accuracy:.3f}', fontsize=12)

        # Optimal threshold confusion matrix
        sns.heatmap(optimal_cm, annot=True, fmt='d', cmap='Greens',
                   xticklabels=['Normal', 'Stroke'],
                   yticklabels=['Normal', 'Stroke'],
                   ax=ax2, annot_kws={'size': 14})
        ax2.set_title(f'Stroke Detection CM\n(Optimal {best_threshold:.0%} Threshold)\nAcc: {optimal_accuracy:.3f}', fontsize=12)

        plt.tight_layout()
        plt.savefig('segmentation_validation/stroke_confusion_matrices.png', dpi=300, bbox_inches='tight')
        plt.show()

        # 4. Threshold Performance Curve
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

        # Stroke area distribution
        ax1.hist(true_stroke_ratios, bins=20, alpha=0.7, color='red', label='True', density=True)
        ax1.hist(pred_stroke_ratios, bins=20, alpha=0.7, color='blue', label='Predicted', density=True)
        ax1.axvline(best_threshold, color='green', linestyle='--', linewidth=2,
                   label=f'Optimal: {best_threshold:.0%}')
        ax1.set_title('Stroke Area Ratio Distribution')
        ax1.set_xlabel('Stroke Area Ratio')
        ax1.set_ylabel('Density')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Threshold performance
        thresholds = list(threshold_results.keys())
        accuracies = [threshold_results[t]['accuracy'] for t in thresholds]
        f1_scores = [threshold_results[t]['f1'] for t in thresholds]

        ax2.plot(thresholds, accuracies, 'o-', linewidth=2, label='Accuracy', color='blue')
        ax2.plot(thresholds, f1_scores, 's-', linewidth=2, label='F1-Score', color='red')
        ax2.axvline(best_threshold, color='green', linestyle='--', linewidth=2,
                   label=f'Optimal: {best_threshold:.0%}')
        ax2.set_title('Threshold Performance Analysis')
        ax2.set_xlabel('Stroke Area Threshold')
        ax2.set_ylabel('Score')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        ax2.set_ylim(0, 1)

        plt.tight_layout()
        plt.savefig('segmentation_validation/threshold_analysis.png', dpi=300, bbox_inches='tight')
        plt.show()

        # ====================================================
        # COMPREHENSIVE JSON REPORT
        # ====================================================
        segmentation_report = {
            'timestamp': datetime.now().isoformat(),
            'model_info': {
                'architecture': 'U-Net',
                'encoder': 'efficientnet-b4',
                'input_size': f"{config.IMG_HEIGHT}x{config.IMG_WIDTH}",
                'segmentation_threshold': float(config.SEGMENTATION_THRESHOLD),
                'stroke_detection_threshold': float(config.STROKE_DETECTION_THRESHOLD)
            },
            'dataset_info': {
                'total_samples': len(test_dataset),
                'valid_samples': valid_samples,
                'error_count': error_count
            },
            'segmentation_metrics': {
                'dice': {
                    'mean': float(np.mean(dice_scores)) if dice_scores else 0.0,
                    'std': float(np.std(dice_scores)) if dice_scores else 0.0,
                    'min': float(np.min(dice_scores)) if dice_scores else 0.0,
                    'max': float(np.max(dice_scores)) if dice_scores else 0.0
                },
                'iou': {
                    'mean': float(np.mean(iou_scores)) if iou_scores else 0.0,
                    'std': float(np.std(iou_scores)) if iou_scores else 0.0,
                    'min': float(np.min(iou_scores)) if iou_scores else 0.0,
                    'max': float(np.max(iou_scores)) if iou_scores else 0.0
                },
                'precision': {
                    'mean': float(np.mean(precision_scores)) if precision_scores else 0.0,
                    'std': float(np.std(precision_scores)) if precision_scores else 0.0
                },
                'recall': {
                    'mean': float(np.mean(recall_scores)) if recall_scores else 0.0,
                    'std': float(np.std(recall_scores)) if recall_scores else 0.0
                },
                'sensitivity': {
                    'mean': float(np.mean(sensitivity_scores)) if sensitivity_scores else 0.0,
                    'std': float(np.std(sensitivity_scores)) if sensitivity_scores else 0.0
                },
                'specificity': {
                    'mean': float(np.mean(specificity_scores)) if specificity_scores else 0.0,
                    'std': float(np.std(specificity_scores)) if specificity_scores else 0.0
                },
                'f1': {
                    'mean': float(np.mean(f1_scores)) if f1_scores else 0.0,
                    'std': float(np.std(f1_scores)) if f1_scores else 0.0
                },
                'accuracy': {
                    'mean': float(np.mean(accuracy_scores)) if accuracy_scores else 0.0,
                    'std': float(np.std(accuracy_scores)) if accuracy_scores else 0.0
                }
            },
            'stroke_detection': {
                'default_threshold': float(config.STROKE_DETECTION_THRESHOLD),
                'default_accuracy': float(default_accuracy),
                'default_f1': float(default_f1),
                'default_confusion_matrix': default_cm.tolist(),
                'optimal_threshold': float(best_threshold),
                'optimal_accuracy': float(optimal_accuracy),
                'optimal_f1': float(optimal_f1),
                'optimal_confusion_matrix': optimal_cm.tolist(),
                'threshold_analysis': {
                    str(threshold): {
                        'accuracy': float(threshold_results[threshold]['accuracy']),
                        'f1': float(threshold_results[threshold]['f1'])
                    } for threshold in threshold_results
                }
            }
        }

        report_filename = f'segmentation_validation/segmentation_report_{datetime.now().strftime("%Y%m%d_%H%M%S")}.json'
        with open(report_filename, 'w') as f:
            json.dump(segmentation_report, f, indent=2)

        print(f"\n📄 Segmentation report saved: {report_filename}")

        # ====================================================
        # FINAL SUMMARY
        # ====================================================
        print("\n" + "="*80)
        print("🎯 SEGMENTATION VALIDATION SUMMARY")
        print("="*80)
        print(f"🔬 Pixel-Level Performance:")
        if dice_scores:
            print(f"   🥇 Dice Score:     {np.mean(dice_scores):.4f} ± {np.std(dice_scores):.4f}")
            print(f"   🥈 IoU Score:      {np.mean(iou_scores):.4f} ± {np.std(iou_scores):.4f}")
            print(f"   ⚖️  Precision:     {np.mean(precision_scores):.4f} ± {np.std(precision_scores):.4f}")
            print(f"   🔄 Recall:         {np.mean(recall_scores):.4f} ± {np.std(recall_scores):.4f}")
            print(f"   📊 F1-Score:       {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")
            print(f"   🎯 Pixel Accuracy: {np.mean(accuracy_scores):.4f} ± {np.std(accuracy_scores):.4f}")

        print(f"\n🏥 Stroke Detection Performance:")
        print(f"   📊 Default ({config.STROKE_DETECTION_THRESHOLD:.0%} threshold):  Accuracy = {default_accuracy:.4f}, F1 = {default_f1:.4f}")
        print(f"   🎯 Optimal ({best_threshold:.0%} threshold): Accuracy = {optimal_accuracy:.4f}, F1 = {optimal_f1:.4f}")
        print(f"   🔬 Valid Samples Processed: {valid_samples}/{len(test_dataset)}")
        print(f"   🩺 True Stroke Cases: {sum(all_true_classes)}/{valid_samples}")

        # Performance assessment
        if np.mean(dice_scores) > 0.80:
            print(f"\n🎉 EXCELLENT SEGMENTATION QUALITY! Dice score > 0.80")
        elif np.mean(dice_scores) > 0.70:
            print(f"\n👍 GOOD SEGMENTATION QUALITY! Dice score > 0.70")
        else:
            print(f"\n⚠️  CONSIDER IMPROVEMENT: Dice score = {np.mean(dice_scores):.3f}")

        if optimal_accuracy > 0.90:
            print(f"🎉 EXCELLENT STROKE DETECTION! Use threshold {best_threshold:.0%}")
        elif optimal_accuracy > 0.80:
            print(f"👍 GOOD STROKE DETECTION! Use threshold {best_threshold:.0%}")
        else:
            print(f"⚠️  STROKE DETECTION NEEDS ATTENTION: {optimal_accuracy:.1%} accuracy")

        print(f"\n📁 Generated Files:")
        print(f"   📄 {report_filename}")
        print(f"   📈 segmentation_validation/sample_predictions.png")
        print(f"   📈 segmentation_validation/metrics_distribution.png")
        print(f"   📈 segmentation_validation/stroke_confusion_matrices.png")
        print(f"   📈 segmentation_validation/threshold_analysis.png")

        return {
            'dice_scores': dice_scores,
            'iou_scores': iou_scores,
            'precision_scores': precision_scores,
            'recall_scores': recall_scores,
            'sensitivity_scores': sensitivity_scores,
            'specificity_scores': specificity_scores,
            'f1_scores': f1_scores,
            'accuracy_scores': accuracy_scores,
            'default_stroke_accuracy': default_accuracy,
            'default_stroke_f1': default_f1,
            'optimal_stroke_accuracy': optimal_accuracy,
            'optimal_stroke_f1': optimal_f1,
            'optimal_threshold': best_threshold,
            'default_confusion_matrix': default_cm,
            'optimal_confusion_matrix': optimal_cm,
            'valid_samples': valid_samples,
            'total_samples': len(test_dataset),
            'threshold_results': threshold_results,
            'true_stroke_ratios': true_stroke_ratios,
            'pred_stroke_ratios': pred_stroke_ratios,
            'report': segmentation_report
        }

# ====================================================
# MAIN EXECUTION - SEGMENTATION
# ====================================================

def run_segmentation_validation():
    """Run complete segmentation validation"""
    print("\n" + "="*80)
    print("🚀 SEGMENTATION MODEL VALIDATION PIPELINE")
    print("="*80)

    # Check if model exists
    if not os.path.exists(config.MODEL_PATH):
        print(f"❌ Model file not found: {config.MODEL_PATH}")
        print("Please train the segmentation model first.")
        return None

    # Initialize and run validation
    validator = SegmentationValidator(config.MODEL_PATH)
    results = validator.validate()

    if results:
        print(f"\n🎉 SEGMENTATION VALIDATION COMPLETED SUCCESSFULLY!")
        print(f"✅ Pixel Quality: Dice = {np.mean(results['dice_scores']):.3f}")
        print(f"✅ Optimal Stroke Detection: {results['optimal_stroke_accuracy']*100:.1f}% accuracy")
        return results
    else:
        print(f"\n❌ Segmentation validation failed.")
        return None

# Run segmentation validation
if __name__ == "__main__":
    segmentation_results = run_segmentation_validation()

# **Real world images testing**

In [None]:
# Install required dependencies
!pip install timm segmentation-models-pytorch albumentations opencv-python-headless matplotlib -q

import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import transforms
from google.colab import files
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import timm
from io import BytesIO
from IPython.display import display

warnings.filterwarnings('ignore')

# ====================================================
# CONFIGURATION
# ====================================================

class AppConfig:
    CLASSIFICATION_MODEL_PATH = 'best_model_3class.pth'
    SEGMENTATION_MODEL_PATH = 'best_stroke_segmentation_model.pth'
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    CLASSES = ['Bleeding', 'Ischemia', 'Normal']
    STROKE_DETECTION_THRESHOLD = 0.01  # 1% threshold for stroke confirmation
    SEGMENTATION_THRESHOLD = 0.5  # Pixel classification threshold

config = AppConfig()
print(f"🔧 Device: {config.DEVICE}")

# ====================================================
# MODEL LOADING
# ====================================================

# Load Classification Model (ConvNeXt)
class ClassificationModel:
    def __init__(self):
        self.model = timm.create_model('convnext_base', pretrained=False, num_classes=3)
        if os.path.exists(config.CLASSIFICATION_MODEL_PATH):
            self.model.load_state_dict(torch.load(config.CLASSIFICATION_MODEL_PATH, map_location=config.DEVICE))
            print(f"✅ Classification model loaded")
        else:
            raise FileNotFoundError(f"Classification model not found: {config.CLASSIFICATION_MODEL_PATH}")

        self.model = self.model.to(config.DEVICE)
        self.model.eval()

        class GrayscaleToRGB:
            def __call__(self, image):
                if image.shape[0] == 1:
                    return torch.cat([image] * 3, dim=0)
                return image

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            GrayscaleToRGB(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])

    def predict(self, image):
        image_tensor = self.transform(image).unsqueeze(0).to(config.DEVICE)
        with torch.no_grad():
            outputs = self.model(image_tensor)
            probs = torch.softmax(outputs, dim=1)
            pred_idx = torch.argmax(probs, dim=1).item()
            pred_prob = probs[0, pred_idx].item()
        pred_class = config.CLASSES[pred_idx]
        return pred_class, pred_prob

# Load Segmentation Model (U-Net)
class SegmentationModel:
    def __init__(self):
        self.model = smp.Unet(
            encoder_name='efficientnet-b4',
            encoder_weights=None,
            in_channels=3,
            classes=1,
            activation=None,
        ).to(config.DEVICE)

        if os.path.exists(config.SEGMENTATION_MODEL_PATH):
            self.model.load_state_dict(torch.load(config.SEGMENTATION_MODEL_PATH, map_location=config.DEVICE))
            print(f"✅ Segmentation model loaded")
        else:
            raise FileNotFoundError(f"Segmentation model not found: {config.SEGMENTATION_MODEL_PATH}")

        self.model.eval()

        self.transform = A.Compose([
            A.Resize(height=256, width=256),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2(),
        ])

    def predict(self, image):
        # Convert PIL to numpy for Albumentations
        image_np = np.array(image.convert('RGB'))
        transformed = self.transform(image=image_np)
        image_tensor = transformed['image'].unsqueeze(0).to(config.DEVICE)

        with torch.no_grad():
            pred_logits = self.model(image_tensor)
            pred_prob = torch.sigmoid(pred_logits)[0, 0].cpu().numpy()
            pred_mask = (pred_prob > 0.5).astype(np.uint8) * 255

        # Calculate stroke ratio for confirmation
        stroke_ratio = np.sum(pred_prob > config.SEGMENTATION_THRESHOLD) / (256 * 256)

        return pred_mask, stroke_ratio

# Initialize models
print("🔍 Loading models...")
class_model = ClassificationModel()
seg_model = SegmentationModel()

# ====================================================
# IMAGE PROCESSING FUNCTIONS
# ====================================================

def create_overlay(original, mask):
    """Create overlay image with stroke highlighted in red"""
    original = np.array(original.convert('RGB'))
    original = cv2.resize(original, (256, 256))

    overlay = original.copy()
    overlay[mask > 0] = [255, 0, 0]  # Red for stroke

    return overlay

# ====================================================
# MAIN APPLICATION
# ====================================================

def process_ct_image():
    """Main function for processing uploaded CT image"""
    print("📤 Please upload a brain CT image (PNG/JPG)...")
    uploaded = files.upload()

    if not uploaded:
        print("❌ No image uploaded. Please try again.")
        return

    # Get the uploaded image
    file_name = list(uploaded.keys())[0]
    image_bytes = uploaded[file_name]

    # Open image
    try:
        image = Image.open(BytesIO(image_bytes)).convert('L')  # Convert to grayscale
        print(f"✅ Image uploaded: {file_name} (Size: {image.size})")
    except Exception as e:
        print(f"❌ Error opening image: {e}")
        return

    # Step 1: Classification
    print("\n🔍 Running classification...")
    pred_class, pred_prob = class_model.predict(image)

    print(f"📊 Classification Result: {pred_class} (Confidence: {pred_prob:.1%})")

    if pred_class == 'Normal':
        print("✅ No stroke detected.")
        # Display original image
        plt.figure(figsize=(8, 8))
        plt.imshow(image, cmap='gray')
        plt.title('Original CT Image - No Stroke')
        plt.axis('off')
        plt.show()
        return

    # Step 2: Segmentation
    print("\n🔬 Running segmentation...")
    pred_mask, stroke_ratio = seg_model.predict(image)

    print(f"📊 Segmentation Result: Stroke Area: {stroke_ratio:.1%}")

    # Stroke confirmation
    is_stroke = stroke_ratio >= config.STROKE_DETECTION_THRESHOLD
    final_diagnosis = pred_class if is_stroke else 'Normal'

    print(f"\n🏥 FINAL DIAGNOSIS: {final_diagnosis} Stroke")

    # Display results
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

    # Original image
    ax1.imshow(image, cmap='gray')
    ax1.set_title('Original CT Image')
    ax1.axis('off')

    # Segmented image
    overlay = create_overlay(image, pred_mask)
    ax2.imshow(overlay)
    ax2.set_title(f'Segmented Image - {final_diagnosis}')
    ax2.text(10, 30, f'Confidence: {pred_prob:.1%}\nStroke Area: {stroke_ratio:.1%}',
             bbox={'facecolor': 'white', 'alpha': 0.7, 'pad': 5})
    ax2.axis('off')

    plt.tight_layout()
    plt.show()

# Run the application
if __name__ == "__main__":
    test = 0
    while(test == 0):
        print("Enter 0 to test\nEnter 1 to exit")
        test = int(input())
        if (test == 1):
          print("User Exited")
          break
        process_ct_image()