In [None]:
import os
import re
import shutil

import albumentations as A
import cv2
import numpy as np
import pandas as pd
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
import torch.nn.functional as F
from albumentations.pytorch import ToTensorV2
from sklearn.model_selection import train_test_split
from torch.cuda.amp import GradScaler, autocast
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision.utils import save_image
from tqdm import tqdm

# =============================
# Config & Hyperparameters
# =============================
DATASET_PATH = "data/fgdar"
TARGET_LESION = 'HE'

# --- Experiment Settings ---
ENCODER = "resnet34"
IMG_SIZE = 512
BATCH_SIZE = 1
ACCUMULATION_STEPS = 4
LR = 1e-4
WEIGHT_DECAY = 1e-4
MIN_LR = 1e-6
PATIENCE = 8
SCHEDULER_PATIENCE = 2
EPOCHS = 55
EMA_DECAY = 0.999
LOSS_BCE_ALPHA = 0.05
# --- Post-processing threshold ---
MIN_LESION_AREA = 24

LESION_MAP = {
    'MA': 'Microaneurysms',
    'HE': 'Hemorrhage',
    'EX': 'HardExudate',
    'SE': 'SoftExudate'
}
LESION_FULL_NAME = LESION_MAP[TARGET_LESION].replace('Exudate', ' Exudate')

# --- Pre-defined Thresholds ---
LESION_THRESHOLDS = {
    'MA': 0.30,
    'HE': 0.45,
    'EX': 0.45,
    'SE': 0.45
}

# =============================
# Helpers & Post-Processing
# =============================
def numeric_sort_key(s):
    return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]

def post_process_mask(mask_np, min_area):
    if mask_np.dtype != np.uint8:
        mask_np = (mask_np * 255).astype(np.uint8)
    
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(mask_np, connectivity=8)
    processed_mask = np.zeros_like(mask_np)
    
    for i in range(1, num_labels):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            processed_mask[labels == i] = 255
            
    return (processed_mask / 255).astype(np.float32)


def prepare_fgadr_dataframe(dataset_path):
    seg_set_path = os.path.join(dataset_path, 'Seg-set')
    original_images_path = os.path.join(seg_set_path, "Original_Images")
    lesion_paths = {
        lesion: os.path.join(seg_set_path, f"{name}_Masks")
        for lesion, name in LESION_MAP.items()
    }
    image_files = sorted([os.path.join(original_images_path, f) for f in os.listdir(original_images_path)], key=numeric_sort_key)
    data = {'image': image_files}
    for lesion_code in LESION_MAP.keys():
        mask_dir = lesion_paths[lesion_code]
        data[f'{lesion_code}_mask'] = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir)], key=numeric_sort_key)
    return pd.DataFrame(data)

def green_clahe_preprocess(image, clip_limit=2.0, tile_grid_size=(8, 8), blend_alpha=0.75):
    green_channel = image[:, :, 1]
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    clahe_green_channel = clahe.apply(green_channel)
    clahe_img_3_channel = cv2.cvtColor(clahe_green_channel, cv2.COLOR_GRAY2BGR)
    blended_image = cv2.addWeighted(image, 1 - blend_alpha, clahe_img_3_channel, blend_alpha, 0)
    return blended_image

class FGADRDataset(Dataset):
    def __init__(self, dataframe, target_lesion, transform=None):
        self.df = dataframe
        self.transform = transform
        self.target_lesion = target_lesion
        self.mask_column = f'{self.target_lesion}_mask'

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = cv2.imread(row['image'])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask_path = row[self.mask_column]
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Could not load mask: {mask_path}")
        mask = (mask > 127).astype(np.float32)
        processed_image = green_clahe_preprocess(image)
        if self.transform:
            augmented = self.transform(image=processed_image.astype(np.uint8), mask=mask)
            image = augmented['image']
            mask = augmented['mask']
        if mask.ndim == 2:
            mask = mask.unsqueeze(0)
        return image, mask

# =============================
# Model (Unet++ with ResNet34)
# =============================
class SingleLesionModel(nn.Module):
    def __init__(self,
                 encoder_name=ENCODER,
                 in_channels=3,
                 target_lesion=TARGET_LESION):
        super().__init__()
        print(f"Initializing Unet++ model for {target_lesion} with encoder {encoder_name}")
        self.model = smp.UnetPlusPlus(
            encoder_name=encoder_name,
            encoder_weights="imagenet",
            in_channels=in_channels,
            classes=1,
            activation=None,
            attention_type='scse'
        )
        
        # Add sensitivity-focused refinement
        self.sensitivity_head = nn.Sequential(
            nn.Conv2d(1, 8, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(8, 1, 1),
            nn.Dropout2d(0.1)
        )

    def forward(self, x):
        base_pred = self.model(x)
        sensitivity_boost = self.sensitivity_head(torch.sigmoid(base_pred))
        return base_pred + 0.5 * sensitivity_boost  # Boost positive predictions

# =============================
# Training Utils
# =============================
class EMA:
    def __init__(self, model, decay):
        self.decay = decay
        self.shadow = {name: param.data.clone() for name, param in model.named_parameters() if param.requires_grad}
        self.backup = {}

    def update(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self, model):
        self.backup = model.state_dict()
        model.load_state_dict(self.shadow, strict=False)

    def restore(self, model):
        model.load_state_dict(self.backup, strict=True)
        self.backup = {}


class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0, path='checkpoint.pth', mode='max', ema=None):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_metric_best = -float('inf') if mode == 'max' else float('inf')
        self.delta = delta
        self.path = path
        self.mode = mode
        self.ema = ema

    def __call__(self, val_metric, model):
        score = val_metric
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_metric, model)
        elif (self.mode == 'max' and score < self.best_score + self.delta) or \
             (self.mode == 'min' and score > self.best_score - self.delta):
            self.counter += 1
            if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_metric, model)
            self.counter = 0

    def save_checkpoint(self, val_metric, model):
        if self.verbose: 
            print(f'Validation metric improved ({self.val_metric_best:.6f} --> {val_metric:.6f}). Saving model...')
        
        if self.ema:
            self.ema.apply_shadow(model)
            torch.save(model.state_dict(), self.path)
            self.ema.restore(model)
        else:
            torch.save(model.state_dict(), self.path)
            
        self.val_metric_best = val_metric


def calculate_metrics(y_true, y_pred, epsilon=1e-7):
    y_true = (y_true > 0.5).float()
    y_pred = (y_pred > 0.5).float()
    tp = torch.sum(y_true * y_pred)
    fp = torch.sum((1 - y_true) * y_pred)
    fn = torch.sum(y_true * (1 - y_pred))
    dice = (2. * tp + epsilon) / (2 * tp + fp + fn + epsilon)
    iou = (tp + epsilon) / (tp + fp + fn + epsilon)
    sensitivity = (tp + epsilon) / (tp + fn + epsilon)
    precision = (tp + epsilon) / (tp + fp + epsilon)
    return dice.item(), iou.item(), sensitivity.item(), precision.item()

# =============================
# Loss Functions
# =============================
class SensitivityOptimizedLoss(nn.Module):
    """Loss function specifically designed to combat overly conservative predictions"""
    def __init__(self, alpha=0.8, beta=0.2, gamma=2.0):
        super().__init__()
        self.alpha = alpha  # Heavy penalty for missing lesions (False Negatives)
        self.beta = beta    # Light penalty for false alarms (False Positives)
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        inputs_sigmoid = torch.sigmoid(inputs)
        
        tp = torch.sum(targets * inputs_sigmoid)
        fn = torch.sum(targets * (1 - inputs_sigmoid))  # Missing true lesions
        fp = torch.sum((1 - targets) * inputs_sigmoid)  # False alarms
        
        tversky_index = tp / (tp + self.alpha * fn + self.beta * fp + 1e-7)
        tversky_loss = (1 - tversky_index) ** self.gamma
        
        # Extra penalty for completely missing lesion-containing images
        has_lesion = torch.sum(targets) > 0
        has_prediction = torch.sum(inputs_sigmoid > 0.2) > 0
        
        if has_lesion and not has_prediction:
            missing_penalty = torch.tensor(3.0, device=inputs.device)
        else:
            missing_penalty = torch.tensor(0.0, device=inputs.device)
            
        return tversky_loss + missing_penalty

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, gamma=4/3, eps=1e-7):
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.eps = eps

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        tp = torch.sum(targets * inputs)
        fp = torch.sum((1 - targets) * inputs)
        fn = torch.sum(targets * (1 - inputs))
        tversky = (tp + self.eps) / (tp + self.alpha * fn + self.beta * fp + self.eps)
        loss = torch.pow((1 - tversky), self.gamma)
        return loss

class CombinedLoss(nn.Module):
    def __init__(self, bce_alpha=LOSS_BCE_ALPHA):
        super().__init__()
        self.focal_tversky = FocalTverskyLoss()
        self.bce = nn.BCEWithLogitsLoss()
        self.alpha = bce_alpha

    def forward(self, inputs, targets):
        ft_loss = self.focal_tversky(inputs, targets)
        bce_loss = self.bce(inputs, targets)
        return ft_loss + self.alpha * bce_loss

# =============================
# Main Execution
# =============================
def main():
    OUTPUT_DIR = f"analysis_{TARGET_LESION}"
    if os.path.exists(OUTPUT_DIR): shutil.rmtree(OUTPUT_DIR)
    os.makedirs(os.path.join(OUTPUT_DIR, "worst_cases"), exist_ok=True)
    os.makedirs(os.path.join(OUTPUT_DIR, "best_cases"), exist_ok=True)
    print(f"Analysis files will be saved to: {OUTPUT_DIR}")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"--- Training Unet++ model for {LESION_FULL_NAME} ({TARGET_LESION}) ---")
    print(f"Using device: {device}")
    
    df = prepare_fgadr_dataframe(DATASET_PATH)
    # This split creates a 70% train, 15% validation, and 15% test set
    train_val_df, test_df = train_test_split(df, test_size=0.15, random_state=42)
    train_df, val_df = train_test_split(train_val_df, test_size=(0.15/0.85), random_state=42)
    
    # --- MODIFIED: Curate TRAINING DATA ONLY for a 70:30 positive:negative split ---
    print("\nCurating training dataset for a 70:30 positive:negative split...")
    mask_col = f'{TARGET_LESION}_mask'
    train_df['has_lesion'] = train_df[mask_col].apply(lambda p: cv2.imread(p, 0).max() > 0)
    
    positive_df = train_df[train_df['has_lesion']].copy()
    negative_df = train_df[~train_df['has_lesion']].copy()
    
    num_pos_available = len(positive_df)
    num_neg_available = len(negative_df)
    print(f"Found {num_pos_available} positive samples and {num_neg_available} negative samples in the original training set.")

    # Determine the limiting factor to maintain the 70:30 ratio.
    # If we keep all positive samples, we need num_pos_available * (30/70) negative samples.
    # If we keep all negative samples, we need num_neg_available * (70/30) positive samples.
    required_neg_for_ratio = int(round(num_pos_available * (30.0 / 70.0)))
    required_pos_for_ratio = int(round(num_neg_available * (70.0 / 30.0)))

    if required_neg_for_ratio <= num_neg_available:
        # We have enough negative samples to match all positive samples.
        print(f"Using all {num_pos_available} positive samples.")
        negative_subset_df = negative_df.sample(n=required_neg_for_ratio, random_state=42)
        print(f"Sampling {len(negative_subset_df)} negative samples to achieve a 70:30 ratio.")
        train_df = pd.concat([positive_df, negative_subset_df])
    else:
        # We don't have enough negative samples. Use all available negatives and downsample positives.
        print(f"Not enough negative samples. Using all {num_neg_available} negative samples.")
        positive_subset_df = positive_df.sample(n=required_pos_for_ratio, random_state=42)
        print(f"Downsampling to {len(positive_subset_df)} positive samples to achieve a 70:30 ratio.")
        train_df = pd.concat([positive_subset_df, negative_df])

    train_df = train_df.sample(frac=1, random_state=42).reset_index(drop=True)
    
    final_pos = train_df['has_lesion'].sum()
    final_total = len(train_df)
    print(f"New curated training set size: {final_total} images.")
    if final_total > 0:
        pos_perc = final_pos / final_total
        neg_perc = (final_total - final_pos) / final_total
        print(f"Final split: {final_pos} positive ({pos_perc:.2%}), {final_total - final_pos} negative ({neg_perc:.2%}).\n")

    # Save split summary for analysis
    split_summary = []
    splits = {'train': train_df, 'val': val_df, 'test': test_df}
    for name, split_df in splits.items():
        n = len(split_df)
        # For val/test, need to check for has_lesion as it wasn't pre-calculated
        if 'has_lesion' not in split_df.columns:
            lesion_positive_count = split_df[mask_col].apply(lambda p: cv2.imread(p, 0).max() > 0).sum()
        else:
            lesion_positive_count = split_df['has_lesion'].sum()
        lesion_rate = lesion_positive_count / n if n > 0 else 0
        split_summary.append({'name': name, 'n': n, 'counts': lesion_positive_count, 'rate': lesion_rate})
    pd.DataFrame(split_summary).to_csv(os.path.join(OUTPUT_DIR, 'split_summary.csv'), index=False)


    positive_count = train_df['has_lesion'].sum()
    negative_count = len(train_df) - positive_count
    weights = [1.0 / positive_count if has_lesion else 1.0 / negative_count 
               for has_lesion in train_df['has_lesion']]
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
    
    train_transform = A.Compose([
        A.Resize(height=IMG_SIZE, width=IMG_SIZE), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.1),
        A.Rotate(limit=10, p=0.5, border_mode=cv2.BORDER_CONSTANT), A.Normalize(), ToTensorV2(),
    ])
    val_test_transform = A.Compose([A.Resize(height=IMG_SIZE, width=IMG_SIZE), A.Normalize(), ToTensorV2()])

    train_dataset = FGADRDataset(train_df, target_lesion=TARGET_LESION, transform=train_transform)
    val_dataset = FGADRDataset(val_df, target_lesion=TARGET_LESION, transform=val_test_transform)
    test_dataset = FGADRDataset(test_df, target_lesion=TARGET_LESION, transform=val_test_transform)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

    model = SingleLesionModel().to(device)
    model_save_path = f'model_{TARGET_LESION}_UnetPP_{ENCODER}.pth'

    encoder_params = model.model.encoder.parameters()
    decoder_params = [p for n, p in model.named_parameters() if 'encoder' not in n]
    optimizer_params = [
        {'params': encoder_params, 'lr': LR},
        {'params': decoder_params, 'lr': LR * 10}
    ]
    optimizer = torch.optim.AdamW(optimizer_params, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999), eps=1e-8)
    
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=SCHEDULER_PATIENCE, min_lr=MIN_LR)
    
    criterion = SensitivityOptimizedLoss(alpha=0.8, beta=0.2, gamma=2.0)
    print("Using SensitivityOptimizedLoss to improve recall and reduce overly conservative predictions")
    
    scaler = GradScaler(enabled=device.type == 'cuda')
    ema = EMA(model, decay=EMA_DECAY) if EMA_DECAY > 0 else None
    early_stopper = EarlyStopping(patience=PATIENCE, verbose=True, path=model_save_path, delta=0.002, mode='max', ema=ema)

    training_history = []
    print("\nStarting training...")
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0
        optimizer.zero_grad()
        for i, (images, masks) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]")):
            images, masks = images.to(device), masks.to(device)
            with autocast(enabled=device.type == 'cuda'):
                predictions = model(images)
                loss = criterion(predictions, masks)
                loss = loss / ACCUMULATION_STEPS
            scaler.scale(loss).backward()
            if (i + 1) % ACCUMULATION_STEPS == 0 or (i + 1) == len(train_loader):
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                if ema: ema.update(model)
            train_loss += loss.item() * ACCUMULATION_STEPS

        if ema: ema.apply_shadow(model)
        model.eval()
        val_loss, val_dice_scores = 0, []
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
                images, masks = images.to(device), masks.to(device)
                with autocast(enabled=device.type == 'cuda'):
                    predictions = model(images)
                    loss = criterion(predictions, masks)
                val_loss += loss.item()
                dice, _, _, _ = calculate_metrics(masks, torch.sigmoid(predictions))
                val_dice_scores.append(dice)
        if ema: ema.restore(model)
        
        avg_train_loss, avg_val_loss, avg_val_dice = train_loss / len(train_loader), val_loss / len(val_loader), np.mean(val_dice_scores)
        print(f"Epoch {epoch+1:03d}: Train Loss={avg_train_loss:.4f} | Val Loss={avg_val_loss:.4f} Dice (EMA)={avg_val_dice:.4f}")
        training_history.append({'epoch': epoch + 1, 'train_loss': avg_train_loss, 'val_loss': avg_val_loss, 'val_dice': avg_val_dice})
        
        scheduler.step(avg_val_dice)
        early_stopper(avg_val_dice, model)
        if early_stopper.early_stop:
            print("Early stopping triggered.")
            break
            
    pd.DataFrame(training_history).to_csv(os.path.join(OUTPUT_DIR, 'metrics.csv'), index=False)
    print("Saved epoch-wise metrics.")

    print(f"\n--- Training Finished ---\nLoading best model from {model_save_path} for final evaluation...")
    model.load_state_dict(torch.load(model_save_path, map_location=device))
    model.eval()

    # --- Use pre-defined threshold instead of tuning ---
    best_thr = LESION_THRESHOLDS[TARGET_LESION]
    print(f"Using pre-defined threshold for {TARGET_LESION}: {best_thr:.4f}")

    per_image_metrics = []
    worst_cases, best_cases = [], []
    with torch.no_grad():
        for i, (images, masks) in enumerate(tqdm(test_loader, desc="Evaluating on Test Set")):
            image_id = os.path.basename(test_df.iloc[i]['image']).replace('.png', '')
            images, masks = images.to(device), masks.to(device)
            
            with autocast(enabled=device.type == 'cuda'):
                predictions = model(images)
            prob = torch.sigmoid(predictions)
            bin_mask_tensor = (prob > best_thr).float()
            
            bin_mask_np = bin_mask_tensor.squeeze().cpu().numpy()
            processed_mask_np = post_process_mask(bin_mask_np, min_area=MIN_LESION_AREA)
            processed_mask_tensor = torch.from_numpy(processed_mask_np).unsqueeze(0).to(device)
            
            gt_pixels = torch.sum(masks).item()
            pred_pixels = torch.sum(processed_mask_tensor).item()
            
            dice, iou, recall, precision = calculate_metrics(masks, processed_mask_tensor)
            per_image_metrics.append({
                'image_id': image_id, 'dice': dice, 'iou': iou,
                'recall': recall, 'precision': precision,
                'gt_pixels': gt_pixels, 'pred_pixels': pred_pixels
            })
            
            case_data = (dice, image_id, images.cpu(), masks.cpu(), processed_mask_tensor.cpu())
            if len(worst_cases) < 5 or dice < worst_cases[-1][0]:
                worst_cases = sorted(worst_cases + [case_data], key=lambda x: x[0])[:5]
            if len(best_cases) < 5 or dice > best_cases[-1][0]:
                best_cases = sorted(best_cases + [case_data], key=lambda x: x[0], reverse=True)[:5]

    pd.DataFrame(per_image_metrics).to_csv(os.path.join(OUTPUT_DIR, 'per_image_test_metrics.csv'), index=False)
    print("Saved per-image test metrics.")
    
    def save_panel(case_data, folder):
        dice, img_id, img, gt_mask, pred_mask = case_data
        mean, std = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1), torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img = torch.clamp(img.squeeze(0) * std + mean, 0, 1)
        gt_overlay, pred_overlay = img.clone(), img.clone()
        gt_overlay[1, :, :] = torch.max(gt_overlay[1, :, :], gt_mask.squeeze(0))
        pred_overlay[0, :, :] = torch.max(pred_overlay[0, :, :], pred_mask.squeeze(0))
        panel = torch.cat([img, gt_overlay, pred_overlay], dim=2)
        save_image(panel, os.path.join(folder, f"{img_id}_dice{dice:.4f}.png"))

    for case in worst_cases: save_panel(case, os.path.join(OUTPUT_DIR, "worst_cases"))
    for case in best_cases: save_panel(case, os.path.join(OUTPUT_DIR, "best_cases"))
    print("Saved best and worst case image panels.")
        
    avg_metrics = pd.DataFrame(per_image_metrics)[['dice', 'iou', 'recall', 'precision']].mean()
    print(f"\n--- Performance Metrics for {LESION_FULL_NAME} on Test Set (Post-Processing) ---")
    print(f"  - Dice: {avg_metrics['dice']:.4f}")
    print(f"  - IoU: {avg_metrics['iou']:.4f}")
    print(f"  - Sensitivity (Recall): {avg_metrics['recall']:.4f}")
    print(f"  - Precision: {avg_metrics['precision']:.4f}")

if __name__ == "__main__":
    main()