In [None]:
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import dataclasses
import math

# =============================================================================
# 1. SETUP THE TRANSFORMER (Increased Limit)
# =============================================================================
CONTEXT_LIMIT = 512  # <--- INCREASED LIMIT (Holds a long paragraph)

@dataclasses.dataclass
class GPTConfig:
    block_size: int = CONTEXT_LIMIT
    vocab_size: int = 128
    n_layer: int = 2
    n_head: int = 4
    n_embd: int = 128
    dropout: float = 0.0
    bias: bool = False

class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                                     .view(1, 1, config.block_size, config.block_size))

    def forward(self, x):
        B, T, C = x.size()
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)
        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        return self.c_proj(y)

class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embd, 4 * config.n_embd),
            nn.GELU(),
            nn.Linear(4 * config.n_embd, config.n_embd)
        )
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.n_embd),
            wpe = nn.Embedding(config.block_size, config.n_embd),
            h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
            ln_f = nn.LayerNorm(config.n_embd),
        ))
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx):
        device = idx.device
        b, t = idx.size()
        if t > self.config.block_size:
            idx = idx[:, -self.config.block_size:]
            t = self.config.block_size
        pos = torch.arange(0, t, dtype=torch.long, device=device)
        x = self.transformer.wte(idx) + self.transformer.wpe(pos)
        for block in self.transformer.h: x = block(x)
        x = self.transformer.ln_f(x)
        return self.lm_head(x)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens=20):
        for _ in range(max_new_tokens):
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
            logits = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# =============================================================================
# 2. INTERACTIVE DEMO (LONG PARAGRAPH VERSION)
# =============================================================================
def run_long_context_demo():
    # Helper: Encode/Decode
    chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!?,.:; "
    stoi = { ch:i for i,ch in enumerate(chars) }
    itos = { i:ch for i,ch in enumerate(chars) }
    encode = lambda s: [stoi.get(c, 0) for c in s]
    decode = lambda l: ''.join([itos.get(i, '') for i in l])

    # Initialize Model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = GPT(GPTConfig())
    model.to(device)
    model.eval()

    print("\n" + "="*60)
    print("      TRANSFORMER MEMORY DEMO (Long Context)")
    print(f"      Memory Limit: {CONTEXT_LIMIT} characters")
    print("="*60)

    # --- INPUT 1: THE SECRET ---
    secret = input("\n[1] Enter the SECRET (e.g., 'Kriti is the winner'): ")
    if not secret: secret = "Kriti is the winner."

    # --- INPUT 2: THE DISTRACTION (LONG) ---
    print(f"\n[2] Enter a VERY LONG PARAGRAPH to push the secret out.")
    print(f"    (Must be > {CONTEXT_LIMIT - len(secret)} chars to trigger forgetting)")
    print("    Tip: Copy/Paste a long text block here.")
    distraction = input("    Paste here: ")

    # Combine History
    history = secret + " " + distraction
    print(f"\n--> Total Text Length: {len(history)} chars")
    print(f"--> Memory Limit:      {CONTEXT_LIMIT} chars")

    # --- INPUT 3: THE QUESTION ---
    prompt = input("\n[3] Enter your QUESTION (e.g., 'Who is the winner?'): ")
    full_input_str = history + " " + prompt

    # Convert to Tensor
    input_ids = torch.tensor(encode(full_input_str), dtype=torch.long, device=device).unsqueeze(0)

    # --- SHOW WHAT HAPPENS INTERNALLY ---
    print("\n" + "-"*40)
    print("INTERNAL MEMORY SCAN:")

    # Visualize the buffer state
    if input_ids.size(1) > CONTEXT_LIMIT:
        overflow_amount = input_ids.size(1) - CONTEXT_LIMIT
        print(f"[!] MEMORY OVERFLOW by {overflow_amount} characters.")
        print("[!] The beginning (Secret) has been DELETED.")
        actual_input = input_ids[:, -CONTEXT_LIMIT:]
        secret_status = "GONE"
    else:
        print("[OK] Memory is NOT full yet.")
        print("     The model can still see the secret (No Forgetting).")
        actual_input = input_ids
        secret_status = "VISIBLE"

    # Decode what the model actually sees
    visible_text = decode(actual_input[0].tolist())

    # Show only the start and end of memory to keep output clean
    preview = visible_text if len(visible_text) < 100 else f"...{visible_text[-100:]}"
    print(f"\n[WHAT THE MODEL SEES NOW (Last 100 chars)]:\n'{preview}'")
    print("-" * 40)

    # Prediction
    if secret_status == "GONE":
        print("\n[PREDICTION]: The model will FAIL (Secret deleted).")
    else:
        print("\n[PREDICTION]: The model MIGHT work (Secret still in memory).")

    # --- GENERATE ANSWER ---
    print("\n--> Model is answering...")
    output_ids = model.generate(actual_input, max_new_tokens=15)
    new_tokens = output_ids[0].tolist()[len(actual_input[0]):]
    reply = decode(new_tokens)

    print("\n" + "="*50)
    print(f"MODEL REPLY: {reply}")
    print("="*50)

# Start the demo
run_long_context_demo()