In [1]:
import torch
import torch.distributed as dist
from torch import nn
from torch.nn import functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
class GatherLayer(torch.autograd.Function):
    """
    Gathers tensors from all process and supports backward propagation
    for the gradients across processes.
    """

    @staticmethod
    def forward(ctx, x):
        if dist.is_available() and dist.is_initialized():
            output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
            dist.all_gather(output, x)
        else:
            output = [x]
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        if dist.is_available() and dist.is_initialized():
            all_gradients = torch.stack(grads)
            dist.all_reduce(all_gradients)
            grad_out = all_gradients[get_rank()]
        else:
            grad_out = grads[0]
        return grad_out

def get_rank():
    if dist.is_available() and dist.is_initialized():
        return dist.get_rank()
    return 0

def gather(X, dim=0):
    """Gathers tensors from all processes, supporting backward propagation."""
    return torch.cat(GatherLayer.apply(X), dim=dim)
class VICRegLoss(nn.Module):
    # https://github.com/vturrisi/solo-learn/blob/main/solo/losses/vicreg.py
    def __init__(
        self,
        sim_loss_weight: float = 25.0,
        var_loss_weight: float = 25.0,
        cov_loss_weight: float = 1.0,
        ) -> None:
        """_summary_

        Args:
            sim_loss_weight (float, optional): _description_. Defaults to 25.0.
            var_loss_weight (float, optional): _description_. Defaults to 25.0.
            cov_loss_weight (float, optional): _description_. Defaults to 1.0.
        """
        super().__init__()
        
        self.sim_loss_weight = sim_loss_weight
        self.var_loss_weight = var_loss_weight
        self.cov_loss_weight = cov_loss_weight

    
    def invariance_loss(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
        """Computes mse loss given batch of projected features z1 from view 1 and
        projected features z2 from view 2.
        Args:
            z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
            z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
        Returns:
            torch.Tensor: invariance loss (mean squared error).
        """

        return F.mse_loss(z1, z2)


    def variance_loss(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
        """Computes variance loss given batch of projected features z1 from view 1 and
        projected features z2 from view 2.
        Args:
            z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
            z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
        Returns:
            torch.Tensor: variance regularization loss.
        """

        eps = 1e-4
        std_z1 = torch.sqrt(z1.var(dim=0) + eps)
        std_z2 = torch.sqrt(z2.var(dim=0) + eps)
        std_loss = torch.mean(F.relu(1 - std_z1)) + torch.mean(F.relu(1 - std_z2))
        return std_loss


    def covariance_loss(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
        """Computes covariance loss given batch of projected features z1 from view 1 and
        projected features z2 from view 2.
        Args:
            z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
            z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
        Returns:
            torch.Tensor: covariance regularization loss.
        """

        N, D = z1.size()

        z1 = z1 - z1.mean(dim=0)
        z2 = z2 - z2.mean(dim=0)
        cov_z1 = (z1.T @ z1) / (N - 1)
        cov_z2 = (z2.T @ z2) / (N - 1)

        diag = torch.eye(D, device=z1.device)
        cov_loss = cov_z1[~diag.bool()].pow_(2).sum() / D + cov_z2[~diag.bool()].pow_(2).sum() / D
        return cov_loss


    def forward(
        self,
        z1: torch.Tensor,
        z2: torch.Tensor
    ) -> torch.Tensor:
        """Computes VICReg's loss given batch of projected features z1 from view 1 and
        projected features z2 from view 2.
        Args:
            z1 (torch.Tensor): NxD Tensor containing projected features from view 1.
            z2 (torch.Tensor): NxD Tensor containing projected features from view 2.
        Returns:
            torch.Tensor: VICReg loss.
        """

        sim_loss = self.invariance_loss(z1, z2)

        # vicreg's official code gathers the tensors here
        # https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py
        z1, z2 = gather(z1), gather(z2)

        var_loss = self.variance_loss(z1, z2)
        cov_loss = self.covariance_loss(z1, z2)

        loss = self.sim_loss_weight * sim_loss + self.var_loss_weight * var_loss + self.cov_loss_weight * cov_loss
        
        return loss

In [15]:
vicreg_loss = VICRegLoss()

x = torch.randn(20, 128)
y = torch.randn(20, 128)

loss = vicreg_loss(x, y)
print(loss)

tensor(69.2148)


In [6]:
class CriterionOutput:
    def __init__(
        self,
        latent_loss: torch.Tensor = torch.tensor(0.0),
        align_loss: torch.Tensor = None,
    ) -> None:
        self.latent_loss = latent_loss
        self.align_loss = align_loss
        self.total_loss = latent_loss + align_loss
        
    def set_attributes(self, **kwargs):
        for key, value in kwargs.items():
            setattr(self, key, value)

In [7]:
loss = CriterionOutput(align_loss=torch.tensor(1.0))
print(loss.total_loss)

TypeError: unsupported operand type(s) for +: 'NoneType' and 'Tensor'