# MLMpire CTF: Guided Flag Extraction
This notebook follows the intended path: avoid the 'api/auth...' decoy by choosing 'A' after 'password: ', then steer to 'AI2025{' and constrain characters inside braces.

In [None]:
# Install dependencies (Windows/Colab-safe)
import sys
!{sys.executable} -m pip install --quiet transformers torch

In [None]:
import torch
import torch.nn.functional as F
from architecture import load_hf_gpt2_model, fill_mask

def top_k(wrapper, prompt_text: str, k: int = 10):
    logits = fill_mask(wrapper, prompt_text)
    probs = F.softmax(logits, dim=-1)
    top_k_probs, top_k_indices = torch.topk(probs, k)
    results = []
    for i in range(k):
        idx = top_k_indices[i].item()
        results.append((wrapper.itos[idx], top_k_probs[i].item(), idx))
    return results

def next_char_constrained(wrapper, prompt_text: str, allowed: set[str], k: int = 15):
    logits = fill_mask(wrapper, prompt_text)
    probs = F.softmax(logits, dim=-1)
    top_k_probs, top_k_indices = torch.topk(probs, k)
    for i in range(k):
        idx = int(top_k_indices[i].item())
        tok = wrapper.itos[idx]
        if tok in allowed:
            return tok, float(top_k_probs[i].item())
    idx = int(torch.argmax(probs).item())
    return wrapper.itos[idx], float(probs[idx].item())

def ensure_in_window(s: str, max_len: int = 120) -> str:
    return s[-max_len:]

In [None]:
# Load wrapper (uses local 'hf_gpt2_model' and 'vocab.json')
wrapper = load_hf_gpt2_model()
print('Model loaded; vocab size =', len(wrapper.vocab))

In [None]:
# Inspect the initial branch after 'password: ' to see the decoy vs correct path
first_top = top_k(wrapper, 'password: [MASK]', k=10)
for i, (tok, p, _) in enumerate(first_top, 1):
    print(f"{i:>2}. {tok!r}  p={p:.4f}")
print("Hint: choose 'A' here to avoid the log-decoy path.")

In [None]:
# Guided extraction: force 'AI2025{' then constrain inside braces until '}'
known = 'password: A'
for ch in 'I2025{':
    prompt = known[-120:] + '[MASK]'
    # Force the intended next character; we print its probability for sanity
    logits = fill_mask(wrapper, prompt)
    probs = F.softmax(logits, dim=-1)
    idxs = [i for i, s in wrapper.itos.items() if s == ch]
    p = float(probs[idxs[0]].item()) if idxs else 0.0
    print(f'Steer to {ch!r}: p={p:.5f}')
    known += ch

allowed_inside = set(list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_'))
allowed_all = allowed_inside | {'}'}

for step in range(256):
    if '}' in known:
        break
    prompt = known[-120:] + '[MASK]'
    ch, p = next_char_constrained(wrapper, prompt, allowed_all, k=30)
    print(f'next={ch!r} p={p:.5f}')
    known += ch

print('
Known string:', known)
s = known.find('AI2025{')
e = known.find('}', s+1) if s != -1 else -1
flag = known[s:e+1] if (s != -1 and e != -1) else ''
print('FLAG:', flag or '(not recovered)')

In [None]:
# Beam search constrained inside braces starting from the known prefix
import math

# Force the intended starting point as per author hint
dec_ctx = 'password: AI2025{'

beam = [(0.0, dec_ctx)]  # (logprob, text)
beam_width = 5
max_steps = 128
allowed_inside = set(list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_'))
allowed_all = allowed_inside | {'}'}

for step in range(max_steps):
    new_beam = []
    closed = []
    for score, cand in beam:
        if '}' in cand:
            closed.append((score, cand))
            continue
        prompt = cand[-120:] + '[MASK]'
        logits = fill_mask(wrapper, prompt)
        logp = F.log_softmax(logits, dim=-1)
        top_logp, top_idx = torch.topk(logp, 80)
        for i in range(top_logp.numel()):
            idx = int(top_idx[i].item())
            tok = wrapper.itos[idx]
            if tok not in allowed_all:
                continue
            bonus = 2.0 if tok == '}' else 0.0
            new_beam.append((score + float(top_logp[i].item()) + bonus, cand + tok))
    if closed:
        closed.sort(key=lambda x: x[0], reverse=True)
        best = closed[0][1]
        print('Recovered from guided beam:', best)
        s = best.find('AI2025{')
        e = best.find('}', s+1) if s != -1 else -1
        flag = best[s:e+1] if (s != -1 and e != -1) else ''
        print('FLAG:', flag or '(not recovered)')
        break
    if not new_beam:
        print('Beam exhausted without closing brace.')
        break
    new_beam.sort(key=lambda x: x[0], reverse=True)
    beam = new_beam[:beam_width]
    print(f"Step {step+1}, best: {beam[0][1][-80:]}")

In [None]:
# Optional pivot: continue decoding after the decoy tail to redirect toward the flag
# If your best path ends with: body={"name":"admin"}
# we try to overwrite that trajectory by forcing the correct prefix next.

decoy_tail = 'body={"name":"admin"}'
start = decoy_tail + ' '  # add a space separator if needed

# if you captured a longer context in a previous search, paste it into 'start' before running

# Force the correct prefix after the decoy
forced = start + 'AI2025{'
print('Starting from decoy tail, forcing:', forced)

beam = [(0.0, forced)]
beam_width = 5
max_steps = 128
allowed_inside = set(list('abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_'))
allowed_all = allowed_inside | {'}'}

for step in range(max_steps):
    new_beam = []
    closed = []
    for score, cand in beam:
        if '}' in cand:
            closed.append((score, cand))
            continue
        prompt = cand[-120:] + '[MASK]'
        logits = fill_mask(wrapper, prompt)
        logp = F.log_softmax(logits, dim=-1)
        top_logp, top_idx = torch.topk(logp, 80)
        for i in range(top_logp.numel()):
            idx = int(top_idx[i].item())
            tok = wrapper.itos[idx]
            if tok not in allowed_all:
                continue
            bonus = 2.0 if tok == '}' else 0.0
            new_beam.append((score + float(top_logp[i].item()) + bonus, cand + tok))
    if closed:
        closed.sort(key=lambda x: x[0], reverse=True)
        best = closed[0][1]
        print('Recovered from decoy pivot:', best)
        s = best.find('AI2025{')
        e = best.find('}', s+1) if s != -1 else -1
        flag = best[s:e+1] if (s != -1 and e != -1) else ''
        print('FLAG:', flag or '(not recovered)')
        break
    if not new_beam:
        print('Beam exhausted without closing brace (pivot run).')
        break
    new_beam.sort(key=lambda x: x[0], reverse=True)
    beam = new_beam[:beam_width]
    print(f"[Pivot] Step {step+1}, best: {beam[0][1][-80:]}")