<a href="https://colab.research.google.com/github/iliemihai/TODOS/blob/main/JEPA_MODEL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
### GPT LLM TRAIING OBJECTIVE + JEPA
jepa_lambda = 0.0
jepa_k = 1
jepa_split = 0.5


def cosine_loss(a, b):
    a = F.normalize(a, dim=-1); b = F.normalize(b, dim=-1)
    return (1.0 - (a * b).sum(dim=-1)).mean()


class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pred_token = nn.Parameter(torch.zeros(1, 1, config.n_embd))
        nn.init.normal_(self.pred_token, mean=0.0, std=0.02)

    def _forward_from_tok_emb(self, tok_emb):
        T = tok_emb.size(1)
        device = tok_emb.device
        pos = torch.arange(0, T, dtype=torch.long, device=device).unsqueeze(0) # (1,T)
        pos_emb = self.transformer.wpe(pos)  # (1,T,C)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        return x  # (B, T, C)

    @torch.no_grad()
    def encode_last(self, idx: torch.Tensor) -> torch.Tensor:
        """
        Enc(x): last hidden state of the sequence x (no gradient by default if wrapped in no_grad)
        idx: (B, T) token ids
        returns: (B, C)
        """
        tok_emb = self.transformer.wte(idx)            # (B, T, C)
        x = self._forward_from_tok_emb(tok_emb)        # (B, T, C)
        return x[:, -1, :]

    def encode_last_with_pred(self, idx: torch.Tensor, k_pred: int = 1) -> torch.Tensor:
        """
        Pred(x, k): append k learnable PRED tokens to x, return last hidden (the final PRED).
        idx: (B, T), k_pred >= 0
        returns: (B, C)
        """
        B, T = idx.shape
        tok_emb = self.transformer.wte(idx)            # (B, T, C)
        if k_pred > 0:
            pred_tok = self.pred_token.expand(B, k_pred, -1)  # (B, k, C)
            tok_emb = torch.cat([tok_emb, pred_tok], dim=1)   # (B, T+k, C)
        x = self._forward_from_tok_emb(tok_emb)        # (B, T(+k), C)
        return x[:, -1, :]                             # last token (last PRED if k>0)


logits, loss = model(x, y) # loss is CE
loss_ce = loss

if jepa_lambda > 0.0:
    T = x.size(1)
    split = max(2, int(T * jepa_split))  # at least 2 tokens
    xa = x[:, :split]
    xb = x[:, split:]

    if xb.size(1) >= 2:
        # Pred(TEXT) → last hidden after k PRED tokens (gradients ON)
        z_pred = model.encode_last_with_pred(xa, k_pred=jepa_k)     # (B, C)
        # Enc(CODE)  → last hidden of the target view (stop-grad OFF)
        with torch.no_grad():
            z_tgt = model.encode_last(xb)                           # (B, C)

        loss_jepa = cosine_loss(z_pred, z_tgt)
        loss = loss_ce + jepa_lambda * loss_jepa
    else:
        loss = loss_ce
        loss_jepa = torch.tensor(0.0, device=loss.device)
else:
    loss = loss_ce
    loss_jepa = torch.tensor(0.0, device=loss.device)

# backward/update exactly as before...
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()

# (optional) log both CE and JEPA terms
if iter_num % 100 == 0:
    print(f"iter {iter_num}: CE={loss_ce.item():.4f} JEPA={loss_jepa.item():.4f} total={loss.item():.4f}")
