In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from tqdm import tqdm

import random

from siamese import TripletBrainDataset, MRI3DAugmentation, triplet_loss
from resnet3d import generate_model as generate_resnet
from unet import generate_model as generate_unet
from densenet import generate_model as generate_densenet

In [None]:
BATCH_SIZE = 256
NUM_EPOCHS = 2000
EMBEDDING_DIM = 128
MARGIN = 0.1
MODEL = 'unet'
ACCUMULATION_STEPS = 2

In [None]:
if MODEL == 'densenet':
    BATCH_SIZE = 64
    ACCUMULATION_STEPS = 8

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"Seed set to {seed}")

set_seed(42)

In [None]:
augmentation = MRI3DAugmentation(p=0.5)

dataset = 'hcpaspecttriplet'
t = 1
DATA_PATH = f'../../../data/{dataset}'
    
augmentation = MRI3DAugmentation(p=0.5)  # Assuming this is defined elsewhere

train_dataset = TripletBrainDataset(
    data_path=DATA_PATH,
    split='train',
    mining_strategy='hard',
    num_negatives_per_positive=3,
    transform=augmentation,
    margin=MARGIN
)
    
val_dataset = TripletBrainDataset(
    data_path=DATA_PATH,
    split='val',
    mining_strategy='hard',
    num_negatives_per_positive=3,
    margin=MARGIN
)

test_dataset = TripletBrainDataset(
    data_path=DATA_PATH,
    split='test',
    mining_strategy='random',
    num_negatives_per_positive=3,
    margin=MARGIN
)


print(f"Dataset loaded:")
print(f"Train: {len(train_dataset.twin_pairs)} twin pairs, {len(train_dataset.all_subjects)} total subjects")
print(f"Val: {len(val_dataset.twin_pairs)} twin pairs, {len(val_dataset.all_subjects)} total subjects")
print(f"Test: {len(test_dataset.twin_pairs)} twin pairs, {len(test_dataset.all_subjects)} total subjects")
print(f"Train dataset size: {len(train_dataset)} triplets")

# Create data loaders
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Example: Test loading a batch
print("\nTesting data loading...")
batch = next(iter(train_loader))
print(f"Batch shapes:")
print(f"  Anchor: {batch['anchor'].shape}")
print(f"  Positive: {batch['positive'].shape}")
print(f"  Negative: {batch['negative'].shape}")
print(f"  Sample anchor ID: {batch['anchor_id'][0]}")
print(f"  Sample positive ID: {batch['positive_id'][0]}")
print(f"  Sample negative ID: {batch['negative_id'][0]}")

# Training loop example (without actual model training)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
def init_embedding_layers(model):
    """Initialize final layers with smaller weights for stable embeddings."""
    # Smaller initialization for embedding layers
    nn.init.xavier_uniform_(model.fc1.weight, gain=0.5)
    nn.init.xavier_uniform_(model.fc2.weight, gain=0.5)
    
    # Initialize biases to zero if they exist
    if hasattr(model.fc1, 'bias') and model.fc1.bias is not None:
        nn.init.zeros_(model.fc1.bias)
    if hasattr(model.fc2, 'bias') and model.fc2.bias is not None:
        nn.init.zeros_(model.fc2.bias)

In [None]:
generate_model = None

if MODEL == 'resnet':
    generate_model = generate_resnet
elif MODEL == 'unet':
    generate_model = generate_unet
elif MODEL == 'densenet':
    generate_model = generate_densenet

model = generate_model(model_depth=10, embedding_dim=EMBEDDING_DIM, use_attention=True)
model = model.to(device)

# Clear cache more aggressively
torch.cuda.empty_cache()
torch.cuda.synchronize()

# Use memory mapping for large datasets
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

model.train()

In [None]:
for name, param in model.named_parameters():
    if not param.requires_grad:
        print(f"Parameter {name} does not require grad!")
        param.requires_grad = True

In [None]:
def embedding_regularization_loss(embeddings, lambda_reg=0.01):
    """Prevent embedding collapse by encouraging diversity."""
    # Compute pairwise distances
    pdist = F.pdist(embeddings, p=2)
    
    # Penalize very small distances (collapse)
    collapse_penalty = torch.exp(-pdist).mean()
    
    return lambda_reg * collapse_penalty

In [None]:
def train_triplet(model, dataloader, optimizer, scheduler, device, epoch, margin=1.0, 
                          scaler=None, accumulation_steps=1, use_regularization=True):
    """
    Training loop with proper scheduler placement for OneCycleLR.
    """
    model.train()
    total_loss = 0
    total_pos_dist = 0
    total_neg_dist = 0
    total_active_ratio = 0
    gradient_errors = 0
    use_mixed_precision = scaler is not None

    mining_strategy = 'hard'
    
    # Update mining strategy
    if hasattr(dataloader.dataset, 'mining_strategy'):
        dataloader.dataset.mining_strategy = mining_strategy
        
        if mining_strategy in ['hard', 'semi_hard']:
            dataloader.dataset.set_model_for_mining(model)
            if epoch % 3 == 0:
                dataloader.dataset.clear_embedding_cache()
    
    num_batches = 0
    optimizer.zero_grad()

    pbar = tqdm(enumerate(dataloader), 
                total=len(dataloader),
                desc=f'Epoch {epoch:3d} [{mining_strategy:9s}]',
                leave=False,
                ncols=120)

    for batch_idx, batch in pbar:
        try:
            anchor = batch['anchor'].to(device, non_blocking=True)
            positive = batch['positive'].to(device, non_blocking=True)
            negative = batch['negative'].to(device, non_blocking=True)
            
            if scaler is not None:
                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    anchor_emb = model(anchor)
                    positive_emb = model(positive)
                    negative_emb = model(negative)
                    
                    loss, pos_dist, neg_dist, active_ratio = triplet_loss(
                        anchor_emb,
                        positive_emb,
                        negative_emb,
                        margin=margin,
                        distance_metric='euclidean',
                    )
                    
                    # Add regularization (was missing in mixed precision path!)
                    if use_regularization:
                        all_embeddings = torch.cat([anchor_emb, positive_emb, negative_emb], dim=0)
                        reg_loss = embedding_regularization_loss(all_embeddings, lambda_reg=0.01)
                        loss = loss + reg_loss
                    
                    loss = loss / accumulation_steps
                
                # Backward pass
                try:
                    scaler.scale(loss).backward()
                    backward_successful = True
                except RuntimeError as e:
                    if "does not require grad" in str(e):
                        gradient_errors += 1
                        optimizer.zero_grad()
                        backward_successful = False
                        continue
                    else:
                        raise e
                
                # Optimizer step with accumulation
                if backward_successful and (batch_idx + 1) % accumulation_steps == 0:
                    try:
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)  # More aggressive clipping
                        
                        scaler.step(optimizer)
                        scaler.update()
                        optimizer.zero_grad()
                        
                        # Scheduler step AFTER optimizer step (for OneCycleLR)
                        scheduler.step()
                        
                    except RuntimeError as e:
                        if "No inf checks were recorded" in str(e):
                            gradient_errors += 1
                            optimizer.zero_grad()
                            scaler = torch.amp.GradScaler(enabled=use_mixed_precision)
                            continue
                        else:
                            raise e
                        
            else:
                # Non-mixed precision path
                anchor_emb = model(anchor)
                positive_emb = model(positive)
                negative_emb = model(negative)
                
                loss, pos_dist, neg_dist, active_ratio = triplet_loss(
                    anchor_emb, positive_emb, negative_emb, margin=margin, distance_metric='euclidean',
                )
                
                if use_regularization:
                    all_embeddings = torch.cat([anchor_emb, positive_emb, negative_emb], dim=0)
                    reg_loss = embedding_regularization_loss(all_embeddings, lambda_reg=0.01)
                    loss = loss + reg_loss
                
                loss = loss / accumulation_steps
                
                # Backward pass
                try:
                    loss.backward()
                    backward_successful = True
                except RuntimeError as e:
                    if "does not require grad" in str(e):
                        gradient_errors += 1
                        optimizer.zero_grad()
                        backward_successful = False
                        continue
                    else:
                        raise e
                
                # Optimizer step with accumulation
                if backward_successful and (batch_idx + 1) % accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                    optimizer.step()
                    optimizer.zero_grad()
                    
                    # Scheduler step AFTER optimizer step
                    scheduler.step()
            
            # Accumulate metrics
            total_loss += loss.item() * accumulation_steps
            total_pos_dist += pos_dist.item()
            total_neg_dist += neg_dist.item()
            total_active_ratio += active_ratio
            num_batches += 1
            
        except Exception as e:
            print(f"\nUnexpected error in batch {batch_idx}: {str(e)}")
            optimizer.zero_grad()
            continue
        
        # Update progress bar
        if num_batches > 0:
            current_loss = total_loss / num_batches
            current_pos_dist = total_pos_dist / num_batches
            current_neg_dist = total_neg_dist / num_batches
            current_active_ratio = total_active_ratio / num_batches
            
            current_lr = optimizer.param_groups[0]['lr']  # Show current LR
            
            postfix = {
                'Loss': f'{current_loss:.3f}',
                'Pos': f'{current_pos_dist:.3f}',
                'Neg': f'{current_neg_dist:.3f}',
                'Active': f'{current_active_ratio:.2%}',
                'LR': f'{current_lr:.2e}'
            }
            
            if gradient_errors > 0:
                postfix['Errs'] = f'{gradient_errors}'
                
            pbar.set_postfix(postfix)
    
    # Handle remaining gradients
    if (batch_idx + 1) % accumulation_steps != 0:
        if scaler is not None:
            try:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()  # Don't forget scheduler here too!
            except RuntimeError:
                pass
        else:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
            optimizer.step()
            scheduler.step()  # Don't forget scheduler here too!
    
    pbar.close()
    
    if gradient_errors > 0:
        print(f"Epoch {epoch}: Encountered {gradient_errors} gradient errors out of {len(dataloader)} batches")
    
    avg_loss = total_loss / num_batches if num_batches > 0 else 0
    avg_pos_dist = total_pos_dist / num_batches if num_batches > 0 else 0
    avg_neg_dist = total_neg_dist / num_batches if num_batches > 0 else 0
    
    return avg_loss, avg_pos_dist, avg_neg_dist

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score, precision_recall_curve, auc

def evaluate_triplet_val(model, val_loader, device, margin=1.0):
    """Improved validation evaluation, including additional embedding metrics, AUC, and F1."""
    model.eval()
    total_loss = 0
    num_batches = 0

    pos_distances = []
    neg_distances = []
    embeddings = []

    with torch.no_grad():
        for batch in val_loader:
            anchor = batch['anchor'].to(device, non_blocking=True)
            positive = batch['positive'].to(device, non_blocking=True)
            negative = batch['negative'].to(device, non_blocking=True)

            anchor_emb = model(anchor)
            positive_emb = model(positive)
            negative_emb = model(negative)

            loss, _, _, _ = triplet_loss( # We only need the loss here, distances are re-calculated with cosine
                anchor_emb, positive_emb, negative_emb, margin=margin, distance_metric='euclidean' # 'euclidean' as per original
            )

            total_loss += loss.item()
            num_batches += 1

            # Compute distances using 1 - F.cosine_similarity as in compute_embedding_metrics
            pos_dist = 1 - F.cosine_similarity(anchor_emb, positive_emb)
            neg_dist = 1 - F.cosine_similarity(anchor_emb, negative_emb)

            pos_distances.extend(pos_dist.cpu().numpy())
            neg_distances.extend(neg_dist.cpu().numpy())

            # Store embeddings for additional metrics
            embeddings.extend(anchor_emb.cpu().numpy())
            embeddings.extend(positive_emb.cpu().numpy())
            embeddings.extend(negative_emb.cpu().numpy())

    avg_loss = total_loss / num_batches

    pos_distances = np.array(pos_distances)
    neg_distances = np.array(neg_distances)
    embeddings = np.array(embeddings)

    # Prepare labels and scores for AUC/F1 calculation
    # For triplet loss: positive pairs should have lower distances (better matches)
    # negative pairs should have higher distances (worse matches)
    labels = np.concatenate([np.ones(len(pos_distances)), np.zeros(len(neg_distances))])  # 1 for positive pairs, 0 for negative pairs
    distances = np.concatenate([pos_distances, neg_distances])
    
    # Convert distances to similarity scores (lower distance = higher similarity)
    similarities = 1 - distances
    
    # Calculate AUC-ROC
    auc_roc = roc_auc_score(labels, similarities)
    
    # Calculate AUC-PR
    precision, recall, _ = precision_recall_curve(labels, similarities)
    auc_pr = auc(recall, precision)
    
    # Calculate F1 score using optimal threshold
    # Find threshold that maximizes F1 score
    thresholds = np.linspace(similarities.min(), similarities.max(), 100)
    f1_scores = []
    
    for threshold in thresholds:
        y_pred = (similarities >= threshold).astype(int)
        if len(np.unique(y_pred)) > 1:  # Avoid division by zero
            f1_scores.append(f1_score(labels, y_pred))
        else:
            f1_scores.append(0)
    
    best_f1 = max(f1_scores)
    best_threshold = thresholds[np.argmax(f1_scores)]

    metrics = {
        'avg_loss': avg_loss,
        
        # Classification metrics
        'auc_roc': auc_roc,
        'auc_pr': auc_pr,
        'best_f1': best_f1,
        'best_threshold': best_threshold,
        
        # Distance statistics
        'pos_dist_mean': np.mean(pos_distances),
        'pos_dist_std': np.std(pos_distances),
        'neg_dist_mean': np.mean(neg_distances),
        'neg_dist_std': np.std(neg_distances),
        'dist_gap': np.mean(neg_distances) - np.mean(pos_distances),

        # Margin violations
        'margin_violations': np.mean(neg_distances < pos_distances + 0.1), # Assuming 0.1 as a generic margin for this metric

        # Distribution overlap
        'pos_95_percentile': np.percentile(pos_distances, 95),
        'neg_5_percentile': np.percentile(neg_distances, 5),
        'distribution_overlap': max(0, np.percentile(pos_distances, 95) - np.percentile(neg_distances, 5)),

        # Embedding space quality
        'embedding_norm_std': np.std(np.linalg.norm(embeddings, axis=1)),

        # Relative improvement
        'relative_gap': (np.mean(neg_distances) - np.mean(pos_distances)) / np.mean(pos_distances),

        # Separability score (higher is better)
        'separability': np.mean(neg_distances > pos_distances)
    }

    return metrics

In [None]:
# Initialize history tracking
history = {
    'train_loss': [], 'val_loss': [],
    'train_pos_distances': [], 'train_neg_distances': [],
    'val_pos_distances': [], 'val_neg_distances': [],
    'val_pos_neg_diffs': [],
    'learning_rates': [], 'separability': [],
    'auc_roc': [], 'best_f1': [],
}

def train_triplet_model(model, train_loader, val_loader=None, num_epochs=50,
                        device='cuda', lr=1e-3, margin=1.0,
                        use_mixed_precision=True, accumulation_steps=1):
    """
    Improved training function with better optimization and scheduling.
    """
    global history

    # Better optimizer configuration
    optimizer = optim.AdamW(
        model.parameters(),
        lr=1e-3,
        weight_decay=5e-4,
        betas=(0.9, 0.999),
        eps=1e-8
    )

    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=lr,  # Lower max LR (was 5e-4)
        epochs=num_epochs,
        steps_per_epoch=len(train_loader),
        pct_start=0.3,  # Longer warmup (was 0.1)
        div_factor=50,  # Start even lower (was 25)
        final_div_factor=1000
    )

    scaler = torch.amp.GradScaler(enabled=use_mixed_precision)

    initial_margin = margin

    best_pos_neg_diff = float('-inf')
    best_separability = float('-inf')
    best_val_loss = float('inf')
    best_auc_roc = float('-inf')
    best_f1 = float('-inf')

    # --- Print the header once before the training loop ---
    header_str = (
        "{:<12} {:<8} {:<8} {:<8} {:<10} {:<10} {:<20} {:<20} {:<10} {:<12} {:<8} {:<8}"
    ).format(
        "Epoch", "Loss", "Pos", "Neg", "LR",
        "Val Loss", "Val Pos (Mean/Std)", "Val Neg (Mean/Std)", "Dist Gap", "Separability", "AUC", "F1"
    )
    print("\n" + header_str)
    print("-" * len(header_str))

    val_metrics = None

    for epoch in range(num_epochs):
        do_log = False
        current_margin = initial_margin

        # Training phase
        train_loss, train_pos_dist, train_neg_dist = train_triplet(
            model, train_loader, optimizer, scheduler, device, epoch,
            margin=current_margin, scaler=scaler, accumulation_steps=accumulation_steps,
            use_regularization=True
        )

        current_lr = optimizer.param_groups[0]['lr']

        # Record metrics
        history['train_loss'].append(train_loss)
        history['train_pos_distances'].append(train_pos_dist)
        history['train_neg_distances'].append(train_neg_dist)
        history['learning_rates'].append(current_lr)

        # Initialize log_msg for the current epoch
        epoch_log_values = {
            "epoch": f"{epoch+1}/{num_epochs}",
            "train_loss": f"{train_loss:.3f}",
            "train_pos_dist": f"{train_pos_dist:.3f}",
            "train_neg_dist": f"{train_neg_dist:.3f}",
            "current_lr": f"{current_lr:.2e}"
        }

        # Validation phase
        if val_loader is not None:
            val_metrics = evaluate_triplet_val(
                model, val_loader, device, margin=current_margin
            )

            # Update history with all new metrics
            history['val_loss'].append(val_metrics['avg_loss'])
            history['val_pos_distances'].append(val_metrics['pos_dist_mean'])
            history['val_neg_distances'].append(val_metrics['neg_dist_mean'])
            history['val_pos_neg_diffs'].append(val_metrics['dist_gap'])
            history['separability'].append(val_metrics['separability'])
            history['auc_roc'].append(val_metrics['auc_roc'])
            history['best_f1'].append(val_metrics['best_f1'])

            # Append validation metrics to the current epoch's log values
            epoch_log_values.update({
                "val_loss": f"{val_metrics['avg_loss']:.3f}",
                "val_pos_mean_std": f"{val_metrics['pos_dist_mean']:.3f}/{val_metrics['pos_dist_std']:.3f}",
                "val_neg_mean_std": f"{val_metrics['neg_dist_mean']:.3f}/{val_metrics['neg_dist_std']:.3f}",
                "dist_gap": f"{val_metrics['dist_gap']:.3f}",
                "separability": f"{val_metrics['separability']:.3f}",
                "auc_roc": f"{val_metrics['auc_roc']:.3f}",
                "best_f1": f"{val_metrics['best_f1']:.3f}",
            })

            # Early stopping and model saving
            star_indicator = ""
            if val_metrics['dist_gap'] > best_pos_neg_diff:
                best_pos_neg_diff = val_metrics['dist_gap']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_metrics': val_metrics, # Save all validation metrics
                    'history': history
                }, "best_triplet_model_dist_gap.pth")
                star_indicator += " *" # Indicate that a new best model was saved
                do_log = True
            
            if val_metrics['separability'] > best_separability:
                best_separability = val_metrics['separability']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_metrics': val_metrics, # Save all validation metrics
                    'history': history
                }, "best_triplet_model_separability.pth")
                star_indicator += " +" # Indicate that a new best model was saved
                do_log = True

            if val_metrics['avg_loss'] < best_val_loss:
                best_val_loss = val_metrics['avg_loss']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_metrics': val_metrics, # Save all validation metrics
                    'history': history
                }, "best_triplet_model_val_loss.pth")
                star_indicator += " ^" # Indicate that a new best model was saved
                do_log = True

            if val_metrics['auc_roc'] > best_auc_roc:
                best_auc_roc = val_metrics['auc_roc']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_metrics': val_metrics,
                    'history': history
                }, "best_triplet_model_auc.pth")
                star_indicator += " ◆" # AUC best model
                do_log = True

            if val_metrics['best_f1'] > best_f1:
                best_f1 = val_metrics['best_f1']
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'val_metrics': val_metrics,
                    'history': history
                }, "best_triplet_model_f1.pth")
                star_indicator += " ◇" # F1 best model
                do_log = True

            # Print formatted epoch log
            if do_log:
                print(
                    (
                        "{:<12} {:<8} {:<8} {:<8} {:<10} {:<10} {:<20} {:<20} {:<10} {:<12} {:<8} {:<8}"
                    ).format(
                        epoch_log_values["epoch"],
                        epoch_log_values["train_loss"],
                        epoch_log_values["train_pos_dist"],
                        epoch_log_values["train_neg_dist"],
                        epoch_log_values["current_lr"],
                        epoch_log_values["val_loss"],
                        epoch_log_values["val_pos_mean_std"],
                        epoch_log_values["val_neg_mean_std"],
                        epoch_log_values["dist_gap"],
                        epoch_log_values["separability"],
                        epoch_log_values["auc_roc"],
                        epoch_log_values["best_f1"]
                    ) + star_indicator
                )

    torch.save({
                'epoch': num_epochs,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'val_metrics': val_metrics, # Save all validation metrics
                'history': history
            }, "best_triplet_model_last_epoch.pth")

In [None]:
train_triplet_model(
    model, train_loader, val_loader=val_loader, 
    num_epochs=NUM_EPOCHS, device=device, margin=MARGIN, accumulation_steps=ACCUMULATION_STEPS, lr=1e-4
)

In [None]:
# Load best model
model = generate_model(model_depth=10, embedding_dim=128, use_attention=True)
checkpoint = torch.load("best_triplet_model_last_epoch.pth")
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

In [None]:
# Load best model
model_dist_gap = generate_model(model_depth=10, embedding_dim=128, use_attention=True)  
checkpoint = torch.load("best_triplet_model_dist_gap.pth")
model_dist_gap.load_state_dict(checkpoint['model_state_dict'])
model_dist_gap = model_dist_gap.to(device)

In [None]:
# Load best model
model_separability = generate_model(model_depth=10, embedding_dim=128, use_attention=True)
checkpoint = torch.load("best_triplet_model_separability.pth")
model_separability.load_state_dict(checkpoint['model_state_dict'])
model_separability = model_separability.to(device)

In [None]:
# Load best model
model_val_loss = generate_model(model_depth=10, embedding_dim=128, use_attention=True)
checkpoint = torch.load("best_triplet_model_val_loss.pth")
model_val_loss.load_state_dict(checkpoint['model_state_dict'])
model_val_loss = model_val_loss.to(device)

In [None]:
# Load best model
model_auc = generate_model(model_depth=10, embedding_dim=128, use_attention=True)
checkpoint = torch.load("best_triplet_model_auc.pth")
model_auc.load_state_dict(checkpoint['model_state_dict'])
model_auc = model_auc.to(device)

In [None]:
# Load best model
model_f1 = generate_model(model_depth=10, embedding_dim=128, use_attention=True)
checkpoint = torch.load("best_triplet_model_f1.pth")
model_f1.load_state_dict(checkpoint['model_state_dict'])
model_f1 = model_f1.to(device)

In [None]:
# Save history to pickle
import pickle
with open('history.pkl', 'wb') as f:
    pickle.dump(history, f)

In [None]:
# Load history from pickle
with open('history.pkl', 'rb') as f:
    history = pickle.load(f)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_history_interpolated(history, n_epochs=1):
    """
    Plot training history with interpolation for clarity.
    Marks 5 best models based on different metrics including AUC and F1.
    
    Args:
        history: Dictionary containing training history with keys:
                 'train_loss', 'val_loss', 'pos_distances', 'neg_distances',
                 'val_pos_neg_diffs', 'separability', 'auc_roc', 'best_f1'
        n_epochs: Number of epochs to average over for interpolation (default=1, no interpolation)
    """
    
    def interpolate_data(data, n):
        """Take the mean at every n epochs."""
        if n <= 1:
            return data, range(1, len(data) + 1)
        
        interpolated = []
        new_x = []
        
        for i in range(0, len(data) + 1, n):
            end_idx = min(i + n, len(data))
            chunk = data[i:end_idx]
            interpolated.append(np.mean(chunk))
            new_x.append(i + len(chunk) // 2 + 1)  # Center of the chunk
        
        return interpolated, new_x
    
    # Interpolate all data
    train_loss_interp, x_train = interpolate_data(history['train_loss'], n_epochs)
    val_loss_interp, x_val = interpolate_data(history['val_loss'], n_epochs)
    pos_distances_interp, x_pos = interpolate_data(history['val_pos_distances'], n_epochs)
    neg_distances_interp, x_neg = interpolate_data(history['val_neg_distances'], n_epochs)
    pos_neg_diffs_interp, x_diff = interpolate_data(history['val_pos_neg_diffs'], n_epochs)
    separability_interp, x_sep = interpolate_data(history['separability'], n_epochs)
    auc_roc_interp, x_auc = interpolate_data(history['auc_roc'], n_epochs)
    best_f1_interp, x_f1 = interpolate_data(history['best_f1'], n_epochs)
    
    # Use the original x for consistent plotting
    x = x_train
    
    plt.figure(figsize=(12, 5))
    marker_size = 3
    
    plt.plot(x, train_loss_interp, label='Training Loss', markersize=marker_size, zorder=0)
    plt.plot(x, val_loss_interp, label='Validation Loss', markersize=marker_size, zorder=0)
    plt.plot(x, pos_distances_interp, label='Twin Distances', marker='o', markersize=marker_size, linestyle='')
    plt.plot(x, neg_distances_interp, label='Non-Twin Distances', marker='o', markersize=marker_size, linestyle='')
    plt.plot(x, pos_neg_diffs_interp, label='Pos-Neg Diff', marker='s', markersize=marker_size, linestyle='')
    plt.plot(x, separability_interp, label='Separability', marker='^', markersize=marker_size, linestyle='')
    plt.plot(x, auc_roc_interp, label='AUC-ROC', marker='d', markersize=marker_size, linestyle='')
    plt.plot(x, best_f1_interp, label='F1 Score', marker='*', markersize=marker_size, linestyle='')
    
    # Find best epochs for each metric
    best_val_loss_original = np.argmin(history['val_loss']) + 1
    best_pos_neg_diff_original = np.argmax(history['val_pos_neg_diffs']) + 1  # Higher is better
    best_separability_original = np.argmax(history['separability']) + 1  # Higher is better
    best_auc_original = np.argmax(history['auc_roc']) + 1  # Higher is better
    best_f1_original = np.argmax(history['best_f1']) + 1  # Higher is better
    
    # Get the actual metric values at best epochs
    best_val_loss_value = history['val_loss'][best_val_loss_original - 1]
    best_pos_neg_diff_value = history['val_pos_neg_diffs'][best_pos_neg_diff_original - 1]
    best_separability_value = history['separability'][best_separability_original - 1]
    best_auc_value = history['auc_roc'][best_auc_original - 1]
    best_f1_value = history['best_f1'][best_f1_original - 1]
    
    # Find closest interpolated points
    best_val_loss_idx = np.argmin(np.abs(np.array(x) - best_val_loss_original))
    best_pos_neg_diff_idx = np.argmin(np.abs(np.array(x) - best_pos_neg_diff_original))
    best_separability_idx = np.argmin(np.abs(np.array(x) - best_separability_original))
    best_auc_idx = np.argmin(np.abs(np.array(x) - best_auc_original))
    best_f1_idx = np.argmin(np.abs(np.array(x) - best_f1_original))
    
    best_epochs = [
        (x[best_val_loss_idx], best_val_loss_idx, 'Val Loss', 'red', best_val_loss_value),
        (x[best_pos_neg_diff_idx], best_pos_neg_diff_idx, 'Pos-Neg Diff', 'orange', best_pos_neg_diff_value),
        (x[best_separability_idx], best_separability_idx, 'Separability', 'purple', best_separability_value),
        (x[best_auc_idx], best_auc_idx, 'AUC-ROC', 'blue', best_auc_value),
        (x[best_f1_idx], best_f1_idx, 'F1 Score', 'green', best_f1_value)
    ]
    
    # Plot best model lines and labels
    for epoch, idx, metric_name, color, metric_value in best_epochs:
        plt.plot([epoch, epoch], 
                 [pos_distances_interp[idx], neg_distances_interp[idx]], 
                 color=color, linestyle='-', lw=1.5, zorder=2, 
                 label=f'Best {metric_name} [{metric_value:.4f}] (Epoch {epoch})')
        
        # Add text annotations
        text_offset = 0.02
        plt.text(epoch, pos_distances_interp[idx] - text_offset, 
                 f'{pos_distances_interp[idx]:.4f}', 
                 fontsize=8, ha='center', va='top', zorder=99, 
                 bbox=dict(boxstyle="round,pad=0.1", facecolor=color, alpha=0.3))
        plt.text(epoch, neg_distances_interp[idx] + text_offset, 
                 f'{neg_distances_interp[idx]:.4f}', 
                 fontsize=8, ha='center', va='bottom', zorder=99,
                 bbox=dict(boxstyle="round,pad=0.1", facecolor=color, alpha=0.3))
    
    # Draw gray vertical lines for other epochs
    best_epoch_set = {epoch for epoch, _, _, _, _ in best_epochs}
    for i, epoch in enumerate(x):
        if epoch not in best_epoch_set:
            plt.plot([epoch, epoch], 
                     [pos_distances_interp[i], neg_distances_interp[i]], 
                     color='gray', linestyle='--', lw=0.5, zorder=1)
    
    plt.grid(alpha=0.2)
    plt.xlabel('Epoch', fontsize=10, weight='bold')
    plt.ylabel('Loss', fontsize=10, weight='bold')
    
    # Set x-ticks starting from 1, then every 100 epochs
    max_epoch = max(x) if x else 100
    xtick_positions = [1] + list(range(100, max_epoch + 1, 100))
    plt.xticks(xtick_positions)
    
    plt.legend()
    plt.show()

# Example usage:
plot_history_interpolated(history, n_epochs=1)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
from sklearn.metrics import precision_score, recall_score, f1_score, roc_curve, auc, confusion_matrix, precision_recall_curve, fbeta_score
from tqdm import tqdm
import torch.nn.functional as F

def create_pairs_from_triplet_dataset(triplet_dataset):
    """
    Create a traditional pairs dataset from triplet dataset for evaluation.
    Returns both positive pairs (twins) and negative pairs (non-twins).
    """
    pairs = []
    labels = []
    
    # Add positive pairs from twin pairs
    for subject1_id, subject2_id in triplet_dataset.twin_pairs:
        pairs.append((subject1_id, subject2_id))
        labels.append(1)  # Twin pair
    
    # Add negative pairs by sampling non-twins
    # Sample same number of negative pairs as positive pairs for balance
    num_negative_pairs = len(triplet_dataset.twin_pairs) * 1
    negative_pairs_added = 0
    
    # Get all possible non-twin combinations
    import random
    random.seed(42)  # For reproducible evaluation
    
    while negative_pairs_added < num_negative_pairs:
        # Randomly sample two subjects
        subject1 = random.choice(triplet_dataset.all_subjects)
        subject2 = random.choice(triplet_dataset.all_subjects)
        
        if subject1 == subject2:
            continue
            
        # Check if they're twins (skip if they are)
        is_twin_pair = False
        twin_idx = triplet_dataset.subject_to_twin_idx.get(subject1)
        if twin_idx is not None:
            twin_subject1, twin_subject2 = triplet_dataset.twin_pairs[twin_idx]
            if subject2 in [twin_subject1, twin_subject2]:
                is_twin_pair = True
                
        if not is_twin_pair:
            # Sort to avoid duplicates
            pair = tuple(sorted([subject1, subject2]))
            if pair not in [(tuple(sorted([s1, s2])) for s1, s2 in pairs)]:
                pairs.append(pair)
                labels.append(0)  # Non-twin pair
                negative_pairs_added += 1
    
    return pairs, labels


def evaluate(model, test_dataset, device, threshold=None, min_recall=None, 
             min_precision=None, batch_size=32):
    """
    Evaluate a triplet-trained model using pairwise comparisons with comprehensive threshold analysis.
    
    Args:
        model: The trained model
        test_dataset: TripletBrainDataset for test data
        device: Device to run evaluation on
        threshold: Fixed threshold to use (if None, will be optimized)
        min_recall: Minimum required recall (0-1)
        min_precision: Minimum required precision (0-1)
        batch_size: Batch size for evaluation
    """
    model.eval()
    
    # Create pairs from triplet dataset
    print("Creating evaluation pairs from triplet dataset...")
    pairs, labels = create_pairs_from_triplet_dataset(test_dataset)
    
    print(f"Created {len(pairs)} pairs for evaluation:")
    print(f"  Positive pairs (twins): {sum(labels)} ({sum(labels)/len(labels)*100:.1f}%)")
    print(f"  Negative pairs (non-twins): {len(labels) - sum(labels)} ({(len(labels) - sum(labels))/len(labels)*100:.1f}%)")
    
    all_distances = []
    all_labels = []
    
    # Process pairs in batches
    print("Computing embeddings and distances...")
    with torch.no_grad():
        for i in tqdm(range(0, len(pairs), batch_size), desc="Evaluating"):
            batch_pairs = pairs[i:i+batch_size]
            batch_labels = labels[i:i+batch_size]
            
            batch_images1 = []
            batch_images2 = []
            
            # Load images for this batch
            for subject1_id, subject2_id in batch_pairs:
                try:
                    img1 = test_dataset._load_image(subject1_id)
                    img2 = test_dataset._load_image(subject2_id)
                    
                    # Convert to tensors and add channel dimension
                    img1 = torch.tensor(img1, dtype=torch.float32).unsqueeze(0)  # [1, D, H, W]
                    img2 = torch.tensor(img2, dtype=torch.float32).unsqueeze(0)

                    img1.to(device)
                    img2.to(device)
                    
                    # Apply transforms if available
                    if test_dataset.transform:
                        img1 = test_dataset.transform(img1)
                        img2 = test_dataset.transform(img2)
                    
                    batch_images1.append(img1)
                    batch_images2.append(img2)
                    
                except Exception as e:
                    print(f"Error loading images for pair {subject1_id}, {subject2_id}: {e}")
                    # Use dummy tensors as fallback
                    dummy_shape = (1, 64, 64, 64)  # Adjust based on your data
                    batch_images1.append(torch.zeros(dummy_shape))
                    batch_images2.append(torch.zeros(dummy_shape))
            
            # Stack into batch tensors
            batch_img1 = torch.stack(batch_images1).to(device).float()
            batch_img2 = torch.stack(batch_images2).to(device).float()
            
            embeddings1 = model(batch_img1)
            embeddings2 = model(batch_img2)
            
            # Compute distances
            # distances = F.pairwise_distance(embeddings1, embeddings2, p=2)
            distances = 1 - F.cosine_similarity(embeddings1, embeddings2, dim=1)
            
            all_distances.extend(distances.cpu().numpy())
            all_labels.extend(batch_labels)
    
    all_distances = np.array(all_distances)
    all_labels = np.array(all_labels)
    
    print(f"Distance statistics:")
    print(f"  Min: {all_distances.min():.4f}")
    print(f"  Max: {all_distances.max():.4f}")
    print(f"  Mean: {all_distances.mean():.4f}")
    print(f"  Std: {all_distances.std():.4f}")
    
    print(f"Label distribution:")
    print(f"  Twins (1): {np.sum(all_labels == 1)} ({np.mean(all_labels == 1)*100:.1f}%)")
    print(f"  Non-twins (0): {np.sum(all_labels == 0)} ({np.mean(all_labels == 0)*100:.1f}%)")

    # Find optimal threshold if not provided
    if threshold is None:
        print("Finding optimal threshold...")
        
        # Method 1: Use reasonable distance range instead of ROC thresholds
        min_dist, max_dist = all_distances.min(), all_distances.max()
        distance_range = max_dist - min_dist
        
        # Create candidate thresholds within the actual distance range
        n_thresholds = 1000
        candidate_thresholds = np.linspace(min_dist + 0.01 * distance_range, 
                                         max_dist - 0.01 * distance_range, 
                                         n_thresholds)
        
        print(f"Searching thresholds in range [{candidate_thresholds[0]:.4f}, {candidate_thresholds[-1]:.4f}]")
        
        # Check for recall or precision constraints
        if min_recall is not None:
            print(f"Applying minimum recall constraint: {min_recall:.3f}")
            threshold = find_threshold_for_min_recall(all_distances, all_labels, candidate_thresholds, min_recall)
            final_method = f'min_recall_{min_recall:.3f}'
            
        elif min_precision is not None:
            print(f"Applying minimum precision constraint: {min_precision:.3f}")
            threshold = find_threshold_for_min_precision(all_distances, all_labels, candidate_thresholds, min_precision)
            final_method = f'min_precision_{min_precision:.3f}'
            
        else:
            # Original optimization methods when no constraints are specified
            methods = {
                'f1': lambda y_true, y_pred: f1_score(y_true, y_pred, zero_division=0),
                'f0.5': lambda y_true, y_pred: fbeta_score(y_true, y_pred, beta=0.5, zero_division=0),
                'f2': lambda y_true, y_pred: fbeta_score(y_true, y_pred, beta=2.0, zero_division=0),
                'balanced_accuracy': lambda y_true, y_pred: 0.5 * (
                    recall_score(y_true, y_pred, pos_label=1, zero_division=0) + 
                    recall_score(y_true, y_pred, pos_label=0, zero_division=0)
                ),
                'youden_j': lambda y_true, y_pred: (
                    recall_score(y_true, y_pred, pos_label=1, zero_division=0) + 
                    recall_score(y_true, y_pred, pos_label=0, zero_division=0) - 1
                )
            }
            
            best_thresholds = {}
            best_scores = {}
            
            for method_name, score_func in methods.items():
                scores = []
                for thresh in candidate_thresholds:
                    predictions = (all_distances < thresh).astype(int)
                    # Skip if all predictions are the same class
                    if len(np.unique(predictions)) <= 1:
                        scores.append(0)
                    else:
                        score = score_func(all_labels, predictions)
                        scores.append(score)
                
                best_idx = np.argmax(scores)
                best_thresholds[method_name] = candidate_thresholds[best_idx]
                best_scores[method_name] = scores[best_idx]
                
                print(f"{method_name}: threshold={best_thresholds[method_name]:.4f}, score={best_scores[method_name]:.4f}")
            
            # Method 2: Statistical approach - find threshold that best separates the classes
            if len(np.unique(all_labels)) == 2:
                twin_distances = all_distances[all_labels == 1]
                non_twin_distances = all_distances[all_labels == 0]
                
                # Option A: Midpoint between means
                midpoint_threshold = (twin_distances.mean() + non_twin_distances.mean()) / 2
                
                # Option B: Intersection of gaussian fits (if distributions overlap)
                try:
                    from scipy import stats
                    # Fit normal distributions
                    twin_params = stats.norm.fit(twin_distances)
                    non_twin_params = stats.norm.fit(non_twin_distances)
                    
                    # Find intersection point (approximate)
                    search_range = np.linspace(min_dist, max_dist, 1000)
                    twin_pdf = stats.norm.pdf(search_range, *twin_params)
                    non_twin_pdf = stats.norm.pdf(search_range, *non_twin_params)
                    
                    # Find where PDFs are closest
                    diff = np.abs(twin_pdf - non_twin_pdf)
                    intersection_threshold = search_range[np.argmin(diff)]
                    
                    print(f"Statistical thresholds:")
                    print(f"  Midpoint: {midpoint_threshold:.4f}")
                    print(f"  Intersection: {intersection_threshold:.4f}")
                    
                    # Add to candidates
                    best_thresholds['midpoint'] = midpoint_threshold
                    best_thresholds['intersection'] = intersection_threshold
                    
                except ImportError:
                    print(f"Statistical threshold (midpoint): {midpoint_threshold:.4f}")
                    best_thresholds['midpoint'] = midpoint_threshold
            
            # Select the best method based on F1 score (or choose your preferred metric)
            preferred_methods = ['f1', 'f0.5', 'youden_j', 'midpoint', 'balanced_accuracy']
            
            final_threshold = None
            final_method = None
            final_score = -1
            
            for method in preferred_methods:
                if method in best_thresholds:
                    thresh = best_thresholds[method]
                    predictions = (all_distances < thresh).astype(int)
                    
                    if len(np.unique(predictions)) > 1:
                        f1_test = f1_score(all_labels, predictions, zero_division=0)
                        if f1_test > final_score:
                            final_score = f1_test
                            final_threshold = thresh
                            final_method = method
            
            if final_threshold is None:
                # Fallback: use midpoint
                final_threshold = (all_distances.min() + all_distances.max()) / 2
                final_method = 'fallback_midpoint'
            
            threshold = final_threshold
        
        print(f"\nSelected method: {final_method}")
        print(f"Final threshold: {threshold:.4f}")
        
        # Plot comprehensive threshold analysis
        plot_threshold_analysis(all_distances, all_labels, threshold, candidate_thresholds, 
                              final_method, min_recall, min_precision)

    # Make final predictions
    all_predictions = (all_distances < threshold).astype(int)
    
    # Calculate metrics
    accuracy = np.mean(all_predictions == all_labels)
    precision = precision_score(all_labels, all_predictions, zero_division=0)
    recall = recall_score(all_labels, all_predictions, zero_division=0)
    f1 = f1_score(all_labels, all_predictions, zero_division=0)
    
    # ROC curve
    fpr, tpr, _ = roc_curve(all_labels, -all_distances)
    auc_score = auc(fpr, tpr)

    print(f"\nTest Results:")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"AUC Score: {auc_score:.4f}")
    print(f"Threshold: {threshold:.4f}")
    
    # Check if constraints were met
    if min_recall is not None:
        constraint_met = recall >= min_recall
        print(f"Minimum recall constraint ({min_recall:.3f}): {'✓ MET' if constraint_met else '✗ NOT MET'}")
    
    if min_precision is not None:
        constraint_met = precision >= min_precision
        print(f"Minimum precision constraint ({min_precision:.3f}): {'✓ MET' if constraint_met else '✗ NOT MET'}")
    
    print(f"\nPrediction breakdown:")
    print(f"Predicted twins: {np.sum(all_predictions == 1)} ({np.mean(all_predictions == 1)*100:.1f}%)")
    print(f"Predicted non-twins: {np.sum(all_predictions == 0)} ({np.mean(all_predictions == 0)*100:.1f}%)")

    # Visualization plots
    plot_evaluation_results(all_distances, all_labels, all_predictions, threshold, 
                           fpr, tpr, auc_score)

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auc': auc_score,
        'threshold': threshold,
        'all_distances': all_distances,
        'all_labels': all_labels,
        'all_predictions': all_predictions,
        'num_pairs': len(pairs)
    }


def find_optimal_threshold_comprehensive(all_distances, all_labels, candidate_thresholds):
    """Comprehensive threshold selection with multiple methods."""
    
    methods = {
        'f1': lambda y_true, y_pred: f1_score(y_true, y_pred, zero_division=0),
        'f0.5': lambda y_true, y_pred: fbeta_score(y_true, y_pred, beta=0.5, zero_division=0),
        'f2': lambda y_true, y_pred: fbeta_score(y_true, y_pred, beta=2.0, zero_division=0),
        'balanced_accuracy': lambda y_true, y_pred: 0.5 * (
            recall_score(y_true, y_pred, pos_label=1, zero_division=0) + 
            recall_score(y_true, y_pred, pos_label=0, zero_division=0)
        ),
        'youden_j': lambda y_true, y_pred: (
            recall_score(y_true, y_pred, pos_label=1, zero_division=0) + 
            recall_score(y_true, y_pred, pos_label=0, zero_division=0) - 1
        )
    }
    
    best_thresholds = {}
    best_scores = {}
    
    for method_name, score_func in methods.items():
        scores = []
        for thresh in candidate_thresholds:
            predictions = (all_distances < thresh).astype(int)
            if len(np.unique(predictions)) <= 1:
                scores.append(0)
            else:
                score = score_func(all_labels, predictions)
                scores.append(score)
        
        best_idx = np.argmax(scores)
        best_thresholds[method_name] = candidate_thresholds[best_idx]
        best_scores[method_name] = scores[best_idx]
    
    # Statistical approach
    if len(np.unique(all_labels)) == 2:
        twin_distances = all_distances[all_labels == 1]
        non_twin_distances = all_distances[all_labels == 0]
        
        midpoint_threshold = (twin_distances.mean() + non_twin_distances.mean()) / 2
        best_thresholds['midpoint'] = midpoint_threshold
    
    # Select the best method based on F1 score
    preferred_methods = ['f1', 'f0.5', 'youden_j', 'midpoint', 'balanced_accuracy']
    
    final_threshold = None
    final_score = -1
    
    for method in preferred_methods:
        if method in best_thresholds:
            thresh = best_thresholds[method]
            predictions = (all_distances < thresh).astype(int)
            
            if len(np.unique(predictions)) > 1:
                f1_test = f1_score(all_labels, predictions, zero_division=0)
                if f1_test > final_score:
                    final_score = f1_test
                    final_threshold = thresh
    
    if final_threshold is None:
        final_threshold = (all_distances.min() + all_distances.max()) / 2
    
    return final_threshold


def find_threshold_for_min_recall(distances, labels, candidate_thresholds, min_recall):
    """Find the threshold that achieves minimum recall while maximizing precision."""
    valid_thresholds = []
    corresponding_precisions = []
    corresponding_f1s = []
    
    for thresh in candidate_thresholds:
        predictions = (distances < thresh).astype(int)
        
        # Skip if all predictions are the same class
        if len(np.unique(predictions)) <= 1:
            continue
            
        recall = recall_score(labels, predictions, zero_division=0)
        
        if recall >= min_recall:
            precision = precision_score(labels, predictions, zero_division=0)
            f1 = f1_score(labels, predictions, zero_division=0)
            
            valid_thresholds.append(thresh)
            corresponding_precisions.append(precision)
            corresponding_f1s.append(f1)
    
    if not valid_thresholds:
        print(f"Warning: No threshold found that achieves minimum recall of {min_recall:.3f}")
        print("Using threshold that maximizes recall instead")
        
        best_recall = -1
        best_threshold = None
        
        for thresh in candidate_thresholds:
            predictions = (distances < thresh).astype(int)
            if len(np.unique(predictions)) > 1:
                recall = recall_score(labels, predictions, zero_division=0)
                if recall > best_recall:
                    best_recall = recall
                    best_threshold = thresh
        
        return best_threshold if best_threshold is not None else candidate_thresholds[len(candidate_thresholds)//2]
    
    # Among valid thresholds, choose the one with highest precision (or F1 as tiebreaker)
    best_idx = np.argmax(corresponding_precisions)
    selected_threshold = valid_thresholds[best_idx]
    
    print(f"Found {len(valid_thresholds)} thresholds meeting minimum recall constraint")
    print(f"Selected threshold: {selected_threshold:.4f} (precision: {corresponding_precisions[best_idx]:.4f}, F1: {corresponding_f1s[best_idx]:.4f})")
    
    return selected_threshold


def find_threshold_for_min_precision(distances, labels, candidate_thresholds, min_precision):
    """Find the threshold that achieves minimum precision while maximizing recall."""
    valid_thresholds = []
    corresponding_recalls = []
    corresponding_f1s = []
    
    for thresh in candidate_thresholds:
        predictions = (distances < thresh).astype(int)
        
        # Skip if all predictions are the same class
        if len(np.unique(predictions)) <= 1:
            continue
            
        precision = precision_score(labels, predictions, zero_division=0)
        
        if precision >= min_precision:
            recall = recall_score(labels, predictions, zero_division=0)
            f1 = f1_score(labels, predictions, zero_division=0)
            
            valid_thresholds.append(thresh)
            corresponding_recalls.append(recall)
            corresponding_f1s.append(f1)
    
    if not valid_thresholds:
        print(f"Warning: No threshold found that achieves minimum precision of {min_precision:.3f}")
        print("Using threshold that maximizes precision instead")
        
        best_precision = -1
        best_threshold = None
        
        for thresh in candidate_thresholds:
            predictions = (distances < thresh).astype(int)
            if len(np.unique(predictions)) > 1:
                precision = precision_score(labels, predictions, zero_division=0)
                if precision > best_precision:
                    best_precision = precision
                    best_threshold = thresh
        
        return best_threshold if best_threshold is not None else candidate_thresholds[len(candidate_thresholds)//2]
    
    # Among valid thresholds, choose the one with highest recall (or F1 as tiebreaker)
    best_idx = np.argmax(corresponding_recalls)
    selected_threshold = valid_thresholds[best_idx]
    
    print(f"Found {len(valid_thresholds)} thresholds meeting minimum precision constraint")
    print(f"Selected threshold: {selected_threshold:.4f} (recall: {corresponding_recalls[best_idx]:.4f}, F1: {corresponding_f1s[best_idx]:.4f})")
    
    return selected_threshold


def plot_threshold_analysis(all_distances, all_labels, threshold, candidate_thresholds, 
                           final_method, min_recall=None, min_precision=None):
    """Plot comprehensive threshold analysis."""
    plt.figure(figsize=(15, 10))
    
    # Distance distributions
    plt.subplot(2, 3, 1)
    plt.hist(all_distances[all_labels == 0], bins=50, alpha=0.7, label='Non-twins', density=True)
    plt.hist(all_distances[all_labels == 1], bins=50, alpha=0.7, label='Twins', density=True)
    plt.axvline(threshold, color='red', linestyle='--', label=f'Selected: {threshold:.4f}')
    plt.xlabel('Distance')
    plt.ylabel('Density')
    plt.title('Distance Distribution')
    plt.legend()
    
    # Threshold optimization curves
    plt.subplot(2, 3, 2)
    methods = ['precision', 'recall', 'f1']
    sample_thresholds = candidate_thresholds[::10]
    
    for method in methods:
        scores = []
        for thresh in sample_thresholds:
            predictions = (all_distances < thresh).astype(int)
            if len(np.unique(predictions)) > 1:
                if method == 'precision':
                    score = precision_score(all_labels, predictions, zero_division=0)
                elif method == 'recall':
                    score = recall_score(all_labels, predictions, zero_division=0)
                else:  # f1
                    score = f1_score(all_labels, predictions, zero_division=0)
            else:
                score = 0
            scores.append(score)
        
        plt.plot(sample_thresholds, scores, label=method, alpha=0.7)
    
    plt.axvline(threshold, color='red', linestyle='--', label=f'Selected: {threshold:.4f}')
    
    # Add constraint lines if applicable
    if min_recall is not None:
        plt.axhline(min_recall, color='orange', linestyle=':', alpha=0.7, label=f'Min recall: {min_recall:.3f}')
    if min_precision is not None:
        plt.axhline(min_precision, color='purple', linestyle=':', alpha=0.7, label=f'Min precision: {min_precision:.3f}')
    
    plt.xlabel('Threshold')
    plt.ylabel('Score')
    plt.title('Metrics vs Threshold')
    plt.legend()
    
    # Prediction fraction vs threshold
    plt.subplot(2, 3, 3)
    prediction_fractions = []
    for thresh in sample_thresholds:
        preds = (all_distances < thresh).astype(int)
        prediction_fractions.append(np.mean(preds))
    
    plt.plot(sample_thresholds, prediction_fractions)
    plt.axvline(threshold, color='red', linestyle='--', label=f'Selected: {threshold:.4f}')
    plt.axhline(0.5, color='gray', linestyle=':', alpha=0.5, label='50% threshold')
    plt.xlabel('Threshold')
    plt.ylabel('Fraction Predicted as Twins')
    plt.title('Prediction Distribution vs Threshold')
    plt.legend()
    
    # ROC curve comparison
    plt.subplot(2, 3, 4)
    fpr, tpr, _ = roc_curve(all_labels, -all_distances)
    plt.plot(fpr, tpr, label=f'ROC (AUC={auc(fpr, tpr):.3f})')
    plt.plot([0, 1], [0, 1], 'k--', alpha=0.5)
    
    # Mark the selected threshold point
    thresh_predictions = (all_distances < threshold).astype(int)
    thresh_fpr = np.sum((thresh_predictions == 1) & (all_labels == 0)) / np.sum(all_labels == 0)
    thresh_tpr = np.sum((thresh_predictions == 1) & (all_labels == 1)) / np.sum(all_labels == 1)
    plt.plot(thresh_fpr, thresh_tpr, 'ro', markersize=8, label=f'Selected point')
    
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve')
    plt.legend()
    
    # Precision-Recall curve
    plt.subplot(2, 3, 5)
    precision_curve, recall_curve, _ = precision_recall_curve(all_labels, -all_distances)
    plt.plot(recall_curve, precision_curve)
    
    # Mark selected threshold
    thresh_precision = precision_score(all_labels, thresh_predictions, zero_division=0)
    thresh_recall = recall_score(all_labels, thresh_predictions, zero_division=0)
    plt.plot(thresh_recall, thresh_precision, 'ro', markersize=8, label=f'Selected point')
    
    # Add constraint lines if applicable
    if min_recall is not None:
        plt.axvline(min_recall, color='orange', linestyle=':', alpha=0.7, label=f'Min recall: {min_recall:.3f}')
    if min_precision is not None:
        plt.axhline(min_precision, color='purple', linestyle=':', alpha=0.7, label=f'Min precision: {min_precision:.3f}')
    
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend()
    
    # Performance metrics for selected threshold
    plt.subplot(2, 3, 6)
    thresh_predictions = (all_distances < threshold).astype(int)
    metrics = {
        'Accuracy': np.mean(thresh_predictions == all_labels),
        'Precision': precision_score(all_labels, thresh_predictions, zero_division=0),
        'Recall': recall_score(all_labels, thresh_predictions, zero_division=0),
        'F1': f1_score(all_labels, thresh_predictions, zero_division=0),
        'F0.5': fbeta_score(all_labels, thresh_predictions, beta=0.5, zero_division=0)
    }
    
    metric_names = list(metrics.keys())
    metric_values = list(metrics.values())
    bars = plt.bar(metric_names, metric_values)
    plt.ylim(0, 1)
    plt.title(f'Performance Metrics\n({final_method})')
    plt.xticks(rotation=45)
    
    # Add value labels on bars
    for bar, value in zip(bars, metric_values):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
                f'{value:.3f}', ha='center', va='bottom', fontsize=9)
    
    plt.tight_layout()
    plt.show()


def plot_evaluation_results(all_distances, all_labels, all_predictions, threshold, 
                           fpr, tpr, auc_score):
    """Plot the final evaluation results."""
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # ROC Curve
    axes[0, 0].plot(fpr, tpr, label=f'ROC curve (AUC = {auc_score:.4f})')
    axes[0, 0].plot([0, 1], [0, 1], 'k--')
    axes[0, 0].set_xlim([0.0, 1.0])
    axes[0, 0].set_ylim([0.0, 1.05])
    axes[0, 0].set_xlabel('False Positive Rate')
    axes[0, 0].set_ylabel('True Positive Rate')
    axes[0, 0].set_title('ROC Curve')
    axes[0, 0].legend()
    
    # Precision-Recall Curve
    precisions, recalls, pr_thresholds = precision_recall_curve(all_labels, -all_distances)
    axes[0, 1].plot(recalls, precisions, marker='.')
    axes[0, 1].set_xlabel('Recall')
    axes[0, 1].set_ylabel('Precision')
    axes[0, 1].set_title('Precision-Recall Curve')
    
    # Confusion Matrix
    cm = confusion_matrix(all_labels, all_predictions)
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", 
                xticklabels=["Non-Twin", "Twin"], 
                yticklabels=["Non-Twin", "Twin"],
                ax=axes[1, 0])
    axes[1, 0].set_title("Confusion Matrix")
    axes[1, 0].set_xlabel("Predicted")
    axes[1, 0].set_ylabel("Actual")
    
    # Distance distributions
    axes[1, 1].hist(all_distances[all_labels == 0], bins=50, alpha=0.7, label='Non-twins', density=True)
    axes[1, 1].hist(all_distances[all_labels == 1], bins=50, alpha=0.7, label='Twins', density=True)
    axes[1, 1].axvline(threshold, color='red', linestyle='--', label=f'Threshold: {threshold:.4f}')
    axes[1, 1].set_xlabel('Distance')
    axes[1, 1].set_ylabel('Density')
    axes[1, 1].set_title('Distance Distribution')
    axes[1, 1].legend()
    
    plt.tight_layout()
    plt.show()


In [None]:
def combine_triplet_datasets(*datasets):
    """
    Combine multiple TripletBrainDataset instances into one, preserving all metadata.
    This properly merges twin_pairs, all_subjects, and other required attributes.
    """
    if not datasets:
        raise ValueError("At least one dataset must be provided")
    
    # Use the first dataset as a template
    first_dataset = datasets[0]
    
    # Create a new combined dataset with the same structure
    class CombinedTripletDataset:
        def __init__(self, datasets):
            self.datasets = datasets
            self.transform = first_dataset.transform if hasattr(first_dataset, 'transform') else None
            
            # Combine twin pairs from all datasets
            self.twin_pairs = []
            self.all_subjects = set()
            self.subject_to_twin_idx = {}
            
            # Combine data from all datasets
            for dataset in datasets:
                if hasattr(dataset, 'twin_pairs'):
                    # Add twin pairs with updated indices
                    start_idx = len(self.twin_pairs)
                    for i, (subj1, subj2) in enumerate(dataset.twin_pairs):
                        self.twin_pairs.append((subj1, subj2))
                        # Update subject to twin index mapping
                        self.subject_to_twin_idx[subj1] = start_idx + i
                        self.subject_to_twin_idx[subj2] = start_idx + i
                
                if hasattr(dataset, 'all_subjects'):
                    if isinstance(dataset.all_subjects, (list, tuple)):
                        self.all_subjects.update(dataset.all_subjects)
                    elif isinstance(dataset.all_subjects, set):
                        self.all_subjects.update(dataset.all_subjects)
            
            # Convert back to list for compatibility
            self.all_subjects = list(self.all_subjects)
            
            # Combine any other attributes that might be needed
            self._combine_other_attributes(datasets)
            
            print(f"Combined dataset statistics:")
            print(f"  Total twin pairs: {len(self.twin_pairs)}")
            print(f"  Total subjects: {len(self.all_subjects)}")
        
        def _combine_other_attributes(self, datasets):
            """Combine other dataset attributes that might be needed."""
            # Copy common attributes from the first dataset
            for attr_name in ['data_dir', 'image_size', 'device']:
                if hasattr(datasets[0], attr_name):
                    setattr(self, attr_name, getattr(datasets[0], attr_name))
            
            # If datasets have triplets, combine them too
            if hasattr(datasets[0], 'triplets'):
                self.triplets = []
                for dataset in datasets:
                    if hasattr(dataset, 'triplets'):
                        self.triplets.extend(dataset.triplets)
        
        def _load_image(self, subject_id):
            """Load image by delegating to the appropriate original dataset."""
            # Try to find which dataset contains this subject
            for dataset in self.datasets:
                if hasattr(dataset, '_load_image'):
                    try:
                        return dataset._load_image(subject_id)
                    except (FileNotFoundError, KeyError, Exception):
                        continue
            
            # If not found in any dataset, raise an error
            raise FileNotFoundError(f"Subject {subject_id} not found in any of the combined datasets")
        
        def __len__(self):
            """Return total length across all datasets."""
            return sum(len(dataset) for dataset in self.datasets)
        
        def __getitem__(self, idx):
            """Get item by finding the appropriate dataset and adjusting index."""
            current_idx = idx
            for dataset in self.datasets:
                if current_idx < len(dataset):
                    return dataset[current_idx]
                current_idx -= len(dataset)
            raise IndexError("Index out of range")
    
    return CombinedTripletDataset(datasets)

In [None]:
combined_dataset = combine_triplet_datasets(val_dataset, test_dataset)
evaluate(model, combined_dataset, device,
         threshold=None, batch_size=BATCH_SIZE)

In [None]:
model_metrics = evaluate(model, test_dataset, device=device)
model_dist_gap_metrics = evaluate(model_dist_gap, test_dataset, device=device)
model_separability_metrics = evaluate(model_separability, test_dataset, device=device)
model_val_loss_metrics = evaluate(model_val_loss, test_dataset, device=device)
model_auc_metrics = evaluate(model_auc, test_dataset, device=device)
model_f1_metrics = evaluate(model_f1, test_dataset, device=device)

In [None]:
def combine_evaluation_results(**kwargs):
    """
    Combine multiple evaluation result dictionaries into a pandas DataFrame.
    
    Args:
        **kwargs: Named evaluation result dictionaries from the evaluate() function.
                 Each kwarg should be in format: experiment_name=evaluation_result_dict
    
    Returns:
        pd.DataFrame: DataFrame with experiments as rows and metrics as columns
    
    Example:
        # Assuming you have evaluation results from different experiments
        baseline_results = evaluate(model1, test_dataset, device)
        improved_results = evaluate(model2, test_dataset, device)
        constrained_results = evaluate(model3, test_dataset, device, min_recall=0.8)
        
        # Combine them into a DataFrame
        df = combine_evaluation_results(
            baseline=baseline_results,
            improved=improved_results,
            constrained=constrained_results
        )
    """
    
    if not kwargs:
        raise ValueError("At least one evaluation result dictionary must be provided")
    
    # Define the metrics we want to extract (excluding arrays and detailed data)
    metrics_to_extract = [
        'accuracy', 'precision', 'recall', 'f1', 'auc', 'threshold', 'num_pairs'
    ]
    
    # Initialize lists to store data
    experiment_names = []
    rows_data = []
    
    for experiment_name, eval_results in kwargs.items():
        if not isinstance(eval_results, dict):
            raise ValueError(f"Evaluation result for '{experiment_name}' must be a dictionary")
        
        experiment_names.append(experiment_name)
        row_data = {}
        
        # Extract scalar metrics
        for metric in metrics_to_extract:
            if metric in eval_results:
                value = eval_results[metric]
                # Handle numpy types
                if isinstance(value, (np.integer, np.floating)):
                    value = value.item()
                row_data[metric] = value
            else:
                row_data[metric] = None
        
        # Add some derived metrics if the arrays are available
        if 'all_distances' in eval_results and 'all_labels' in eval_results:
            distances = eval_results['all_distances']
            labels = eval_results['all_labels']
            
            # Add distance statistics
            row_data['mean_distance'] = np.mean(distances)
            row_data['std_distance'] = np.std(distances)
            row_data['min_distance'] = np.min(distances)
            row_data['max_distance'] = np.max(distances)
            
            # Add class-specific distance statistics
            if len(np.unique(labels)) == 2:
                twin_distances = distances[labels == 1]
                non_twin_distances = distances[labels == 0]
                
                if len(twin_distances) > 0:
                    row_data['mean_twin_distance'] = np.mean(twin_distances)
                    row_data['std_twin_distance'] = np.std(twin_distances)
                
                if len(non_twin_distances) > 0:
                    row_data['mean_non_twin_distance'] = np.mean(non_twin_distances)
                    row_data['std_non_twin_distance'] = np.std(non_twin_distances)
                
                # Distance separation metric (higher is better)
                if len(twin_distances) > 0 and len(non_twin_distances) > 0:
                    separation = abs(np.mean(non_twin_distances) - np.mean(twin_distances))
                    row_data['distance_separation'] = separation
        
        # Add prediction statistics if available
        if 'all_predictions' in eval_results and 'all_labels' in eval_results:
            predictions = eval_results['all_predictions']
            labels = eval_results['all_labels']
            
            # True/False positives and negatives
            tp = np.sum((predictions == 1) & (labels == 1))
            fp = np.sum((predictions == 1) & (labels == 0))
            tn = np.sum((predictions == 0) & (labels == 0))
            fn = np.sum((predictions == 0) & (labels == 1))
            
            row_data['true_positives'] = tp
            row_data['false_positives'] = fp
            row_data['true_negatives'] = tn
            row_data['false_negatives'] = fn
            
            # Specificity (True Negative Rate)
            if (tn + fp) > 0:
                row_data['specificity'] = tn / (tn + fp)
            else:
                row_data['specificity'] = None
        
        rows_data.append(row_data)
    
    # Create DataFrame
    df = pd.DataFrame(rows_data, index=experiment_names)
    
    # Round numeric columns for better readability
    numeric_columns = df.select_dtypes(include=[np.number]).columns
    df[numeric_columns] = df[numeric_columns].round(4)
    
    return df


def compare_evaluation_results(df, sort_by='f1', ascending=False):
    """
    Enhanced comparison function to analyze the combined evaluation results.
    
    Args:
        df: DataFrame returned by combine_evaluation_results()
        sort_by: Column name to sort by (default: 'f1')
        ascending: Sort order (default: False for descending)
    
    Returns:
        pd.DataFrame: Sorted DataFrame with additional analysis
    """
    
    if sort_by not in df.columns:
        available_cols = list(df.columns)
        raise ValueError(f"Column '{sort_by}' not found. Available columns: {available_cols}")
    
    # Sort the dataframe
    df_sorted = df.sort_values(by=sort_by, ascending=ascending)
    
    print(f"Evaluation Results Comparison (sorted by {sort_by}):")
    print("=" * 60)
    
    # Display key metrics
    key_metrics = ['accuracy', 'precision', 'recall', 'f1', 'auc', 'threshold']
    available_key_metrics = [col for col in key_metrics if col in df.columns]
    
    if available_key_metrics:
        print("\nKey Performance Metrics:")
        print(df_sorted[available_key_metrics].to_string())
    
    # Show distance statistics if available
    distance_cols = [col for col in df.columns if 'distance' in col.lower()]
    if distance_cols:
        print(f"\nDistance Statistics:")
        print(df_sorted[distance_cols].to_string())
    
    # Show confusion matrix components if available
    confusion_cols = ['true_positives', 'false_positives', 'true_negatives', 'false_negatives']
    available_confusion_cols = [col for col in confusion_cols if col in df.columns]
    
    if available_confusion_cols:
        print(f"\nConfusion Matrix Components:")
        print(df_sorted[available_confusion_cols].to_string())
    
    # Find best performing experiment
    if not df_sorted.empty:
        best_experiment = df_sorted.index[0]
        best_value = df_sorted.iloc[0][sort_by]
        print(f"\nBest performing experiment: '{best_experiment}' with {sort_by} = {best_value:.4f}")
    
    return df_sorted

# Combine results
comparison_df = combine_evaluation_results(
    baseline=model_metrics,
    dist_gap=model_dist_gap_metrics,
    separability=model_separability_metrics,
    val_loss=model_val_loss_metrics,
    auc=model_auc_metrics,
    f1=model_f1_metrics,
)

In [None]:
comparison_df

In [None]:
model.set_threshold(comparison_df.loc['baseline', 'threshold'])
model_dist_gap.set_threshold(comparison_df.loc['dist_gap', 'threshold'])
model_separability.set_threshold(comparison_df.loc['separability', 'threshold'])
model_val_loss.set_threshold(comparison_df.loc['val_loss', 'threshold'])
model_auc.set_threshold(comparison_df.loc['auc', 'threshold'])
model_f1.set_threshold(comparison_df.loc['f1', 'threshold'])

In [None]:
torch.save({
    'model_state_dict': model.state_dict(),
    'threshold': model.threshold,
}, f"{MODEL}/model.pth")

torch.save({
    'model_state_dict': model_dist_gap.state_dict(),
    'threshold': model_dist_gap.threshold,
}, f"{MODEL}/model_dist_gap.pth")

torch.save({
    'model_state_dict': model_separability.state_dict(),
    'threshold': model_separability.threshold,
}, f"{MODEL}/model_separability.pth")

torch.save({
    'model_state_dict': model_val_loss.state_dict(),
    'threshold': model_val_loss.threshold,
}, f"{MODEL}/model_val_loss.pth")

torch.save({
    'model_state_dict': model_auc.state_dict(),
    'threshold': model_auc.threshold,
}, f"{MODEL}/model_auc.pth")

torch.save({
    'model_state_dict': model_f1.state_dict(),
    'threshold': model_f1.threshold,
}, f"{MODEL}/model_f1.pth")