In [None]:
# ============================================================
#  FINAL SCRIPT: DENSE PAIRING + REGRESSION LOSS (8-DIM)
# ============================================================

from pathlib import Path
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights

# ------------------------------------------------------------
#  1. Dataset (Expects Pairs)
# ------------------------------------------------------------

class WearPairDataset(Dataset):
    """
    Sample format: (img1_path, img2_path, w1_vec, w2_vec)
    """
    def __init__(self, samples, transform):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        path1, path2, w1, w2 = self.samples[idx]

        img1 = self.transform(Image.open(path1).convert("RGB"))
        img2 = self.transform(Image.open(path2).convert("RGB"))

        return (
            img1,
            img2,
            torch.tensor(w1, dtype=torch.float32),
            torch.tensor(w2, dtype=torch.float32),
        )

# ------------------------------------------------------------
#  2. Model (One Backbone, Multiple Heads)
# ------------------------------------------------------------

class MultiHeadWearNet(nn.Module):
    def __init__(self, num_heads=3, embed_dim=8, freeze_backbone=True):
        super().__init__()

        # Shared Backbone
        backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
        feat_dim = backbone.fc.in_features
        backbone.fc = nn.Identity()

        if freeze_backbone:
            for p in backbone.parameters():
                p.requires_grad = False
        
        self.backbone = backbone

        # Independent Heads (Linear Layers)
        # We do NOT normalize here because we need the output magnitude 
        # to match the physical wear distance (mm).
        self.heads = nn.ModuleList([
            nn.Linear(feat_dim, embed_dim) for _ in range(num_heads)
        ])

    def forward(self, x, head_idx=None):
        features = self.backbone(x)
        
        # If specific head requested (during inference or single-head tasks)
        if head_idx is not None:
            return self.heads[head_idx](features)
        
        # Return all for training loop
        return [head(features) for head in self.heads]

# ------------------------------------------------------------
#  3. Loss (Equation 1: Regression)
# ------------------------------------------------------------

class WearDistanceLoss(nn.Module):
    """
    Minimizes: ( ||E1 - E2|| - |w1 - w2| )^2
    """
    def forward(self, E1, E2, w1, w2):
        # Euclidean distance in embedding space
        d_embed = torch.norm(E1 - E2, p=2, dim=1)
        
        # Absolute difference in wear (ground truth)
        d_wear = torch.abs(w1 - w2)
        
        # MSE
        return ((d_embed - d_wear) ** 2).mean()

# ------------------------------------------------------------
#  4. Setup & Dense Pairing Logic
# ------------------------------------------------------------

device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

# --- USER DATA INPUT AREA ---
# You need to parse your folders to fill these two lists.
# sequence_paths: List[List[str]] -> [ [seq1_img0, seq1_img1...], [seq2_img0...] ]
# sequence_wears: List[List[List[float]]] -> [ [[0.1, 0.2, 0.3], ...], ... ]
# (Make sure wears are 3D vectors: [Flank, Adhesion, F+A])

sequence_paths = [] 
sequence_wears = []
# Example filler (Remove this when you add your real loading logic):
# sequence_paths = [["p1.jpg", "p2.jpg", "p3.jpg"]]
# sequence_wears = [[[0.1, 0.0, 0.1], [0.2, 0.0, 0.2], [0.5, 0.1, 0.6]]]

samples = []

print("Generating Dense Pairs...")
# DENSE PAIRING LOOP (All vs All within sequence)
for paths, wears in zip(sequence_paths, sequence_wears):
    N = len(paths)
    for i in range(N):
        for j in range(i + 1, N): # j > i ensures unique pairs and no self-pairs
            # Append tuple: (path_i, path_j, wear_i, wear_j)
            samples.append((paths[i], paths[j], wears[i], wears[j]))

print(f"Generated {len(samples)} pairs.")

# ------------------------------------------------------------
#  5. Training Loop
# ------------------------------------------------------------

if len(samples) > 0:
    dataset = WearPairDataset(samples, transform)
    loader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)

    # Initialize Model (One Backbone, 3 Heads, 8-Dim output)
    model = MultiHeadWearNet(num_heads=3, embed_dim=8, freeze_backbone=True).to(device)
    optimizer = torch.optim.Adam(model.heads.parameters(), lr=1e-3)
    criterion = WearDistanceLoss()

    epochs = 20

    for epoch in range(epochs):
        model.train()
        total_loss = [0.0, 0.0, 0.0]
        batch_count = 0

        for img1, img2, w1, w2 in loader:
            img1, img2 = img1.to(device), img2.to(device)
            w1, w2 = w1.to(device), w2.to(device) # [B, 3]

            # 1. Forward Pass (Backbone runs once per batch)
            # We get a list of 3 embeddings for img1 and img2
            out1 = model(img1) # List of 3 tensors
            out2 = model(img2) # List of 3 tensors

            batch_loss = 0.0
            
            # 2. Compute Loss per Head
            for k in range(3):
                # out1[k] is the embedding for head k
                # w1[:, k] is the scalar wear value for head k
                loss_k = criterion(out1[k], out2[k], w1[:, k], w2[:, k])
                
                batch_loss += loss_k
                total_loss[k] += loss_k.item()

            # 3. Backprop (Sum of losses)
            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()
            
            batch_count += 1

        # Logging
        avg_losses = [l / batch_count for l in total_loss] if batch_count > 0 else [0]*3
        print(f"Epoch {epoch+1:02d} | L_Flank: {avg_losses[0]:.4f} | L_Adh: {avg_losses[1]:.4f} | L_FA: {avg_losses[2]:.4f}")

else:
    print("No samples found. Please populate sequence_paths and sequence_wears.")