<a href="https://colab.research.google.com/github/iliemihai/TODOS/blob/main/JEPA_MODEL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **LLM-JEPA nanoGPT**

> **Goal:** explain why next-token prediction (NTP) isn’t enough, what **LLM-JEPA** adds, what **[PRED]** tokens are, and how to add a minimal JEPA loss to **nanoGPT** (baseline vs. JEPA).  
> **Reference:** “LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures” (arXiv:2509.14252).

---

## **1) The problem LLM-JEPA solves**

- **NTP trains surface continuation**, not cross-view alignment.
- Example:
  - **Text view:** “five plus three is equal to eight”
  - **Code view:** `5 + 3 = 8`
- With NTP alone, the model learns to generate each string correctly, but there’s no explicit pressure for their **internal vectors to be near each other** in representation space.
- **LLM-JEPA adds a representation-space term**: make a **prediction from Text** that **matches the Code embedding**, while keeping NTP for generation. This improves cross-view mapping **without sacrificing generation**.

---

## **2) LLM-JEPA objective**

- Keep your normal **next-token loss** (cross-entropy).
- Add a second term that **measures distance** between:
  - a **predicted embedding from the Text view** (Pred(Text))
  - and the **embedding of the Code view** (Enc(Code), with gradients stopped).
- Total loss = **next-token cross-entropy** + **lambda × distance(Pred(Text), Enc(Code))**.
- Distance can be **cosine** (recommended, scale-invariant) or **L2**.
- Start with **lambda** in the range **0.3 to 1.0**.

---

## **3) What are prediction tokens [PRED] and why use them?**

- Append **k** learnable **[PRED]** tokens to the **end** of the Text sequence.
- Take the **last token’s hidden state** as **Pred(Text)**.
- In a decoder-only LM, the **last position can attend to all previous tokens**, so the final [PRED] acts like a **tiny predictor head** implemented by the model’s own layers (no separate MLP needed).
- **Capacity knob:**
  - **k = 0** → identity (Pred(x) = Enc(x))
  - **k = 1 or 2** → usually enough refinement
  - More than 2 gives diminishing returns on small setups.

---

## **4) NITS**

- **Last token:** in GPT-style models, the last position sees the full prefix. Its hidden state is the best single “summary/prediction” vector.
- **Not EOS or last content token:** that would force one position to serve **two jobs** (generation and alignment). Dedicated [PRED] positions **separate concerns**.
- **Cosine distance:** robust and scale-invariant for embedding alignment. L2 also works, especially if you normalize embeddings.

---

## **5) What is the “encoder” in an LLM here?**

- **Enc(x)** just means: run the LLM to get hidden states, then **select/pool one vector** for the sequence.
- In nanoGPT, “encoder” here is simply the **last hidden state** (or a mean-pool) from the decoder-only stack.

---

## **6) Minimal training

```python
### GPT LLM TRAIING OBJECTIVE + JEPA
jepa_lambda = 0.0
jepa_k = 1
jepa_split = 0.5


def cosine_loss(a, b):
    a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
    return (1.0 - (a * b).sum(dim=-1)).mean()


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pred_token = nn.Parameter(torch.zeros(1, 1, config.n_embd))
        nn.init.normal_(self.pred_token, mean=0.0, std=0.02)

    def _forward_from_tok_emb(self, tok_emb):
        T = tok_emb.size(1)
        device = tok_emb.device
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0) # (1,T)
        pos_emb = self.transformer.wpe(pos)  # (1,T,C)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        return x  # (B, T, C)

    @torch.no_grad()
    def encode_last(self, idx: torch.Tensor) -> torch.Tensor:
        """
        Enc(x): last hidden state of the sequence x (no gradient by default if wrapped in no_grad)
        idx: (B, T) token ids
        returns: (B, C)
        """
        tok_emb = self.transformer.wte(idx)            # (B, T, C)
        x = self._forward_from_tok_emb(tok_emb)        # (B, T, C)
        return x[:, -1, :]

    def encode_last_with_pred(self, idx: torch.Tensor, k_pred: int = 1) -> torch.Tensor:
        """
        Pred(x, k): append k learnable PRED tokens to x, return last hidden (the final PRED).
        idx: (B, T), k_pred >= 0
        returns: (B, C)
        """
        B, T = idx.shape
        tok_emb = self.transformer.wte(idx)            # (B, T, C)
        if k_pred > 0:
            pred_tok = self.pred_token.expand(B, k_pred, -1)  # (B, k, C)
            tok_emb = torch.cat([tok_emb, pred_tok], dim=1)   # (B, T+k, C)
        x = self._forward_from_tok_emb(tok_emb)        # (B, T(+k), C)
        return x[:, -1, :]                             # last token (last PRED if k>0)


logits, loss = model(x, y) # loss is CE
loss_ce = loss

if jepa_lambda > 0.0:
    T = x.size(1)
    split = max(2, int(T * jepa_split))  # at least 2 tokens
    xa = x[:, :split]
    xb = x[:, split:]

    if xb.size(1) >= 2:
        # Pred(TEXT) → last hidden after k PRED tokens (gradients ON)
        z_pred = model.encode_last_with_pred(xa, k_pred=jepa_k)     # (B, C)
        # Enc(CODE)  → last hidden of the target view (stop-grad OFF)
        with torch.no_grad():
            z_tgt = model.encode_last(xb)                           # (B, C)

        loss_jepa = cosine_loss(z_pred, z_tgt)
        loss = loss_ce + jepa_lambda * loss_jepa
    else:
        loss = loss_ce
        loss_jepa = torch.tensor(0.0, device=loss.device)
else:
    loss = loss_ce
    loss_jepa = torch.tensor(0.0, device=loss.device)

# backward/update exactly as before...
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

# (optional) log both CE and JEPA terms
if iter_num % 100 == 0:
    print(f"iter {iter_num}: CE={loss_ce.item():.4f} JEPA={loss_jepa.item():.4f} total={loss.item():.4f}")
