In [1]:
import json
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import os
import math

# --- PyTorch/Torchvision Imports ---
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR
from torch.amp import autocast_mode, grad_scaler

import torchvision.transforms as T
from torchvision.models import convnext_base, ConvNeXt_Base_Weights

# --- Library Imports ---
from wildlife_datasets.datasets import SeaTurtleID2022
from wildlife_tools.data import ImageDataset
from wildlife_datasets.splits import ClosedSetSplit

os.environ['KAGGLE_USERNAME'] = "nashadammuoz"
os.environ['KAGGLE_KEY'] = "KGAT_9f227e36a409b0debe5ee7a27090bd72"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True

class ConvNeXtBackbone(nn.Module):
    def __init__(self, embedding_dim=512, dropout=0.2, pretrained=True):
        super().__init__()
        weights = ConvNeXt_Base_Weights.IMAGENET1K_V1 if pretrained else None
        model = convnext_base(weights=weights)
        # Remove original classifier
        in_features = model.classifier[2].in_features
        model.classifier[2] = nn.Identity()
        self.backbone = model
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(in_features, embedding_dim)

    def forward(self, x):
        feat = self.backbone(x)
        feat = self.dropout(feat)
        emb = self.proj(feat)
        # Return normalized embedding and norms (for AdaFace)
        norms = torch.norm(emb, p=2, dim=1, keepdim=True)
        emb = F.normalize(emb, dim=1)
        return emb, norms.squeeze()


class AdaFaceHead(nn.Module):
    def __init__(self, embedding_size, num_classes, m=0.35, h=0.2, s=64., t_alpha=0.01):
        super().__init__()
        self.num_classes = num_classes
        self.kernel = nn.Parameter(torch.Tensor(embedding_size, num_classes))
        self.kernel.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
        self.m = m
        self.h = h
        self.s = s
        self.t_alpha = t_alpha
        self.register_buffer('batch_mean', torch.ones(1)*20)
        self.register_buffer('batch_std', torch.ones(1)*100)

    def forward(self, embeddings, norms, label):
        kernel_norm = torch.nn.functional.normalize(self.kernel, dim=0)
        cosine = torch.mm(embeddings, kernel_norm).clamp(-1+1e-3, 1-1e-3)

        if label is None:
            return cosine * self.s

        with torch.no_grad():
            std = norms.std() if norms.size(0) > 1 else torch.tensor(0.0, device=norms.device) # Handling for batch size 1
            self.batch_mean = norms.mean() * self.t_alpha + (1 - self.t_alpha) * self.batch_mean
            self.batch_std = std * self.t_alpha + (1 - self.t_alpha) * self.batch_std

        margin_scaler = (norms - self.batch_mean) / (self.batch_std + 1e-3)
        margin_scaler = torch.clip(margin_scaler * self.h, -1, 1)

        # AdaFace logic
        m_arc = torch.zeros_like(cosine)
        m_arc.scatter_(1, label.view(-1, 1), 1.0)
        g_angular = -self.m * margin_scaler
        m_arc = m_arc * g_angular.unsqueeze(1)
        
        theta = cosine.acos()
        theta_m = torch.clip(theta + m_arc, min=1e-3, max=math.pi-1e-3)
        cosine_m = theta_m.cos()

        m_cos = torch.zeros_like(cosine)
        m_cos.scatter_(1, label.view(-1, 1), 1.0)
        g_add = self.m + (self.m * margin_scaler)
        m_cos = m_cos * g_add.unsqueeze(1)
        
        return (cosine_m - m_cos) * self.s


# Wrapper Model
class ReIDModel(nn.Module): 
    def __init__(self, backbone, head):
        super().__init__()
        self.backbone = backbone
        self.head = head

    def forward(self, x, labels=None):
        emb, norms = self.backbone(x)
        if labels is not None:
            logits = self.head(emb, norms, labels)
            return logits, emb
        return emb


def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)


def extract_features(model, dataset, device, batch_size=16):
    model.eval()
    model = model.to(device)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    all_features = []
    with torch.no_grad():
        for imgs, _ in tqdm(loader, desc="Extracting features", leave=False):
            imgs = imgs.to(device)
            features = model(imgs)
            all_features.append(features.cpu().numpy())
    return np.vstack(all_features)


def compute_cosine_similarity(query_features, gallery_features):
    query_norm = query_features / (np.linalg.norm(query_features, axis=1, keepdims=True) + 1e-8)
    gallery_norm = gallery_features / (np.linalg.norm(gallery_features, axis=1, keepdims=True) + 1e-8)
    
    similarity_matrix = np.dot(query_norm, gallery_norm.T)
    return similarity_matrix


def evaluate(model, gallery_set, query_set, device, batch_size=16):
    was_training = model.training
    model.eval()

    gallery_features = extract_features(model, gallery_set, device, batch_size)
    query_features = extract_features(model, query_set, device, batch_size)

    similarity_matrix = compute_cosine_similarity(query_features, gallery_features) 

    query_labels = np.array(query_set.labels_string)
    gallery_labels = np.array(gallery_set.labels_string)

    topk_indices = np.argsort(similarity_matrix, axis=1)[:, ::-1][:, :5] 

    rank1_acc = 0
    rank5_acc = 0
    
    for i, q_label in enumerate(query_labels):
        retrieved_labels = gallery_labels[topk_indices[i]]
        
        if retrieved_labels[0] == q_label:
            rank1_acc += 1
            
        if q_label in retrieved_labels:
            rank5_acc += 1

    rank1_acc = (rank1_acc / len(query_labels))
    rank5_acc = (rank5_acc / len(query_labels))

    if was_training:
        model.train()
    
    return rank1_acc, rank5_acc


def clean_path(p):
    if 'turtles-data/data/' in p:
        return p.replace('turtles-data/data/', '')
    return p


def safe_split(df, name, seed):
    if len(df) == 0:
        print(f"WARNING: {name} dataframe is empty!")
        return df, df 
        
    splitter = ClosedSetSplit(ratio_train=0.5, seed=seed)
    # Pass values directly to avoid index confusion
    splits = splitter.split(df)
    
    if len(splits) == 0:
         print(f"WARNING: Splitter returned no splits for {name}")
         return df, df

    gallery_idx, query_idx = splits[0]
    
    # Verify indices are valid
    if gallery_idx.max() >= len(df) or query_idx.max() >= len(df):
         raise IndexError(f"Splitter returned invalid indices for {name}. Max idx: {gallery_idx.max()}, DF len: {len(df)}")

    gal_df = df.iloc[gallery_idx].reset_index(drop=True)
    qry_df = df.iloc[query_idx].reset_index(drop=True)
    return gal_df, qry_df


def generate_query_gallery_splits(df, seed):
    gallery_df, query_df = safe_split(df, "Dataset Split", seed)
    t_eval = T.Compose([
        T.Resize((config['image_size'], config['image_size'])),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    gallery_set = ImageDataset(
        gallery_df,
        transform=t_eval,
        col_path='path',
        col_label='identity'
    )
    query_set = ImageDataset(
        query_df,
        transform=t_eval,
        col_path='path',
        col_label='identity'
    )
    return gallery_set, query_set


def partition_data(df, num_clients, seed, overlap_ratio=0.1, max_client_ratio=0.4):
    all_identities = sorted(df['identity'].unique().tolist())
    rng = np.random.RandomState(seed)
    rng.shuffle(all_identities)

    # 1. Split Identities into Public (Shared) and Private
    num_shared = int(len(all_identities) * overlap_ratio)
    shared_identities = all_identities[:num_shared]
    private_identities = all_identities[num_shared:]

    print(f"Total Identities: {len(all_identities)} | Shared: {len(shared_identities)} | Private: {len(private_identities)}")

    # This map tracks which clients get which IDENTITY (Logic Map)
    identity_to_clients_map = {} 
    
    # This map tracks the actual IMAGE INDICES per client (Data Map)
    client_image_indices = {i: [] for i in range(num_clients)}

    # Max clients a shared identity can belong to
    max_clients_limit = max(2, int(num_clients * max_client_ratio))

    # 3. Assign Shared Identities (Multi-client)
    for identity in shared_identities: # FIXED: Loop over shared, not private
        n_partners = rng.randint(2, max_clients_limit + 1)
        assigned_clients = rng.choice(num_clients, size=n_partners, replace=False)
        identity_to_clients_map[identity] = assigned_clients

    # 4. Assign Private Identities (Single-client)
    for identity in private_identities:
        assigned_client = rng.randint(0, num_clients)
        # Store as a list of 1 so the logic below is consistent
        identity_to_clients_map[identity] = [assigned_client]

    # 5. Pre-shuffle images
    id_to_indices = {id: df[df['identity'] == id].index.tolist() for id in all_identities}
    for id in id_to_indices:
        rng.shuffle(id_to_indices[id])

    # 6. Distribute Images
    for identity in all_identities:
        assigned_clients = identity_to_clients_map[identity]
        all_imgs = id_to_indices[identity]
        
        # 5. Calculate split size
        total_shares = len(assigned_clients)
        imgs_per_client = len(all_imgs) // total_shares
        
        for i, client_id in enumerate(assigned_clients):
            start = i * imgs_per_client
            # If last client, take all remaining to handle odd divisions
            if i == total_shares - 1:
                end = len(all_imgs)
            else:
                end = (i + 1) * imgs_per_client
                
            # Add these specific image rows to the client's pile
            if end > start:
                client_image_indices[client_id].extend(all_imgs[start:end])

    # 7. Build Final DataFrames
    client_dfs = []
    for client_id in range(num_clients):
        indices = client_image_indices[client_id]
        # Use .iloc to fetch rows by integer index
        client_df = df.loc[indices].copy().reset_index(drop=True)
        
        # Optional: Add metadata for debugging
        client_df['is_shared'] = client_df['identity'].isin(shared_identities)
        
        client_dfs.append(client_df)
    
    return client_dfs


def move_to_device(obj, device):
    """Recursively move tensors in a nested structure to the device."""
    if torch.is_tensor(obj):
        return obj.to(device)
    elif isinstance(obj, dict):
        return {k: move_to_device(v, device) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [move_to_device(v, device) for v in obj]
    elif hasattr(obj, 'to'):  # Handle models/modules
        return obj.to(device)
    return obj

def optimizer_to(optim, device):
    """Moves optimizer state (momentum, variance) to the specified device."""
    for state in optim.state.values():
        for k, v in state.items():
            if torch.is_tensor(v):
                state[k] = v.to(device)


class FederatedClient:
    def __init__(self, client_id, train_df, config):
        self.client_id = client_id
        self.train_df = train_df.copy()
        self.config = config
        self.device = config['device']  # Target device (GPU)
        self.cpu_device = torch.device('cpu')

        self.unique_identities = sorted(self.train_df['identity'].unique().tolist())
        self.num_local_classes = len(self.unique_identities)

        # 1. Initialize on CPU to save memory
        backbone = ConvNeXtBackbone(embedding_dim=config['embedding_dim'], pretrained=False)
        head = AdaFaceHead(embedding_size=config['embedding_dim'], num_classes=self.num_local_classes)
        
        # Keep model on CPU
        self.model = ReIDModel(backbone, head).to(self.cpu_device)

        # Optimizer on CPU parameters
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config['lr'],
            weight_decay=self.config['w_decay']
        )
        self.scaler = grad_scaler.GradScaler()

        # Buffers on CPU
        self.proto_sums = torch.zeros(self.num_local_classes, config['embedding_dim'], device=self.cpu_device)
        self.proto_counts = torch.zeros(self.num_local_classes, device=self.cpu_device)

        self.identity_to_local_idx = {id_str: idx for idx, id_str in enumerate(self.unique_identities)}
        self.local_idx_to_identity = {idx: id_str for id_str, idx in self.identity_to_local_idx.items()}

        print(f"Client {self.client_id} - Model initialized (CPU-Resident).")

    def get_loader(self):
        t_train = T.Compose([
            T.Resize((self.config['image_size'], self.config['image_size'])),
            T.RandomHorizontalFlip(),
            T.ColorJitter(0.2, 0.2, 0.2, 0.1),
            T.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        dataset = ImageDataset(
            self.train_df,
            root=self.config['root'],
            transform=t_train,
            col_path='path',
            col_label='identity'
        )
        return DataLoader(
            dataset,
            batch_size=self.config['batch_size'],
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )

    def train(self, server_msg):
        # --- A. MOVE TO GPU (Active Phase) ---
        self.model.to(self.device)
        optimizer_to(self.optimizer, self.device)
        self.proto_sums = self.proto_sums.to(self.device)
        self.proto_counts = self.proto_counts.to(self.device)
        
        try:
            # 1. Load Global Backbone
            global_weights = server_msg['model_state']
            self.model.backbone.load_state_dict(global_weights, strict=True)
            self.model.train()

            # 2. Update LR
            current_lr = server_msg['current_lr']
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = current_lr

            criterion = nn.CrossEntropyLoss()

            # 3. Prepare Prototypes
            global_prototypes = server_msg['prototypes']
            target_protos = torch.zeros(self.num_local_classes, self.config['embedding_dim'], device=self.device)
            has_proto_mask = torch.zeros(self.num_local_classes, dtype=torch.bool, device=self.device)

            if global_prototypes:
                for local_idx, identity in self.local_idx_to_identity.items():
                    if identity in global_prototypes:
                        target_protos[local_idx] = global_prototypes[identity].to(self.device)
                        has_proto_mask[local_idx] = True

            loader = self.get_loader()
            epoch_loss = 0.0

            # 4. Training Loop
            for epoch in range(self.config['local_epochs']):
                batch_loss = 0.0
                pbar = tqdm(loader, desc=f"Client {self.client_id} Epoch {epoch+1}")
                
                for imgs, labels in pbar:
                    imgs, labels = imgs.to(self.device), labels.to(self.device)

                    with autocast_mode.autocast(device_type='cuda'):
                        logits, emb = self.model(imgs, labels)
                        loss_cls = criterion(logits, labels)

                        loss_proto = torch.tensor(0.0, device=self.device)
                        if self.config['lambda_proto'] > 0 and has_proto_mask.any():
                            active_mask = has_proto_mask[labels]
                            if active_mask.any():
                                active_embs = emb[active_mask]
                                active_targets = target_protos[labels[active_mask]]
                                loss_proto = F.mse_loss(active_embs, active_targets)
                        
                        loss_total = loss_cls + (loss_proto * self.config['lambda_proto'])
                
                    self.optimizer.zero_grad(set_to_none=True)
                    self.scaler.scale(loss_total).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()

                    batch_loss += loss_total.item() * imgs.size(0)
                    pbar.set_postfix({'loss': f"{loss_total.item():.4f}"})
            
                avg_loss = batch_loss / len(loader.dataset)
                
            # 5. Compute New Prototypes (Still on GPU)
            # Clear cache before inference to make room
            torch.cuda.empty_cache()
            new_prototypes = self._compute_local_prototypes(loader)
            
        finally:
            # --- B. MOVE BACK TO CPU (Inactive Phase) ---
            # Even if training crashes, we must offload to clear GPU for next client
            self.model.to(self.cpu_device)
            optimizer_to(self.optimizer, self.cpu_device)
            self.proto_sums = self.proto_sums.to(self.cpu_device)
            self.proto_counts = self.proto_counts.to(self.cpu_device)
            
            # Clean up GPU tensors
            del target_protos, has_proto_mask
            torch.cuda.empty_cache()

        return {
            'client_id': self.client_id,
            'model_state': self.model.backbone.state_dict(),
            'prototypes': new_prototypes,
            'num_samples': len(self.train_df),
            'loss': avg_loss
        }

    def _compute_local_prototypes(self, loader):
        self.model.eval()
        self.proto_sums.zero_()
        self.proto_counts.zero_()

        dataset_labels = loader.dataset.labels_string
        unique_labels = sorted(list(set(dataset_labels)))
        int_to_str = {i: s for i, s in enumerate(unique_labels)}

        with torch.no_grad():
            for images, labels in loader:
                # Ensure data is on GPU for inference
                images, labels = images.to(self.device), labels.to(self.device)
                
                # Model is already on GPU from train() call
                emb = self.model.backbone(images)[0] 
                
                # Accumulate
                self.proto_sums.index_add_(0, labels, emb)
                self.proto_counts.index_add_(0, labels, torch.ones_like(labels, dtype=torch.float))
        
        # Calculate averages on CPU to save GPU memory/time
        proto_dict = {}
        cpu_sums = self.proto_sums.cpu()
        cpu_counts = self.proto_counts.cpu()

        for idx, count in enumerate(cpu_counts):
            if count > 0:
                mean_emb = cpu_sums[idx] / count
                mean_emb = F.normalize(mean_emb, p=2, dim=0)

                identity_str = int_to_str[idx]
                proto_dict[identity_str] = (mean_emb, count.item())
        
        # Restore buffers to correct device (GPU) so 'finally' block can move them to CPU correctly
        # (Technically they are already on GPU, but we moved them to CPU for calc)
        # Actually, self.proto_sums is still on GPU, we just took a .cpu() copy.
        
        return proto_dict


class FederatedServer:
    def __init__(self, config):
        self.config = config
        self.device = config['device']
        self.global_backbone = ConvNeXtBackbone(embedding_dim=config['embedding_dim']).to(self.device)
        self.global_prototypes = {}
        self.current_lr = config['lr']

        self.dummy_optimizer = optim.AdamW(self.global_backbone.parameters(), lr=config['lr'], weight_decay=config['w_decay'])
        warmup_rounds = config['warmup_rounds']
        main_scheduler = CosineAnnealingLR(self.dummy_optimizer, T_max=config['rounds'] - warmup_rounds, eta_min=1e-6)
        warmup_scheduler = LinearLR(self.dummy_optimizer, start_factor=0.1, total_iters=warmup_rounds)
        self.scheduler = SequentialLR(self.dummy_optimizer, schedulers=[warmup_scheduler, main_scheduler], milestones=[warmup_rounds])

    def step_scheduler(self):
        self.scheduler.step()
        self.current_lr = self.scheduler.get_last_lr()[0]
        print(f"[Server] Learning Rate updated to: {self.current_lr:.6f}")

    def aggregate(self, client_msgs):
        print("[Server] Aggregating Weights and Prototypes...")
        
        total_samples = sum(msg['num_samples'] for msg in client_msgs)
        first_state = client_msgs[0]['model_state']

        agg_state = {k: torch.zeros_like(v) for k, v in first_state.items()}

        for msg in client_msgs:
            weight_factor = msg['num_samples'] / total_samples
            for k, v in msg['model_state'].items():
                agg_state[k] += v * weight_factor
        
        self.global_backbone.load_state_dict(agg_state)

        round_sums = {}
        round_counts = {}

        for msg in client_msgs:
            for id_str, (vec, count) in msg['prototypes'].items():
                if id_str not in round_sums:
                    round_sums[id_str] = vec.float() * count
                    round_counts[id_str] = count
                else:
                    round_sums[id_str] += vec.float() * count
                    round_counts[id_str] += count

        momentum = self.config['proto_momentum']
        updated_cnt = 0
        for id_str, vec_sum in round_sums.items():
            new_proto = F.normalize(vec_sum / round_counts[id_str], p=2, dim=0)
            if id_str in self.global_prototypes:
                old_proto = self.global_prototypes[id_str].cpu() # Keep on CPU for storage
                avg_proto = (old_proto * momentum) + (new_proto * (1 - momentum))
                self.global_prototypes[id_str] = F.normalize(avg_proto, p=2, dim=0)
                updated_cnt += 1
            else:
                self.global_prototypes[id_str] = new_proto

        print(f"[Server] Global Prototypes Updated: {updated_cnt} | Total: {len(self.global_prototypes)}")

    def get_eval_model(self):
        eval_model = ReIDModel(self.global_backbone, nn.Identity())
        return eval_model

    def distribute(self):
        comm_msg = {
            'model_state': self.global_backbone.state_dict(),
            'prototypes': self.global_prototypes,
            'current_lr': self.current_lr
        }
        return comm_msg


def main(config):

    print("Configuration:")
    print("-" * 60)
    for key, value in config.items():
        print(f"{key:15s}: {value}")
    print("-" * 60 + "\n")

    results_path = Path(config['results_root']) / config['results_name']
    results_path.mkdir(parents=True, exist_ok=True)

    print("--- Loading Data ---")
    SeaTurtleID2022.get_data(root=config['root'])

    if config['body_part'] is None:
        dataset_df = SeaTurtleID2022(root=config['root']).df
    else:
        dataset_df = SeaTurtleID2022(root=config['root'], category_name=config['body_part'], img_load='bbox').df

    print(f"Original Dataset Size: {len(dataset_df)}")

    # 2. Load Metadata
    try:
        # Check standard path
        meta_path = Path(config['root']) / 'turtles-data' / 'data' / 'metadata_splits.csv'
        meta_df = pd.read_csv(meta_path)
    except FileNotFoundError:
        print("Metadata not found at standard path.")
        meta_df = None


    if meta_df is None:
        # Robust fallback search
        found_metas = list(Path(config['root']).rglob('metadata_splits.csv'))
        if found_metas:
            meta_path = found_metas[0]
            print(f"[System] Found metadata at: {meta_path}")
            meta_df = pd.read_csv(meta_path)

    
    if meta_df is not None:
        print(f"[System] Loading metadata from: {meta_path}")
        dataset_df['join_key'] = dataset_df['path'].apply(clean_path)
        
        merged_df = pd.merge(
            dataset_df, 
            meta_df[['file_name', 'split_closed', 'split_open']], 
            left_on='join_key', 
            right_on='file_name',
            how='inner'
        )
        img_root = 'turtles-data/data/'
        merged_df['path'] = merged_df['join_key'].apply(lambda x: str(Path(img_root) / x))
    else:
        print("[WARNING] Metadata not found! Proceeding with raw dataset (splitting might fail).")
        merged_df = dataset_df

    print(f"Merged Dataset Size: {len(merged_df)}")

    display(merged_df.head())

    test_path = Path(config['root']) / merged_df.iloc[0]['path']
    if not test_path.exists():
        print(f"\n[FATAL ERROR] Path Verification Failed!")
        print(f"Root: {config['root']}")
        print(f"DF Path: {merged_df.iloc[0]['path']}")
        print(f"Combined: {test_path}")
        raise FileNotFoundError("Check dataset structure.")
    print("[System] ✅ Path Verification Successful.")

    # 4. Create Base Splits
    # split_col = f'split_{config["set"]}'
    # train_df = merged_df[merged_df[split_col] == 'train'].reset_index(drop=True)
    # valid_df = merged_df[merged_df[split_col] == 'valid'].reset_index(drop=True)
    # test_df = merged_df[merged_df[split_col] == 'test'].reset_index(drop=True)

    # print(f"Train: {len(train_df)}, Val: {len(valid_df)}, Test: {len(test_df)}")

    # client_dfs = partition_data(
    #     train_df,
    #     num_clients=config['num_clients'],
    #     seed=config['seed'],
    #     overlap_ratio=config['overlap_ratio'],
    #     max_client_ratio=config['max_client_ratio']
    # )

    # print("\n--- Partition Verification ---")
    # for i, df in enumerate(client_dfs):
    #     n_unique = df['identity'].nunique()
    #     n_shared = df[df['is_shared'] == True]['identity'].nunique()
    #     print(f"Client {i}: {len(df)} images | {n_unique} IDs | {n_shared} Shared IDs")


    # # Generate Query-Gallery splits for validation and testing
    # val_gallery_set, val_query_set = generate_query_gallery_splits(valid_df, seed=config['seed'])
    # test_gallery_set, test_query_set = generate_query_gallery_splits(test_df, seed=config['seed'])

    # server = FederatedServer(config)
    # initial_state = server.distribute()['model_state']
    
    # clients = []
    # for i in range(config['num_clients']):
    #     client = FederatedClient(i, client_dfs[i], config)
    #     client.model.backbone.load_state_dict(initial_state, strict=True)
    #     clients.append(client)

    # print("[Main] Starting Federated Training...")
    # best_val_rank1 = 0.0
    # history = {
    #     'rank1': [],
    #     **{f'loss_C{client_id}': [] for client_id in range(config['num_clients'])}
    # }

    # early_stopping_counter = 0
    # best_round = 0

    # for round_idx in range(1, config['rounds'] + 1):
    #     print(f"\n=== Communication Round {round_idx} / {config['rounds']} ===")
    #     server_msg = server.distribute()

    #     client_results = []
    #     for client in clients:
    #         results = client.train(server_msg)
    #         client_results.append(results)
    #         history[f'loss_C{client.client_id}'].append(results['loss'])
    #         print(f"   Client {client.client_id} Loss: {results['loss']:.4f}")

    #     server.aggregate(client_results)
    #     server.step_scheduler()

    #     print("[Server] Evaluating Global Model on Validation Set...")
    #     eval_model = server.get_eval_model().to(config['device'])
    #     val_r1, val_r5 = evaluate(
    #         eval_model, 
    #         val_gallery_set, 
    #         val_query_set, 
    #         config['device'], 
    #         batch_size=config['batch_size']
    #     )
    #     print(f"   Validation Rank-1: {val_r1*100:.2f}%, Rank-5: {val_r5*100:.2f}%")
    #     history['rank1'].append(val_r1)

    #     if val_r1 > best_val_rank1:
    #         best_val_rank1 = val_r1
    #         best_round = round_idx
    #         early_stopping_counter = 0
    #         torch.save(
    #             server.global_backbone.state_dict(),
    #             results_path / 'best_backbone.pth'
    #         )
    #         print(f"[Server] ✅ New best model found! Evaluating on Test Set... at round {round_idx} with Rank-1: {best_val_rank1*100:.2f}%")
    #     else:
    #         early_stopping_counter += 1
    #         print(f"[Server] No improvement. Early Stopping Counter: {early_stopping_counter}/{config['patience']}")
    #         if early_stopping_counter >= config['patience']:
    #             print("[Server] Early stopping triggered. Ending training.")
    #             break
    
    # with open(results_path / 'training_history.json', 'w') as f:
    #     json.dump(history, f, indent=4)

    # print(f"\n=== Training Complete. Best Validation Rank-1: {best_val_rank1*100:.2f}% at round {best_round} ===")
    # print("[Server] Loading Best Model for Final Evaluation...")
    # best_state = torch.load(results_path / 'best_backbone.pth')
    # server.global_backbone.load_state_dict(best_state)
    # final_model = server.get_eval_model().to(config['device'])
    # test_r1, test_r5 = evaluate(
    #     final_model, 
    #     test_gallery_set, 
    #     test_query_set, 
    #     config['device'], 
    #     batch_size=config['batch_size']
    # )
    # print(f"[Server] Final Test Set Performance - Rank-1: {test_r1*100:.2f}%, Rank-5: {test_r5*100:.2f}%")

    # with open(results_path / 'results.txt', 'w') as f:
    #     f.write(f"Best Validation Rank-1 Accuracy: {best_val_rank1*100:.2f}% at round {best_round}\n")
    #     f.write(f"Test Rank-1 Accuracy: {test_r1*100:.2f}%, Rank-5 Accuracy: {test_r5*100:.2f}%\n")

if __name__ == "__main__":
    # config = {
    #     'root': './data/SeaTurtleID2022',
    #     'results_root': './results/federated_reid', 
    #     'description': '',
    #     'image_size': 384,
    #     'batch_size': 128,
    #     'patience': 8,

    #     'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    #     'seed': 42,
    #     'body_part': 'head',  # 'head', 'turtle', 'flipper', or None for full image
    #     'set': 'closed',      # 'closed' or 'open'
    #     'lr': 1e-4,
    #     'w_decay': 1e-4,

    #     'embedding_dim': 512,
        
    #     # Federated Settings
    #     'num_clients': 5,
    #     'overlap_ratio': 0.1,
    #     'max_client_ratio': 0.4,
    #     'local_epochs': 1,
    #     'rounds': 30,
    #     'lambda_proto': 0.1,
    #     'proto_momentum': 0.9,
    #     'warmup_rounds': 5,
    # }

    # experiments = [
    #     {
    #         'results_name': 'TEST_RUN',
    #         'description': 'Test run with default settings',
    #         'body_part': 'head',
    #         'set': 'closed',
    #         'seeds': [42],
    #     },
    # ]

    # for exp in experiments:
    #     print(f"Starting {exp['results_name']}...")
    #     for seed in exp['seeds']:
    #         exp_config = config.copy()
    #         exp_config.update(exp)
    #         exp_config['results_name'] = f"{exp['results_name']}_SEED_{seed}"
    #         exp_config['seed'] = seed
    #         exp_config.pop('seeds', None)
    #         main(exp_config)
    convnext = ConvNeXtBackbone()
    print(convnext)

  from .autonotebook import tqdm as notebook_tqdm


ConvNeXtBackbone(
  (backbone): ConvNeXt(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
        (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
      )
      (1): Sequential(
        (0): CNBlock(
          (block): Sequential(
            (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (1): Permute()
            (2): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
            (3): Linear(in_features=128, out_features=512, bias=True)
            (4): GELU(approximate='none')
            (5): Linear(in_features=512, out_features=128, bias=True)
            (6): Permute()
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): CNBlock(
          (block): Sequential(
            (0): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (1): Permute()
            (2)