In [None]:
# ==============================================================================
# EXPERIMENT DATA GENERATION SCRIPT
# ==============================================================================
# This script generates isolated datasets to support ablation studies and 
# comparative analysis (e.g., Threshold Sensitivity, FixMatch baseline).
# It does NOT modify the official pseudo-label dataset.

import torch
import shutil
import models
import utils
import config
from pathlib import Path
from tqdm import tqdm

# --- EXPERIMENTAL CONFIGURATION ---
TEACHER_ARCH = "resnet50"  # Winning teacher architecture used for generation
DEVICE = torch.device(config.DEVICE)

# Relative Paths
BASE_RESULTS_DIR = config.ARTIFACTS_DIR
UNLABELED_DIR = config.UNLABELED_DIR
CLASS_NAMES = sorted([p.name for p in config.TRAIN_VAL_DIR.iterdir() if p.is_dir()])

# --- EXPERIMENTS DEFINITION ---
# Dictionary defining the experiments to generate.
# Key: Experiment Name (Folder Name)
# Value: Confidence Threshold for Pseudo-Labeling
EXPERIMENTS_DICT = {
    # Experiment 3.3: Threshold Ablation Study
    # Objective: Validate the choice of 0.80 by comparing against looser (0.70) and stricter (0.90) thresholds.
    "PSEUDO_EXP_3.3_Umb70": 0.70,
    "PSEUDO_EXP_3.3_Umb90": 0.90,
    
    # Experiment 3.4: Automated Baseline (No Human Verification)
    # Objective: Assess the performance of raw pseudo-labels at the chosen threshold (0.80).
    "PSEUDO_EXP_3.4_Auto80": 0.80, 

    # Experiment 3.2: High-Confidence Baseline (Simulating FixMatch)
    # Objective: Compare against a fully automated SOTA approach which typically uses high thresholds (0.95).
    "PSEUDO_EXP_3.2_FixMatch95": 0.95
}

print(f"--- STARTING EXPERIMENTAL DATA GENERATION ---")
print(f"Scenario: {config.CURRENT_CARRIER}")
print(f"Base Teacher: {TEACHER_ARCH}")
print(f"Experiments: {list(EXPERIMENTS_DICT.keys())}")

# 1. LOAD TEACHER MODEL (Inference Mode)
def find_best_teacher_checkpoint(model_arch):
    """Locates the best checkpoint for the specified teacher architecture."""
    candidates = sorted([d for d in BASE_RESULTS_DIR.iterdir() if f"teacher_{model_arch}" in d.name])
    if not candidates: 
        raise FileNotFoundError(f"Teacher checkpoint for {model_arch} not found.")
    return candidates[-1] / "best_model.pth"

ckpt_path = find_best_teacher_checkpoint(TEACHER_ARCH)
print(f"[INFO] Loading Teacher Weights: {ckpt_path.name}")

model = models.make_model(TEACHER_ARCH, len(CLASS_NAMES), resnet_use_pretrain=False)
try:
    model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE, weights_only=True))
except TypeError:
    model.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))

model.to(DEVICE)
model.eval()

# 2. SETUP ISOLATED FOLDER STRUCTURE
# Creates a separate root folder to avoid contamination with the official dataset
root_exp_dir = BASE_RESULTS_DIR / "ABLATION_EXPERIMENTS"

if root_exp_dir.exists():
    print("[INFO] Cleaning previous experiment data...")
    shutil.rmtree(root_exp_dir)
root_exp_dir.mkdir()

# Create subdirectories for each experiment and class
for exp_name in EXPERIMENTS_DICT.keys():
    exp_path = root_exp_dir / exp_name
    exp_path.mkdir()
    for cname in CLASS_NAMES:
        (exp_path / cname).mkdir()
    print(f" > Created directory: {exp_path.name}")

# 3. INFERENCE AND DISTRIBUTION LOOP
unlabeled_images = list(UNLABELED_DIR.glob("*.png"))
print(f"\n[INFO] Processing {len(unlabeled_images)} unlabeled images...")

stats = {k: 0 for k in EXPERIMENTS_DICT.keys()}

with torch.no_grad():
    for img_path in tqdm(unlabeled_images, desc="Distributing Data"):
        try:
            # Preprocessing
            img_tensor = utils.load_png_gray(img_path)
            
            # Mandatory resize for ResNet
            img_tensor = torch.nn.functional.interpolate(
                img_tensor.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False
            ).squeeze(0)
            
            img_tensor = img_tensor.unsqueeze(0).to(DEVICE)

            # Teacher Prediction
            logits = model(img_tensor)
            probs = torch.softmax(logits, dim=1)
            max_prob, pred_idx = torch.max(probs, dim=1)
            
            confidence = max_prob.item()
            pred_class = CLASS_NAMES[pred_idx.item()]

            # DISTRIBUTION LOGIC:
            # A single image can belong to multiple experiments if it meets their respective thresholds.
            # Each experiment folder is an isolated universe.
            for exp_name, threshold in EXPERIMENTS_DICT.items():
                if confidence >= threshold:
                    # Copy image to the specific experiment folder
                    dest = root_exp_dir / exp_name / pred_class / img_path.name
                    shutil.copy(str(img_path), str(dest))
                    stats[exp_name] += 1

        except Exception as e:
            print(f"[WARN] Error processing {img_path.name}: {e}")

print("\n--- GENERATION SUMMARY ---")
for exp_name, count in stats.items():
    print(f"[{exp_name}] (Threshold >={EXPERIMENTS_DICT[exp_name]}): {count} images generated.")

print(f"\n[DONE] Experimental datasets created at: {root_exp_dir}")
print("[NOTE] The official 'pseudo_labels' directory remains unchanged.")