# Few-shot Embedding Adapter

In [1]:
import math
import random
from typing import Optional, Literal

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset

# Utils

In [3]:
def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def few_shot_indices(ds, k: int = 4, seed: int = 42):
    rng = random.Random(seed)
    by_class = {}
    for idx, (_, c) in enumerate(ds):
        by_class.setdefault(c, []).append(idx)
    chosen = []
    for c, idxs in by_class.items():
        rng.shuffle(idxs)
        chosen.extend(idxs[:k])
    return chosen

# Encoder helpers (frozen)

In [4]:
def freeze_encoder(model: nn.Module):
    for p in model.parameters():
        p.requires_grad = False
    model.eval()
    return model

@torch.no_grad()
def encode_images(encoder: nn.Module, x: torch.Tensor) -> torch.Tensor:
    feats = encoder.encode_image(x) if hasattr(encoder, "encode_image") else encoder(x)
    feats = F.normalize(feats, dim=-1)
    return feats  # [B, D]

def feature_dim(encoder: nn.Module, sample: torch.Tensor) -> int:
    with torch.no_grad():
        f = encoder.encode_image(sample) if hasattr(encoder, "encode_image") else encoder(sample)
    return f.shape[-1]

# Learnable class embeddings

In [5]:
def create_class_embeddings(
    num_classes: int,
    dim: int,
    tokens_per_class: int = 4,
    device: Optional[str] = None,
    init_scale: float = 0.02,
) -> nn.Parameter:
    emb = nn.Parameter(torch.randn(num_classes, tokens_per_class, dim, device=device) * init_scale)
    emb.requires_grad_(True)
    return emb  # shape: [C, T, D]

# optional tiny attention head for token pooling
def create_token_attn(dim: int, device: Optional[str] = None) -> nn.Linear:
    attn = nn.Linear(dim, 1, bias=False).to(device)
    return attn

# Logits from features + embeddings

In [7]:
@torch.no_grad()
def normalize_embeddings(emb: torch.Tensor) -> torch.Tensor:
    return F.normalize(emb, dim=-1)

def logits_from(
    feats: torch.Tensor,          # [B, D]
    emb: torch.Tensor,            # [C, T, D] (learnable)
    mode: Literal["max", "mean", "attn"] = "max",
    temperature: Optional[nn.Parameter] = None,  # e.g., nn.Parameter(torch.tensor(10.0, device=device))
    token_attn: Optional[nn.Module] = None,      # nn.Linear(D,1)
) -> torch.Tensor:
    E = normalize_embeddings(emb)                 # [C, T, D]
    Fz = F.normalize(feats, dim=-1)               # [B, D]
    sim = torch.einsum("bd,ctd->bct", Fz, E)      # [B, C, T]

    if mode == "max":
        logit = sim.amax(dim=-1)                  # [B, C]
    elif mode == "mean":
        logit = sim.mean(dim=-1)
    elif mode == "attn":
        assert token_attn is not None, "token_attn required for mode='attn'"
        w = torch.softmax(token_attn(E).squeeze(-1), dim=-1)  # [C, T]
        logit = (sim * w.unsqueeze(0)).sum(dim=-1)
    else:
        raise ValueError("mode must be one of {'max','mean','attn'}")

    if temperature is not None:
        logit = temperature * logit
    return logit  # [B, C]

# Optimizer factory (emb + optional temperature)

In [8]:
def make_optimizer(emb: nn.Parameter, lr: float = 5e-3, weight_decay: float = 1e-4, temperature: nn.Parameter = None):
    params = [emb]
    if temperature is not None:
        params.append(temperature)
    return torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)

# Train / Eval loops (emb-only training)

In [9]:
def train_epoch(
    encoder: nn.Module,
    emb: nn.Parameter,
    loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: str,
    mode: Literal["max","mean","attn"] = "max",
    temperature: nn.Parameter = None,
    token_attn: nn.Module = None,
) -> tuple[float, float]:
    encoder.eval()  # frozen
    tot = cor = n = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            feats = encode_images(encoder, x)     # [B, D]
        logit = logits_from(feats, emb, mode=mode, temperature=temperature, token_attn=token_attn)
        loss = F.cross_entropy(logit, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        tot += loss.item() * x.size(0)
        cor += (logit.argmax(1) == y).sum().item()
        n += x.size(0)
    return tot / max(n,1), cor / max(n,1)

@torch.no_grad()
def evaluate(
    encoder: nn.Module,
    emb: nn.Parameter,
    loader: DataLoader,
    device: str,
    mode: Literal["max","mean","attn"] = "max",
    temperature: nn.Parameter = None,
    token_attn: nn.Module = None,
    topk: int = 5,
) -> dict:
    encoder.eval()
    tot = cor1 = cork = n = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        feats = encode_images(encoder, x)
        logit = logits_from(feats, emb, mode=mode, temperature=temperature, token_attn=token_attn)
        loss = F.cross_entropy(logit, y)
        tot += loss.item() * x.size(0)
        pred1 = logit.argmax(1)
        cor1 += (pred1 == y).sum().item()
        topk_idx = logit.topk(min(topk, logit.size(1)), dim=1).indices
        cork += (topk_idx == y.unsqueeze(1)).any(dim=1).sum().item()
        n += x.size(0)
    return {
        "loss": tot / max(n,1),
        "top1": cor1 / max(n,1),
        "topk": cork / max(n,1),
    }

# Save / Load embeddings

In [10]:
def save_embeddings(emb: nn.Parameter, path: str = "fewshot_emb.pt"):
    torch.save(emb.detach().cpu(), path)

def load_embeddings(path: str, device: Optional[str] = None) -> nn.Parameter:
    tensor = torch.load(path, map_location=device if device else "cpu")
    param = nn.Parameter(tensor.to(device) if device else tensor)
    param.requires_grad_(True)
    return param