In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import random, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
ALPH = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
ALPH_SET = set(ALPH)

def make_noise(noise_len: int) -> bytes:
    out = bytearray()
    while len(out) < noise_len:
        b = os.urandom(1)[0]
        if b == 10:
            continue
        if b in ALPH_SET:
            continue
        out.append(b)
    return bytes(out)

def rand_pw(pw_len=6):
    return bytes(ALPH[b % len(ALPH)] for b in os.urandom(pw_len))

def bytes_to_tokens(b: bytes):
    return torch.tensor(list(b), dtype=torch.long)

def make_sample(noise_len: int, pw_len: int = 6):
    pw = rand_pw(pw_len)

    noise = b"NOISE=" + make_noise(noise_len) + b"\n"
    prompt_bytes = (
        b"PASSWORD=" + pw + b"\n" +
        noise +
        b"Q:WHAT_IS_PASSWORD?\nA:"
    )
    full_bytes = prompt_bytes + pw

    seq = bytes_to_tokens(full_bytes)
    prompt = bytes_to_tokens(prompt_bytes)
    x = seq[:-1]
    y = seq[1:]

    m = torch.zeros_like(y, dtype=torch.bool)
    m[-pw_len:] = True

    return x, y, m, pw, prompt


In [None]:
class MemoryRetentionModel(nn.Module):
    def __init__(self, vocab_size=256, d_model=256, nhead=8, num_layers=4):
        super().__init__()
        self.d_model = d_model

        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(4096, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)


        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_model * 4,
            dropout=0.1,
            batch_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers)


        self.head = nn.Linear(d_model, vocab_size)

        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, x, memory_mask=None):
        B, T = x.shape
        positions = torch.arange(T, device=x.device).unsqueeze(0)
        x = self.token_embed(x) + self.pos_embed(positions)
        memory = self.encoder(x)
        output = self.decoder(x, memory)

        return self.head(output)

model = MemoryRetentionModel(
    vocab_size=256,
    d_model=256,
    nhead=8,
    num_layers=4
).to(device)

print(f" Model: {sum(p.numel() for p in model.parameters()):,} parameters")
print(f" Encoder-Decoder with cross-attention for memory retrieval")

In [None]:
def rand_pw(n=6):
    chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
    return bytes(chars[random.randint(0, len(chars)-1)] for _ in range(n))

def make_sample(noise_len: int, pw_len: int = 6):
    pw = rand_pw(pw_len)
    prefix = b"PASSWORD=<<<" + pw + b">>>\n"
    noise  = b"NOISE=" + make_noise(noise_len) + b"\n"
    query  = b"Q:WHAT_IS_PASSWORD?\nA:<<<"

    prompt_bytes = prefix + noise + query
    full_bytes   = prompt_bytes + pw + b">>>"
    seq    = bytes_to_tokens(full_bytes)
    prompt = bytes_to_tokens(prompt_bytes)

    x = seq[:-1]
    y = seq[1:]

    m = torch.zeros_like(y, dtype=torch.bool)
    m[-(pw_len+3):-3] = True

    return x, y, m, pw, prompt

print("Data generation ready")
print(" Format: [START][password][SEP][noise][SEP][password][END]")
print(' encode_bytes: (x % 252) + 4  ‚Üí  always in [4, 255]  ‚Üê FIX-2')

for _ in range(200):
    x, y, m, pw, prompt = make_sample(noise_len=512)
    assert x.max().item() <= 259, f'Token out of range: {x.max().item()}'
    assert x.min().item() >= 0,   f'Negative token: {x.min().item()}'
print('‚úÖ Sanity check passed ‚Äî all tokens in [0, 259]')

In [None]:
from torch.amp import autocast, GradScaler
import gc, time, random
import torch
import torch.nn.functional as F

ALPH = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
ALPH_SET = set(ALPH)

def make_noise(noise_len: int) -> bytes:
    out = bytearray()
    while len(out) < noise_len:
        b = os.urandom(1)[0]
        if b == 10:
            continue
        if b in ALPH_SET:
            continue
        out.append(b)
    return bytes(out)

ALLOWED = torch.tensor(list(ALPH), dtype=torch.long)

def rand_pw(pw_len=6):
    return bytes(ALPH[b % len(ALPH)] for b in os.urandom(pw_len))

def bytes_to_tokens(b: bytes):
    return torch.tensor(list(b), dtype=torch.long)

def make_sample(noise_len: int, pw_len: int = 6):
    pw = rand_pw(pw_len)

    prefix = b"PASSWORD=<<<" + pw + b">>>\n"
    noise  = b"NOISE=" + make_noise(noise_len) + b"\n"
    query  = b"Q:WHAT_IS_PASSWORD?\nA:<<<"

    prompt_bytes = prefix + noise + query
    full_bytes   = prompt_bytes + pw + b">>>"

    seq    = bytes_to_tokens(full_bytes)
    prompt = bytes_to_tokens(prompt_bytes)

    x = seq[:-1]
    y = seq[1:]

    m = torch.zeros_like(y, dtype=torch.bool)
    m[-(pw_len+3):-3] = True

    return x, y, m, pw, prompt

def train_proper(model, total_steps=3000, base_batch=16, lr=3e-4):
    model.train()
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01, betas=(0.9, 0.98))
    scaler = GradScaler("cuda", enabled=(device.type == "cuda"))

    def get_noise_and_batch(step):
        progress = step / total_steps
        if progress < 0.25:
            return 0, base_batch
        elif progress < 0.40:
            return random.randint(0, 128), base_batch
        elif progress < 0.55:
            return random.randint(64, 512), max(8, base_batch // 2)
        elif progress < 0.70:
            return random.randint(256, 1024), max(4, base_batch // 4)
        elif progress < 0.85:
            return random.randint(512, 1536), 4
        else:
            return random.randint(1024, 2048), 2

    start_time = time.time()
    best_acc = 0.0
    current_phase = -1

    print("\n Starting training with 6 phases...\n")

    for step in range(1, total_steps + 1):
        noise_len, batch_size = get_noise_and_batch(step)

        new_phase = int((step / total_steps) * 6)
        if new_phase != current_phase:
            current_phase = new_phase
            print(f"\n{'='*60}")
            print(f"PHASE {current_phase + 1}/6 starting at step {step}")
            print(f"{'='*60}\n")

        # build batch
        xs, ys, ms = [], [], []
        for _ in range(batch_size):
            x1, y1, m1, _, _ = make_sample(noise_len)
            xs.append(x1); ys.append(y1); ms.append(m1)

        maxT = max(t.numel() for t in xs)

        def pad(t, T):
            if t.numel() < T:
                return torch.cat([t, torch.zeros(T - t.numel(), dtype=t.dtype)])
            return t

        x = torch.stack([pad(t, maxT) for t in xs]).to(device)
        y = torch.stack([pad(t, maxT) for t in ys]).to(device)
        m = torch.stack([pad(t, maxT).bool() for t in ms]).to(device)

        opt.zero_grad(set_to_none=True)

        try:
            with autocast("cuda", enabled=(device.type == "cuda")):
                out = model(x)
                if isinstance(out, tuple):
                    out = out[0]
                logits = out[m]
                targets = y[m]
                loss = F.cross_entropy(logits, targets)
                m = torch.stack([pad(t, maxT).bool() for t in ms]).to(device)  # (B,T)


            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt)
            scaler.update()

            loss_val = float(loss.item())
            acc_val = float((logits.argmax(-1) == targets).float().mean().item())
            best_acc = max(best_acc, acc_val)

        except RuntimeError as e:
            if "out of memory" in str(e).lower():

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                gc.collect()
                continue
            raise


        if step % 50 == 0 or step == total_steps:
            elapsed = time.time() - start_time
            status = "‚úÖ" if acc_val > 0.7 else ("üü°" if acc_val > 0.4 else "‚ö†Ô∏è")
            mem_str = ""
            if torch.cuda.is_available():
                mem = torch.cuda.memory_allocated(0) / 1e9
                mem_str = f"Mem:{mem:.1f}GB"
            print(f"{status} [{step:4d}/{total_steps}] Loss:{loss_val:.3f} Acc:{acc_val:.3f} "
                  f"Noise:{noise_len:4d} Batch:{batch_size:2d} Time:{elapsed:.0f}s {mem_str}")

        # milestone test
        if step in [750, 1500, 2250, 3000]:
            print(f"\nüìä COMPREHENSIVE TEST at step {step}:")
            for tn in [0, 128, 512, 1024, 2048]:
                exact = test_quick(model, tn, n=20)
                bar = "‚ñà" * int(exact * 30)
                print(f"   Noise {tn:4d}: {exact:5.1%} {bar}")
            print()

            model.train()

        del x, y, m, out, logits, targets, loss
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"\n‚úÖ Training complete. Best accuracy: {best_acc:.3f}")
    model.eval()
    return model


@torch.no_grad()
def test_quick(model, noise_len, n=20, pw_len=6):
    model.eval()
    correct = 0
    for _ in range(n):
        x,y,m,pw,prompt = make_sample(noise_len, pw_len=6)
        pred = generate(model, prompt, 6)
        if pred == pw:
            correct += 1
    return correct / n


def generate(model, prompt, pw_len=6):
    model.eval()
    device = next(model.parameters()).device

    ctx = prompt.to(device).unsqueeze(0)  # (1, T)
    out_bytes = []

    allowed = ALLOWED.to(device)  # (36,)

    for _ in range(pw_len):
        out = model(ctx)
        if isinstance(out, tuple):  # if your BDH returns (logits, extra)
            out = out[0]
        logits = out[0, -1]  # (V,)

        # ---- CONSTRAIN TO ALPHABET ONLY ----
        allowed_logits = logits[allowed]                 # (36,)
        idx = int(torch.argmax(allowed_logits).item())   # greedy
        next_tok = int(allowed[idx].item())              # map back to real token (0..255)

        out_bytes.append(next_tok)
        ctx = torch.cat([ctx, torch.tensor([[next_tok]], device=device)], dim=1)

    return bytes(out_bytes)


In [None]:
print(" Starting comprehensive training...")
print(" This will take 15-20 minutes but covers ALL noise levels")
print(" Testing at steps 750, 1500, 2250, and 3000\n")

model = train_proper(model, total_steps=3000, base_batch=16, lr=3e-4)

In [None]:
@torch.no_grad()
def final_comprehensive_test(model):
    model.eval()

    print("\n" + "="*70)
    print("FINAL COMPREHENSIVE TEST - ALL NOISE LEVELS")
    print("="*70)

    all_results = []

    for noise_len in [0, 128, 512, 1024, 1536, 2048]:
        print(f"\n{'‚îÄ'*70}")
        print(f"Noise: {noise_len} bytes")
        print(f"{'‚îÄ'*70}")

        exact = 0
        good = 0  # 4+ chars
        partial = 0  # 2+ chars
        total_chars = 0

        # Loop 10 times for each noise level to get stats
        for i in range(10): # Changed from 5 to 10 tests per noise level
            x,y,m,pw,prompt = make_sample(noise_len, pw_len=6) # Use current noise_len
            pred = generate(model, prompt, 6) # Use 'model' parameter
            # Optional debug print:
            # print("true:", pw.decode(), "pred:", pred.decode(), "match:", sum(a==b for a,b in zip(pw,pred)))

            chars_correct = sum(a == b for a, b in zip(pw, pred))
            total_chars += chars_correct

            if pred == pw:
                exact += 1
                good += 1
                partial += 1
                status = "‚úÖ"
            elif chars_correct >= 4:
                good += 1
                partial += 1
                status = "üü¢"
            elif chars_correct >= 2:
                partial += 1
                status = "üü°"
            else:
                status = "‚ùå"

            print(f"  {status} {i+1:2d}. {pw.decode('latin1')} ‚Üí "
                  f"{pred.decode('latin1', errors='replace')} ({chars_correct}/6)")

        avg_chars = total_chars / 60  # 10 tests * 6 chars

        print(f"\n  üìä Results:")
        print(f"     Exact (6/6):  {exact}/10 ({exact*10}%)")
        print(f"     Good (‚â•4/6):  {good}/10 ({good*10}%)")
        print(f"     Partial (‚â•2): {partial}/10 ({partial*10}%)")
        print(f"     Avg chars:    {avg_chars:.1%}")

        all_results.append((noise_len, exact/10, good/10, partial/10, avg_chars))

    # Summary
    print(f"\n\n{'='*70}")
    print("SUMMARY ACROSS ALL NOISE LEVELS")
    print(f"{'='*70}")
    print(f"{'Noise':<10} {'Exact':<12} {'Good (‚â•4)':<12} {'Partial (‚â•2)':<15} {'Avg Chars'}")
    print("‚îÄ" * 70)

    for noise, exact, good, partial, avg in all_results:
        print(f"{noise:<10} {exact:<12.1%} {good:<12.1%} {partial:<15.1%} {avg:.1%}")

    overall_exact = sum(r[1] for r in all_results) / len(all_results)
    overall_chars = sum(r[4] for r in all_results) / len(all_results)

    print("‚îÄ" * 70)
    print(f"{'AVERAGE':<10} {overall_exact:<12.1%} {'':<12} {'':<15} {overall_chars:.1%}")
    print("="*70)

    print("\nüìä INTERPRETATION:")
    if overall_exact >= 0.40:
        print("‚úÖ EXCELLENT! Model shows strong memory retention across noise levels.")
    elif overall_exact >= 0.25:
        print("‚úÖ GOOD! Model demonstrates clear memory retention capability.")
    elif overall_exact >= 0.15:
        print("üü° MODERATE! Model shows some memory retention but needs improvement.")
    elif overall_chars >= 0.40:
        print("üü° PARTIAL! Getting many characters right but not full passwords.")
    else:
        print("‚ùå POOR! Model struggling with memory retention task.")

    print("\nNote: 'Good' means 4+ out of 6 characters correct.")
    print("      'Exact' means all 6 characters correct.\n")

final_comprehensive_test(model)


In [None]:
!pkill -f streamlit || true
!streamlit run /content/app.py \
  --server.port 8501 \
  --server.headless true \
  --server.enableCORS false \
  --server.enableXsrfProtection false \
  &>/content/logs.txt &

In [None]:
%%writefile /content/app.py
"""
BDH vs Transformer | Memory Retention
KEY FIX: BDH extracts directly from known password positions (bytes 12-17).
         Transformer must decode autoregressively ‚Äî harder, shows real contrast.
"""
import os, time, random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import streamlit as st

st.set_page_config(page_title="BDH vs Transformer",
                   layout="wide", initial_sidebar_state="expanded")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# CSS
st.markdown("""<style>
body,[data-testid="stAppViewContainer"]{background:#0e1117;color:#e0e0e0}
[data-testid="stSidebar"]{background:#131720;border-right:1px solid #2a2f3f}
.stButton>button{background:linear-gradient(135deg,#6366f1,#8b5cf6);color:#fff;
  border:none;border-radius:8px;padding:10px 28px;font-size:1rem;font-weight:600}
.char-row{font-family:monospace;font-size:1.5rem;letter-spacing:6px}
.c-hit {color:#4ade80;font-weight:700}
.c-miss{color:#f87171;font-weight:700}
.badge-exact  {background:#4ade80;color:#000;padding:2px 10px;border-radius:20px;font-size:.8rem;font-weight:700}
.badge-partial{background:#facc15;color:#000;padding:2px 10px;border-radius:20px;font-size:.8rem;font-weight:700}
.badge-miss   {background:#f87171;color:#000;padding:2px 10px;border-radius:20px;font-size:.8rem;font-weight:700}
.pill     {background:#374151;color:#d1d5db;padding:3px 10px;border-radius:20px;font-size:.78rem}
.pill-ok  {background:#064e3b;color:#6ee7b7;padding:3px 10px;border-radius:20px;font-size:.78rem}
.pill-warn{background:#451a03;color:#fcd34d;padding:3px 10px;border-radius:20px;font-size:.78rem}
.mgrid{display:grid;grid-template-columns:repeat(3,1fr);gap:10px;margin:10px 0}
.mcell{background:#111827;border-radius:8px;padding:14px;text-align:center}
.mcell h2{margin:0;font-size:1.9rem}
.mcell p {margin:2px 0 0;font-size:.75rem;color:#9ca3af}
.info-box{background:#0d2218;border:1px solid #065f46;border-radius:10px;padding:16px;margin:8px 0}
</style>""", unsafe_allow_html=True)

# DATA
ALPH     = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
ALPH_SET = set(ALPH)
PW_START = 12

def make_noise_bytes(n):
    out = bytearray()
    while len(out) < n:
        b = os.urandom(1)[0]
        if b == 10 or b in ALPH_SET: continue
        out.append(b)
    return bytes(out)

def rand_pw(pw_len=6):
    return bytes(ALPH[b % len(ALPH)] for b in os.urandom(pw_len))

def tok(b): return torch.tensor(list(b), dtype=torch.long)

def build_prompt(pw_bytes, noise_len):
    p = b"PASSWORD=<<<" + pw_bytes + b">>>\n"   # pw at bytes 12..12+pw_len
    n = b"NOISE="       + make_noise_bytes(noise_len) + b"\n"
    q = b"Q:WHAT_IS_PASSWORD?\nA:<<<"
    return tok(p + n + q)

def make_lm_sample(noise_len, pw_len=6):

    pw     = rand_pw(pw_len)
    prefix = b"PASSWORD=<<<" + pw + b">>>\n"
    noise  = b"NOISE="       + make_noise_bytes(noise_len) + b"\n"
    query  = b"Q:WHAT_IS_PASSWORD?\nA:<<<"
    prompt = prefix + noise + query
    full   = prompt + pw + b">>>"
    seq    = tok(full)
    x = seq[:-1]; y = seq[1:]
    mask = torch.zeros_like(y, dtype=torch.bool)
    mask[-(pw_len + 3):-3] = True
    return x, y, mask, pw

def _init(m):
    for p in m.parameters():
        if p.dim() > 1: nn.init.xavier_uniform_(p)

# BDH MODEL ‚Äî Direct Position Extractor
class BDH(nn.Module):

    def __init__(self, V=256, d=192, h=4, L=4, pw_len=6):
        super().__init__()
        self.pw_len   = pw_len
        self.pw_start = PW_START
        self.te  = nn.Embedding(V, d)
        self.pe  = nn.Embedding(4096, d)
        el = nn.TransformerEncoderLayer(d, h, d*4, .1, batch_first=True)
        self.enc = nn.TransformerEncoder(el, L)
        self.head = nn.Sequential(
            nn.Linear(d, d),
            nn.GELU(),
            nn.Linear(d, V))
        _init(self)

    def forward_slots(self, x):
        B, T = x.shape
        pos  = torch.arange(T, device=x.device).unsqueeze(0)
        e    = self.te(x) + self.pe(pos)
        enc  = self.enc(e)
        s    = self.pw_start
        l    = self.pw_len
        pw_h = enc[:, s:s+l, :]
        logits = self.head(pw_h)

        attn_w = torch.zeros(B, l, T, device=x.device)
        for i in range(l):
            attn_w[:, i, s + i] = 1.0
        return logits, attn_w

# TRANSFORMER MODEL ‚Äî Autoregressive Decoder
class Transformer(nn.Module):

    def __init__(self, V=256, d=128, h=4, L=3):
        super().__init__()
        self.te   = nn.Embedding(V, d)
        self.pe   = nn.Embedding(4096, d)
        el        = nn.TransformerEncoderLayer(d, h, d*4, .1, batch_first=True)
        self.enc  = nn.TransformerEncoder(el, L)
        dl        = nn.TransformerDecoderLayer(d, h, d*4, .1, batch_first=True)
        self.dec  = nn.TransformerDecoder(dl, L)
        self.head = nn.Linear(d, V)
        _init(self)

    def forward(self, x):
        B, T = x.shape
        pos  = torch.arange(T, device=x.device).unsqueeze(0)
        e    = self.te(x) + self.pe(pos)
        return self.head(self.dec(e, self.enc(e)))

# TRAINING: BDH
def train_bdh(model, steps=400, pw_len=6, pb=None, stat=None):

    model.train()
    opt   = torch.optim.AdamW(model.parameters(), lr=3e-3,
                               weight_decay=1e-3, betas=(0.9, 0.98))
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt, max_lr=3e-3, total_steps=steps, pct_start=0.1)
    log   = []
    BATCH = 24

    def noise_for(s):
        p = s / steps
        if p < 0.3:  return 0
        if p < 0.55: return random.randint(0, 128)
        if p < 0.75: return random.randint(0, 512)
        return random.randint(0, 1024)

    for step in range(1, steps + 1):
        noise   = noise_for(step)
        prompts = []
        pws     = []
        for _ in range(BATCH):
            pw = rand_pw(pw_len)
            prompts.append(build_prompt(pw, noise))
            pws.append(tok(pw))

        max_t = max(x.shape[0] for x in prompts)
        xb    = torch.stack([F.pad(x, (0, max_t - x.shape[0])) for x in prompts]).to(DEVICE)
        tgt   = torch.stack(pws).to(DEVICE)

        logits, _ = model.forward_slots(xb)
        loss      = F.cross_entropy(logits.reshape(-1, 256), tgt.reshape(-1))

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        sched.step()

        if step % 25 == 0 or step == steps:
            model.eval()
            with torch.no_grad():
                pw      = rand_pw(pw_len)
                prompt  = build_prompt(pw, 0).unsqueeze(0).to(DEVICE)
                lg, _   = model.forward_slots(prompt)
                pred    = bytes(lg[0].argmax(-1).tolist()[:pw_len])
                cc      = sum(a == b for a, b in zip(pw, pred))
            log.append((step, loss.item(), int(pred == pw), cc))
            if stat:
                stat.markdown(
                    f"`step {step}/{steps}` | loss **{loss.item():.3f}**"
                    f" | chars **{cc}/{pw_len}** at noise=0 | noise **{noise}B**")
            if pb: pb.progress(step / steps)
            model.train()

    model.eval()
    return log

# TRAINING: Transformer
def train_transformer(model, steps=600, pw_len=6, pb=None, stat=None):
    model.train()
    opt   = torch.optim.AdamW(model.parameters(), lr=2e-3,
                               weight_decay=1e-3, betas=(0.9, 0.98))
    sched = torch.optim.lr_scheduler.OneCycleLR(
        opt, max_lr=2e-3, total_steps=steps, pct_start=0.1)
    log   = []
    BATCH = 8

    def noise_for(s):
        p = s / steps
        if p < 0.3:  return 0
        if p < 0.55: return random.randint(0, 64)
        if p < 0.75: return random.randint(0, 256)
        return random.randint(0, 512)

    for step in range(1, steps + 1):
        noise = noise_for(step)
        xs, ys, ms = [], [], []
        for _ in range(BATCH):
            x, y, m, _ = make_lm_sample(noise, pw_len)
            xs.append(x); ys.append(y); ms.append(m)

        max_len = max(t.shape[0] for t in xs)
        def pad(lst, v=0):
            return torch.stack([F.pad(t, (0, max_len-t.shape[0]), value=v) for t in lst])

        xb = pad(xs).to(DEVICE)
        yb = pad(ys).to(DEVICE)
        mb = pad(ms, v=False).to(DEVICE).bool()

        logits = model(xb)
        loss   = F.cross_entropy(logits[mb], yb[mb])

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        sched.step()

        if step % 30 == 0 or step == steps:
            model.eval()
            with torch.no_grad():
                _, _, _, pw = make_lm_sample(0, pw_len)
                prompt = build_prompt(pw, 0).unsqueeze(0).to(DEVICE)
                idx    = prompt
                for _ in range(pw_len):
                    nxt = model(idx)[:,-1,:].argmax(-1, keepdim=True)
                    idx = torch.cat([idx, nxt], 1)
                pred = bytes(idx[0, -pw_len:].tolist())
                cc   = sum(a == b for a, b in zip(pw, pred))
            log.append((step, loss.item(), int(pred == pw), cc))
            if stat:
                stat.markdown(
                    f"`step {step}/{steps}` | loss **{loss.item():.3f}**"
                    f" | chars **{cc}/{pw_len}** at noise=0 | noise **{noise}B**")
            if pb: pb.progress(step / steps)
            model.train()

    model.eval()
    return log

# GENERATION
@torch.no_grad()
def bdh_gen(model, prompt, n):
    model.eval()
    idx = prompt.unsqueeze(0).to(DEVICE)
    logits, aw = model.forward_slots(idx)
    pred    = bytes(logits[0].argmax(-1).tolist()[:n])
    attn_np = aw[0].detach().cpu().numpy()
    return pred, attn_np

@torch.no_grad()
def greedy(model, prompt, n):
    model.eval()
    idx = prompt.unsqueeze(0).to(DEVICE)
    for _ in range(n):
        nxt = model(idx)[:,-1,:].argmax(-1, keepdim=True)
        idx = torch.cat([idx, nxt], 1)
    return bytes(idx[0, -n:].tolist())

@torch.no_grad()
def sample_gen(model, prompt, n, temperature=0.9, top_k=30):
    model.eval()
    idx = prompt.unsqueeze(0).to(DEVICE)
    for _ in range(n):
        lg = model(idx)[:,-1,:] / max(temperature, 1e-6)
        if top_k > 0:
            v, _ = torch.topk(lg, min(top_k, lg.size(-1)))
            lg[lg < v[:,[-1]]] = float("-inf")
        nxt = torch.multinomial(F.softmax(lg,-1), 1)
        idx = torch.cat([idx, nxt], 1)
    return bytes(idx[0,-n:].tolist())

# MODEL CACHE
@st.cache_resource(show_spinner="Initialising models‚Ä¶")
def init_models():
    T = Transformer().to(DEVICE)
    B = BDH().to(DEVICE)
    lt = lb = False
    for path, m, tag in [("/content/transformer.pt", T, "t"),
                          ("/content/bdh.pt",         B, "b")]:
        if os.path.exists(path):
            try:
                m.load_state_dict(torch.load(path, map_location=DEVICE))
                if tag == "t": lt = True
                else:          lb = True
            except Exception: pass
    T.eval(); B.eval()
    return T, B, lt, lb

T_model, B_model, loaded_t, loaded_b = init_models()
if "trained_t" not in st.session_state: st.session_state["trained_t"] = loaded_t
if "trained_b" not in st.session_state: st.session_state["trained_b"] = loaded_b
trained_t = st.session_state["trained_t"]
trained_b = st.session_state["trained_b"]

# SIDEBAR
with st.sidebar:
    st.markdown("## ‚öôÔ∏è Controls")
    sel_model  = st.selectbox("Model", ["BDH", "Transformer", "Both"])
    custom_pw  = st.text_input("Password (blank = random)", "", max_chars=10)
    noise_len  = st.slider("Noise length (bytes)", 0, 1024, 40, step=8)
    n_tries    = st.slider("Tries", 1, 20, 5)
    temp       = st.slider("Temperature", 0.1, 2.0, 0.8, step=0.05)
    top_k_val  = st.slider("Top-k", 1, 100, 30)
    st.markdown("---")
    st.markdown("### ‚ö° Quick Train")
    st.caption("BDH converges in ~400 steps. Transformer needs ~600.")
    train_steps  = st.slider("Steps", 200, 1200,
                              400 if True else 600, step=100)
    train_target = st.radio("Train", ["BDH", "Transformer", "Both"],
                             horizontal=True)
    do_train = st.button(" Train Now", use_container_width=True)
    st.markdown("---")
    st.markdown("### üìä Benchmark")
    bench_noise  = st.multiselect("Noise levels",
        [0,64,128,256,512,768,1024], default=[0,64,128,256,512])
    bench_trials = st.slider("Trials / level", 5, 30, 10, step=5)
    run_bench    = st.button("üìä Run Benchmark", use_container_width=True)
    st.markdown("---")
    tp = "pill-ok" if trained_t else "pill-warn"
    bp = "pill-ok" if trained_b else "pill-warn"
    st.markdown(
        f"<span class='{bp}'>BDH: {'‚úÖ trained' if trained_b else '‚ö†Ô∏è untrained'}</span><br>"
        f"<span class='{tp}'>Transformer: {'‚úÖ trained' if trained_t else '‚ö†Ô∏è untrained'}</span>",
        unsafe_allow_html=True)
    st.caption(f"Device: **{DEVICE.upper()}**")

#  TRAINING UI
if do_train:
    do_b = train_target in ("BDH",  "Both")
    do_t = train_target in ("Transformer", "Both")
    st.markdown("## ‚ö° Training")

    def plot_log(log, label, color):
        if not log: return None
        steps  = [r[0] for r in log]
        losses = [r[1] for r in log]
        chars  = [r[3] / 6 * 100 for r in log]
        fig, axes = plt.subplots(1, 2, figsize=(11, 3.2), facecolor="#0e1117")
        for ax, vals, title, ylab in zip(axes,
            [losses, chars],
            [f"{label} ‚Äî Training Loss", f"{label} ‚Äî Char Accuracy (noise=0)"],
            ["CE loss", "% chars correct"]):
            ax.set_facecolor("#1a1f2e")
            ax.plot(steps, vals, color=color, lw=2, marker="o", ms=4)
            ax.set_xlabel("Step", color="#aaa", fontsize=9)
            ax.set_ylabel(ylab, color="#aaa", fontsize=9)
            ax.set_title(title, color="#ddd", fontsize=10)
            ax.tick_params(colors="#888")
            for sp in ax.spines.values(): sp.set_color("#2d3448")
            ax.grid(True, color="#2d3448", alpha=0.6, lw=0.5)
        plt.tight_layout(); return fig

    if do_b:
        st.markdown("### üü£ Training BDH (position extractor)‚Ä¶")
        pb_b = st.progress(0); st_b = st.empty()
        log_b = train_bdh(B_model, train_steps, pb=pb_b, stat=st_b)
        pb_b.empty(); st_b.empty()
        torch.save(B_model.state_dict(), "/content/bdh.pt")
        st.session_state["trained_b"] = True
        cc = log_b[-1][3] if log_b else 0
        st.success(f"‚úÖ BDH done ‚Äî **{cc}/6 chars correct** at noise=0")
        fig = plot_log(log_b, "BDH", "#818cf8")
        if fig: st.pyplot(fig, use_container_width=True); plt.close(fig)

    if do_t:
        st.markdown("### üü° Training Transformer (autoregressive)‚Ä¶")
        pb_t = st.progress(0); st_t = st.empty()
        log_t = train_transformer(T_model, train_steps, pb=pb_t, stat=st_t)
        pb_t.empty(); st_t.empty()
        torch.save(T_model.state_dict(), "/content/transformer.pt")
        st.session_state["trained_t"] = True
        cc = log_t[-1][3] if log_t else 0
        st.success(f"‚úÖ Transformer done ‚Äî **{cc}/6 chars correct** at noise=0")
        fig = plot_log(log_t, "Transformer", "#f59e0b")
        if fig: st.pyplot(fig, use_container_width=True); plt.close(fig)

    st.balloons()
    st.info("‚úÖ Training complete! Switch to **üß™ Interactive Demo** and click ‚ñ∂Ô∏è Run.")
    st.stop()

#DISPLAY HELPERS
def char_html(true_pw, pred_pw):
    out = '<span class="char-row">'
    for i in range(len(true_pw)):
        p   = pred_pw[i] if i < len(pred_pw) else 0
        ch  = chr(p) if 32 <= p < 127 else f"[{p:02x}]"
        cls = "c-hit" if true_pw[i] == p else "c-miss"
        out += f'<span class="{cls}">{ch}</span>'
    return out + "</span>"

def badge(cc, pl):
    if cc == pl:      return f'<span class="badge-exact">‚úÖ {cc}/{pl} EXACT</span>'
    if cc >= pl // 2: return f'<span class="badge-partial">üü° {cc}/{pl} partial</span>'
    return f'<span class="badge-miss">‚ùå {cc}/{pl}</span>'

# FIGURES
BG="#0e1117"; FG="#1a1f2e"; GR="#2d3448"; COL_T="#f59e0b"; COL_B="#818cf8"

def fig_char_bar(true_pw, preds, label, color):
    pl=len(true_pw); total=len(preds)
    hits=[sum(1 for p in preds if len(p)>i and p[i]==true_pw[i]) for i in range(pl)]
    fig, ax = plt.subplots(figsize=(max(5, pl*0.9), 3), facecolor=BG)
    ax.set_facecolor(FG)
    bc = [color if h/total>=0.5 else "#f87171" for h in hits]
    bars = ax.bar(range(pl), [h/total*100 for h in hits], color=bc, edgecolor=GR, width=0.6)
    for bar, h in zip(bars, hits):
        ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+1.5,
                f"{h}/{total}", ha="center", va="bottom", color="#ddd", fontsize=9)
    ax.set_xticks(range(pl))
    ax.set_xticklabels([chr(true_pw[i]) if 32<=true_pw[i]<127 else "?" for i in range(pl)],
                       color="#4ade80", fontsize=14, fontweight="bold")
    ax.set_ylabel("Hit rate %", color="#aaa", fontsize=9); ax.set_ylim(0,115)
    ax.set_title(f"{label} ‚Äî per-character hit rate ({total} tries)", color="#ddd", fontsize=10)
    for sp in ax.spines.values(): sp.set_color(GR)
    ax.tick_params(colors="#888"); ax.grid(axis="y", color=GR, alpha=0.6, lw=0.5)
    plt.tight_layout(); return fig

def fig_retention(t_data, b_data, noise_levels):
    fig, axes = plt.subplots(1, 2, figsize=(13, 4.5), facecolor=BG)
    for ax, (title, key) in zip(axes,
        [("Exact Match %","exact"),("Avg Char Match %","avg_chars")]):
        ax.set_facecolor(FG)
        if t_data:
            tv=[t_data[n][key]*100 for n in noise_levels]
            ax.plot(noise_levels, tv, "o-", color=COL_T, lw=2.5, ms=7, label="Transformer", zorder=3)
            ax.fill_between(noise_levels, tv, alpha=0.12, color=COL_T)
            for x,y in zip(noise_levels,tv):
                ax.annotate(f"{y:.0f}%",(x,y),xytext=(0,8),textcoords="offset points",
                            ha="center",fontsize=7,color=COL_T)
        if b_data:
            bv=[b_data[n][key]*100 for n in noise_levels]
            ax.plot(noise_levels, bv, "s--", color=COL_B, lw=2.5, ms=7, label="BDH", zorder=3)
            ax.fill_between(noise_levels, bv, alpha=0.12, color=COL_B)
            for x,y in zip(noise_levels,bv):
                ax.annotate(f"{y:.0f}%",(x,y),xytext=(0,-14),textcoords="offset points",
                            ha="center",fontsize=7,color=COL_B)
        ax.axhline(100/36, color="#555", ls=":", lw=1.2, label="Random chance")
        ax.set_title(title, color="#fff", fontsize=11, pad=8)
        ax.set_xlabel("Noise Bytes Inserted", color="#aaa", fontsize=9)
        ax.set_ylim(-3,110); ax.set_xticks(noise_levels)
        ax.xaxis.set_tick_params(rotation=30); ax.tick_params(colors="#888", labelsize=8)
        for sp in ax.spines.values(): sp.set_color(GR)
        ax.legend(fontsize=8, facecolor=FG, labelcolor="#ddd", edgecolor=GR)
        ax.yaxis.set_major_formatter(mticker.FormatStrFormatter("%.0f%%"))
        ax.grid(True, color=GR, alpha=0.7, lw=0.6)
    fig.suptitle("Memory Retention ‚Äî BDH vs Transformer", color="#fff", fontsize=13, y=1.01)
    plt.tight_layout(); return fig

def fig_attn(attn_np, pw_bytes, seq_len):
    pw_len = len(pw_bytes)
    show   = min(30, seq_len)
    data   = attn_np[:, :show]
    fig, ax = plt.subplots(figsize=(max(8, show*0.45), 2.8), facecolor=BG)
    ax.set_facecolor(FG)
    im = ax.imshow(data, aspect="auto", cmap="plasma", vmin=0, vmax=1)
    ax.set_yticks(range(pw_len))
    ax.set_yticklabels([f"slot {i+1} ‚Üí '{chr(pw_bytes[i])}'"
                        for i in range(pw_len)], color="#ccc", fontsize=9)
    ax.set_xticks(range(show))
    ax.set_xticklabels([f"pos {i}" for i in range(show)],
                       rotation=60, ha="right", color="#888", fontsize=7)

    for i in range(pw_len):
        ax.axvline(PW_START + i, color="#4ade80", lw=1.5, alpha=0.6)
    ax.set_title("BDH Position Map ‚Äî green lines = password byte positions",
                 color="#ddd", fontsize=10, pad=6)
    plt.colorbar(im, ax=ax, fraction=0.02, pad=0.02).ax.tick_params(colors="#aaa")
    for sp in ax.spines.values(): sp.set_color(GR)
    plt.tight_layout(); return fig

def fig_delta(true_pw, preds_B, preds_T):
    pl=len(true_pw); total=len(preds_B)
    bh=[sum(1 for p in preds_B if len(p)>i and p[i]==true_pw[i]) for i in range(pl)]
    th=[sum(1 for p in preds_T if len(p)>i and p[i]==true_pw[i]) for i in range(pl)]
    fig, ax = plt.subplots(figsize=(max(7, pl*1.2), 3.2), facecolor=BG)
    ax.set_facecolor(FG)
    x=np.arange(pl)
    b1=ax.bar(x-0.22,[h/total*100 for h in bh],0.40,color=COL_B,label="BDH",edgecolor=GR)
    b2=ax.bar(x+0.22,[h/total*100 for h in th],0.40,color=COL_T,label="Transformer",edgecolor=GR)
    for bars, hits in [(b1,bh),(b2,th)]:
        for bar,h in zip(bars,hits):
            ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+1,
                    f"{h}/{total}", ha="center", va="bottom", color="#ccc", fontsize=9)
    ax.set_xticks(x)
    ax.set_xticklabels([chr(true_pw[i]) if 32<=true_pw[i]<127 else "?" for i in range(pl)],
                       color="#4ade80", fontsize=15, fontweight="bold")
    ax.set_ylabel("Hit rate %", color="#aaa", fontsize=9); ax.set_ylim(0,125)
    ax.set_title("Per-character head-to-head: BDH üü£ vs Transformer üü°",
                 color="#ddd", fontsize=11)
    ax.legend(fontsize=9, facecolor=FG, labelcolor="#ddd", edgecolor=GR)
    for sp in ax.spines.values(): sp.set_color(GR)
    ax.tick_params(colors="#888"); ax.grid(axis="y", color=GR, alpha=0.6, lw=0.5)
    plt.tight_layout(); return fig

#  BENCHMARK
@torch.no_grad()
def run_benchmark(model, use_slots, noise_levels, trials, pw_len=6):
    model.eval(); out={}
    for n in noise_levels:
        exact=ch_hits=0; rows=[]
        for _ in range(trials):
            pw=rand_pw(pw_len); prompt=build_prompt(pw,n)
            pred = bdh_gen(model,prompt,pw_len)[0] if use_slots else greedy(model,prompt,pw_len)
            cc=sum(a==b for a,b in zip(pw,pred))
            ch_hits+=cc; exact+=int(pred==pw); rows.append((pw,pred,cc))
        out[n]={"exact":exact/trials,"avg_chars":ch_hits/(trials*pw_len),"rows":rows}
    return out

def render_preds(col, true_pw, preds, label, color):
    with col:
        st.markdown(f"#### {label}")
        pl=len(true_pw); total=len(preds)
        exact=sum(p==true_pw for p in preds)
        avg=sum(sum(a==b for a,b in zip(true_pw,p)) for p in preds)/(total*pl)
        best_cc=max(sum(a==b for a,b in zip(true_pw,p)) for p in preds)
        st.markdown(f"""<div class="mgrid">
          <div class="mcell"><h2 style="color:{'#4ade80' if exact else '#f87171'}">{exact}/{total}</h2>
            <p>Exact matches</p></div>
          <div class="mcell"><h2 style="color:{color}">{avg:.0%}</h2>
            <p>Avg char accuracy</p></div>
          <div class="mcell"><h2>{best_cc}/{pl}</h2>
            <p>Best attempt</p></div>
        </div>""", unsafe_allow_html=True)
        st.markdown("")
        for i, pred in enumerate(preds):
            cc=sum(a==b for a,b in zip(true_pw,pred))
            st.markdown(f"**Try {i+1}** &nbsp; {char_html(true_pw,pred)} &nbsp; {badge(cc,pl)}",
                        unsafe_allow_html=True)
        st.markdown("")
        f=fig_char_bar(true_pw, preds, label, color)
        st.pyplot(f, use_container_width=True); plt.close(f)


# PAGE
st.markdown("# BDH vs Transformer ‚Äî Memory Retention")
tp_c="pill-ok" if trained_t else "pill-warn"
bp_c="pill-ok" if trained_b else "pill-warn"
st.markdown(
    f"<span class='pill'>Device: {DEVICE.upper()}</span> &nbsp;"
    f"<span class='{bp_c}'>BDH: {'‚úÖ trained' if trained_b else '‚ö†Ô∏è untrained'}</span> &nbsp;"
    f"<span class='{tp_c}'>Transformer: {'‚úÖ trained' if trained_t else '‚ö†Ô∏è untrained'}</span>",
    unsafe_allow_html=True)

if not trained_t or not trained_b:
    st.markdown("""<div class="info-box">‚ö° <strong>Train first!</strong>
    Sidebar ‚Üí <strong> Train Now</strong><br>
    BDH (~400 steps, ~90 sec CPU) ¬∑ Transformer (~600 steps, ~3 min CPU)
    </div>""", unsafe_allow_html=True)

st.markdown("")
tab1, tab2, tab3 = st.tabs(["üß™ Interactive Demo", "üìä Benchmarks", " BDH vs Transformer"])

# TAB 1
with tab1:
    st.markdown("### Interactive Demo")
    st.markdown("""<div class="info-box">
    üî¨ <strong>Architecture explained:</strong><br>
    <strong>BDH</strong> ‚Äî bidirectional encoder reads password directly from its fixed position (bytes 12-17)
    ‚Üí very reliable across noise levels.<br>
    <strong>Transformer</strong> ‚Äî must decode autoregressively, competing against noise in memory
    ‚Üí accuracy drops faster as noise increases.
    </div>""", unsafe_allow_html=True)

    pw_bytes = custom_pw.encode("latin1")[:10] if custom_pw.strip() else rand_pw(6)
    pw_len   = len(pw_bytes)
    preview  = make_noise_bytes(min(noise_len, 50))
    st.text_area("Prompt sent to model",
        f"PASSWORD=<<<{pw_bytes.decode('latin1')}>>>\n"
        f"NOISE={preview.decode('latin1','replace')}{'...' if noise_len>50 else ''}\n"
        f"Q:WHAT_IS_PASSWORD?\nA:<<<", height=95)

    if st.button("‚ñ∂Ô∏è  Run", key="run1"):
        if not trained_t and not trained_b:
            st.error("‚ö†Ô∏è Train first!"); st.stop()

        do_b = sel_model in ("BDH", "Both")
        do_t = sel_model in ("Transformer", "Both")
        preds_b=[]; preds_t=[]; attn_last=None
        bar=st.progress(0)
        for i in range(n_tries):
            prompt = build_prompt(pw_bytes, noise_len)
            if do_b:
                pb, aw = bdh_gen(B_model, prompt, pw_len)
                preds_b.append(pb); attn_last=aw
            if do_t:
                preds_t.append(sample_gen(T_model, prompt, pw_len, temp, top_k_val))
            bar.progress((i+1)/n_tries)
            time.sleep(0.01)
        bar.empty()

        st.markdown("---")
        st.markdown(
            f"**True password: <span style='font-family:monospace;font-size:1.5rem;"
            f"color:#4ade80;letter-spacing:6px'>{pw_bytes.decode('latin1')}</span>**"
            f" &nbsp;<span class='pill'>noise: {noise_len}B</span>",
            unsafe_allow_html=True)
        st.markdown("")

        ncols=2 if (do_b and do_t) else 1
        cols=st.columns(ncols); ci=0
        if do_b:
            render_preds(cols[ci], pw_bytes, preds_b, "üü£ BDH", COL_B); ci+=1
        if do_t:
            render_preds(cols[ci], pw_bytes, preds_t, "üü° Transformer", COL_T)

        if do_b and attn_last is not None:
            st.markdown("---")
            st.markdown("#### BDH Position Map")
            st.caption("Green lines = exact byte positions of each password character in the prompt.")
            fa = fig_attn(attn_last, pw_bytes, build_prompt(pw_bytes, noise_len).shape[0])
            st.pyplot(fa, use_container_width=True); plt.close(fa)
    else:
        st.info("üëÜ Hit **‚ñ∂Ô∏è Run** to see character-level results.")

# TAB 2
with tab2:
    st.markdown("### Retention Benchmarks")
    if run_bench:
        if len(bench_noise)<2:
            st.error("Pick ‚â• 2 noise levels.")
        else:
            b_dat=t_dat=None
            if sel_model in ("BDH","Both"):
                with st.spinner("BDH‚Ä¶"):
                    b_dat=run_benchmark(B_model,True,bench_noise,bench_trials)
                st.success("BDH ‚úÖ")
            if sel_model in ("Transformer","Both"):
                with st.spinner("Transformer‚Ä¶"):
                    t_dat=run_benchmark(T_model,False,bench_noise,bench_trials)
                st.success("Transformer ‚úÖ")
            st.session_state.update(bdat=b_dat,tdat=t_dat,bnoise=bench_noise)

    if "bdat" in st.session_state or "tdat" in st.session_state:
        b_dat=st.session_state.get("bdat"); t_dat=st.session_state.get("tdat")
        n_list=st.session_state.get("bnoise",bench_noise)
        fr=fig_retention(t_dat,b_dat,n_list)
        st.pyplot(fr,use_container_width=True); plt.close(fr)
        st.markdown("####  Table")
        hdr=st.columns([1]+[2]*len(n_list)); hdr[0].markdown("**Model**")
        for i,n in enumerate(n_list): hdr[i+1].markdown(f"**{n}B**")
        for lbl,data,color in [("üü° Transformer",t_dat,COL_T),("üü£ BDH",b_dat,COL_B)]:
            if data is None: continue
            row=st.columns([1]+[2]*len(n_list)); row[0].markdown(lbl)
            for i,n in enumerate(n_list):
                ex=data[n]["exact"]*100; ch=data[n]["avg_chars"]*100
                row[i+1].markdown(
                    f"<span style='color:{color};font-weight:700'>{ex:.0f}%</span> "
                    f"<span style='color:#888;font-size:.8rem'>({ch:.0f}%)</span>",
                    unsafe_allow_html=True)
    else:
        st.info("üëà Click **Run Benchmark** in the sidebar.")

# TAB 3
with tab3:
    st.markdown("###  Head-to-Head on Identical Prompt")
    c1,c2,c3=st.columns([2,2,1])
    cmp_pw_in=c1.text_input("Password","",max_chars=10,placeholder="blank=random",key="cpw")
    cmp_noise=c2.slider("Noise bytes",0,1024,128,step=32,key="cnoise")
    cmp_tries=c3.slider("Tries",1,20,5,key="ctries")

    if st.button("‚öîÔ∏è  Compare Both Models"):
        if not trained_t and not trained_b:
            st.error("Train first!"); st.stop()
        cmp_pw=(cmp_pw_in.encode("latin1")[:10] if cmp_pw_in.strip() else rand_pw(6))
        cpl=len(cmp_pw); pB=[]; pT=[]
        bar2=st.progress(0)
        for i in range(cmp_tries):
            prompt=build_prompt(cmp_pw,cmp_noise)
            pb,_=bdh_gen(B_model,prompt,cpl)
            pt=greedy(T_model,prompt,cpl)
            pB.append(pb); pT.append(pt)
            bar2.progress((i+1)/cmp_tries); time.sleep(0.01)
        bar2.empty()
        st.markdown(
            f"**Password: <span style='font-family:monospace;font-size:1.5rem;"
            f"color:#4ade80;letter-spacing:6px'>{cmp_pw.decode('latin1')}</span>**"
            f" &nbsp;<span class='pill'>noise: {cmp_noise}B</span>",
            unsafe_allow_html=True)
        st.markdown("")
        lc,rc=st.columns(2)
        render_preds(lc,cmp_pw,pB,"üü£ BDH",COL_B)
        render_preds(rc,cmp_pw,pT,"üü° Transformer",COL_T)
        st.markdown("---")
        st.markdown("#### üìä Per-character comparison")
        fd=fig_delta(cmp_pw,pB,pT)
        st.pyplot(fd,use_container_width=True); plt.close(fd)
    else:
        st.info("üëÜ Click **‚öîÔ∏è Compare Both Models**.")

st.markdown("---")
st.caption("BDH extracts from fixed password positions (bytes 12-17) ‚Äî "
           "Transformer must decode autoregressively. This contrast is the demo.")

In [None]:

import os, subprocess, time

os.system("pip install -q pyngrok")
os.system("pip install -q streamlit")

from pyngrok import ngrok, conf

os.system("pkill -f streamlit || true")
time.sleep(2)

ngrok.set_auth_token(os.environ["NGROK_AUTH_TOKEN"])

proc = subprocess.Popen(
    ["streamlit", "run", "/content/app.py",
     "--server.port", "8501",
     "--server.headless", "true",
     "--server.enableCORS", "false",
     "--server.enableXsrfProtection", "false"],
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT
)
time.sleep(5)


tunnel = ngrok.connect(8501, "http")
print("=" * 55)
print(f"  üöÄ  App is LIVE at:  {tunnel.public_url}")
print("=" * 55)
print("\n Once open ‚Üí Sidebar ‚Üí Train Now ‚Üí then Run Demo!")