<a href="https://colab.research.google.com/github/muadmazahir/box-embedding-transformer/blob/main/Box_Embedding_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass

# Word2Box Model and Utilities

In [15]:
# =============================================================================
# BOX UTILITIES AND WRAPPERS
# =============================================================================

euler_gamma = 0.57721566490153286060

def _box_shape_ok(t: torch.Tensor, learnt_temp=False) -> bool:
    if len(t.shape) < 2:
        return False
    if not learnt_temp:
        if t.size(-2) != 2:
            return False
        return True
    else:
        if t.size(-2) != 4:
            return False
        return True

def _shape_error_str(tensor_name, expected_shape, actual_shape):
    return "Shape of {} has to be {} but is {}".format(
        tensor_name, expected_shape, tuple(actual_shape)
    )

class BoxTensor(object):
    """A wrapper which contains single tensor which represents single or multiple boxes."""

    def __init__(self, data: torch.Tensor, learnt_temp: bool = False) -> None:
        if _box_shape_ok(data, learnt_temp):
            self.data = data
        else:
            raise ValueError(_shape_error_str("data", "(**,2,num_dims)", data.shape))
        super().__init__()

    def __repr__(self):
        return "box_tensor_wrapper(" + self.data.__repr__() + ")"

    @property
    def z(self) -> torch.Tensor:
        """Lower left coordinate as Tensor"""
        return self.data[..., 0, :]

    @property
    def Z(self) -> torch.Tensor:
        """Top right coordinate as Tensor"""
        return self.data[..., 1, :]

    @classmethod
    def from_zZ(cls, z: torch.Tensor, Z: torch.Tensor):
        """Creates a box by stacking z and Z along -2 dim."""
        if z.shape != Z.shape:
            raise ValueError(
                "Shape of z and Z should be same but is {} and {}".format(
                    z.shape, Z.shape
                )
            )
        box_val: torch.Tensor = torch.stack((z, Z), -2)
        return cls(box_val)

    @classmethod
    def from_split(cls, t: torch.Tensor, dim: int = -1):
        """Creates a BoxTensor by splitting on the dimension dim at midpoint"""
        len_dim = t.size(dim)
        if len_dim % 2 != 0:
            raise ValueError(
                "dim has to be even to split on it but is {}".format(t.size(dim))
            )
        split_point = int(len_dim / 2)
        z = t.index_select(
            dim,
            torch.tensor(list(range(split_point)), dtype=torch.int64, device=t.device),
        )
        Z = t.index_select(
            dim,
            torch.tensor(
                list(range(split_point, len_dim)), dtype=torch.int64, device=t.device
            ),
        )
        return cls.from_zZ(z, Z)

    def _intersection(self, other, gumbel_beta: float = 1.0, bayesian: bool = False):
        t1 = self
        t2 = other

        if bayesian:
            try:
                z = gumbel_beta * torch.logaddexp(
                    t1.z / gumbel_beta, t2.z / gumbel_beta
                )
                z = torch.max(z, torch.max(t1.z, t2.z))
                Z = -gumbel_beta * torch.logaddexp(
                    -t1.Z / gumbel_beta, -t2.Z / gumbel_beta
                )
                Z = torch.min(Z, torch.min(t1.Z, t2.Z))
            except Exception as e:
                print("Gumbel intersection is not possible")
                z = torch.max(t1.z, t2.z)
                Z = torch.min(t1.Z, t2.Z)
        else:
            z = torch.max(t1.z, t2.z)
            Z = torch.min(t1.Z, t2.Z)

        return z, Z

    def gumbel_intersection_log_volume(self, other, volume_temp=1.0, intersection_temp: float = 1.0, scale=1.0):
        z, Z = self._intersection(other, gumbel_beta=intersection_temp, bayesian=True)
        vol = self._log_soft_volume_adjusted(
            z, Z, temp=volume_temp, gumbel_beta=intersection_temp, scale=scale
        )
        return vol

    @classmethod
    def _log_soft_volume(cls, z: torch.Tensor, Z: torch.Tensor, temp: float = 1.0, scale = 1.0) -> torch.Tensor:
        eps = torch.finfo(z.dtype).tiny
        if isinstance(scale, float):
            s = torch.tensor(scale, dtype=z.dtype, device=z.device)
        else:
            s = scale
        return torch.sum(
            torch.log(F.softplus(Z - z, beta=temp) + 1e-23), dim=-1
        ) + torch.log(torch.tensor(s, dtype=z.dtype, device=z.device))

    def log_soft_volume(self, temp: float = 1.0, scale = 1.0) -> torch.Tensor:
        res = self._log_soft_volume(self.z, self.Z, temp=temp, scale=scale)
        return res

    @classmethod
    def _log_soft_volume_adjusted(cls, z: torch.Tensor, Z: torch.Tensor, temp: float = 1.0,
                                gumbel_beta: float = 1.0, scale = 1.0) -> torch.Tensor:
        eps = torch.finfo(z.dtype).tiny
        if isinstance(scale, float):
            s = torch.tensor(scale, dtype=z.dtype, device=z.device)
        else:
            s = scale
        return (
            torch.sum(
                torch.log(
                    F.softplus(Z - z - 2 * euler_gamma * gumbel_beta, beta=temp) + 1e-23
                ),
                dim=-1,
            )
            + torch.log(torch.tensor(s, dtype=z.dtype, device=z.device))
        )

    def intersection_log_soft_volume(self, other, temp: float = 1.0, gumbel_beta: float = 1.0,
                                   bayesian: bool = False, scale = 1.0) -> torch.Tensor:
        z, Z = self._intersection(other, gumbel_beta, bayesian)
        vol = self._log_soft_volume(z, Z, temp=temp, scale=scale)
        return vol

    @classmethod
    def get_wW(cls, z, Z):
        return z, Z

# =============================================================================
# BOX EMBEDDING MODULE
# =============================================================================

def _uniform_small(weight, emb_dim, box_type):
    """
    Creates a temporary tensor with uniform random values between 0.0 + 1e-7 and 0.9 - 1e-7
    Sets the first half of the embedding (z) to these random values
    Sets the second half (Z) to z + 0.1, creating small boxes with a fixed width of 0.1
    Uses the box type's get_wW method to convert these (z, Z) coordinates into the appropriate weight representation
    For BoxTensor: get_wW simply returns (z, Z) as-is
    """
    with torch.no_grad():
        temp = torch.zeros_like(weight)
        torch.nn.init.uniform_(temp, 0.0 + 1e-7, 1.0 - 0.1 - 1e-7)
        z = temp[..., :emb_dim]
        Z = z + 0.1
        w, W = box_type.get_wW(z, Z)
        weight[..., :emb_dim] = w
        weight[..., emb_dim : emb_dim * 2] = W

class BoxEmbedding(nn.Embedding):
    """BoxEmbedding is a wrapper around nn.Embedding.

    It takes the provided embedding dimension and divides it by 2.
    It initializes the weights using _uniform_small.
    """
    box_types = {"BoxTensor": BoxTensor}

    def init_weights(self):
        _uniform_small(
            self.weight,
            self.box_embedding_dim,
            self.box_types[self.box_type],
        )

    def __init__(self, num_embeddings: int, box_embedding_dim: int, box_type="BoxTensor") -> None:
        super().__init__(num_embeddings, box_embedding_dim)
        self.box_type = box_type
        self.box = self.box_types[box_type]
        self.box_embedding_dim = box_embedding_dim // 2
        self.init_weights()

    def forward(self, input: torch.Tensor):  # type: ignore
        emb = super().forward(input) # (..., self.box_embedding_dim)
        box_emb = self.box.from_split(emb) # (..., 2, self.box_embedding_dim // 2)
        return box_emb


# =============================================================================
# WORD2BOX MODEL
# =============================================================================

class Word2Box(nn.Module):
    def __init__(self, vocab_size, embedding_dim=50, batch_size=10, n_gram=4,
                 volume_temp=1.0, intersection_temp=1.0, box_type="BoxTensor"):
        super(Word2Box, self).__init__()

        # Model parameters
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim

        # Box features
        self.volume_temp = volume_temp
        self.intersection_temp = intersection_temp
        self.box_type = box_type

        # Create embeddings
        self.embeddings_word = BoxEmbedding(
            self.vocab_size, self.embedding_dim, box_type=box_type
        )
        self.embeddings_context = BoxEmbedding(
            self.vocab_size, self.embedding_dim, box_type=box_type
        )

    def forward(self, idx_word, idx_context):
        # idx_word - (batch_size)
        # idx_context - (batch_size, 1 + negative samples)
        idx_word = idx_word.unsqueeze(1) # Broadcast the word vector to the the context + negative_samples. idx_word after unsqueezee - (batch_size, 1)
        word_boxes = self.embeddings_word(idx_word) # (batch_size, 1, 2, embedding_dim//2)
        context_boxes = self.embeddings_context(idx_context) # (batch_size, 1 + negative_samples, 2, embedding_dim//2)

        if self.intersection_temp == 0.0:
            score = word_boxes.intersection_log_soft_volume(
                context_boxes, temp=self.volume_temp
            )
        else:
            score = word_boxes.gumbel_intersection_log_volume(
                context_boxes,
                volume_temp=self.volume_temp,
                intersection_temp=self.intersection_temp,
            )
        return score

    @staticmethod
    def max_margin_loss(pos, neg, margin=5.0):
        """Max margin loss for box embeddings"""
        zero = torch.tensor(0.0, device=pos.device, dtype=pos.dtype)
        return torch.sum(torch.max(zero, neg - pos + margin), dim=1)

# Transfomer Building Bocks

In [7]:
# -----------------------------
# Config
# -----------------------------

@dataclass
class GPTConfig:
    vocab_size: int
    n_layers: int = 4
    n_heads: int = 4
    d_model: int = 256
    d_ff: int = 1024
    max_seq_len: int = 64
    attn_dropout: float = 0.1
    resid_dropout: float = 0.1
    emb_dropout: float = 0.1
    device: str = "cpu"
    num_negatives: int = 4
    margin: int = 5.0
    top_k_for_box: int = 32

In [8]:
class LayerNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5):
        super().__init__()
        self.ln = nn.LayerNorm(d_model, eps=eps)
    def forward(self, x):
        return self.ln(x)

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, attn_dropout: float, resid_dropout: float):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_head = d_model // n_heads
        self.n_heads = n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model)
        self.proj = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.resid_dropout = nn.Dropout(resid_dropout)

    def forward(self, x):
        B, L, C = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(C, dim=-1)
        q = q.view(B, L, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(B, L, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, L, self.n_heads, self.d_head).transpose(1, 2)
        attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
        causal_mask = torch.tril(torch.ones(L, L, device=x.device)).view(1, 1, L, L)
        attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf'))
        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.attn_dropout(attn_weights)
        y = attn_weights @ v
        y = y.transpose(1, 2).contiguous().view(B, L, C)
        y = self.resid_dropout(self.proj(y))
        return y

class MLP(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.ln1 = LayerNorm(cfg.d_model)
        self.attn = CausalSelfAttention(cfg.d_model, cfg.n_heads, cfg.attn_dropout, cfg.resid_dropout)
        self.ln2 = LayerNorm(cfg.d_model)
        self.mlp = MLP(cfg.d_model, cfg.d_ff, cfg.resid_dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

# Transformer Model

In [9]:
class BoxEmbeddingTransformer(nn.Module):
    def __init__(self, cfg: GPTConfig):
        super().__init__()
        self.cfg = cfg
        self.word2box = Word2Box(vocab_size=cfg.vocab_size, embedding_dim=cfg.d_model)
        self.token_emb = nn.Embedding(cfg.vocab_size, cfg.d_model)
        self.token_emb.weight = self.word2box.embeddings_word.weight # tie token embs to word2box embs
        self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model)
        self.drop = nn.Dropout(cfg.emb_dropout)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_f = LayerNorm(cfg.d_model)
        self.lm_head = nn.Linear(cfg.d_model, cfg.vocab_size, bias=False)
        self.lm_head.weight = self.word2box.embeddings_word.weight # tie lm_head weights to word2box embs


    def forward(self, idx, targets=None):
        B, L = idx.shape
        pos = torch.arange(0, L, device=idx.device).unsqueeze(0)
        x = self.token_emb(idx) + self.pos_emb(pos)
        x = self.drop(x)
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        logits = self.lm_head(x) # logits - (batch_size, block_size, vocab_size)

        total_loss = None
        if targets is not None:
          # probs over vocab
          probs = torch.softmax(logits, dim=-1)  # (B, L, V)

          # top-K per position
          topk_probs, topk_idx = torch.topk(probs, self.cfg.top_k_for_box, dim=-1) # (B, L, K)

          # renormalize over K (safer with clamp)
          denom = topk_probs.sum(dim=-1, keepdim=True).clamp_min(1e-12)
          topk_probs = topk_probs / denom   # (B, L, K)

          # flatten candidate ids
          BL = B * L
          cand_idx = topk_idx.reshape(-1)  # (BL*K,)

          # build contexts aligned to (BL, 1+N) then repeat K times
          ctx = targets.view(BL, targets.size(-1))  # (BL, 1+N)
          ctx_rep = ctx.unsqueeze(1).expand(BL, self.cfg.top_k_for_box, ctx.size(-1)).reshape(BL * self.cfg.top_k_for_box, ctx.size(-1)) # (BL*K, 1+N)

          # score each candidate token against its context
          scores = self.word2box(idx_word=cand_idx, idx_context=ctx_rep)  # (BL*K, 1+N)

          # split to pos/neg
          N = self.cfg.num_negatives
          assert scores.size(-1) == 1 + N, f"scores last dim {scores.size(-1)} != 1+N"
          pos = scores[..., 0].view(BL * self.cfg.top_k_for_box, 1) # (BLK, 1)
          neg = scores[..., 1:].view(BL * self.cfg.top_k_for_box, N) # (BLK, N)

          # per-candidate margin loss -> (BLK,)
          loss_cand = self.word2box.max_margin_loss(pos, neg, margin=self.cfg.margin)

          # expectation over top-K
          loss_cand = loss_cand.view(BL, self.cfg.top_k_for_box)  # (BL, K)
          weights   = topk_probs.view(BL, self.cfg.top_k_for_box) # (BL, K)
          loss_bl   = (weights * loss_cand).sum(dim=-1)  # (BL,)

          # mean
          total_loss = loss_bl.mean()

        return logits, total_loss

    def generate(self, idx, max_new_tokens=50, temperature=0.8):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.cfg.max_seq_len:]
            logits, _ = self(idx_cond)
            probs = F.softmax(logits[:, -1, :] / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_token), dim=1)
        return idx

# Data preparation + Model Setup

In [10]:
from datasets import load_dataset
from tokenizers import Tokenizer
from transformers import AutoTokenizer

In [None]:
# -------------------------
# 1) Load a small external dataset (WikiText-2 raw)
# -------------------------
ds = load_dataset("wikitext", "wikitext-2-raw-v1")
train_texts = ds["train"]["text"]
val_texts = ds["validation"]["text"]

# Filter out empty lines to avoid tons of EOS tokens
train_texts = [t for t in train_texts if t and not t.isspace()]
val_texts = [t for t in val_texts if t and not t.isspace()]

BASE_MODEL_NAME = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)


# -------------------------
# Encode dataset and Calculate distribution
# -------------------------

def encode_and_build_dist(texts, power=0.75):
    from collections import Counter
    ids = []
    counter = Counter()

    for t in texts:
        # encode each text (without special tokens, unless you want them counted)
        token_ids = tokenizer.encode(t, add_special_tokens=False)

        # extend the global list of IDs
        ids.extend(token_ids)

        # update counts at the same time
        counter.update(token_ids)

    # Convert IDs into a tensor
    ids = torch.tensor(ids, dtype=torch.long)

    # Build frequency tensor
    vocab_size = len(tokenizer)
    freqs = torch.zeros(vocab_size, dtype=torch.float)
    for idx, count in counter.items():
        freqs[idx] = count

    # Apply 3/4 smoothing (word2vec trick)
    freqs = freqs.pow(power)

    # Normalize to get a probability distribution
    dist = freqs / freqs.sum()

    return ids, dist

train_ids, train_dist = encode_and_build_dist(train_texts)
val_ids, val_dist     = encode_and_build_dist(val_texts)


In [12]:
# -------------------------
# 4) Batching utility for contiguous language modeling
# -------------------------
block_size = 128  # context window
batch_size = 12
negative_samples = 3

def make_batch(ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    assert ids.numel() > block_size + 1
    ix = torch.randint(0, ids.numel() - block_size - 1, (batch_size,))
    x = torch.stack([ids[i:i + block_size] for i in ix])
    y = torch.stack([ids[i + 1:i + block_size + 1] for i in ix])
    return x, y

def add_negatives(
    y_true: torch.Tensor,           # shape - (batch_size, block_size)
    sampling_distn: torch.Tensor,   # shape - (vocab_size,) non-negative weights
    num_negatives: int
) -> torch.Tensor:
    """
    Returns (batch_size, block_size, 1+num_negatives) tensor:
      - [:, :, 0] is the true target id
      - [:, :, 1:] are sampled negatives
    """
    batch_size, block_size = y_true.shape
    vocab_size = sampling_distn.numel()

    # Expand distribution per position and zero-out the true id
    p = sampling_distn.unsqueeze(0).expand(batch_size * block_size, vocab_size).clone()
    true_flat = y_true.reshape(-1, 1).long()
    p.scatter_(1, true_flat, 0)

    # Normalize so each row sums to 1
    p /= p.sum(dim=1, keepdim=True)

    # Sample negatives
    negs = torch.multinomial(p, num_samples=num_negatives, replacement=True)  # (batch_size * block_size, num_negatives)
    negs = negs.view(batch_size, block_size, num_negatives)

    # Concatenate true ids
    y_all = torch.cat([y_true.unsqueeze(-1), negs], dim=-1)  # (batch_size * block_size, 1+num_negatives)
    return y_all

In [16]:
# -------------------------
# 5) Configure and build model
# -------------------------
cfg = GPTConfig(
    vocab_size=tokenizer.vocab_size,
    n_layers=4,
    n_heads=4,
    d_model=256,
    d_ff=1024,
    max_seq_len=block_size,
    attn_dropout=0.1,
    resid_dropout=0.1,
    emb_dropout=0.1,
    device="cuda" if torch.cuda.is_available() else "cpu",
    num_negatives=4,
    margin=5.0,
    top_k_for_box = 32
)

model = BoxEmbeddingTransformer(cfg)
model.to(cfg.device)

# -------------------------
# 6) Optimizer
# -------------------------
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1)

# Training Loop

In [None]:
# -------------------------
# 7) Training loop
# -------------------------
steps = 1400  # modest run on a small dataset
eval_every = 200
grad_clip = 1.0

model.train()
for step in range(1, steps + 1):
    x, y = make_batch(train_ids) # x and y - (batch_size, block_size)
    y_new = add_negatives(y, train_dist, cfg.num_negatives)  # y_new - (batch_size, block_size, 1+cfg.num_negatives)
    x = x.to(model.cfg.device)
    y_new = y_new.to(model.cfg.device)

    _, loss = model(x, targets=y_new)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    if grad_clip is not None:
        nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    optimizer.step()

    if step % eval_every == 0 or step == 1:
        model.eval()
        with torch.no_grad():
            vx, vy = make_batch(val_ids)
            vy_new = add_negatives(vy, val_dist, cfg.num_negatives)
            vx = vx.to(model.cfg.device)
            vy_new = vy_new.to(model.cfg.device)
            _, vloss = model(vx, targets=vy_new)
        print(f"step {step:4d} | train loss {loss} | val loss {vloss}")
        model.train()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
torch.save(model, "/content/drive/MyDrive/box-embedding-transformer-model.pth")

In [None]:
# Load Model
model = torch.load("/content/drive/MyDrive/box-embedding-transformer-model.pth", map_location=cfg.device, weights_only=False,)

In [None]:
# -------------------------
# 8) Sample a short completion
# -------------------------
model.eval()
prompt_text = "Wikipedia is a free online"
prompt_ids = tokenizer.encode(prompt_text)
prompt = torch.tensor(prompt_ids, dtype=torch.long, device=cfg.device).unsqueeze(0)
out_ids = model.generate(prompt, max_new_tokens=10)

# Decode (strip any double BOS/EOS artifacts from processor)
decoded = tokenizer.decode(out_ids[0].tolist(), skip_special_tokens=True)
print(out_ids[0][36])