# losses

> LeJEPA, GAN discriminator, ...aand more


In [None]:
#| default_exp losses

In [None]:
#| hide
from nbdev.showdoc import *


In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F 

## Safe Mean

Turns out zero element tensors will yield NaN when you try to run `.mean()`, so...

In [None]:
#| export
def safe_mean(t, dim=None): 
    """safe replacement for torch.mean( ).  can't be used as a suffix though"""
    return t.mean(dim=dim) if t.numel() > 0 else 0.0

## LeJEPA Loss

For an interactive overview of LeJEPA, see https://www.scotthawley.com/ssltoy/

In [None]:
#| export
def SIGReg(x, global_step, num_slices=256):
    """SIGReg with Epps-Pulley statistic. x is (N, K) tensor."""
    device = x.device
    g = torch.Generator(device=device).manual_seed(global_step)
    proj_shape = (x.size(1), num_slices)
    A = torch.randn(proj_shape, generator=g, device=device)
    A = A / (A.norm(dim=0, keepdim=True) + 1e-10)  # normalize columns
    
    # Epps-Pulley statistic
    t = torch.linspace(-5, 5, 17, device=device) # values used in LeJEPA paper, worked for SSLtoy
    exp_f = torch.exp(-0.5 * t**2)  # theoretical CF for N(0,1)
    x_t = (x @ A).unsqueeze(2) * t  # (N, M, T)
    ecf = (torch.exp(1j * x_t).mean(dim=0)).abs()  # empirical CF
    diff = (ecf - exp_f).abs().square().mul(exp_f)  # weighted L2 distance
    #N = x.size(0)  # With respect to Yann: Don't scale by N because then if you change the batch size you have to retune lambd by hand ugh
    T = torch.trapz(diff, t, dim=1).sum() #* N  # sum here is over num slices, not data points
    return T

In [None]:
#| eval: false
# Test SIGReg with random embeddings
batch_size, embed_dim = 32, 64
x = torch.randn(batch_size, embed_dim)
loss = SIGReg(x, global_step=0, num_slices=256)
print(f"SIGReg loss: {loss.item():.4f}")

SIGReg loss: 2.9035


In [None]:
#| export
def attraction_loss(z1, z2,  # embeddings of two "views" of the same thing (in batches)
                    deltas=None,   # optional/TBD: info on semantic 'distance' between z1 & z2
                    tau = 100.0):    # inverse strength of fall-off for delta distances, big=slower
    "How we pull similar 'views' together"
    if deltas is None: return safe_mean( (z1 - z2).square() )
    delta_diag = (deltas**2).sum(dim=1)
    delta_fac = torch.exp(-delta_diag / tau) # less attraction for more 'distant' views
    #delta_fac = 1/(1 + delta_diag/tau)  # longer tail than exp
    return safe_mean( (z1 - z2).square() * delta_fac.unsqueeze(-1) )

In [None]:
#| export
def LeJEPA(z1, z2, global_step, lambd=0.5, deltas=None): 
    "Main LeJEPA loss function"
    sim = attraction_loss(z1, z2, deltas=deltas)
    sigreg = SIGReg( torch.cat((z1, z2), dim=0), global_step ) * 1 # normalize to similar scale as sim
    return {'loss': (1-lambd)*sim + lambd*sigreg, 'sim':sim.item(), 'sigreg':sigreg.item()}

In [None]:
#| eval: false
# Test LeJEPA loss
batch_size, embed_dim = 32, 64
z1 = torch.randn(batch_size, embed_dim)
z2 = torch.randn(batch_size, embed_dim)

loss = LeJEPA(z1, z2, global_step=0, lambd=0.5)
print(f"LeJEPA loss: {loss['loss'].item():.4f}")
print(f"  Attraction: {attraction_loss(z1, z2).item():.4f}")
print(f"  SIGReg: {SIGReg(torch.cat((z1, z2), dim=0), global_step=0).item():.4f}")

LeJEPA loss: 1.8238
  Attraction: 2.0526
  SIGReg: 1.5950


## Masked (Auto)Encoder Loss

In [None]:
#| export
# def calc_mae_loss(recon_patches, img, enc_out, patch_size=16, lambda_visible=0.1):
#     "BCE loss on reconstructed patches, with optional downweighting of visible patches"
#     mae_mask = enc_out.mae_mask[1:]  # strip CLS
#     img_patches = img.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
#     img_patches = img_patches.flatten(2, 3).flatten(-2, -1).squeeze(1)  # (B, N, ps*ps)
#     weights = torch.ones_like(recon_patches)
#     weights[:, mae_mask, :] = lambda_visible
#     loss = (F.binary_cross_entropy_with_logits(recon_patches, img_patches, reduction='none') * weights).mean()
#     return loss

def calc_mae_loss(recon_patches, img, enc_out, lambda_visible=0.1):
    "BCE loss on reconstructed patches, with optional downweighting of visible patches"
    mae_mask = enc_out.mae_mask
    patch_size = int(recon_patches.shape[-1] ** 0.5)
    if mae_mask.dim() == 1: mae_mask = mae_mask[1:]  # strip CLS (ViT only)
    img_patches = img.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    img_patches = img_patches.flatten(2, 3).flatten(-2, -1).squeeze(1)  # (B, N, ps*ps)
    weights = torch.ones_like(recon_patches)
    weights[mae_mask] = lambda_visible  # works for both (N,) and (B,N)
    loss = (F.binary_cross_entropy_with_logits(recon_patches, img_patches, reduction='none') * weights).mean()
    return loss


## Full Encoder Loss

In [None]:
#| export
def anchor_loss(z1, z2):
    "Anchor embeddings of empty patches to the origin"
    return safe_mean( z1.square() ) + safe_mean( z2.square() )

In [None]:
#| export
def calc_enc_loss(z1, z2, global_step, deltas=None, lambd=0.5, non_emptys=(None,None), lambda_anchor=0.05):
    "Main loss function for Encoder"
    non_empty1, non_empty2 = non_emptys
    non_empty = non_empty1 & non_empty2  # both non-empty
    valid = non_empty.view(-1).bool()
    loss_dict = LeJEPA(z1[valid], z2[valid], global_step, deltas=deltas[valid], lambd=lambd)
    aloss = anchor_loss(z1[~non_empty1.view(-1).bool()], z2[~non_empty2.view(-1).bool()])
    loss_dict['anchor'] = aloss
    loss_dict['loss'] = loss_dict['loss'] + lambda_anchor * aloss
    return loss_dict

In [None]:
#| export
#| export
@torch.compiler.disable
def calc_enc_loss_multiscale(z1, z2,  # lists of embeddings at each swin hierarchy level, 0=coarsest, -1=finest
                             global_step, img_size, deltas=None,
                             lambd=0.5, non_emptys=None, lambda_anchor=0.05):
    """Compute encoder loss at each hierarchy level, weighted by patch overlap fraction.
    Levels where delta/shift exceeds patch size (resulting in zero overlap) are skipped entirely.
    Non-empty patch masks focus the sim/anchor losses on musically active regions."""
    if not isinstance(z1, list): # regular vit 
        d_exp = deltas.repeat_interleave(z1.shape[0] // deltas.shape[0], dim=0)
        return calc_enc_loss(z1, z2, global_step, deltas=d_exp, lambd=lambd, non_emptys=non_emptys, lambda_anchor=lambda_anchor)
    total = {}
    abs_deltas = deltas.abs()
    n_levels = len(z1)
    for i, (z1_l, z2_l, ne) in enumerate(zip(z1, z2, non_emptys)):
        if z1_l.shape[1] == 1: continue  # skip degenerate 1-token level
        B, N, D = z1_l.shape
        grid = int(N ** 0.5)
        patch_w, patch_h = img_size // grid, img_size // grid
        dx, dy = abs_deltas[:, 0], abs_deltas[:, 1]
        # hinge-style weight to avoid comparing patches if the deltas (aka shift) exceed patch size. 
        w = (((patch_w - dx) / patch_w).clamp(min=0) * ((patch_h - dy) / patch_h).clamp(min=0)).mean()
        if w < 1e-8: continue

        num_tokens = N
        d_exp = deltas.repeat_interleave(num_tokens, dim=0)
        z1_flat, z2_flat = z1_l.reshape(-1, D),  z2_l.reshape(-1, D)
        #ne_flat = (ne[0].reshape(-1), ne[1].reshape(-1))
        ne_flat = (ne[0].reshape(-1).bool(), ne[1].reshape(-1).bool())

        ld = calc_enc_loss(z1_flat, z2_flat, global_step, deltas=d_exp,
                           lambd=lambd, non_emptys=ne_flat, lambda_anchor=lambda_anchor)
        for k, v in ld.items():  total[k] = total.get(k, 0) + v * w

    if not total: return {'loss': torch.tensor(0.0, device=deltas.device)}
    return {k: v / n_levels for k, v in total.items()}

## Adversarial Loss 

In [None]:
#| export
class PatchGANDiscriminator(nn.Module):
    def __init__(self, in_ch=1, base_ch=64, n_layers=3, use_spectral_norm=True):
        super().__init__()
        norm = nn.utils.spectral_norm if use_spectral_norm else (lambda x: x)
        layers = [norm(nn.Conv2d(in_ch, base_ch, kernel_size=4, stride=2, padding=1)), nn.LeakyReLU(0.2, True)]
        ch = base_ch
        for i in range(1, n_layers):
            ch_next = min(ch * 2, 512)  # double channels each layer, but cap at 512 to limit params
            layers += [norm(nn.Conv2d(ch, ch_next, kernel_size=4, stride=2, padding=1)), nn.LeakyReLU(0.2, True)]
            ch = ch_next
        layers.append(norm(nn.Conv2d(ch, 1, kernel_size=4, stride=1, padding=1)))
        self.net = nn.Sequential(*layers)

    def forward(self, x): return self.net(x)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()