In [None]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import torch.nn as nn
from sklearn.model_selection import train_test_split
import re


DATASET_PATH = "data/fgdar"
ENCODER = "resnet34"
IMG_SIZE = 512
BATCH_SIZE = 8
EPOCHS = 55
LR = 5e-4
WEIGHT_DECAY = 1e-4
MIN_LR = 1e-6
PATIENCE = 6
SCHEDULER_PATIENCE = 2
LAMBDA = {"MA": 1.0, "HE": 1.0, "EX": 1.0, "SE": 1.3}
LESION_TYPES = ['MA', 'HE', 'EX', 'SE']
LESION_NAMES_FULL = ['Microaneurysm', 'Hemorrhage', 'Hard Exudate', 'Soft Exudate']


# Set to True to use your local pretrained weights, or False to use ImageNet
USE_CUSTOM_PRETRAINED_WEIGHTS = 0
CUSTOM_ENCODER_PATH = "best_resnet34_fundus_encoder.pth"



# 2. HELPERS, DATASET, AND MODEL
def numeric_sort_key(s):
    return [int(text) if text.isdigit() else text.lower() for text in re.split('([0-9]+)', s)]

def prepare_fgadr_dataframe(dataset_path):
    seg_set_path = os.path.join(dataset_path, 'Seg-set')
    original_images_path = os.path.join(seg_set_path, "Original_Images")
    lesion_paths = {
        lesion: os.path.join(seg_set_path, f"{lesion_name}_Masks")
        for lesion, lesion_name in zip(['MA', 'HE', 'EX', 'SE'], ['Microaneurysms', 'Hemorrhage', 'HardExudate', 'SoftExudate'])
    }
    image_files = sorted([os.path.join(original_images_path, f) for f in os.listdir(original_images_path)], key=numeric_sort_key)
    data = {'image': image_files}
    for lesion in LESION_TYPES:
        mask_dir = lesion_paths[lesion]
        data[f'{lesion}_mask'] = sorted([os.path.join(mask_dir, f) for f in os.listdir(mask_dir)], key=numeric_sort_key)
    return pd.DataFrame(data)

def green_clahe_preprocess(image, clip_limit=2.0, tile_grid_size=(8, 8), blend_alpha=0.75):
    green_channel = image[:, :, 1]
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    clahe_green_channel = clahe.apply(green_channel)
    clahe_img_3_channel = cv2.cvtColor(clahe_green_channel, cv2.COLOR_GRAY2BGR)
    blended_image = cv2.addWeighted(image, 1 - blend_alpha, clahe_img_3_channel, blend_alpha, 0)
    return blended_image

class FGADRDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.df = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        image = cv2.imread(row['image'])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        masks = []
        for lesion in LESION_TYPES:
            mask_path = row[f'{lesion}_mask']
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None:
                raise FileNotFoundError(f"Could not load mask: {mask_path}")
            mask = (mask > 127).astype(np.float32)
            masks.append(mask)
        masks = np.stack(masks, axis=-1)
        processed_image = green_clahe_preprocess(image)
        if self.transform:
            augmented = self.transform(image=processed_image.astype(np.uint8), mask=masks)
            image = augmented['image']
            masks = augmented['mask']
        if masks.ndim == 3:
            masks = masks.permute(2, 0, 1)
        return image, masks

class MultiDecoderModel(nn.Module):
    def __init__(self, encoder_name=ENCODER, in_channels=3, encoder_weights="imagenet"):
        super().__init__()
        initial_smp_weights = 'imagenet'
        is_custom_weights = isinstance(encoder_weights, str) and os.path.exists(encoder_weights)

        if is_custom_weights:
            initial_smp_weights = None
            print(f"Initializing encoder '{encoder_name}' architecture (will load custom weights)...")
        else:
            print(f"Initializing encoder '{encoder_name}' with STANDARD '{encoder_weights}' weights.")

        self.encoder = smp.encoders.get_encoder(
            name=encoder_name,
            in_channels=in_channels,
            depth=5,
            weights=initial_smp_weights,
            output_stride=8
        )

        if is_custom_weights:
            print(f"Loading CUSTOM weights from file: {encoder_weights}")
            custom_state_dict = torch.load(encoder_weights, map_location='cpu')
            self.encoder.load_state_dict(custom_state_dict, strict=False)
            print("  > Successfully loaded custom pretrained weights into the encoder.")

        # --- Lesion-specific Decoder Heads ---
        unet_model_ma = smp.Unet(encoder_name=encoder_name, encoder_output_stride=8, decoder_channels=(256, 128, 64, 32, 16), decoder_attention_type='scse', classes=1)
        self.ma_decoder, self.ma_head = unet_model_ma.decoder, unet_model_ma.segmentation_head
        deeplab_model_he = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_output_stride=8, decoder_channels=512, decoder_atrous_rates=(6, 12, 18), classes=1)
        self.he_decoder, self.he_head = deeplab_model_he.decoder, deeplab_model_he.segmentation_head
        unet_model_ex = smp.Unet(encoder_name=encoder_name, encoder_output_stride=8, classes=1)
        self.ex_decoder, self.ex_head = unet_model_ex.decoder, unet_model_ex.segmentation_head
        deeplab_model_se = smp.DeepLabV3Plus(encoder_name=encoder_name, encoder_output_stride=8, decoder_channels=640, decoder_atrous_rates=(12, 24, 36), classes=1)
        self.se_decoder, self.se_head = deeplab_model_se.decoder, deeplab_model_se.segmentation_head

    def forward(self, x):
        features = self.encoder(x)
        ma_out = self.ma_head(self.ma_decoder(features))
        he_out = self.he_head(self.he_decoder(features))
        ex_out = self.ex_head(self.ex_decoder(features))
        se_out = self.se_head(self.se_decoder(features))
        return {'MA': ma_out, 'HE': he_out, 'EX': ex_out, 'SE': se_out}

# 3. TRAINING AND EVALUATION UTILITIES

class EarlyStopping:
    def __init__(self, patience=7, verbose=True, delta=0, path='checkpoint.pth', mode='max'):
        self.patience, self.verbose, self.counter, self.best_score, self.early_stop = patience, verbose, 0, None, False
        self.val_metric_best = -float('inf') if mode == 'max' else float('inf')
        self.delta, self.path, self.mode = delta, path, mode

    def __call__(self, val_metric, model):
        score = val_metric
        if self.best_score is None or \
           (self.mode == 'max' and score > self.best_score + self.delta) or \
           (self.mode == 'min' and score < self.best_score - self.delta):
            self.best_score = score
            self.save_checkpoint(val_metric, model)
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose: print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience: self.early_stop = True

    def save_checkpoint(self, val_metric, model):
        if self.verbose: print(f'Validation metric improved ({self.val_metric_best:.6f} --> {val_metric:.6f}). Saving model to {self.path}...')
        torch.save(model.state_dict(), self.path)
        self.val_metric_best = val_metric

def calculate_metrics(y_true, y_pred, epsilon=1e-7):
    # Ensure inputs are binary
    y_true_bin = (y_true > 0.5).float()
    y_pred_bin = (y_pred > 0.5).float()

    tp = torch.sum(y_true_bin * y_pred_bin)
    fp = torch.sum((1 - y_true_bin) * y_pred_bin)
    fn = torch.sum(y_true_bin * (1 - y_pred_bin))

    dice = (2. * tp + epsilon) / (2 * tp + fp + fn + epsilon)
    iou = (tp + epsilon) / (tp + fp + fn + epsilon)
    sensitivity = (tp + epsilon) / (tp + fn + epsilon) 
    precision = (tp + epsilon) / (tp + fp + epsilon)

    return dice.item(), iou.item(), sensitivity.item(), precision.item()

class FocalTverskyLoss(nn.Module):
    def __init__(self, alpha=0.5, beta=0.5, gamma=4/3, eps=1e-7):
        super().__init__()
        self.alpha, self.beta, self.gamma, self.eps = alpha, beta, gamma, eps
    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)
        tp = torch.sum(targets * inputs)
        fp = torch.sum((1 - targets) * inputs)
        fn = torch.sum(targets * (1 - inputs))
        tversky = (tp + self.eps) / (tp + self.alpha * fn + self.beta * fp + self.eps)
        return torch.pow((1 - tversky), self.gamma)

class CombinedLoss(nn.Module):
    def __init__(self, ft_alpha=0.5, ft_beta=0.5, ft_gamma=4/3, bce_weight=0.05):
        super().__init__()
        self.focal_tversky = FocalTverskyLoss(alpha=ft_alpha, beta=ft_beta, gamma=ft_gamma)
        self.bce = nn.BCEWithLogitsLoss()
        self.bce_weight = bce_weight
    def forward(self, inputs, targets):
        return self.focal_tversky(inputs, targets) + self.bce_weight * self.bce(inputs, targets)


def evaluate_on_test_set(model, test_loader, device):
    print("\n--- Evaluating model on the test set ---")
    model.eval()
    
    # Store lists of scores for each lesion type
    test_scores = {lesion: {'dice': [], 'iou': [], 'sensitivity': [], 'precision': []} for lesion in LESION_TYPES}

    with torch.no_grad():
        for images, masks in tqdm(test_loader, desc="Testing"):
            images, masks = images.to(device), masks.to(device)
            predictions = model(images)
            
            for lesion in LESION_TYPES:
                li = LESION_TYPES.index(lesion)
                pred_prob = torch.sigmoid(predictions[lesion])
                true_mask = masks[:, li:li+1]

                # Calculate metrics for this batch item
                dice, iou, sensitivity, precision = calculate_metrics(true_mask, pred_prob)
                
                # Append scores
                test_scores[lesion]['dice'].append(dice)
                test_scores[lesion]['iou'].append(iou)
                test_scores[lesion]['sensitivity'].append(sensitivity)
                test_scores[lesion]['precision'].append(precision)
    
    # Calculate and print average scores
    print("\n--- Test Set Evaluation Results ---")
    print(f"{'Lesion':<15} | {'Dice':<10} | {'IoU':<10} | {'Sensitivity':<12} | {'Precision':<10}")
    print("-" * 65)
    for lesion, metrics in test_scores.items():
        avg_dice = np.mean(metrics['dice'])
        avg_iou = np.mean(metrics['iou'])
        avg_sensitivity = np.mean(metrics['sensitivity'])
        avg_precision = np.mean(metrics['precision'])
        print(f"{lesion:<15} | {avg_dice:<10.4f} | {avg_iou:<10.4f} | {avg_sensitivity:<12.4f} | {avg_precision:<10.4f}")
    print("-" * 65)


# 4. MAIN EXECUTION

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    df = prepare_fgadr_dataframe(DATASET_PATH)
    train_val_df, test_df = train_test_split(df, test_size=0.15, random_state=42)
    train_df, val_df = train_test_split(train_val_df, test_size=(0.15/0.85), random_state=42)
    print(f"Dataset split: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")

    train_transform = A.Compose([
        A.Resize(height=IMG_SIZE, width=IMG_SIZE), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.1),
        A.Rotate(limit=10, p=0.5, border_mode=cv2.BORDER_CONSTANT), A.Normalize(), ToTensorV2(),
    ])
    val_test_transform = A.Compose([A.Resize(height=IMG_SIZE, width=IMG_SIZE), A.Normalize(), ToTensorV2()])

    train_dataset = FGADRDataset(train_df, transform=train_transform)
    val_dataset = FGADRDataset(val_df, transform=val_test_transform)
    test_dataset = FGADRDataset(test_df, transform=val_test_transform)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    # Using batch size of 1 for test loader for per-image evaluation
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    if USE_CUSTOM_PRETRAINED_WEIGHTS:
        print("\n--- CONFIG: Using CUSTOM Pretrained Encoder ---")
        encoder_weights_source, model_save_path = CUSTOM_ENCODER_PATH, 'custom_model.pth'
    else:
        print("\n--- CONFIG: Using standard IMAGENET Encoder ---")
        encoder_weights_source, model_save_path = "imagenet", 'imagenet_model.pth'

    model = MultiDecoderModel(encoder_weights=encoder_weights_source).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999), eps=1e-8)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=SCHEDULER_PATIENCE, min_lr=MIN_LR)
    early_stopper = EarlyStopping(patience=PATIENCE, verbose=True, path=model_save_path, delta=0.002, mode='max')
    criterion = CombinedLoss()

    print(f"\nStarting training (Model will be saved to: {model_save_path})...")
    for epoch in range(EPOCHS):
        model.train()
        train_loss, train_dice_scores = 0, []
        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            predictions = model(images)
            loss, batch_dices_train = 0, []
            for lesion in predictions:
                lesion_idx = LESION_TYPES.index(lesion)
                target_mask = masks[:, lesion_idx:lesion_idx+1]
                loss += criterion(predictions[lesion], target_mask) * LAMBDA[lesion]
                # We only need dice for the training loop printout
                dice, _, _, _ = calculate_metrics(target_mask, torch.sigmoid(predictions[lesion]))
                batch_dices_train.append(dice)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            train_loss += loss.item()
            train_dice_scores.append(np.mean(batch_dices_train))

        model.eval()
        val_loss, val_dice_scores = 0, []
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
                images, masks = images.to(device), masks.to(device)
                predictions = model(images)
                loss, batch_dices = 0, []
                for lesion in predictions:
                    lesion_idx = LESION_TYPES.index(lesion)
                    target_mask = masks[:, lesion_idx:lesion_idx+1]
                    loss += criterion(predictions[lesion], target_mask) * LAMBDA[lesion]
                    dice, _, _, _ = calculate_metrics(target_mask, torch.sigmoid(predictions[lesion]))
                    batch_dices.append(dice)
                val_loss += loss.item()
                val_dice_scores.append(np.mean(batch_dices))

        avg_train_loss, avg_train_dice = train_loss / len(train_loader), np.mean(train_dice_scores)
        avg_val_loss, avg_val_dice = val_loss / len(val_loader), np.mean(val_dice_scores)
        print(f"Epoch {epoch+1:03d}: Train Loss={avg_train_loss:.4f} | Train Dice={avg_train_dice:.4f} | Val Loss={avg_val_loss:.4f} | Val Dice={avg_val_dice:.4f}")
        
        scheduler.step(avg_val_dice)
        early_stopper(avg_val_dice, model)
        if early_stopper.early_stop:
            print("Early stopping triggered."); break

    print(f"\n--- Training Finished ---\nLoading best model for final evaluation from {model_save_path}...")
    model.load_state_dict(torch.load(model_save_path, map_location=device))
    
    evaluate_on_test_set(model, test_loader, device)


if __name__ == "__main__":
    main()