# Contrastive Loss (InfoNCE)

> Implements the InfoNCE loss function for CLIP training, handling distributed data parallel (DDP) correctly.

In [None]:
#| default_exp loss

## Colab Setup

In [None]:
#| hide
# Mount Google Drive (Optional, but recommended for persistent storage)
from pathlib import Path
import sys

try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted successfully.")
except ModuleNotFoundError:
    print("Not running in Colab, skipping Drive mount.")
except Exception as e:
    print(f"Error mounting Google Drive: {e}")

Mounted at /content/drive
Google Drive mounted successfully.


In [None]:
#| export
try:
    import indic_clip.core
    print("Reloaded indic_clip.core")
except ModuleNotFoundError:
    print("indic_clip.core not found initially.")
    # Attempt to set sys.path if running in Colab and project cloned
    import sys
    if 'google.colab' in sys.modules:
        project_parent = '/content' # Assuming cloned into /content/indic-clip
        if Path('/content/drive/MyDrive/Indic-Clip').exists():
             project_parent = '/content/drive/MyDrive/Indic-Clip'
        if project_parent not in sys.path:
             sys.path.insert(0, project_parent)
             print(f"Added {project_parent} to sys.path")
        try:
            import indic_clip.core
            print("Imported indic_clip.core after path adjustment.")
        except ModuleNotFoundError:
            print("ERROR: Still cannot find indic_clip.core. Ensure project structure is correct.")
            print("Expected: /content/Indic-Clip/indic_clip/core.py or similar in Drive")
            # raise # Stop execution if core components missing

indic_clip.core not found initially.
Added /content/drive/MyDrive/Indic-Clip to sys.path
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive detected, setting PROJECT_ROOT to /content/drive/MyDrive/Indic-Clip
Ensure your project files are located there.
Imported indic_clip.core after path adjustment.


In [None]:
#| hide
%cd /content/drive/MyDrive/Indic-Clip/

/content/drive/MyDrive/Indic-Clip


In [None]:
#| hide
!pip install -qr requirements.txt

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.3/40.3 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m296.7/296.7 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.8/297.8 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.9/46.9 MB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m322.2/322.2 kB[0m [31m18.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
import logging
from typing import Optional

from fastai.vision.all import *

try:
    from indic_clip.core import get_logger, setup_logging
except ModuleNotFoundError:
    # Fallback if core not found
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    def get_logger(name): return logging.getLogger(name)
    def setup_logging(): pass

setup_logging()
logger = get_logger(__name__)

## AllGather Helper for DDP

In [None]:
#| export
class AllGather(torch.autograd.Function):
    """Custom autograd function to gather tensors from all processes, supporting gradients."""

    @staticmethod
    def forward(ctx, tensor: torch.Tensor) -> torch.Tensor:
        """Performs the all_gather operation and prepares context for backward pass."""
        # Check if distributed environment is initialized
        if not dist.is_available() or not dist.is_initialized():
            # If not distributed, just return the input tensor
            return tensor

        # Ensure tensor is contiguous before gathering
        tensor = tensor.contiguous()
        world_size = dist.get_world_size()
        # Create a list to hold tensors from all ranks
        output = [torch.empty_like(tensor) for _ in range(world_size)]
        # Perform the all_gather operation
        dist.all_gather(output, tensor)
        # Concatenate the gathered tensors along the batch dimension (dim=0)
        gathered_tensor = torch.cat(output, dim=0)

        # Save world_size for backward pass (optional, could re-fetch)
        # ctx.world_size = world_size
        return gathered_tensor

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        """Performs the reduce_scatter operation for the backward pass."""
        # Check if distributed environment is initialized
        if not dist.is_available() or not dist.is_initialized():
            # If not distributed, just return the gradient
            return grad_output

        # Ensure grad_output is contiguous
        grad_output = grad_output.contiguous()
        world_size = dist.get_world_size()

        # Check if the gradient tensor size is divisible by world_size
        if grad_output.shape[0] % world_size != 0:
            raise RuntimeError("Gradient output size must be divisible by world size for all_gather backward pass.")

        # Calculate the chunk size for each process
        chunk_size = grad_output.shape[0] // world_size

        # Prepare the input tensor for reduce_scatter (this will hold the gradient for the current rank)
        grad_input = torch.empty(chunk_size, *grad_output.shape[1:], dtype=grad_output.dtype, device=grad_output.device)

        # Perform reduce_scatter: sums gradients corresponding to each rank's input
        # The list comprehension splits the gathered gradient tensor back into chunks
        dist.reduce_scatter(grad_input, list(grad_output.chunk(world_size, dim=0)), op=dist.ReduceOp.SUM)

        # grad_input now contains the correct gradient sum for the input tensor on this rank
        return grad_input

## Contrastive Loss Implementation

In [None]:
#| export
class ContrastiveLoss(Module):
    """Calculates the contrastive loss (InfoNCE) between image and text features.

    Handles distributed training by gathering features across GPUs before calculating loss.
    Assumes input features (image_features, text_features) are already L2 normalized.
    """
    def __init__(self, *args, axis:int = -1, **kwargs):
        """
        Args:
            args: Arguments passed to the parent BaseLoss.
            axis (int): The axis to perform the reduction over (passed to BaseLoss).
            kwargs: Keyword arguments passed to the parent BaseLoss.
        """
        self.all_gather = AllGather.apply # Use the custom autograd function

    def forward(self, preds: tuple[torch.Tensor, torch.Tensor, torch.Tensor]) -> torch.Tensor:

        """
        Calculates the contrastive loss.

        Args:
            preds (tuple): A tuple containing:
                - image_features (torch.Tensor): Normalized image features (B, D).
                - text_features (torch.Tensor): Normalized text features (B, D).
                - logit_scale (torch.Tensor): The learnable logit scaling factor (scalar tensor).

        Returns:
            torch.Tensor: The calculated contrastive loss (scalar tensor).
        """
        # logger.info(">>> ContrastiveLoss.forward ENTERED")

        image_features, text_features, logit_scale = preds

        if torch.isnan(image_features).any() or torch.isinf(image_features).any():
            logger.error("!!! NaN/Inf DETECTED IN INPUT image_features !!!")
        if torch.isnan(text_features).any() or torch.isinf(text_features).any():
            logger.error("!!! NaN/Inf DETECTED IN INPUT text_features !!!")
        if torch.isnan(logit_scale).any() or torch.isinf(logit_scale).any():
            logger.error(f"!!! NaN/Inf DETECTED IN INPUT logit_scale: {logit_scale.item()} !!!")

        # logger.info(f"Input shapes: Img={image_features.shape}, Txt={text_features.shape}, Scale={logit_scale.shape}")
        # logger.info(f"Input norms (mean): Img={image_features.norm(dim=-1).mean().item():.4f}, Txt={text_features.norm(dim=-1).mean().item():.4f}")
        # logger.info(f"Logit Scale value: {logit_scale.item():.4f}")

        # --- Gather Features in Distributed Setting ---
        if dist.is_available() and dist.is_initialized():
            gathered_image_features = self.all_gather(image_features)
            gathered_text_features = self.all_gather(text_features)
            world_size = dist.get_world_size()
        else:
            gathered_image_features = image_features
            gathered_text_features = text_features
            world_size = 1

        # --- Calculate Similarity Scores ---
        # Note: logit_scale is applied *before* softmax in cross_entropy
        # We use the raw logit_scale parameter and apply exp() inside the loss calculation if needed,
        # or directly multiply as CLIP does.
        # The forward pass of IndicCLIP already returns exp(logit_scale).
        # Here, we assume logit_scale passed in is already exponentiated.

        # Cosine similarity as dot product of normalized features
        # logits_per_image: How well each image matches each text [Global B, Global B]
        logits_per_image = logit_scale * gathered_image_features @ gathered_text_features.t()
        # logits_per_text: How well each text matches each image [Global B, Global B]
        logits_per_text = logits_per_image.t() # More efficient than recalculating

        # --- Calculate Loss ---
        # Create ground truth labels. The diagonal elements (i,i) correspond to matching pairs.
        local_batch_size = image_features.size(0)
        global_batch_size = gathered_image_features.size(0)

        # Ensure calculation happens on the correct device
        device = image_features.device
        labels = torch.arange(global_batch_size, device=device, dtype=torch.long)

        if torch.isnan(logits_per_image).any() or torch.isinf(logits_per_image).any():
          logger.warning("NaN/Inf detected in logits_per_image!")
          # Optionally print min/max/mean
          logger.warning(f"Logit Scale (exp): {logit_scale.item()}")
          logger.warning(f"Image Features Norm: {gathered_image_features.norm(dim=-1).mean().item()}")
          logger.warning(f"Text Features Norm: {gathered_text_features.norm(dim=-1).mean().item()}")


        # Calculate cross-entropy loss for both directions
        loss_img = F.cross_entropy(logits_per_image, labels)
        loss_txt = F.cross_entropy(logits_per_text, labels)

        # Average the two losses
        total_loss = (loss_img + loss_txt) / 2

        # logger.info(f"loss_img: {loss_img}, loss_txt: {loss_txt}, total loss: {total_loss}")

        return total_loss

## Example Usage and Testing

In [None]:
#| eval: false
if __name__ == '__main__':
    print("--- Testing ContrastiveLoss (Non-Distributed) ---")
    # Setup dummy data
    batch_size = 4
    embed_dim = 512
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Simulate normalized features
    img_feat = F.normalize(torch.randn(batch_size, embed_dim, device=device), dim=-1)
    txt_feat = F.normalize(torch.randn(batch_size, embed_dim, device=device), dim=-1)
    # Simulate logit scale from model (already exponentiated)
    logit_scale = torch.exp(torch.ones([], device=device) * torch.log(torch.tensor(1/0.07)))

    print("Input Shapes:")
    print(f"  Image Features: {img_feat.shape}")
    print(f"  Text Features:  {txt_feat.shape}")
    print(f"  Logit Scale:    {logit_scale.shape}")

    # Instantiate loss
    loss_fn = ContrastiveLoss()

    # Calculate loss
    loss_val = loss_fn((img_feat, txt_feat, logit_scale))

    print(f"Output Loss: {loss_val.item():.4f} (Type: {type(loss_val)}, Device: {loss_val.device})")
    assert isinstance(loss_val, torch.Tensor) and loss_val.ndim == 0

    # --- Test Distributed Scenario (Simulated) ---
    print("\n--- Testing ContrastiveLoss (Simulated Distributed, World Size=2) ---")

    if dist.is_available() and not dist.is_initialized():
        # This part requires initializing a process group, usually done via torchrun/launch
        # We can't fully simulate it here without that setup.
        print("Distributed environment not available/initialized. Skipping DDP test.")
        print("Note: To run the distributed test, initialize a process group first.")
        print("Example (requires torchrun or similar):")
        print("  import torch.distributed as dist")
        print("  dist.init_process_group(backend='nccl') # Or 'gloo' for CPU")
        print("  # ... run the test code ...")

    elif dist.is_available() and dist.is_initialized():
        # This block would run if a process group *is* initialized
        world_size = dist.get_world_size()
        rank = dist.get_rank()
        print(f"Running DDP test on Rank {rank}/{world_size}")

        # Assume img_feat, txt_feat, logit_scale are the local tensors for this rank
        loss_val_ddp = loss_fn((img_feat, txt_feat, logit_scale))

        print(f"DDP Output Loss (Rank {rank}): {loss_val_ddp.item():.4f}")
        assert isinstance(loss_val_ddp, torch.Tensor) and loss_val_ddp.ndim == 0
        # Note: Loss value might differ across ranks if inputs are different,
        # but the calculation mechanism (gathering) is tested.
        # For identical inputs across ranks (less realistic), losses should match.

    else:
         print("Torch distributed is not available on this system.")

--- Testing ContrastiveLoss (Non-Distributed) ---
Input Shapes:
  Image Features: torch.Size([4, 512])
  Text Features:  torch.Size([4, 512])
  Logit Scale:    torch.Size([])
Output Loss: 1.4547 (Type: <class 'torch.Tensor'>, Device: cpu)

--- Testing ContrastiveLoss (Simulated Distributed, World Size=2) ---
Distributed environment not available/initialized. Skipping DDP test.
Note: To run the distributed test, initialize a process group first.
Example (requires torchrun or similar):
  import torch.distributed as dist
  dist.init_process_group(backend='nccl') # Or 'gloo' for CPU
  # ... run the test code ...


In [None]:
#| hide
import nbdev
nbdev.nbdev_export() # Run this in terminal to export