In [None]:
"""
BERT-style Bidirectional Encoder for Embedding Generation

This module implements an encoder-only transformer architecture similar to BERT
for learning text embeddings. The model supports both Masked Language Modeling (MLM)
and contrastive learning objectives.

Architecture:
    - Token + Position Embeddings
    - N Transformer Encoder Blocks (bidirectional)
    - Pooling Layer (mean/cls/max)
    - MLM Head (optional)
    - Contrastive Learning Support
"""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, Dict

In [None]:


class MultiHeadAttention(nn.Module):
    """
    Multi-head self-attention without causal masking (bidirectional).

    Unlike the decoder-only model, this attention mechanism allows each position
    to attend to all positions in the sequence, enabling bidirectional context.

    Args:
        hidden_size (int): Dimension of input embeddings
        num_attention_heads (int): Number of parallel attention heads
        dropout (float): Dropout probability
    """

    def __init__(self, hidden_size: int, num_attention_heads: int, dropout: float = 0.1):
        super().__init__()
        assert hidden_size % num_attention_heads == 0, "hidden_size must be divisible by num_attention_heads"

        self.hidden_size = hidden_size
        self.num_attention_heads = num_attention_heads
        self.head_dim = hidden_size // num_attention_heads
        self.scale = math.sqrt(self.head_dim)

        # Combined QKV projection
        self.qkv_proj = nn.Linear(hidden_size, 3 * hidden_size)
        self.out_proj = nn.Linear(hidden_size, hidden_size)

        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """
        Forward pass with bidirectional attention.

        Args:
            hidden_states: Input (batch_size, seq_len, hidden_size)
            attention_mask: Optional mask (batch_size, 1, 1, seq_len)

        Returns:
            Output tensor (batch_size, seq_len, hidden_size)
        """
        batch_size, seq_len, _ = hidden_states.size()

        # Project and split into Q, K, V
        qkv = self.qkv_proj(hidden_states)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_attention_heads, self.head_dim) # (batch, seq_len, 3, heads, head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch, heads, seq_len, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Compute attention scores
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / self.scale

        # Apply padding mask if provided (no causal masking)
        if attention_mask is not None:
            attn_weights = attn_weights.masked_fill(attention_mask == 0, float('-inf'))

        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)

        # Apply attention to values
        attn_output = torch.matmul(attn_weights, v)
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)

        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output)

        return attn_output


class FeedForward(nn.Module):
    """
    Position-wise feed-forward network with GELU activation.

    Args:
        hidden_size (int): Input/output dimension
        intermediate_size (int): Hidden dimension
        dropout (float): Dropout probability
    """

    def __init__(self, hidden_size: int, intermediate_size: int, dropout: float = 0.1):
        super().__init__()
        self.fc1 = nn.Linear(hidden_size, intermediate_size)
        self.fc2 = nn.Linear(intermediate_size, hidden_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """
    Single transformer encoder block with bidirectional attention.

    Args:
        hidden_size (int): Dimension of embeddings
        num_attention_heads (int): Number of attention heads
        intermediate_size (int): FFN intermediate dimension
        dropout (float): Dropout probability
        layer_norm_eps (float): Layer norm epsilon
    """

    def __init__(
        self,
        hidden_size: int,
        num_attention_heads: int,
        intermediate_size: int,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5
    ):
        super().__init__()
        self.ln_1 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.attn = MultiHeadAttention(hidden_size, num_attention_heads, dropout)
        self.ln_2 = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
        self.ffn = FeedForward(hidden_size, intermediate_size, dropout)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None
    ) -> torch.Tensor:
        """Apply attention and FFN with residual connections."""
        # Attention block
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        hidden_states = self.attn(hidden_states, attention_mask)
        hidden_states = residual + hidden_states

        # FFN block
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        hidden_states = self.ffn(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


class BERTEmbeddingModel(nn.Module):
    """
    BERT-style Bidirectional Encoder for generating text embeddings.

    This model can be trained with:
        1. Masked Language Modeling (MLM) - BERT-style pretraining
        2. Contrastive Learning - SimCSE-style sentence embeddings
        3. Both objectives simultaneously

    The model generates dense vector representations of text that capture
    semantic meaning and can be used for similarity search, clustering, etc.

    Args:
        vocab_size (int): Vocabulary size
        hidden_size (int): Embedding dimension
        num_hidden_layers (int): Number of transformer blocks
        num_attention_heads (int): Number of attention heads
        intermediate_size (int): FFN intermediate dimension
        max_position_embeddings (int): Maximum sequence length
        dropout (float): Dropout probability
        layer_norm_eps (float): Layer norm epsilon
        pooling_mode (str): Pooling strategy - 'mean', 'cls', or 'max'
    """

    def __init__(
        self,
        vocab_size: int,
        hidden_size: int,
        num_hidden_layers: int,
        num_attention_heads: int,
        intermediate_size: int,
        max_position_embeddings: int,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5,
        pooling_mode: str = "mean"
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.pooling_mode = pooling_mode

        # Embeddings
        self.token_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
        self.dropout = nn.Dropout(dropout)

        # Transformer encoder blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(
                hidden_size,
                num_attention_heads,
                intermediate_size,
                dropout,
                layer_norm_eps
            ) for _ in range(num_hidden_layers)
        ])

        self.ln_f = nn.LayerNorm(hidden_size, eps=layer_norm_eps)

        # MLM head for masked language modeling
        self.mlm_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.LayerNorm(hidden_size, eps=layer_norm_eps),
            nn.Linear(hidden_size, vocab_size)
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights with small random values."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def get_attention_mask(self, attention_mask: torch.Tensor) -> torch.Tensor:
        """
        Convert attention mask to proper format for attention mechanism.

        Args:
            attention_mask: Binary mask (batch_size, seq_len)
                           1 for real tokens, 0 for padding

        Returns:
            Extended mask (batch_size, 1, 1, seq_len)
        """
        # Extend dimensions for broadcasting
        extended_mask = attention_mask.unsqueeze(1).unsqueeze(2) # what is this? it is adding two dimensions to the attention mask
        return extended_mask

    def mean_pooling(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Mean pooling over sequence (excluding padding tokens).

        This averages the token representations, weighted by the attention mask
        to exclude padding tokens from the mean calculation.

        Args:
            hidden_states: Token representations (batch_size, seq_len, hidden_size)
            attention_mask: Binary mask (batch_size, seq_len)

        Returns:
            Pooled embeddings (batch_size, hidden_size)
        """
        # Expand mask to match hidden_states dimensions
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()

        # Sum of masked embeddings
        sum_embeddings = torch.sum(hidden_states * mask_expanded, dim=1)

        # Sum of mask (number of real tokens)
        sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)

        return sum_embeddings / sum_mask

    def cls_pooling(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        CLS token pooling (use first token representation).

        Args:
            hidden_states: Token representations (batch_size, seq_len, hidden_size)

        Returns:
            CLS embeddings (batch_size, hidden_size)
        """
        return hidden_states[:, 0, :]

    def max_pooling(
        self,
        hidden_states: torch.Tensor,
        attention_mask: torch.Tensor
    ) -> torch.Tensor:
        """
        Max pooling over sequence (excluding padding).

        Args:
            hidden_states: Token representations (batch_size, seq_len, hidden_size)
            attention_mask: Binary mask (batch_size, seq_len)

        Returns:
            Max pooled embeddings (batch_size, hidden_size)
        """
        # Set padding positions to very negative value before max
        mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        hidden_states = hidden_states.clone()
        hidden_states[mask_expanded == 0] = -1e9

        return torch.max(hidden_states, dim=1)[0]

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        masked_labels: Optional[torch.Tensor] = None,
        return_embeddings: bool = True
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass of the embedding model.

        Args:
            input_ids: Token indices (batch_size, seq_len)
            attention_mask: Binary mask (batch_size, seq_len), 1=real, 0=padding
            masked_labels: Labels for MLM task (batch_size, seq_len), -100=ignore
            return_embeddings: Whether to return pooled embeddings

        Returns:
            Dictionary containing:
                - embeddings: Pooled sentence embeddings (if return_embeddings=True)
                - mlm_logits: Token prediction logits (if masked_labels provided)
                - mlm_loss: MLM loss (if masked_labels provided)
                - hidden_states: Token-level representations
        """
        batch_size, seq_len = input_ids.size()
        device = input_ids.device

        # Create attention mask if not provided
        if attention_mask is None:
            attention_mask = (input_ids != 0).long()

        # Position ids
        position_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
        position_ids = position_ids.unsqueeze(0).expand(batch_size, seq_len)

        # Get embeddings
        token_embeds = self.token_embeddings(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        hidden_states = token_embeds + position_embeds
        hidden_states = self.dropout(hidden_states)

        # Prepare attention mask for transformer
        extended_attention_mask = self.get_attention_mask(attention_mask)

        # Apply transformer blocks
        for block in self.blocks:
            hidden_states = block(hidden_states, extended_attention_mask)

        hidden_states = self.ln_f(hidden_states)

        # Prepare output dictionary
        outputs = {"hidden_states": hidden_states}

        # Compute pooled embeddings
        if return_embeddings:
            if self.pooling_mode == "mean":
                embeddings = self.mean_pooling(hidden_states, attention_mask)
            elif self.pooling_mode == "cls":
                embeddings = self.cls_pooling(hidden_states)
            elif self.pooling_mode == "max":
                embeddings = self.max_pooling(hidden_states, attention_mask)
            else:
                raise ValueError(f"Unknown pooling mode: {self.pooling_mode}")

            # Normalize embeddings for better similarity computation
            embeddings = F.normalize(embeddings, p=2, dim=1)
            outputs["embeddings"] = embeddings

        # Compute MLM loss if labels provided
        if masked_labels is not None:
            mlm_logits = self.mlm_head(hidden_states)
            mlm_loss = F.cross_entropy(
                mlm_logits.view(-1, self.vocab_size),
                masked_labels.view(-1),
                ignore_index=-100
            )
            outputs["mlm_logits"] = mlm_logits
            outputs["mlm_loss"] = mlm_loss

        return outputs


def compute_contrastive_loss(
    embeddings: torch.Tensor,
    temperature: float = 0.05
) -> torch.Tensor:
    """
    Compute InfoNCE contrastive loss for embedding learning.

    This implements a SimCSE-style contrastive objective where each example
    should be similar to itself (positive pair) and different from others
    (negative pairs). Works with augmented pairs or dropout-based augmentation.

    Args:
        embeddings: Normalized embeddings (batch_size * 2, hidden_size)
                   First half and second half are augmented pairs
        temperature: Temperature parameter for scaling similarities

    Returns:
        Contrastive loss scalar
    """
    batch_size = embeddings.size(0) // 2

    # Split into two views
    embeddings_a = embeddings[:batch_size]
    embeddings_b = embeddings[batch_size:]

    # Compute similarity matrix
    # Shape: (batch_size, batch_size)
    sim_aa = torch.matmul(embeddings_a, embeddings_a.t()) / temperature
    sim_bb = torch.matmul(embeddings_b, embeddings_b.t()) / temperature
    sim_ab = torch.matmul(embeddings_a, embeddings_b.t()) / temperature
    sim_ba = torch.matmul(embeddings_b, embeddings_a.t()) / temperature

    # Labels: each example's positive pair is at the same index
    labels = torch.arange(batch_size, device=embeddings.device)

    # Contrastive loss for both directions
    # For each anchor in view A, the positive is the corresponding sample in view B
    loss_a = F.cross_entropy(
        torch.cat([sim_ab, sim_aa], dim=1),
        labels
    )

    loss_b = F.cross_entropy(
        torch.cat([sim_ba, sim_bb], dim=1),
        labels
    )

    return (loss_a + loss_b) / 2

