In [None]:
# ============================================================
#  COMPLETE SCRIPT — 3 INDEPENDENT WEAR REGRESSION NETWORKS
# ============================================================

from pathlib import Path
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights

# ------------------------------------------------------------
#  Dataset
# ------------------------------------------------------------

class WearPairDataset(Dataset):
    """
    Each sample:
      ref_img, curr_img, w_ref[3], w_curr[3]  
    """
    ## w_ref and w_curr are 3D y-wear vectors! The 3 wear types values of an image.
    ## ref is the first image! All of these values are in a tuple in samples
    def __init__(self, samples, transform):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        ref_path, curr_path, w_ref, w_curr = self.samples[idx]

        ref_img = self.transform(Image.open(ref_path).convert("RGB"))
        curr_img = self.transform(Image.open(curr_path).convert("RGB"))

        return (
            ref_img,
            curr_img,
            torch.tensor(w_ref, dtype=torch.float32),
            torch.tensor(w_curr, dtype=torch.float32),
        )   ## This is just to retrieve the images and the wear values

# ------------------------------------------------------------
#  Model
# ------------------------------------------------------------

class WearNet(nn.Module):
    def __init__(self, embed_dim=128, freeze_backbone=True):
        super().__init__()

        backbone = resnet50(weights=ResNet50_Weights.DEFAULT)
        feat_dim = backbone.fc.in_features
        backbone.fc = nn.Identity()

        if freeze_backbone:
            for p in backbone.parameters():
                p.requires_grad = False

        self.backbone = backbone
        self.embedding = nn.Linear(feat_dim, embed_dim)

    def forward(self, x):
        feat = self.backbone(x)        # [B, 2048]
        embed = self.embedding(feat)   # [B, D]
        return embed

# ------------------------------------------------------------
#  Loss
# ------------------------------------------------------------

class WearDistanceLoss(nn.Module):
    """
    ||E_curr - E_ref|| ≈ |w_curr - w_ref|
    """
    def forward(self, E_ref, E_curr, w_ref, w_curr):
        d_embed = torch.norm(E_curr - E_ref, dim=1)
        d_wear = torch.abs(w_curr - w_ref)
        return ((d_embed - d_wear) ** 2).mean()

# ------------------------------------------------------------
#  Setup
# ------------------------------------------------------------

device = "cuda" if torch.cuda.is_available() else "cpu"

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])

# samples = [
#   (ref_img_path, curr_img_path, [w1,w2,w3], [w1,w2,w3]),
#   ...
# ]
samples = []  # <-- fill this

dataset = WearPairDataset(samples, transform)
loader = DataLoader(dataset, batch_size=16, shuffle=True) ## we are doing mini-batch training. U can also experiment by setting it to 1 or 8!!!

# ------------------------------------------------------------
#  Three independent networks
# ------------------------------------------------------------

nets = [WearNet(embed_dim=128, freeze_backbone=True).to(device) for _ in range(3)]
optimizers = [
    torch.optim.Adam(net.embedding.parameters(), lr=1e-3)
    for net in nets
]

criterion = WearDistanceLoss()

# ------------------------------------------------------------
#  Training Loop
# ------------------------------------------------------------

epochs = 20

for epoch in range(epochs):
    total_loss = [0.0, 0.0, 0.0]

    for ref_img, curr_img, w_ref, w_curr in loader:
        ref_img = ref_img.to(device)
        curr_img = curr_img.to(device)
        w_ref = w_ref.to(device)    # [B, 3]
        w_curr = w_curr.to(device)

        for k in range(3):
            net = nets[k]
            opt = optimizers[k]

            E_ref = net(ref_img)
            E_curr = net(curr_img)

            loss = criterion(
                E_ref,
                E_curr,
                w_ref[:, k],
                w_curr[:, k],
            )   # The resnet backbone is frozen, what we train is the final head! For that we 
                # are using small batch training and our custom loss.

            opt.zero_grad()
            loss.backward()
            opt.step()

            total_loss[k] += loss.item()

    print(
        f"Epoch {epoch} | "
        f"Loss1: {total_loss[0]/len(loader):.4f} | "
        f"Loss2: {total_loss[1]/len(loader):.4f} | "
        f"Loss3: {total_loss[2]/len(loader):.4f}"
    )

# ============================================================
#  END
# ============================================================
