<a href="https://colab.research.google.com/github/iliemihai/TODOS/blob/main/JEPA_MODEL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
### GPT LLM TRAIING OBJECTIVE + JEPA
jepa_lambda = 0.0
jepa_k = 1
jepa_split = 0.5


def cosine_loss(a, b):
    a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
    return (1.0 - (a * b).sum(dim=-1)).mean()


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pred_token = nn.Parameter(torch.zeros(1, 1, config.n_embd))
        nn.init.normal_(self.pred_token, mean=0.0, std=0.02)

    def _forward_from_tok_emb(self, tok_emb):
        T = tok_emb.size(1)
        device = tok_emb.device
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0) # (1,T)
        pos_emb = self.transformer.wpe(pos)  # (1,T,C)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        return x  # (B, T, C)

    @torch.no_grad()
    def encode_last(self, idx: torch.Tensor) -> torch.Tensor:
        """
        Enc(x): last hidden state of the sequence x (no gradient by default if wrapped in no_grad)
        idx: (B, T) token ids
        returns: (B, C)
        """
        tok_emb = self.transformer.wte(idx)            # (B, T, C)
        x = self._forward_from_tok_emb(tok_emb)        # (B, T, C)
        return x[:, -1, :]

    def encode_last_with_pred(self, idx: torch.Tensor, k_pred: int = 1) -> torch.Tensor:
        """
        Pred(x, k): append k learnable PRED tokens to x, return last hidden (the final PRED).
        idx: (B, T), k_pred >= 0
        returns: (B, C)
        """
        B, T = idx.shape
        tok_emb = self.transformer.wte(idx)            # (B, T, C)
        if k_pred > 0:
            pred_tok = self.pred_token.expand(B, k_pred, -1)  # (B, k, C)
            tok_emb = torch.cat([tok_emb, pred_tok], dim=1)   # (B, T+k, C)
        x = self._forward_from_tok_emb(tok_emb)        # (B, T(+k), C)
        return x[:, -1, :]                             # last token (last PRED if k>0)


logits, loss = model(x, y) # loss is CE
loss_ce = loss

if jepa_lambda > 0.0:
    T = x.size(1)
    split = max(2, int(T * jepa_split))  # at least 2 tokens
    xa = x[:, :split]
    xb = x[:, split:]

    if xb.size(1) >= 2:
        # Pred(TEXT) → last hidden after k PRED tokens (gradients ON)
        z_pred = model.encode_last_with_pred(xa, k_pred=jepa_k)     # (B, C)
        # Enc(CODE)  → last hidden of the target view (stop-grad OFF)
        with torch.no_grad():
            z_tgt = model.encode_last(xb)                           # (B, C)

        loss_jepa = cosine_loss(z_pred, z_tgt)
        loss = loss_ce + jepa_lambda * loss_jepa
    else:
        loss = loss_ce
        loss_jepa = torch.tensor(0.0, device=loss.device)
else:
    loss = loss_ce
    loss_jepa = torch.tensor(0.0, device=loss.device)

# backward/update exactly as before...
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

# (optional) log both CE and JEPA terms
if iter_num % 100 == 0:
    print(f"iter {iter_num}: CE={loss_ce.item():.4f} JEPA={loss_jepa.item():.4f} total={loss.item():.4f}")


In [2]:
!pip install PyMuPDF

Collecting PyMuPDF
  Downloading pymupdf-1.26.4-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (3.4 kB)
Downloading pymupdf-1.26.4-cp39-abi3-manylinux_2_28_x86_64.whl (24.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.1/24.1 MB[0m [31m87.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: PyMuPDF
Successfully installed PyMuPDF-1.26.4


In [5]:
import os, textwrap

pdf_in = "./2509.14252v1.pdf"
pdf_out = "./LLM-JEPA_annotated_on_paper.pdf"

try:
    import fitz  # PyMuPDF
except Exception as e:
    raise RuntimeError(f"PyMuPDF not available: {e}")

#doc = fitz.open(pdf_in)

def write_note(page, text, y_top=36, box_height=140):
    """Insert a wrapped text box near the top of the page with our note."""
    rect = page.rect
    # Leave margins; reserve an overlay area at the top
    box = fitz.Rect(36, y_top, rect.width - 36, y_top + box_height)
    # Draw a light rectangle background (gray border)
    shape = page.new_shape()
    shape.draw_rect(box)
    shape.finish(width=0.8, color=(0,0,0))
    shape.commit()
    # Insert text with wrapping
    page.insert_textbox(
        box,
        text,
        fontsize=9.5,
        fontname="helv",
        align=0,  # left
    )

notes = {
    0: ("ANNOTATION — Page 1 (Abstract & Fig.1)\n"
        "• LLM-JEPA = NTP (next-token) + a latent-space alignment term between two views (e.g., Text↔Code).\n"
        "• Fig.1 shows consistent accuracy gains across models (Llama3, Gemma2, OpenELM) and datasets "
        "(NL-RX SYNTH/TURK, Spider, GSM8K); right-hand plot: faster and higher accuracy vs baseline.\n"
        "• Claim: improves across finetuning and pretraining settings, with robustness to overfitting."),
    1: ("ANNOTATION — Page 2 (Two-view setup, Fig.2)\n"
        "• JEPA for NLP: treat Text and Code as two views of the same underlying knowledge.\n"
        "• Loss adds d(Pred(Enc(Text)), Enc(Code)) to NTP; tasks shown: NL→Regex (NL-RX) and NL→SQL (Spider).\n"
        "• Non-trivial paired views are key; this paper focuses on datasets that provide them."),
    2: ("ANNOTATION — Page 3 (Objective, §2.1–2.2)\n"
        "• NTP (Eq.1): standard autoregressive cross-entropy.\n"
        "• LLM-JEPA (Eq.2): L = ∑ CE + λ·d(Pred(Enc(Text)), Enc(Code)).\n"
        "• Enc(x): last-layer, last-token hidden. Pred uses k appended [PRED] tokens; take the last [PRED] hidden.\n"
        "• Metric d: cosine (recommended) or L2. Training uses two extra forwards (Text & Code) to avoid cross-view leakage."),
    3: ("ANNOTATION — Page 4 (Table 1: Pretraining)\n"
        "• Pretraining on NL-RX-SYNTH with Llama-3.2-1B: accuracy rises from 54.38±1.70 to 60.59±1.01 "
        "with λ=2, k=3 (p≈2.94e-4). JEPA improves representation even before finetuning."),
    7: ("ANNOTATION — Page 8 (Table 3: LoRA)\n"
        "• LoRA finetuning on NL-RX-SYNTH: JEPA > NTP at all ranks. At rank 512, JEPA reaches 72.41±2.94 — "
        "matching full finetuning (70.42±2.36), while baseline NTP is 50.18±5.15."),
    8: ("ANNOTATION — Page 9 (Table 4 & 5: Pretrain→Finetune, γ/λ)\n"
        "• Pretrain on paraphrases with JEPA → better downstream (RottenTomatoes +1.19pp; Yelp +0.69pp) using same lrs.\n"
        "• Ablation with γ/λ shows NTP remains essential: with γ=0, model collapses to empty outputs; JEPA acts as a strong regularizer."),
    9: ("ANNOTATION — Page 10 (Fig.3 grid, Fig.4)\n"
        "• Best (k,λ) can occur anywhere; adjacent cells similar → small sweeps suffice (k∈{0..4}, λ∈{0.5,1,2,4}).\n"
        "• Fig.4: NTP loss is similar across methods, but JEPA-pred loss is minimized only with JEPA; accuracy gap attributed to the JEPA term."),
    10: ("ANNOTATION — Page 11 (Table 8 & Fig.5)\n"
         "• Across families on NL-RX-SYNTH: Gemma2 +9.5pp, OpenELM +13.3pp, OLMo +0.43pp (best configs).\n"
         "• Fig.5: Under LoRA, JEPA continues improving across epochs while baseline overfits — supports ‘robust to overfitting’ claim."),
    11: ("ANNOTATION — Page 12 (Table 9 & 10)\n"
         "• Llama-3.2-1B across datasets: SYNTH +14.17pp, TURK +8.45pp, GSM8K +4.0pp, Spider +3.03pp (best configs).\n"
         "• Table 10: Near-linear mapping Enc(Text)→Enc(Code); least-squares error drops by orders of magnitude vs baseline."),
    12: ("ANNOTATION — Page 13 (Fig.6 t-SNE, Table 11 sizes)\n"
         "• t-SNE: JEPA imposes clean structure on Text/Code embeddings; NTP disrupts base structure.\n"
         "• Scaling: significant gains across sizes (e.g., Llama-3.2-1B +14.2pp; 3B +2.6pp; 8B large jump under startswith metric; OLMo-7B +0.5pp)."),
    13: ("ANNOTATION — Page 14 (Fig.7 singular values)\n"
         "• Singular values of Enc(Text)−Enc(Code) shrink drastically with JEPA → mapping confined to a narrow subspace.\n"
         "• Supports the hypothesis that JEPA regularizes geometry and promotes near-linear cross-view mapping.")
}

# Insert notes
# for pno, text in notes.items():
#     if pno < len(doc):
#         #page = doc[pno]
#         # Make sure we don't overlap content too much: push box a bit down if the page has a header
#         y_top = 36 if pno != 0 else 60
#         #write_note(page, text, y_top=y_top, box_height=150)

# Save annotated PDF
# doc.save(pdf_out)
# doc.close()

pdf_out


'./LLM-JEPA_annotated_on_paper.pdf'

## **LLM-JEPA nanoGPT**

> **Goal:** explain why next-token prediction (NTP) isn’t enough, what **LLM-JEPA** adds, what **[PRED]** tokens are, and how to add a minimal JEPA loss to **nanoGPT** (baseline vs. JEPA).  
> **Reference:** “LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures” (arXiv:2509.14252).

---

## **1) The problem LLM-JEPA solves**

- **NTP trains surface continuation**, not cross-view alignment.
- Example:
  - **Text view:** “five plus three is equal to eight”
  - **Code view:** `5 + 3 = 8`
- With NTP alone, the model learns to generate each string correctly, but there’s no explicit pressure for their **internal vectors to be near each other** in representation space.
- **LLM-JEPA adds a representation-space term**: make a **prediction from Text** that **matches the Code embedding**, while keeping NTP for generation. This improves cross-view mapping **without sacrificing generation**.

---

## **2) LLM-JEPA objective**

- Keep your normal **next-token loss** (cross-entropy).
- Add a second term that **measures distance** between:
  - a **predicted embedding from the Text view** (Pred(Text))
  - and the **embedding of the Code view** (Enc(Code), with gradients stopped).
- Total loss = **next-token cross-entropy** + **lambda × distance(Pred(Text), Enc(Code))**.
- Distance can be **cosine** (recommended, scale-invariant) or **L2**.
- Start with **lambda** in the range **0.3 to 1.0**.

---

## **3) What are prediction tokens [PRED] and why use them?**

- Append **k** learnable **[PRED]** tokens to the **end** of the Text sequence.
- Take the **last token’s hidden state** as **Pred(Text)**.
- In a decoder-only LM, the **last position can attend to all previous tokens**, so the final [PRED] acts like a **tiny predictor head** implemented by the model’s own layers (no separate MLP needed).
- **Capacity knob:**
  - **k = 0** → identity (Pred(x) = Enc(x))
  - **k = 1 or 2** → usually enough refinement
  - More than 2 gives diminishing returns on small setups.

---

## **4) NITS**

- **Last token:** in GPT-style models, the last position sees the full prefix. Its hidden state is the best single “summary/prediction” vector.
- **Not EOS or last content token:** that would force one position to serve **two jobs** (generation and alignment). Dedicated [PRED] positions **separate concerns**.
- **Cosine distance:** robust and scale-invariant for embedding alignment. L2 also works, especially if you normalize embeddings.

---

## **5) What is the “encoder” in an LLM here?**

- **Enc(x)** just means: run the LLM to get hidden states, then **select/pool one vector** for the sequence.
- In nanoGPT, “encoder” here is simply the **last hidden state** (or a mean-pool) from the decoder-only stack.

---

## **6) Minimal training

```python
### GPT LLM TRAIING OBJECTIVE + JEPA
jepa_lambda = 0.0
jepa_k = 1
jepa_split = 0.5


def cosine_loss(a, b):
    a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
    return (1.0 - (a * b).sum(dim=-1)).mean()


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pred_token = nn.Parameter(torch.zeros(1, 1, config.n_embd))
        nn.init.normal_(self.pred_token, mean=0.0, std=0.02)

    def _forward_from_tok_emb(self, tok_emb):
        T = tok_emb.size(1)
        device = tok_emb.device
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0) # (1,T)
        pos_emb = self.transformer.wpe(pos)  # (1,T,C)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        return x  # (B, T, C)

    @torch.no_grad()
    def encode_last(self, idx: torch.Tensor) -> torch.Tensor:
        """
        Enc(x): last hidden state of the sequence x (no gradient by default if wrapped in no_grad)
        idx: (B, T) token ids
        returns: (B, C)
        """
        tok_emb = self.transformer.wte(idx)            # (B, T, C)
        x = self._forward_from_tok_emb(tok_emb)        # (B, T, C)
        return x[:, -1, :]

    def encode_last_with_pred(self, idx: torch.Tensor, k_pred: int = 1) -> torch.Tensor:
        """
        Pred(x, k): append k learnable PRED tokens to x, return last hidden (the final PRED).
        idx: (B, T), k_pred >= 0
        returns: (B, C)
        """
        B, T = idx.shape
        tok_emb = self.transformer.wte(idx)            # (B, T, C)
        if k_pred > 0:
            pred_tok = self.pred_token.expand(B, k_pred, -1)  # (B, k, C)
            tok_emb = torch.cat([tok_emb, pred_tok], dim=1)   # (B, T+k, C)
        x = self._forward_from_tok_emb(tok_emb)        # (B, T(+k), C)
        return x[:, -1, :]                             # last token (last PRED if k>0)


logits, loss = model(x, y) # loss is CE
loss_ce = loss

if jepa_lambda > 0.0:
    T = x.size(1)
    split = max(2, int(T * jepa_split))  # at least 2 tokens
    xa = x[:, :split]
    xb = x[:, split:]

    if xb.size(1) >= 2:
        # Pred(TEXT) → last hidden after k PRED tokens (gradients ON)
        z_pred = model.encode_last_with_pred(xa, k_pred=jepa_k)     # (B, C)
        # Enc(CODE)  → last hidden of the target view (stop-grad OFF)
        with torch.no_grad():
            z_tgt = model.encode_last(xb)                           # (B, C)

        loss_jepa = cosine_loss(z_pred, z_tgt)
        loss = loss_ce + jepa_lambda * loss_jepa
    else:
        loss = loss_ce
        loss_jepa = torch.tensor(0.0, device=loss.device)
else:
    loss = loss_ce
    loss_jepa = torch.tensor(0.0, device=loss.device)

# backward/update exactly as before...
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

# (optional) log both CE and JEPA terms
if iter_num % 100 == 0:
    print(f"iter {iter_num}: CE={loss_ce.item():.4f} JEPA={loss_jepa.item():.4f} total={loss.item():.4f}")
