# Stage 2: Pseudo-Label Generation

This notebook utilizes the "Teacher" models trained in Stage 1 to perform inference on the unlabeled dataset (`UNLABELED_DIR`).

**Methodology:**
1. Load the best checkpoint for each selected Teacher architecture (e.g., ResNet18, ResNet50).
2. Predict class probabilities for every unlabeled image.
3. **Confidence Thresholding:** Only predictions exceeding the defined confidence threshold (`CONFIDENCE_THRESHOLD`) are accepted.
4. Accepted images are copied into a new directory structure (`pseudo_labels/`), effectively expanding the training set for the "Student" model in the next stage.

In [None]:
# ==============================================================================
# CELL 2: Imports & Setup
# ==============================================================================
import torch
import shutil
import models
import utils
import config
from pathlib import Path
from tqdm import tqdm 

# Ensure reproducibility (optional but good practice for inference order)
torch.manual_seed(config.SEED)

In [None]:
# ==============================================================================
# CELL 3: Configuration of Models for Inference
# ==============================================================================

# Define which architectures to evaluate for pseudo-label generation.
# Typically, we evaluate multiple teachers to decide which one generates the best quality dataset.
TEACHER_MODELS_TO_USE = ['resnet18', 'resnet50'] 

# Confidence Threshold (High confidence required to avoid noise propagation)
CONFIDENCE_THRESHOLD = 0.80

print(f"[CONFIG] Selected models for pseudo-labeling: {TEACHER_MODELS_TO_USE}")
print(f"[CONFIG] Confidence Threshold: {CONFIDENCE_THRESHOLD}")

# Verify Unlabeled Data Directory
unlabeled_dir = config.UNLABELED_DIR
if not unlabeled_dir.exists():
    raise FileNotFoundError(f"Unlabeled data directory not found at: {unlabeled_dir}")

# Retrieve list of images
unlabeled_images = list(unlabeled_dir.glob("*.png"))
print(f"[DATA] Total unlabeled images found: {len(unlabeled_images)}")

# Load class names from the training directory structure to ensure consistency
class_names = sorted([p.name for p in config.TRAIN_VAL_DIR.iterdir() if p.is_dir()])
print(f"[DATA] Detected Classes: {class_names}")

In [None]:
# ==============================================================================
# CELL 4: Pseudo-Labeling Pipeline
# ==============================================================================

def find_best_teacher_checkpoint(model_arch):
    """
    Locates the most recent training run for a given architecture 
    and returns the path to the 'best_model.pth' checkpoint.
    """
    # Search in the artifacts directory defined in config.py
    results_dir = config.ARTIFACTS_DIR
    
    # Filter directories matching the teacher pattern
    candidates = sorted([d for d in results_dir.iterdir() if d.name.startswith(f"teacher_{model_arch}")])
    
    if not candidates:
        return None
    
    # Select the most recent run (last in sorted list)
    best_run = candidates[-1]
    return best_run / "best_model.pth"

device = torch.device(config.DEVICE)

# --- Main Inference Loop ---
for model_arch in TEACHER_MODELS_TO_USE:
    print(f"\n{'-'*60}")
    print(f"[INFO] PROCESSING MODEL: {model_arch.upper()}")
    print(f"{'-'*60}")
    
    # 1. Locate Checkpoint
    ckpt_path = find_best_teacher_checkpoint(model_arch)
    if not ckpt_path or not ckpt_path.exists():
        print(f"[ERROR] Checkpoint not found for {model_arch}. Skipping...")
        continue
        
    print(f"[INFO] Loading weights from: {ckpt_path.name}")
    
    # 2. Load Model Architecture
    # Note: resnet_use_pretrain=False because we are loading our own fine-tuned weights
    model = models.make_model(model_arch, len(class_names), resnet_use_pretrain=False)
    
    # Load state dictionary
    try:
        model.load_state_dict(torch.load(ckpt_path, map_location=device, weights_only=True))
    except TypeError: 
        # Fallback for older PyTorch versions lacking weights_only
        model.load_state_dict(torch.load(ckpt_path, map_location=device))
        
    model.to(device)
    model.eval()
    
    # 3. Prepare Output Directory
    # We create a specific folder for each teacher model's pseudo-labels
    output_dir = config.PSEUDO_LABEL_DIR / model_arch
    
    if output_dir.exists():
        print(f"[INFO] Cleaning existing output directory: {output_dir}")
        shutil.rmtree(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Create subdirectories for each class
    for cname in class_names:
        (output_dir / cname).mkdir(exist_ok=True)
        
    # 4. Inference and File Sorting
    accepted_count = 0
    print(f"[INFO] Generating pseudo-labels (Threshold >= {CONFIDENCE_THRESHOLD})...")
    
    with torch.no_grad():
        for img_path in tqdm(unlabeled_images, desc=f"Inference {model_arch}"):
            try:
                # A. Load and Preprocess Image
                img_tensor = utils.load_png_gray(img_path) 
                
                # Resize if using ResNet backbone (matches training preprocessing)
                if "resnet" in model_arch:
                    img_tensor = torch.nn.functional.interpolate(
                        img_tensor.unsqueeze(0), 
                        size=config.IMG_SIZE if hasattr(config, 'IMG_SIZE') else (100, 100), 
                        mode="bilinear", 
                        align_corners=False
                    ).squeeze(0)
                
                img_tensor = img_tensor.unsqueeze(0).to(device) # Add batch dimension
                
                # B. Forward Pass
                logits = model(img_tensor)
                probs = torch.softmax(logits, dim=1)
                max_prob, pred_idx = torch.max(probs, dim=1)
                
                confidence = max_prob.item()
                pred_class = class_names[pred_idx.item()]
                
                # C. Threshold Check
                if confidence >= CONFIDENCE_THRESHOLD:
                    # Copy file to the corresponding class folder
                    dest_path = output_dir / pred_class / img_path.name
                    shutil.copy(str(img_path), str(dest_path))
                    accepted_count += 1
                    
            except Exception as e:
                print(f"[WARN] Failed to process {img_path.name}: {e}")
                
    # 5. Final Report for this Model
    acceptance_rate = accepted_count / len(unlabeled_images) if unlabeled_images else 0
    print(f"\n[RESULT] {model_arch.upper()}: Accepted {accepted_count}/{len(unlabeled_images)} images ({acceptance_rate:.1%}).")
    print(f"[INFO] Pseudo-labels saved to: {output_dir}")

print("\n[DONE] Pseudo-labeling stage completed.")