In [None]:

### Use this for inference only with jupyter notebook, download the model from the ./models folder

import os
import re
import random
import numpy as np
import pandas as pd
from PIL import Image, ImageFile
from tqdm.auto import tqdm
import glob

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from torchvision import models, transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from ete3 import Tree

ImageFile.LOAD_TRUNCATED_IMAGES = True

SEED = 42

## add your own path
USER_CHECKPOINTS_DIR = "./models/results/final-fold/kfold_checkpoints"
TRAIN_ANNOTATIONS_CSV = './data/train.csv'
TEST_ANNOTATIONS_CSV  = './data/test.csv'
TREE_FILE_PATH        = './data/tree.nh'
ROI_FOLDER_NAME = "rois"
FULL_IMAGE_FOLDER_NAME = "images"
MAX_DISTANCE = 12
BATCH_SIZE   = 32
NUM_WORKERS  = os.cpu_count() // 2 if os.cpu_count() else 2
ROI_IMAGE_SIZE = (224, 224)
TARGET_FULL_IMAGE_EDGE = 384

SLIGHT_ROTATION_ANGLE = 10

def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)
    print(f"[seed_everything] seed={seed} applied.")

def seed_worker(worker_id: int):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def get_full_image_path_from_roi_path(roi_path: str, roi_folder_name: str = "rois", full_image_folder_name: str = "images") -> str:
    roi_filename = os.path.basename(roi_path)
    base_roi_name = os.path.splitext(roi_filename)[0].lstrip('_')
    parts = base_roi_name.split('_')
    image_id_str = parts[0] if parts[0] else parts[1]
    if not image_id_str:
        raise ValueError(f"Cannot extract image_id from ROI filename '{roi_filename}'.")
    
    full_image_filename = f"{image_id_str}.png"
    roi_dir = os.path.dirname(roi_path)
    path_parts = roi_dir.split(os.sep)
    try:
        idx_to_replace = path_parts.index(roi_folder_name)
        path_parts[idx_to_replace] = full_image_folder_name
        full_image_folder_path = os.sep.join(path_parts)
    except ValueError:
         raise ValueError(f"ROI folder name '{roi_folder_name}' not found in ROI path '{roi_path}'.")
    return os.path.join(full_image_folder_path, full_image_filename)

class FathomNetDataset(Dataset):
    def __init__(self, df_subset, transform_roi, transform_full, label_encoder):
        self.data = df_subset.reset_index(drop=True)
        self.roi_image_paths = self.data['path'].tolist()
        self.transform_roi = transform_roi
        self.transform_full = transform_full
        self.label_encoder = label_encoder
        self.original_ids = self.data.get('annotation_id', self.data.get('id', self.data['path'])).tolist()
        self.full_image_paths = [get_full_image_path_from_roi_path(p) for p in tqdm(self.roi_image_paths, desc="Deriving full image paths")]

    def __len__(self): return len(self.roi_image_paths)
    def __getitem__(self, idx):
        roi_img = Image.open(self.roi_image_paths[idx]).convert("RGB")
        full_img = Image.open(self.full_image_paths[idx]).convert("RGB")
        if self.transform_roi: roi_img = self.transform_roi(roi_img)
        if self.transform_full: full_img = self.transform_full(full_img)
        return (roi_img, full_img), self.original_ids[idx]

class FathomNetHierarchicalClassifier(pl.LightningModule):
    def __init__(self, num_classes, dropout1=0.4, dropout2=0.3, hidden_dim=512, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        roi_backbone = models.efficientnet_v2_m(weights=models.EfficientNet_V2_M_Weights.DEFAULT)
        roi_feat_dim = roi_backbone.classifier[1].in_features
        roi_backbone.classifier = nn.Identity()
        self.roi_backbone = roi_backbone
        full_image_backbone = models.efficientnet_v2_s(weights=models.EfficientNet_V2_S_Weights.DEFAULT)
        full_feat_dim = full_image_backbone.classifier[1].in_features
        full_image_backbone.classifier = nn.Identity()
        self.full_image_backbone = full_image_backbone
        combined_feat_dim = roi_feat_dim + full_feat_dim
        self.classifier_head = nn.Sequential(
            nn.Dropout(p=dropout1),
            nn.Linear(combined_feat_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(p=dropout2),
            nn.Linear(hidden_dim, num_classes)
        )
        
    def forward(self, x_tuple):
        x_roi, x_full = x_tuple
        roi_features = self.roi_backbone(x_roi)
        full_features = self.full_image_backbone(x_full)
        combined_features = torch.cat((roi_features, full_features), dim=1)
        return self.classifier_head(combined_features)

if __name__ == "__main__":
    seed_everything(SEED)
    print(f"Using 2-way TTA with slight rotation angle: {SLIGHT_ROTATION_ANGLE} degrees")
    print("Using weighted arithmetic mean for ensembling.")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda" and hasattr(torch, 'set_float32_matmul_precision'):
        torch.set_float32_matmul_precision('high')
    print(f"Using device: {device}")
    print("\n" + "─"*25 + " GLOBAL SETUP PHASE " + "─"*25)
    train_df = pd.read_csv(TRAIN_ANNOTATIONS_CSV)
    global_label_encoder = LabelEncoder().fit(train_df['label'].unique())
    num_classes = len(global_label_encoder.classes_)
    print(f"Number of classes: {num_classes}")
    checkpoint_files = sorted(glob.glob(os.path.join(USER_CHECKPOINTS_DIR, "**", "*.ckpt"), recursive=True))
    if not checkpoint_files:
        raise FileNotFoundError(f"No .ckpt model files found in '{USER_CHECKPOINTS_DIR}'.")
    print(f"\nFound {len(checkpoint_files)} model(s). Parsing scores for weighted ensembling...")
    models_with_scores = []
    for ckpt_path in checkpoint_files:
        match = re.search(r"val_hdist=([0-9]+\.?[0-9]*)", os.path.basename(ckpt_path))
        score = float(match.group(1)) if match else float(MAX_DISTANCE)
        models_with_scores.append({'path': ckpt_path, 'score': score})
        print(f"  - Parsed score {score:.4f} from: {os.path.basename(ckpt_path)}")
    raw_scores = np.array([m['score'] for m in models_with_scores])
    weights_np = 1.0 / (raw_scores + 1e-6)
    model_weights = torch.tensor(weights_np / np.sum(weights_np), dtype=torch.float32, device=device)
    print("\nCalculated Model Weights:")
    for i, m in enumerate(models_with_scores):
        print(f"  Model: {os.path.basename(m['path'])}, Weight: {model_weights[i].item():.4f}")
    print("\n" + "─"*25 + " ENSEMBLE INFERENCE PHASE " + "─"*25)
    val_test_transforms_roi = transforms.Compose([
        transforms.Resize(ROI_IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    full_image_transforms = transforms.Compose([
        transforms.Resize((TARGET_FULL_IMAGE_EDGE, TARGET_FULL_IMAGE_EDGE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    test_df = pd.read_csv(TEST_ANNOTATIONS_CSV)
    test_ds = FathomNetDataset(test_df, val_test_transforms_roi, full_image_transforms, global_label_encoder)
    test_loader = DataLoader(test_ds, BATCH_SIZE*2, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, worker_init_fn=seed_worker)
    print(f"✅ Test DataLoader initialized with {len(test_ds)} samples.")
    all_model_probs = []
    all_original_ids = []
    for model_idx, model_info in enumerate(models_with_scores):
        print(f"\nInferring with model {model_idx + 1}/{len(models_with_scores)}: {os.path.basename(model_info['path'])}")
        model = FathomNetHierarchicalClassifier.load_from_checkpoint(
            checkpoint_path=model_info['path'], map_location="cpu"
        ).to(device).eval()
        model_accumulated_logits = []
        with torch.no_grad():
            for (roi_images, full_images), batch_ids in tqdm(test_loader, desc=f"Predict (M {model_idx+1})"):
                if model_idx == 0: all_original_ids.extend(list(batch_ids))
                roi_images, full_images = roi_images.to(device), full_images.to(device)
                logits1 = model((roi_images, full_images))
                roi_rot, full_rot = TF.rotate(roi_images, angle=SLIGHT_ROTATION_ANGLE), TF.rotate(full_images, angle=SLIGHT_ROTATION_ANGLE)
                logits2 = model((roi_rot, full_rot))
                batch_avg_tta_logits = (logits1 + logits2) / 2.0
                model_accumulated_logits.append(batch_avg_tta_logits.cpu())
        model_probs = F.softmax(torch.cat(model_accumulated_logits), dim=1)
        all_model_probs.append(model_probs)
        del model
        if device.type == 'cuda': torch.cuda.empty_cache()
    print("\nApplying weighted arithmetic mean to ensembled probabilities.")
    stacked_probs = torch.stack(all_model_probs)
    ensembled_probs = torch.sum(stacked_probs * model_weights.cpu().view(-1, 1, 1), dim=0)
    final_preds_indices = torch.argmax(ensembled_probs, dim=1).numpy()
    decoded_concepts = global_label_encoder.inverse_transform(final_preds_indices)
    submission_df = pd.DataFrame({'annotation_id': all_original_ids, 'concept_name': decoded_concepts})
    num_ensembled_models = len(models_with_scores)
    final_submission_path = f"submission_weighted_arithmetic_{num_ensembled_models}models_TTA-2way_rotate{SLIGHT_ROTATION_ANGLE}.csv"
    submission_df.to_csv(final_submission_path, index=False)
    print(f"\n✅ Final submission written to {final_submission_path}; preview:")
    print(submission_df.head())
    print("\n🏁 Inference script execution finished.")
