<a href="https://colab.research.google.com/github/jonathanrbelanger-lang/Janus_Arc/blob/main/Janus_Hero_Test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# @title [Run] Janus Architecture Testing Script: Runs head to head tests to determine an optimal configuration for a 40m Model

import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import json
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# UPDATE THIS PATH TO YOUR BIN FILE
DATA_FILE = "/content/drive/MyDrive/Project_XAI_Physical_Janus/data/processed/TinyStories-train_full.bin"
RESULTS_DIR = "/content/drive/MyDrive/Project_XAI_Physical_Janus/data/results/janus_v2_clean"
os.makedirs(RESULTS_DIR, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"‚öîÔ∏è JANUS GAUNTLET: The Clean Room")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. THE ARCHITECTURE (INLINE - NO SAFEGUARDS) ---

class CleanConfig:
    def __init__(self, **kwargs):
        # Defaults
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 16 # Default, overridden in tests
        self.max_seq_len = 128
        self.dropout = 0.0
        self.lambda_diversity = 0.0
        self.lambda_coherence = 0.0
        self.spatial_schedule = 'cubic'
        self.enable_steering = True

        for k, v in kwargs.items(): setattr(self, k, v)

        # Calculated
        self.d_head = self.d_model // self.n_heads

class CleanScheduler:
    def __init__(self, config):
        self.total_layers = config.n_layers
        self.base_div = config.lambda_diversity
        self.base_coh = config.lambda_coherence
        self.spatial = config.spatial_schedule

    def get_lambdas(self, step, max_steps, layer_id):
        # Time Ramp (Trapezoidal)
        t_mult = 0.0
        if max_steps > 0:
            prog = step / max_steps
            if prog < 0.2: t_mult = prog / 0.2
            elif prog < 0.8: t_mult = 1.0
            else: t_mult = max(0.0, 1.0 - (prog - 0.8)/0.2)

        # Space Ramp
        ratio = (layer_id + 1) / self.total_layers
        if self.spatial == 'cubic': s_mult = ratio ** 3
        elif self.spatial == 'linear': s_mult = ratio
        else: s_mult = 1.0

        return (self.base_coh * t_mult * s_mult, self.base_div * t_mult * s_mult)

class CleanAttention(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_model // config.n_heads
        self.scale = 1.0 / math.sqrt(self.d_head) # <--- THIS GUARANTEES LOSS DIVERGENCE

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

    def forward(self, x, lambdas):
        B, S, D = x.shape
        l_coh, l_div = lambdas

        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)

        head_out = attn_probs @ v

        # Steering Physics
        steer_loss = 0.0
        if l_div > 0.0:
            # Orthogonality: Penalize cosine sim between head outputs
            # Flatten to (H, B*S*D)
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss

class CleanBlock(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CleanAttention(config, layer_id)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model)
        )
    def forward(self, x, lambdas):
        a, s = self.attn(self.ln1(x), lambdas)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s

class CleanGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([CleanBlock(config, i) for i in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.scheduler = CleanScheduler(config)
        self.step = 0; self.max_steps = 1

    def forward(self, idx, targets=None):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))

        total_steer = 0.0
        for i, block in enumerate(self.blocks):
            lambdas = self.scheduler.get_lambdas(self.step, self.max_steps, i)
            x, s = block(x, lambdas)
            total_steer += s

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss, total_steer

# --- 3. LOADER (MEMMAP) ---
class BinLoader:
    def __init__(self, bin_path, block_size, batch_size):
        self.block_size = block_size
        self.batch_size = batch_size
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        print(f"üì¶ Data Loaded: {len(self.data):,} tokens")

    def get_batch(self):
        ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

# --- 4. THE GAUNTLET ---
def run_probe(name, config_dict, loader, steps=500):
    gc.collect(); torch.cuda.empty_cache()

    cfg = CleanConfig(**config_dict)
    model = CleanGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4)

    losses = []
    pbar = tqdm(range(steps), desc=name, leave=False)

    model.train()
    for step in pbar:
        x, y = loader.get_batch()
        model.step = step
        model.max_steps = steps

        loss, steer = model(x, y)
        total = loss + steer

        optimizer.zero_grad()
        total.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        losses.append(loss.item())
        if step % 20 == 0:
            pbar.set_description(f"{name} | L:{loss.item():.3f}")

    return np.mean(losses[-100:])

def run_pressure_ramp(name, config_dict, loader, steps=500):
    print(f"\nüèãÔ∏è Ramping Pressure on {name}...")
    gc.collect(); torch.cuda.empty_cache()

    cfg = CleanConfig(**config_dict)
    # Start at 0 pressure
    cfg.lambda_diversity = 0.0

    model = CleanGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4)

    baseline = 10.0

    model.train()
    for step in range(steps):
        x, y = loader.get_batch()

        # Manually ramp pressure in the scheduler
        curr_p = 0.50 * (step / steps) # Max 0.50
        model.scheduler.base_div = curr_p
        model.scheduler.base_coh = curr_p * 0.2
        model.step = step
        model.max_steps = steps

        loss, steer = model(x, y)
        total = loss + steer

        optimizer.zero_grad(); total.backward(); optimizer.step()

        if step == 50: baseline = loss.item()

        # Explosion Check (30% spike)
        if step > 100 and loss.item() > baseline * 1.3:
            print(f"üí• FRACTURE at Pressure {curr_p:.4f} (Loss {loss.item():.3f})")
            return curr_p * 0.75 # 25% safety buffer

        if step % 50 == 0:
            print(f"   Step {step} | P:{curr_p:.3f} | L:{loss.item():.3f}")

    print("‚úÖ No Fracture. Cap at 0.50.")
    return 0.40 # Safe conservative cap

def main():
    loader = BinLoader(DATA_FILE, 128, 64)
    manifest = {}

    # --- PHASE 1: GEOMETRY ---
    print("\n\n=== PHASE 1: GEOMETRY DUEL ===")
    # Note: We add tiny pressure (0.01) to force differentiation if needed,
    # but the scaling factor alone should differentiate them.
    score_32 = run_probe("16x32d", {'n_heads': 16, 'lambda_diversity': 0.01}, loader)
    score_64 = run_probe("8x64d",  {'n_heads': 8, 'lambda_diversity': 0.01}, loader)

    print(f"\nüèÅ RESULTS: 16x32d={score_32:.4f} | 8x64d={score_64:.4f}")
    if score_64 < score_32:
        print("üèÜ WINNER: 64-dim Heads")
        best_geo = {'n_heads': 8}
    else:
        print("üèÜ WINNER: 32-dim Heads")
        best_geo = {'n_heads': 16}

    manifest.update(best_geo)

    # --- PHASE 2: PRESSURE ---
    print("\n\n=== PHASE 2: PRESSURE TEST ===")
    safe_p = run_pressure_ramp("Winner", best_geo, loader)
    print(f"üõ°Ô∏è Safe Pressure: {safe_p:.4f}")
    manifest['lambda_diversity'] = safe_p
    manifest['lambda_coherence'] = safe_p * 0.2

    # --- PHASE 3: SPATIAL ---
    print("\n\n=== PHASE 3: SPATIAL DUEL ===")
    print(f"Testing schedules at P={safe_p:.4f}")

    score_lin = run_probe("Linear", {**best_geo, 'lambda_diversity': safe_p, 'spatial_schedule': 'linear'}, loader)
    score_cub = run_probe("Cubic",  {**best_geo, 'lambda_diversity': safe_p, 'spatial_schedule': 'cubic'}, loader)

    print(f"\nüèÅ RESULTS: Linear={score_lin:.4f} | Cubic={score_cub:.4f}")
    if score_cub < score_lin:
        print("üèÜ WINNER: Cubic")
        manifest['spatial_schedule'] = 'cubic'
    else:
        print("üèÜ WINNER: Linear")
        manifest['spatial_schedule'] = 'linear'

    # --- SAVE ---
    manifest['n_layers'] = 12
    manifest['d_model'] = 512
    out_path = os.path.join(RESULTS_DIR, "janus_v2_manifest.json")
    with open(out_path, 'w') as f:
        json.dump(manifest, f, indent=4)
    print(f"\nüíæ Manifest Saved: {out_path}")

if __name__ == "__main__":
    main()

In [None]:

# @title [RUN] JSmallHConfirmation

import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
import time
from tqdm import tqdm
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# PATHS
PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_FILE = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v2")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ JANUS HERO V2: The Homeostatic Run")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. CONFIGURATION (LOCKED) ---
class HeroConfig:
    def __init__(self):
        # ARCHITECTURE (Winner of Geometry Duel)
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8  # 64-dim heads
        self.d_head = 64  # Explicit
        self.max_seq_len = 512
        self.dropout = 0.05

        # HOMEOSTASIS (67% of Max Safe)
        self.max_lambda_div = 0.27
        self.max_lambda_coh = 0.054 # 20% rule
        self.spatial_schedule = 'cubic'

        # TRAINING
        self.max_steps = 5000
        self.batch_size = 32
        self.grad_accum = 2 # Effective Batch 64
        self.val_interval = 250
        self.val_steps = 50

# --- 3. THE ENGINE (INLINE) ---
class CleanScheduler:
    def __init__(self, config):
        self.config = config
        self.current_div = 0.0
        self.current_coh = 0.0

    def step(self, step_num):
        """3-Stage Burn Strategy"""
        # Stage 1: IGNITION (0 - 750) -> Zero Pressure
        if step_num < 750:
            self.current_div = 0.0
            self.current_coh = 0.0

        # Stage 2: PRESSURIZATION (750 - 2000) -> Ramp to Target
        elif step_num < 2000:
            progress = (step_num - 750) / (2000 - 750)
            self.current_div = self.config.max_lambda_div * progress
            self.current_coh = self.config.max_lambda_coh * progress

        # Stage 3: CRUISING (2000+) -> Hold Steady
        else:
            self.current_div = self.config.max_lambda_div
            self.current_coh = self.config.max_lambda_coh

        return self.current_div

    def get_lambdas(self, layer_id):
        # Spatial Component (Cubic)
        ratio = (layer_id + 1) / self.config.n_layers
        s_mult = ratio ** 3
        return (self.current_coh * s_mult, self.current_div * s_mult)

class CleanAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, lambdas, return_metrics=False):
        B, S, D = x.shape
        l_coh, l_div = lambdas

        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)

        head_out = attn_probs @ v

        # --- PHYSICS & TELEMETRY ---
        metrics = {}
        steer_loss = 0.0

        # 1. Steering (Training Only)
        if l_div > 0.0 and self.training:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div

        # 2. Telemetry (If requested)
        if return_metrics:
            with torch.no_grad():
                # Sigma_P (Focus/Entropy)
                entropy = -torch.sum(attn_probs * torch.log(attn_probs + 1e-9), dim=-1)
                max_ent = math.log(S)
                metrics['sigma_p'] = (1.0 - (entropy / max_ent)).mean(dim=[0, 2]) # Avg over Batch/Seq

                # Sigma_A (Uniqueness/Orthogonality)
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

                # Eff_Rank (Dimensionality)
                # Sample subset for speed
                sub_out = head_out[:, :, :128, :].transpose(1, 2).reshape(self.n_heads, -1, self.d_head)
                ranks = []
                for h in range(self.n_heads):
                    try:
                        S_vals = torch.linalg.svdvals(sub_out[h].float())
                        p = S_vals / S_vals.sum()
                        ent = -torch.sum(p * torch.log(p + 1e-9))
                        ranks.append(torch.exp(ent))
                    except: ranks.append(torch.tensor(0.0))
                metrics['eff_rank'] = torch.stack(ranks).to(x.device)

                # Kurtosis (Sharpness)
                # flat_probs: (H, B*S*S) - Expensive, estimate on sample
                # Skipping full kurtosis to save VRAM, using Sigma_P as proxy for sharpness

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class CleanBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CleanAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model),
            nn.Dropout(config.dropout)
        )
    def forward(self, x, lambdas, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), lambdas, return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class CleanGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([CleanBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.scheduler = CleanScheduler(config)

    def forward(self, idx, targets=None, step_num=0, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))

        # Update Scheduler
        current_div = self.scheduler.step(step_num)

        total_steer = 0.0
        all_metrics = []

        for i, block in enumerate(self.blocks):
            lambdas = self.scheduler.get_lambdas(i)
            # Only calculate expensive metrics for deep layers or periodically
            do_metrics = return_metrics # Can optimize later
            x, s, m = block(x, lambdas, do_metrics)
            total_steer += s
            if do_metrics: all_metrics.append(m)

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return loss, total_steer, all_metrics, current_div

# --- 4. DATA LOADER (SPLIT) ---
class SplitLoader:
    def __init__(self, bin_path, block_size, batch_size):
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        total_tokens = len(self.data)
        split_idx = int(total_tokens * 0.95)

        self.train_data = self.data[:split_idx]
        self.val_data = self.data[split_idx:]

        self.block_size = block_size
        self.batch_size = batch_size
        print(f"üì¶ Data Split | Train: {len(self.train_data):,} | Val: {len(self.val_data):,}")

    def get_batch(self, split='train'):
        source = self.train_data if split == 'train' else self.val_data
        ix = torch.randint(len(source) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(source[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(source[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

# --- 5. BLACK BOX (LOGGER) ---
class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir

    def log(self, step, loss, val_loss, pressure, metrics_list):
        # Flatten metrics for storage
        # We store averages per layer to keep file size manageable
        row = {
            "step": step,
            "loss": loss,
            "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": pressure
        }

        for i, layer_m in enumerate(metrics_list):
            if not layer_m: continue
            # Average across heads for the summary log
            row[f"L{i}_sigma_p"] = layer_m['sigma_p'].mean().item()
            row[f"L{i}_sigma_a"] = layer_m['sigma_a'].mean().item()
            row[f"L{i}_eff_rank"] = layer_m['eff_rank'].mean().item()

        self.buffer.append(row)

    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        # Append mode
        fpath = os.path.join(self.save_dir, "telemetry_hero.parquet")
        if os.path.exists(fpath):
            existing = pd.read_parquet(fpath)
            df = pd.concat([existing, df])
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 6. MAIN LOOP ---
def run_hero():
    # Init
    cfg = HeroConfig()
    loader = SplitLoader(DATA_FILE, cfg.max_seq_len, cfg.batch_size)
    model = CleanGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4, weight_decay=1e-4)
    recorder = BlackBox(SAVE_DIR)

    # Checkpoint check
    start_step = 0
    ckpt_path = os.path.join(SAVE_DIR, "ckpt_latest.pt")
    if os.path.exists(ckpt_path):
        print("üîÑ Resuming from Checkpoint...")
        ckpt = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(ckpt['model'])
        optimizer.load_state_dict(ckpt['optim'])
        start_step = ckpt['step']

    print(f"\nüèÉ STARTING RUN: {start_step} -> {cfg.max_steps}")

    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    accum_loss = 0.0

    for step in pbar:
        # 1. Training Step
        model.train()

        # Gradient Accumulation
        batch_loss = 0.0
        batch_steer = 0.0

        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch('train')
            # Only calc metrics on the last micro-batch to save compute
            do_metrics = (_ == cfg.grad_accum - 1)

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, steer, metrics, pressure = model(x, y, step, return_metrics=do_metrics)
                total = (loss + steer) / cfg.grad_accum

            optimizer.zero_grad()
            total.backward()
            batch_loss += loss.item()
            if steer > 0: batch_steer += steer.item()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        avg_loss = batch_loss / cfg.grad_accum

        # 2. Validation & Logging (Interval)
        val_loss = 0.0
        if step % cfg.val_interval == 0 or step == cfg.max_steps - 1:
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(cfg.val_steps):
                    vx, vy = loader.get_batch('val')
                    # No steering in validation
                    vl, _, _, _ = model(vx, vy, step, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)

            # Log to BlackBox
            recorder.log(step, avg_loss, val_loss, pressure, metrics)
            recorder.flush()

            # Save Checkpoint
            torch.save({
                'step': step,
                'model': model.state_dict(),
                'optim': optimizer.state_dict()
            }, ckpt_path)

            # Update Bar
            desc = f"L:{avg_loss:.3f}|V:{val_loss:.3f}|P:{pressure:.3f}"
            pbar.set_description(desc)
        else:
            pbar.set_description(f"L:{avg_loss:.3f}|P:{pressure:.3f}")

    # Final Save
    torch.save(model.state_dict(), os.path.join(SAVE_DIR, "janus_hero_final.pt"))
    print("\n‚úÖ MISSION COMPLETE.")

if __name__ == "__main__":
    run_hero()

In [None]:
# @title [Run] Janus Hero v2 Telemetry Analyzer

import sys
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.backends.backend_pdf import PdfPages
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
TELEMETRY_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v2/telemetry_hero.parquet")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
os.makedirs(REPORT_DIR, exist_ok=True)

print(f"üïµÔ∏è JANUS AUTOPSY: Forensic Analysis Tool")
print(f"üìÇ Input: {TELEMETRY_PATH}")
print(f"üìÑ Output: {REPORT_DIR}")

# --- 2. THE ANALYST ENGINE ---
class JanusAutopsy:
    def __init__(self, filepath):
        self.df = pd.read_parquet(filepath)
        self.pdf_path = os.path.join(REPORT_DIR, "janus_hero_autopsy.pdf")
        self.phases = {
            'Ignition': (0, 750),
            'Pressurization': (750, 2000),
            'Cruising': (2000, 5000)
        }

        # Pre-calc global averages if not present
        if 'sigma_a_avg' not in self.df.columns:
            # Detect layer columns dynamically
            l_cols = [c for c in self.df.columns if 'sigma_a' in c and 'L' in c]
            self.df['sigma_a_avg'] = self.df[l_cols].mean(axis=1)

        if 'eff_rank_avg' not in self.df.columns:
            l_cols = [c for c in self.df.columns if 'eff_rank' in c and 'L' in c]
            self.df['eff_rank_avg'] = self.df[l_cols].mean(axis=1)

        print(f"‚úÖ Loaded {len(self.df)} steps of telemetry.")

    def run_full_analysis(self):
        with PdfPages(self.pdf_path) as pdf:
            # 1. Title Page & Global Vitals
            self._plot_global_vitals(pdf)

            # 2. Phase Correlation Matrices
            self._plot_phase_correlations(pdf)

            # 3. Layer Tomography (The MRI)
            self._plot_layer_tomography(pdf, metric='sigma_a', title="Uniqueness (Sigma_A)")
            self._plot_layer_tomography(pdf, metric='eff_rank', title="Dimensionality (Eff Rank)")

            # 4. Elasticity (Pressure Response)
            self._plot_elasticity(pdf)

            # 5. The Business Metric (Perplexity vs Pressure)
            self._plot_perplexity_pressure(pdf)

        print(f"\n‚ú® Analysis Complete. Report saved to: {self.pdf_path}")

    def _plot_global_vitals(self, pdf):
        """Page 1: The Heartbeat (Loss, Pressure, Redundancy)"""
        fig, axes = plt.subplots(3, 1, figsize=(10, 12), sharex=True)

        # Plot Loss
        sns.lineplot(data=self.df, x='step', y='loss', ax=axes[0], color='tab:red', label='Train Loss')
        if 'val_loss' in self.df.columns:
            sns.lineplot(data=self.df, x='step', y='val_loss', ax=axes[0], color='tab:orange', label='Val Loss', linestyle='--')
        axes[0].set_title('Global Loss Trajectory')
        axes[0].set_ylabel('Cross Entropy')
        axes[0].legend()

        # Plot Pressure
        sns.lineplot(data=self.df, x='step', y='pressure', ax=axes[1], color='tab:green', linewidth=2)
        axes[1].set_title('Diversity Pressure (Lambda)')
        axes[1].set_ylabel('Force')

        # Plot Redundancy (Inverted Sigma_A for intuitive "Redundancy" view, or raw Sigma_A)
        # Let's plot raw Sigma_A (Uniqueness)
        sns.lineplot(data=self.df, x='step', y='sigma_a_avg', ax=axes[2], color='tab:blue')
        axes[2].set_title('Average Head Uniqueness (Sigma_A)')
        axes[2].set_ylabel('Orthogonality Score')

        # Mark Phases
        for ax in axes:
            for phase, (start, end) in self.phases.items():
                ax.axvline(x=start, color='gray', linestyle=':', alpha=0.5)
                if end < self.df['step'].max():
                    ax.axvline(x=end, color='gray', linestyle=':', alpha=0.5)

        plt.tight_layout()
        pdf.savefig(fig)
        plt.close()

    def _plot_phase_correlations(self, pdf):
        """Page 2: Correlation Matrices by Phase"""
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))

        cols = ['loss', 'val_loss', 'pressure', 'sigma_a_avg', 'eff_rank_avg']
        # Filter cols that exist
        cols = [c for c in cols if c in self.df.columns]

        for i, (phase_name, (start, end)) in enumerate(self.phases.items()):
            mask = (self.df['step'] >= start) & (self.df['step'] < end)
            subset = self.df.loc[mask, cols]

            if len(subset) > 10:
                corr = subset.corr()
                sns.heatmap(corr, annot=True, fmt=".2f", cmap='coolwarm', vmin=-1, vmax=1, ax=axes[i], cbar=False)
                axes[i].set_title(f"Phase: {phase_name}")
            else:
                axes[i].text(0.5, 0.5, "Insufficient Data", ha='center')

        plt.suptitle("Phase Transition Correlations")
        plt.tight_layout()
        pdf.savefig(fig)
        plt.close()

    def _plot_layer_tomography(self, pdf, metric, title):
        """Page 3/4: The Layer-wise MRI"""
        # Extract layer columns
        l_cols = [c for c in self.df.columns if metric in c and 'L' in c]
        # Sort by layer index L0, L1...
        l_cols.sort(key=lambda x: int(x.split('_')[0].replace('L','')))

        if not l_cols: return

        # Pivot data for heatmap (Layers x Time)
        # We need to downsample time for visibility if huge
        sample_rate = max(1, len(self.df) // 100)
        subset = self.df.iloc[::sample_rate].copy()

        heatmap_data = subset[l_cols].T # Rows=Layers, Cols=Time

        fig, ax = plt.subplots(figsize=(12, 6))
        sns.heatmap(heatmap_data, cmap='viridis', ax=ax, cbar_kws={'label': metric})

        ax.set_title(f"Layer Tomography: {title}")
        ax.set_ylabel("Network Depth (Layer 0 -> 11)")
        ax.set_xlabel("Training Time (Steps)")

        # Fix X-axis labels to show steps
        xticks = np.linspace(0, len(subset), 10)
        xlabels = [int(subset.iloc[int(i) if i < len(subset) else -1]['step']) for i in xticks]
        ax.set_xticks(xticks)
        ax.set_xticklabels(xlabels)

        plt.tight_layout()
        pdf.savefig(fig)
        plt.close()

    def _plot_elasticity(self, pdf):
        """Page 5: The Modulus of Elasticity (dUniqueness / dPressure)"""
        # Calculate Rolling Correlation between Pressure and Sigma_A
        window = 200

        fig, ax1 = plt.subplots(figsize=(10, 6))

        rolling_corr = self.df['pressure'].rolling(window).corr(self.df['sigma_a_avg'])

        sns.lineplot(x=self.df['step'], y=rolling_corr, color='purple', ax=ax1, linewidth=2)
        ax1.set_title("Elasticity: Correlation(Pressure, Uniqueness) over Time")
        ax1.set_ylabel("Correlation Coefficient (Pearson)")
        ax1.axhline(0, color='black', linewidth=1)
        ax1.axhline(-1, color='red', linestyle='--', alpha=0.3)

        # Annotate
        ax1.text(self.df['step'].max()*0.5, -0.8, "Strong Response (Elastic)", color='red', ha='center')
        ax1.text(self.df['step'].max()*0.5, 0.2, "No Response (Plastic/Collapsed)", color='gray', ha='center')

        plt.tight_layout()
        pdf.savefig(fig)
        plt.close()

    def _plot_perplexity_pressure(self, pdf):
        """Page 6: The Golden Cross"""
        if 'perplexity' not in self.df.columns: return

        fig, ax1 = plt.subplots(figsize=(10, 6))

        # Plot Perplexity
        color = 'tab:orange'
        ax1.set_xlabel('Step')
        ax1.set_ylabel('Perplexity (Lower is Better)', color=color)
        sns.lineplot(data=self.df, x='step', y='perplexity', ax=ax1, color=color)
        ax1.tick_params(axis='y', labelcolor=color)
        ax1.set_ylim(bottom=0)

        # Twin axis for Pressure
        ax2 = ax1.twinx()
        color = 'tab:green'
        ax2.set_ylabel('Diversity Pressure', color=color)
        sns.lineplot(data=self.df, x='step', y='pressure', ax=ax2, color=color, alpha=0.3, linestyle='--')
        ax2.tick_params(axis='y', labelcolor=color)

        plt.title("The 'Golden Cross': Perplexity vs. Pressure")
        plt.tight_layout()
        pdf.savefig(fig)
        plt.close()

if __name__ == "__main__":
    if os.path.exists(TELEMETRY_PATH):
        analyst = JanusAutopsy(TELEMETRY_PATH)
        analyst.run_full_analysis()
    else:
        print(f"‚ùå Telemetry file not found at: {TELEMETRY_PATH}")
        print("Did you run the Hero script to completion?")

In [None]:
# @title
import sys
import os
import pandas as pd
import numpy as np
from scipy.stats import pearsonr
from google.colab import drive

# --- SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
TELEMETRY_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v2/telemetry_hero.parquet")
REPORT_PATH = os.path.join(PROJECT_ROOT, "reports/janus_final_report.md")

print(f"üìù JANUS WRITER: Generating Text Report")
print(f"üìÇ Source: {TELEMETRY_PATH}")

class ForensicWriter:
    def __init__(self, filepath):
        self.df = pd.read_parquet(filepath)

        # Calculate derived columns if missing
        if 'sigma_a_avg' not in self.df.columns:
            l_cols = [c for c in self.df.columns if 'sigma_a' in c and 'L' in c]
            if l_cols: self.df['sigma_a_avg'] = self.df[l_cols].mean(axis=1)
            else: self.df['sigma_a_avg'] = 0.0 # Fallback

        if 'eff_rank_avg' not in self.df.columns:
            l_cols = [c for c in self.df.columns if 'eff_rank' in c and 'L' in c]
            if l_cols: self.df['eff_rank_avg'] = self.df[l_cols].mean(axis=1)
            else: self.df['eff_rank_avg'] = 0.0

        # Define Phases
        self.phases = {
            'Ignition': self.df[self.df['step'] < 750],
            'Pressurization': self.df[(self.df['step'] >= 750) & (self.df['step'] < 2000)],
            'Cruising': self.df[self.df['step'] >= 2000]
        }

    def generate_report(self):
        with open(REPORT_PATH, 'w') as f:
            # HEADER
            f.write("# üèõÔ∏è JANUS-HERO v2: Forensic Analysis Report\n\n")
            f.write(f"**Status:** {'‚úÖ COMPLETE' if self.df['step'].max() >= 4900 else '‚ö†Ô∏è INCOMPLETE'}\n")
            f.write(f"**Total Steps:** {self.df['step'].max()}\n")
            f.write(f"**Checkpoints Logged:** {len(self.df)}\n\n")

            # 1. EXECUTIVE SUMMARY
            f.write("## 1. Executive Summary\n")
            start_loss = self.df.iloc[0]['loss']
            end_loss = self.df.iloc[-1]['loss']
            loss_delta = ((end_loss - start_loss) / start_loss) * 100

            final_perp = self.df.iloc[-1].get('perplexity', 0.0)

            f.write(f"The model completed the 3-stage burn protocol. ")
            f.write(f"Training Loss moved from **{start_loss:.3f}** to **{end_loss:.3f}** ({loss_delta:.1f}%). ")
            if final_perp > 0:
                f.write(f"Final Validation Perplexity settled at **{final_perp:.2f}**.\n\n")
            else:
                f.write("Perplexity data unavailable.\n\n")

            # 2. PHASE ANALYSIS
            f.write("## 2. Phase Analysis\n")

            # --- IGNITION ---
            ign = self.phases['Ignition']
            if not ign.empty:
                f.write("### üî• Phase 1: Ignition (Steps 0-750)\n")
                f.write("*Goal: Natural Feature Formation (Zero Pressure)*\n")
                f.write(f"- **Avg Loss:** {ign['loss'].mean():.3f}\n")
                f.write(f"- **Avg Redundancy (Sigma_A):** {ign['sigma_a_avg'].mean():.3f}\n\n")

            # --- PRESSURIZATION ---
            press = self.phases['Pressurization']
            if not press.empty and len(press) > 2:
                f.write("### üèãÔ∏è Phase 2: Pressurization (Steps 750-2000)\n")
                f.write("*Goal: Forced Orthogonality (Ramping Pressure)*\n")

                # Elasticity Calculation (Corr between Pressure and Sigma_A)
                # We want Negative correlation (Pressure UP -> Redundancy DOWN)
                corr, _ = pearsonr(press['pressure'], press['sigma_a_avg'])
                elasticity = "Elastic (Responsive)" if corr < -0.5 else "Plastic (Resistant)"

                f.write(f"- **Elasticity Coefficient:** {corr:.3f} ({elasticity})\n")
                f.write(f"- **Pressure Delta:** {press['pressure'].min():.2f} -> {press['pressure'].max():.2f}\n")
                f.write(f"- **Redundancy Response:** {press.iloc[0]['sigma_a_avg']:.3f} -> {press.iloc[-1]['sigma_a_avg']:.3f}\n\n")

            # --- CRUISING ---
            cruise = self.phases['Cruising']
            if not cruise.empty:
                f.write("### ‚úàÔ∏è Phase 3: Cruising (Steps 2000+)\n")
                f.write("*Goal: High-Efficiency Convergence*\n")

                start_p = cruise.iloc[0]['loss']
                end_p = cruise.iloc[-1]['loss']
                stability = "Stable" if abs(start_p - end_p) < 0.5 else "Volatile"

                f.write(f"- **Stability:** {stability}\n")
                f.write(f"- **Final Effective Rank:** {cruise.iloc[-1]['eff_rank_avg']:.2f} (Target: >4.0)\n")
                f.write(f"- **Final Uniqueness:** {cruise.iloc[-1]['sigma_a_avg']:.3f}\n\n")

            # 3. ANOMALY DETECTION
            f.write("## 3. Anomalies & Warnings\n")
            # Check for Loss Spikes
            spikes = self.df[self.df['loss'].diff() > 0.5]
            if not spikes.empty:
                f.write(f"‚ö†Ô∏è **Loss Spikes Detected:** Found {len(spikes)} events where loss jumped > 0.5.\n")
                for _, row in spikes.iterrows():
                    f.write(f"- Step {int(row['step'])}: Loss {row['loss']:.3f}\n")
            else:
                f.write("‚úÖ **Trajectory Clean:** No significant loss spikes detected.\n")

            # Check for Collapse
            if self.df.iloc[-1]['sigma_a_avg'] > 0.8:
                f.write("‚ö†Ô∏è **CRITICAL WARNING:** High Redundancy (>0.80) detected at end of run. Possible Mode Collapse.\n")
            elif self.df.iloc[-1]['sigma_a_avg'] < 0.1:
                f.write("‚ö†Ô∏è **WARNING:** Ultra-low Redundancy (<0.10). Model may be over-regularized (incoherent).\n")
            else:
                f.write("‚úÖ **Homeostasis Achieved:** Redundancy is within healthy parameters (0.10 - 0.80).\n")

        print(f"‚úÖ Report compiled: {REPORT_PATH}")
        # Print preview to console
        with open(REPORT_PATH, 'r') as f:
            print("\n" + "="*40)
            print(f.read())
            print("="*40)

if __name__ == "__main__":
    if os.path.exists(TELEMETRY_PATH):
        writer = ForensicWriter(TELEMETRY_PATH)
        writer.generate_report()
    else:
        print("‚ùå Telemetry file missing.")

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# PATHS
PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
MODEL_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v2/janus_hero_final.pt")
# We need the tokenizer data to decode
TOKEN_BIN = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üó£Ô∏è JANUS INFERENCE: The Moment of Truth")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. RECONSTRUCT ARCHITECTURE (CLEAN ROOM) ---
# Must match the training config exactly
class HeroConfig:
    def __init__(self):
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.0 # No dropout needed for inference

class CleanAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

    def forward(self, x):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out)

class CleanBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CleanAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model)
        )
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class CleanGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([CleanBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(self, idx):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        return self.head(x)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=0.7, top_k=50):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -512:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] / temperature

            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')

            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# --- 3. LOAD MODEL & TOKENIZER ---
# Quick Tokenizer Hack: We use PreTrainedTokenizerFast if available, or fallback
try:
    from transformers import PreTrainedTokenizerFast
    # Assuming you saved the tokenizer in processed/hero_mix or similar
    # If not, we can rely on standard GPT2 encoding if vocab matches
    # Let's try to load the one you built
    TOK_PATH = os.path.join(PROJECT_ROOT, "data/processed/hero_mix")
    if os.path.exists(TOK_PATH):
        tokenizer = PreTrainedTokenizerFast.from_pretrained(TOK_PATH)
        print("‚úÖ Custom Tokenizer Loaded")
    else:
        print("‚ö†Ô∏è Custom Tokenizer not found. Falling back to GPT2 (Warning: ID mismatch likely)")
        from transformers import GPT2Tokenizer
        tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
except ImportError:
    print("‚ùå Transformers lib not installed. Install it!")
    sys.exit(1)

# Load Weights
cfg = HeroConfig()
model = CleanGPT(cfg).to(DEVICE)

if os.path.exists(MODEL_PATH):
    print(f"üîÑ Loading Weights from {MODEL_PATH}")
    # Load state dict
    try:
        # It might be saved as a dict with 'model' key
        checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
        if 'model' in checkpoint:
            model.load_state_dict(checkpoint['model'])
        else:
            model.load_state_dict(checkpoint)
        print("‚úÖ Weights Loaded Successfully")
    except Exception as e:
        print(f"‚ùå Weight Load Failed: {e}")
        sys.exit(1)
else:
    print("‚ùå Model file not found.")
    sys.exit(1)

model.eval()

# --- 4. THE PROBE ---
prompts = [
    "Once upon a time",
    "Lily wanted a",
    "The big red ball",
    "Tom went to the",
    "One day, a little"
]

print("\n" + "="*40)
print("üß™ JANUS HERO V2 OUTPUTS")
print("="*40)

for p in prompts:
    print(f"\nüìù PROMPT: {p}")
    try:
        # Encode
        input_ids = tokenizer.encode(p, return_tensors='pt').to(DEVICE)

        # Generate
        output_ids = model.generate(input_ids, max_new_tokens=100, temperature=0.6, top_k=40)

        # Decode
        text = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)
        print(f"ü§ñ JANUS: {text}")
    except Exception as e:
        print(f"‚ùå Generation Error: {e}")

In [None]:
# @title
from transformers import GPT2Tokenizer

# 1. Load the Standard GPT-2 Tokenizer (The one likely used for the .bin file)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# 2. Re-Run Inference
print(f"\n‚ú® RE-TESTING WITH GPT-2 TOKENIZER ‚ú®")
prompts = ["Once upon a time", "Lily wanted a", "The big red ball"]

for p in prompts:
    input_ids = tokenizer.encode(p, return_tensors='pt').to(DEVICE)
    output = model.generate(input_ids, max_new_tokens=100, temperature=0.6, top_k=40)
    text = tokenizer.decode(output[0], skip_special_tokens=True)
    print(f"\nüìù {p} -> {text}")

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
import time
from tqdm import tqdm
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_FILE = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")
# New Save Dir for Baseline
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_baseline")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üìâ JANUS BASELINE: The Control Group")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. CONFIGURATION (CONTROL) ---
class BaselineConfig:
    def __init__(self):
        # EXACT SAME ARCHITECTURE
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        # ZERO PRESSURE (The Variable)
        self.max_lambda_div = 0.0
        self.max_lambda_coh = 0.0

        # SAME TRAINING SPECS
        self.max_steps = 5000
        self.batch_size = 32
        self.grad_accum = 2
        self.val_interval = 250
        self.val_steps = 50

# --- 3. THE ENGINE (NO STEERING) ---
class CleanAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, return_metrics=False):
        B, S, D = x.shape

        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)

        head_out = attn_probs @ v

        # NO STEERING LOSS HERE
        steer_loss = 0.0
        metrics = {}

        # Still log telemetry to prove it's bad
        if return_metrics:
            with torch.no_grad():
                # Sigma_P (Focus)
                entropy = -torch.sum(attn_probs * torch.log(attn_probs + 1e-9), dim=-1)
                max_ent = math.log(S)
                metrics['sigma_p'] = (1.0 - (entropy / max_ent)).mean(dim=[0, 2])

                # Sigma_A (Uniqueness)
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

                # Eff_Rank
                sub_out = head_out[:, :, :128, :].transpose(1, 2).reshape(self.n_heads, -1, self.d_head)
                ranks = []
                for h in range(self.n_heads):
                    try:
                        S_vals = torch.linalg.svdvals(sub_out[h].float())
                        p = S_vals / S_vals.sum()
                        ent = -torch.sum(p * torch.log(p + 1e-9))
                        ranks.append(torch.exp(ent))
                    except: ranks.append(torch.tensor(0.0))
                metrics['eff_rank'] = torch.stack(ranks).to(x.device)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class CleanBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CleanAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model),
            nn.Dropout(config.dropout)
        )
    def forward(self, x, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class CleanGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([CleanBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(self, idx, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))

        all_metrics = []
        for block in self.blocks:
            x, _, m = block(x, return_metrics)
            if return_metrics: all_metrics.append(m)

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return loss, 0.0, all_metrics

# --- 4. DATA & LOGGING ---
class SplitLoader:
    def __init__(self, bin_path, block_size, batch_size):
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        total_tokens = len(self.data)
        split_idx = int(total_tokens * 0.95)
        self.train_data = self.data[:split_idx]
        self.val_data = self.data[split_idx:]
        self.block_size = block_size
        self.batch_size = batch_size
        print(f"üì¶ Split: Train {len(self.train_data):,} | Val {len(self.val_data):,}")

    def get_batch(self, split='train'):
        source = self.train_data if split == 'train' else self.val_data
        ix = torch.randint(len(source) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(source[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(source[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir

    def log(self, step, loss, val_loss, metrics_list):
        row = {
            "step": step,
            "loss": loss,
            "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": 0.0 # Baseline has no pressure
        }
        for i, layer_m in enumerate(metrics_list):
            if not layer_m: continue
            row[f"L{i}_sigma_p"] = layer_m['sigma_p'].mean().item()
            row[f"L{i}_sigma_a"] = layer_m['sigma_a'].mean().item()
            row[f"L{i}_eff_rank"] = layer_m['eff_rank'].mean().item()
        self.buffer.append(row)

    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_baseline.parquet")
        if os.path.exists(fpath):
            existing = pd.read_parquet(fpath)
            df = pd.concat([existing, df])
        df.to_parquet(fpath)
        self.buffer = []

# --- 5. MAIN ---
def run_baseline():
    cfg = BaselineConfig()
    loader = SplitLoader(DATA_FILE, cfg.max_seq_len, cfg.batch_size)
    model = CleanGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4, weight_decay=1e-4)
    recorder = BlackBox(SAVE_DIR)

    print(f"\nüìâ STARTING BASELINE RUN: 0 -> {cfg.max_steps}")
    pbar = tqdm(range(cfg.max_steps))

    for step in pbar:
        model.train()
        batch_loss = 0.0

        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch('train')
            do_metrics = (_ == cfg.grad_accum - 1)
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, _, metrics = model(x, y, return_metrics=do_metrics)
                total = loss / cfg.grad_accum
            optimizer.zero_grad(); total.backward()
            batch_loss += loss.item()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if step % cfg.val_interval == 0 or step == cfg.max_steps - 1:
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(cfg.val_steps):
                    vx, vy = loader.get_batch('val')
                    vl, _, _ = model(vx, vy)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)

            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, metrics)
            recorder.flush()
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|V:{val_loss:.3f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}")

    print("\n‚úÖ BASELINE COMPLETE.")

if __name__ == "__main__":
    run_baseline()

In [None]:
# @title
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import drive

# --- SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
HERO_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v2/telemetry_hero.parquet")
BASE_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_baseline/telemetry_baseline.parquet")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
os.makedirs(REPORT_DIR, exist_ok=True)

print("‚öîÔ∏è JANUS DUEL: Hero vs. Baseline")

def load_data(path, label):
    if not os.path.exists(path):
        print(f"‚ùå Missing: {path}")
        return None
    df = pd.read_parquet(path)
    df['Model'] = label

    # Calc averages if missing
    if 'sigma_a_avg' not in df.columns:
        l_cols = [c for c in df.columns if 'sigma_a' in c and 'L' in c]
        if l_cols: df['sigma_a_avg'] = df[l_cols].mean(axis=1)

    if 'eff_rank_avg' not in df.columns:
        l_cols = [c for c in df.columns if 'eff_rank' in c and 'L' in c]
        if l_cols: df['eff_rank_avg'] = df[l_cols].mean(axis=1)

    return df

def run_comparison():
    hero = load_data(HERO_PATH, "Hero (Pressure)")
    base = load_data(BASE_PATH, "Baseline (Control)")

    if hero is None or base is None: return

    # Combine
    combined = pd.concat([hero, base])

    # PLOTTING
    fig, axes = plt.subplots(3, 1, figsize=(10, 15), sharex=True)

    # 1. Loss (The Score)
    sns.lineplot(data=combined, x='step', y='val_loss', hue='Model', ax=axes[0], palette=['tab:green', 'tab:gray'])
    axes[0].set_title("Validation Loss (Performance)")
    axes[0].set_ylabel("Cross Entropy")
    axes[0].grid(True, alpha=0.3)

    # 2. Redundancy (The Laziness)
    sns.lineplot(data=combined, x='step', y='sigma_a_avg', hue='Model', ax=axes[1], palette=['tab:green', 'tab:gray'])
    axes[1].set_title("Head Redundancy (Sigma_A)")
    axes[1].set_ylabel("Correlation (Lower = Better)")
    axes[1].grid(True, alpha=0.3)

    # 3. Efficiency (The Capacity Usage)
    sns.lineplot(data=combined, x='step', y='eff_rank_avg', hue='Model', ax=axes[2], palette=['tab:green', 'tab:gray'])
    axes[2].set_title("Effective Rank (Dimensional Usage)")
    axes[2].set_ylabel("Rank (Max 64)")
    axes[2].grid(True, alpha=0.3)

    save_path = os.path.join(REPORT_DIR, "janus_duel_comparison.png")
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"‚úÖ Comparison Chart Saved: {save_path}")

    # PRINT STATS
    print("\n--- FINAL STATS (Step 5000) ---")
    h_final = hero.iloc[-1]
    b_final = base.iloc[-1]

    print(f"LOSS:       Hero {h_final['val_loss']:.4f} vs Base {b_final['val_loss']:.4f}")
    print(f"REDUNDANCY: Hero {h_final['sigma_a_avg']:.4f} vs Base {b_final['sigma_a_avg']:.4f}")
    print(f"RANK:       Hero {h_final['eff_rank_avg']:.2f}   vs Base {b_final['eff_rank_avg']:.2f}")

if __name__ == "__main__":
    run_comparison()

In [None]:
# @title
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import drive

# --- SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
HERO_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v2/telemetry_hero.parquet")
BASE_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_baseline/telemetry_baseline.parquet")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
os.makedirs(REPORT_DIR, exist_ok=True)

print("‚öîÔ∏è JANUS DUEL: Hero vs. Baseline (Telemetry Analysis)")

def load_data(path, label):
    if not os.path.exists(path):
        print(f"‚ùå Missing: {path}")
        return None
    df = pd.read_parquet(path)
    df['Model'] = label

    # Calc averages if missing
    if 'sigma_a_avg' not in df.columns:
        l_cols = [c for c in df.columns if 'sigma_a' in c and 'L' in c]
        if l_cols: df['sigma_a_avg'] = df[l_cols].mean(axis=1)

    if 'eff_rank_avg' not in df.columns:
        l_cols = [c for c in df.columns if 'eff_rank' in c and 'L' in c]
        if l_cols: df['eff_rank_avg'] = df[l_cols].mean(axis=1)

    return df

def run_comparison():
    hero = load_data(HERO_PATH, "Hero (Pressure)")
    base = load_data(BASE_PATH, "Baseline (Control)")

    if hero is None or base is None: return

    # Combine
    combined = pd.concat([hero, base])

    # PLOTTING
    fig, axes = plt.subplots(3, 1, figsize=(10, 15), sharex=True)

    # 1. Loss (The Score)
    sns.lineplot(data=combined, x='step', y='val_loss', hue='Model', ax=axes[0], palette=['tab:green', 'tab:gray'])
    axes[0].set_title("Validation Loss (Performance)")
    axes[0].set_ylabel("Cross Entropy")
    axes[0].grid(True, alpha=0.3)

    # 2. Redundancy (The Laziness)
    sns.lineplot(data=combined, x='step', y='sigma_a_avg', hue='Model', ax=axes[1], palette=['tab:green', 'tab:gray'])
    axes[1].set_title("Head Redundancy (Sigma_A)")
    axes[1].set_ylabel("Correlation (Lower = Better)")
    axes[1].grid(True, alpha=0.3)

    # 3. Efficiency (The Capacity Usage)
    sns.lineplot(data=combined, x='step', y='eff_rank_avg', hue='Model', ax=axes[2], palette=['tab:green', 'tab:gray'])
    axes[2].set_title("Effective Rank (Dimensional Usage)")
    axes[2].set_ylabel("Rank (Max 8.0)")
    axes[2].grid(True, alpha=0.3)

    save_path = os.path.join(REPORT_DIR, "janus_duel_comparison.png")
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"‚úÖ Comparison Chart Saved: {save_path}")

    # PRINT STATS
    print("\n--- FINAL STATS (Step 5000) ---")
    h_final = hero.iloc[-1]
    b_final = base.iloc[-1]

    print(f"LOSS:       Hero {h_final['val_loss']:.4f} vs Base {b_final['val_loss']:.4f}")
    print(f"REDUNDANCY: Hero {h_final['sigma_a_avg']:.4f} vs Base {b_final['sigma_a_avg']:.4f}")
    print(f"RANK:       Hero {h_final['eff_rank_avg']:.2f}   vs Base {b_final['eff_rank_avg']:.2f}")

if __name__ == "__main__":
    run_comparison()

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import matplotlib.pyplot as plt
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_FILE = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
MODEL_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_prewarm")
os.makedirs(REPORT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚ùÑÔ∏è JANUS PRE-WARM (CHECKPOINT EDITION)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. ARCHITECTURE (Clean Room) ---
class CleanConfig:
    def __init__(self):
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 256 # Lite Mode
        self.dropout = 0.0

class CleanAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

    def forward(self, x, l_div):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)

        head_out = attn_probs @ v

        steer_loss = 0.0
        if l_div > 0.0:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss

class CleanBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CleanAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model)
        )
    def forward(self, x, l_div):
        a, s = self.attn(self.ln1(x), l_div)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s

class CleanGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([CleanBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.config = config

    def forward(self, idx, targets=None, l_div=0.0):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))

        total_steer = 0.0
        for block in self.blocks:
            x, s = block(x, l_div)
            total_steer += s

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return loss, total_steer

# --- 3. LOADER ---
class BinLoader:
    def __init__(self, bin_path, block_size, batch_size):
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        self.block_size = block_size
        self.batch_size = batch_size

    def get_batch(self):
        ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

    def get_noise_batch(self):
        x = torch.randint(0, 50304, (self.batch_size, self.block_size)).to(DEVICE)
        return x, None

# --- 4. THE DUEL ---
def run_sprint(mode, steps=500, save_name="unknown"):
    # CLEAN SLATE
    gc.collect()
    torch.cuda.empty_cache()

    print(f"\nüèÉ RUNNING: {mode}")
    cfg = CleanConfig()
    model = CleanGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4)
    loader = BinLoader(DATA_FILE, 256, 16)

    loss_history = []

    # --- PHASE A: PRE-WARMING ---
    if mode == "Pre-Warmed":
        print("   üî• Warming Up (Orthogonality on Noise)...")
        WARM_STEPS = 200
        for i in tqdm(range(WARM_STEPS), desc="Warming", leave=False):
            x, _ = loader.get_noise_batch()
            _, steer = model(x, targets=None, l_div=0.5)
            optimizer.zero_grad(); steer.backward(); optimizer.step()

    # --- PHASE B: TRAINING ---
    print("   üöÄ Training on TinyStories...")
    model.train()
    for step in tqdm(range(steps), desc="Training"):
        x, y = loader.get_batch()

        # Zero pressure race
        loss, _ = model(x, y, l_div=0.0)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        loss_history.append(loss.item())

    # --- SAVE CHECKPOINT ---
    save_path = os.path.join(MODEL_DIR, f"{save_name}.pt")
    print(f"   üíæ Saving Checkpoint to {save_path}...")
    torch.save({
        'step': steps,
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'config': cfg.__dict__
    }, save_path)

    # CLEANUP
    del model
    del optimizer
    gc.collect()
    return loss_history

# --- 5. EXECUTION ---
def main():
    steps = 500

    # Run Both
    hist_control = run_sprint("Control", steps, "control_500")
    hist_warm = run_sprint("Pre-Warmed", steps, "prewarmed_500")

    print("\nüìä Generating Report...")
    plt.figure(figsize=(10, 6))
    plt.plot(hist_control, label='Control (Random Init)', alpha=0.7)
    plt.plot(hist_warm, label='Pre-Warmed (Orthogonal Init)', linewidth=2)
    plt.title("Geometric Pre-Warming: Loss Trajectory (500 Steps)")
    plt.xlabel("Training Steps")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)

    save_path = os.path.join(REPORT_DIR, "prewarm_duel_500.png")
    plt.savefig(save_path)
    print(f"‚úÖ Comparison saved to {save_path}")

    avg_c = np.mean(hist_control[-50:])
    avg_w = np.mean(hist_warm[-50:])
    print(f"\nüèÅ FINAL LOSS (Step {steps}):")
    print(f"   Control:    {avg_c:.4f}")
    print(f"   Pre-Warmed: {avg_w:.4f}")

    if avg_w < avg_c: print("üèÜ RESULT: Pre-Warming Improved Performance!")
    else: print("üìâ RESULT: No Improvement.")

if __name__ == "__main__":
    main()

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import matplotlib.pyplot as plt
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. SETUP & INSTALL ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Install Stanza if missing
try:
    import stanza
except ImportError:
    print("üì¶ Installing Stanza...")
    !pip install stanza -q
    import stanza

# Download Stanza English model (lightweight)
print("üì¶ Downloading NLP Models...")
stanza.download('en', processors='tokenize,pos,constituency', logging_level='WARN')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_FILE = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
os.makedirs(REPORT_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 2. DATA PROCESSOR (The "Professor") ---
from transformers import GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

class CurriculumBuilder:
    def __init__(self, bin_path, sample_tokens=500_000):
        self.bin_path = bin_path
        self.sample_tokens = sample_tokens
        self.nlp = stanza.Pipeline('en', processors='tokenize,pos,constituency', use_gpu=True, logging_level='WARN')

    def build(self):
        print(f"\nüìö HARVESTING CURRICULUM ({self.sample_tokens} tokens)...")
        # 1. Load Raw Data
        data = np.memmap(self.bin_path, dtype=np.uint16, mode='r')
        # Grab a chunk from the middle to avoid header/intro bias
        start_idx = len(data) // 2
        chunk = data[start_idx : start_idx + self.sample_tokens].astype(np.int64)

        # 2. Decode to Text
        print("   -> Decoding text...")
        text_blob = tokenizer.decode(chunk)

        # 3. Analyze with Stanza
        print("   -> Parsing Syntax (This may take 2-3 mins)...")
        doc = self.nlp(text_blob[:200000]) # Limit char count to save time

        simple = []
        complex_sents = []

        for sentence in doc.sentences:
            text = sentence.text
            # Heuristic via Constituency Parse
            # S = Simple declarative clause
            # SBAR = Clause introduced by subordinating conjunction
            const_str = str(sentence.constituency)

            if "SBAR" in const_str:
                complex_sents.append(text)
            elif const_str.count("(S ") == 1: # Single clause
                simple.append(text)
            else:
                complex_sents.append(text) # Treat compound as complex for this test

        print(f"   -> Found {len(simple)} Simple | {len(complex_sents)} Complex sentences")

        # 4. Re-Tokenize
        def tokenize_batch(sents):
            ids = []
            for s in sents:
                ids.extend(tokenizer.encode(s) + [tokenizer.eos_token_id])
            return torch.tensor(ids, dtype=torch.long)

        return tokenize_batch(simple), tokenize_batch(complex_sents)

# --- 3. MODEL ARCHITECTURE (Clean Room) ---
class CleanConfig:
    def __init__(self):
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 256
        self.dropout = 0.0

class CleanAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

    def forward(self, x, l_div):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        head_out = attn @ v

        steer_loss = 0.0
        if l_div > 0.0:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss

class CleanBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CleanAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model)
        )
    def forward(self, x, l_div):
        a, s = self.attn(self.ln1(x), l_div)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s

class CleanGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([CleanBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.config = config

    def forward(self, idx, targets=None, l_div=0.0):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))

        total_steer = 0.0
        for block in self.blocks:
            x, s = block(x, l_div)
            total_steer += s

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return loss, total_steer

# --- 4. LOADER ---
class TensorLoader:
    def __init__(self, tensor_data, block_size, batch_size):
        self.data = tensor_data
        self.block_size = block_size
        self.batch_size = batch_size

    def get_batch(self):
        ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
        x = torch.stack([self.data[i:i+self.block_size] for i in ix])
        y = torch.stack([self.data[i+1:i+1+self.block_size] for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BinLoader: # For the full dataset
    def __init__(self, bin_path, block_size, batch_size):
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        self.block_size = block_size
        self.batch_size = batch_size
    def get_batch(self):
        ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

# --- 5. THE DUEL ---
def run_sprint(mode, simple_data, complex_data, steps=500):
    gc.collect(); torch.cuda.empty_cache()
    print(f"\nüèÉ RUNNING: {mode}")

    cfg = CleanConfig()
    model = CleanGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4)

    # Loaders
    simple_loader = TensorLoader(simple_data, 256, 16)
    complex_loader = TensorLoader(complex_data, 256, 16)
    full_loader = BinLoader(DATA_FILE, 256, 16)

    loss_history = []

    # --- CURRICULUM PHASE ---
    if mode == "Curriculum":
        # Step 1: Simple (Force Subject/Verb distinction)
        print("   üë∂ Phase 1: Simple Sentences (High Pressure)")
        for i in tqdm(range(100), leave=False):
            x, y = simple_loader.get_batch()
            # High Pressure to crystallize simple grammar
            loss, steer = model(x, y, l_div=0.3)
            optimizer.zero_grad(); (loss + steer).backward(); optimizer.step()

        # Step 2: Complex (Force Clause distinction)
        print("   üéì Phase 2: Complex Sentences (Med Pressure)")
        for i in tqdm(range(100), leave=False):
            x, y = complex_loader.get_batch()
            loss, steer = model(x, y, l_div=0.1)
            optimizer.zero_grad(); (loss + steer).backward(); optimizer.step()

    # --- MAIN PHASE ---
    print("   üöÄ Phase 3: Full Dataset (Race Mode)")
    for step in tqdm(range(steps), desc="Training"):
        x, y = full_loader.get_batch()

        # Zero pressure for the race to see natural performance
        loss, _ = model(x, y, l_div=0.0)

        optimizer.zero_grad(); loss.backward(); optimizer.step()
        loss_history.append(loss.item())

    del model; del optimizer
    return loss_history

# --- 6. EXECUTION ---
def main():
    # 1. Build Curriculum
    builder = CurriculumBuilder(DATA_FILE)
    simple_t, complex_t = builder.build()

    # 2. Run Duel
    steps = 500
    hist_ctrl = run_sprint("Control", simple_t, complex_t, steps) # Ignores curr data
    hist_curr = run_sprint("Curriculum", simple_t, complex_t, steps)

    # 3. Report
    print("\nüìä Generating Report...")
    plt.figure(figsize=(10, 6))
    plt.plot(hist_ctrl, label='Control (Random Init)', alpha=0.7)
    plt.plot(hist_curr, label='Curriculum (Syntax Pre-Warm)', linewidth=2)
    plt.title("Syntax-Aware Pre-Warming: Loss Trajectory")
    plt.xlabel("Training Steps (Post-Warmup)")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)

    save_path = os.path.join(REPORT_DIR, "curriculum_duel.png")
    plt.savefig(save_path)

    avg_c = np.mean(hist_ctrl[-50:])
    avg_curr = np.mean(hist_curr[-50:])
    print(f"\nüèÅ FINAL LOSS (Step {steps}):")
    print(f"   Control:    {avg_c:.4f}")
    print(f"   Curriculum: {avg_curr:.4f}")

    if avg_curr < avg_c: print("üèÜ RESULT: Syntax Pre-Warming Worked!")
    else: print("üìâ RESULT: No Improvement.")

if __name__ == "__main__":
    main()

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import matplotlib.pyplot as plt
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# PATHS
PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_FILE = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")
CURRIC_DIR = os.path.join(PROJECT_ROOT, "data/processed/curriculum_cache")
MODEL_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_curriculum")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")

os.makedirs(CURRIC_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(REPORT_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üéì JANUS CURRICULUM PROTOCOL")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. ARCHITECTURE (Clean Room) ---
class CleanConfig:
    def __init__(self):
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 256 # Lite Mode
        self.dropout = 0.0

class CleanAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

    def forward(self, x, l_div, spatial_mult=1.0):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        head_out = attn @ v

        steer_loss = 0.0
        # Apply pressure if requested
        if l_div > 0.0:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            # Apply spatial multiplier (Cubic schedule happens here via caller)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div * spatial_mult

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss

class CleanBlock(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.total_layers = config.n_layers
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CleanAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model)
        )

    def forward(self, x, l_div):
        # Cubic Spatial Schedule Logic
        # ratio = (layer + 1) / total
        # mult = ratio^3
        ratio = (self.layer_id + 1) / self.total_layers
        spatial_mult = ratio ** 3

        a, s = self.attn(self.ln1(x), l_div, spatial_mult)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s

class CleanGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([CleanBlock(config, i) for i in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.config = config

    def forward(self, idx, targets=None, l_div=0.0):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))

        total_steer = 0.0
        for block in self.blocks:
            x, s = block(x, l_div)
            total_steer += s

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return loss, total_steer

# --- 3. LOADERS ---
class TensorLoader:
    def __init__(self, tensor_path, batch_size):
        self.data = torch.load(tensor_path)
        self.batch_size = batch_size
        print(f"üì¶ Loaded Tensor: {len(self.data):,} items from {os.path.basename(tensor_path)}")

    def get_batch(self):
        ix = torch.randint(len(self.data), (self.batch_size,))
        # Data was saved as (N, SeqLen) tensors
        chunk = torch.stack([self.data[i] for i in ix])
        x = chunk[:, :-1].to(DEVICE)
        y = chunk[:, 1:].to(DEVICE)
        return x, y

class BinLoader:
    def __init__(self, bin_path, block_size, batch_size):
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        self.block_size = block_size
        self.batch_size = batch_size
    def get_batch(self):
        ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

# --- 4. STEP 1: CREATE CURRICULUM ---
def step_1_create_curriculum():
    simple_path = os.path.join(CURRIC_DIR, "simple_curr.pt")
    complex_path = os.path.join(CURRIC_DIR, "complex_curr.pt")

    if os.path.exists(simple_path) and os.path.exists(complex_path):
        print("‚úÖ Curriculum files found. Skipping generation.")
        return simple_path, complex_path

    print("üõ†Ô∏è Generating Curriculum (One-Time Cost)...")
    # Install Stanza only if needed
    try: import stanza
    except:
        print("üì¶ Installing Stanza...")
        os.system('pip install stanza -q')
        import stanza
    stanza.download('en', processors='tokenize,pos,constituency', logging_level='WARN')

    # Process
    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    nlp = stanza.Pipeline('en', processors='tokenize,pos,constituency', use_gpu=True, logging_level='WARN')

    # Load raw data sample (Enough for 200 steps * 16 batch = 3200 sentences roughly)
    data = np.memmap(DATA_FILE, dtype=np.uint16, mode='r')
    start = len(data) // 2
    # Sample 2M tokens to ensure we get enough valid sentences
    chunk = data[start:start+2000000].astype(np.int64)
    text_blob = tokenizer.decode(chunk)

    # Parse
    doc = nlp(text_blob[:500000]) # Limit chars for speed

    simple_sents = []
    complex_sents = []

    print("   -> Analyzing Syntax...")
    for s in doc.sentences:
        if len(simple_sents) > 2000 and len(complex_sents) > 2000: break

        txt = s.text
        if len(txt) < 10: continue

        const = str(s.constituency)
        if "SBAR" in const: complex_sents.append(txt)
        elif const.count("(S ") == 1: simple_sents.append(txt)
        else: complex_sents.append(txt)

    print(f"   -> Harvested {len(simple_sents)} Simple | {len(complex_sents)} Complex")

    # Save as Tensors
    def save_batch(sents, path):
        ids_list = []
        for s in sents:
            # Pad/Truncate to 257 (256 + 1 for X/Y)
            ids = tokenizer.encode(s) + [tokenizer.eos_token_id]
            if len(ids) < 257: ids = ids + [tokenizer.eos_token_id] * (257 - len(ids))
            ids = ids[:257]
            ids_list.append(torch.tensor(ids, dtype=torch.long))
        torch.save(ids_list, path)

    save_batch(simple_sents, simple_path)
    save_batch(complex_sents, complex_path)
    print("‚úÖ Curriculum Saved to Disk.")
    return simple_path, complex_path

# --- 5. EXECUTION ---
def main():
    # STEP 1
    simp_path, comp_path = step_1_create_curriculum()

    # STEP 2: TRAIN BASELINE
    print("\nüìâ STEP 2: Baseline Run (500 Steps)...")
    gc.collect(); torch.cuda.empty_cache()

    cfg = CleanConfig()
    model = CleanGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4)
    loader = BinLoader(DATA_FILE, 256, 16)

    base_loss = []
    for step in tqdm(range(500), desc="Baseline"):
        x, y = loader.get_batch()
        loss, _ = model(x, y, l_div=0.0) # No Pressure
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        base_loss.append(loss.item())

    torch.save(model.state_dict(), os.path.join(MODEL_DIR, "janus_baseline.pt"))
    print("üíæ Baseline Saved.")
    del model; del optimizer

    # STEP 3 & 4: JANUS WARMUP + RUN
    print("\nüî• STEP 3: Janus Warm-Up (200 Steps)...")
    gc.collect(); torch.cuda.empty_cache()

    model = CleanGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4)

    s_loader = TensorLoader(simp_path, 16)
    c_loader = TensorLoader(comp_path, 16)

    # 100 Simple (Pressure 0.27)
    for i in tqdm(range(100), desc="Simple Warmup"):
        x, y = s_loader.get_batch()
        loss, steer = model(x, y, l_div=0.27)
        optimizer.zero_grad(); (loss + steer).backward(); optimizer.step()

    # 100 Complex (Pressure 0.27)
    for i in tqdm(range(100), desc="Complex Warmup"):
        x, y = c_loader.get_batch()
        loss, steer = model(x, y, l_div=0.27)
        optimizer.zero_grad(); (loss + steer).backward(); optimizer.step()

    print("\nüöÄ STEP 4: Janus Training (500 Steps - No Pressure)...")
    janus_loss = []
    for step in tqdm(range(500), desc="Race Mode"):
        x, y = loader.get_batch()
        loss, _ = model(x, y, l_div=0.0) # Pressure OFF
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        janus_loss.append(loss.item())

    torch.save(model.state_dict(), os.path.join(MODEL_DIR, "JanusWUCuric.pt"))
    print("üíæ JanusWUCuric Saved.")

    # STEP 5: COMPARE
    print("\nüìä STEP 5: Generating Report...")
    plt.figure(figsize=(10, 6))
    plt.plot(base_loss, label='Baseline (Random Init)', alpha=0.7)
    plt.plot(janus_loss, label='Janus (Curriculum Pre-Warm)', linewidth=2)
    plt.title("Curriculum Pre-Warming: Loss Trajectory (Race Phase)")
    plt.xlabel("Training Steps")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.savefig(os.path.join(REPORT_DIR, "curriculum_final_duel.png"))

    b_avg = np.mean(base_loss[-50:])
    j_avg = np.mean(janus_loss[-50:])
    print(f"\nüèÅ FINAL STATS (Step 500):")
    print(f"   Baseline: {b_avg:.4f}")
    print(f"   Janus:    {j_avg:.4f}")

    if j_avg < b_avg: print("üèÜ RESULT: Pre-Warming Success!")
    else: print("üìâ RESULT: Baseline Won.")

if __name__ == "__main__":
    main()

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import GPT2Tokenizer
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
MODEL_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_curriculum")
BASELINE_PATH = os.path.join(MODEL_DIR, "janus_baseline.pt")
JANUS_PATH = os.path.join(MODEL_DIR, "JanusWUCuric.pt")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üó£Ô∏è CURRICULUM INFERENCE DUEL")

# --- 2. ARCHITECTURE ---
class CleanConfig:
    def __init__(self):
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 256
        self.dropout = 0.0

class CleanAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

    def forward(self, x):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out)

class CleanBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = CleanAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model)
        )
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class CleanGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([CleanBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

    def forward(self, idx):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        return self.head(x)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=0.7, top_k=50):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -256:] # Match lite mode seq len
            logits = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# --- 3. LOAD MODELS ---
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
cfg = CleanConfig()

# Load Baseline
model_base = CleanGPT(cfg).to(DEVICE)
try:
    model_base.load_state_dict(torch.load(BASELINE_PATH, map_location=DEVICE))
    print("‚úÖ Baseline Loaded")
except Exception as e:
    print(f"‚ùå Baseline Load Fail: {e}")

# Load Janus
model_janus = CleanGPT(cfg).to(DEVICE)
try:
    model_janus.load_state_dict(torch.load(JANUS_PATH, map_location=DEVICE))
    print("‚úÖ Janus (Curriculum) Loaded")
except Exception as e:
    print(f"‚ùå Janus Load Fail: {e}")

# --- 4. THE PROBE ---
prompts = ["Lily wanted a", "The big red ball", "Once upon a time"]

for p in prompts:
    print(f"\nüìù PROMPT: {p}")
    input_ids = tokenizer.encode(p, return_tensors='pt').to(DEVICE)

    # Baseline Output
    out_b = model_base.generate(input_ids, max_new_tokens=80, temperature=0.6)
    text_b = tokenizer.decode(out_b[0], skip_special_tokens=True)
    print(f"üìâ BASELINE: {text_b}")

    # Janus Output
    out_j = model_janus.generate(input_ids, max_new_tokens=80, temperature=0.6)
    text_j = tokenizer.decode(out_j[0], skip_special_tokens=True)
    print(f"üöÄ JANUS:    {text_j}")

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import matplotlib.pyplot as plt
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_FILE = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
os.makedirs(REPORT_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üèéÔ∏è JANUS ARCH DUEL (FIXED): Stock vs. Spec")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. CONFIG ---
class DuelConfig:
    def __init__(self):
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 256
        self.dropout = 0.0

# --- 3. OLD JANUS (Stock GPT-2) ---
class OldAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)
        self.c_attn = nn.Linear(config.d_model, 3 * config.d_model)
        self.c_proj = nn.Linear(config.d_model, config.d_model)

    def forward(self, x):
        B, S, D = x.shape
        # Calculate QKV
        qkv = self.c_attn(x).view(B, S, 3, self.n_heads, self.d_head).permute(2, 0, 1, 3, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # FIX: Transpose to (B, H, S, D) for proper attention
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        y = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.c_proj(y)

class OldBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = OldAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model)
        )
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class OldGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([OldBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # WEIGHT TYING (Crucial for Param Efficiency)
        self.token_emb.weight = self.head.weight

    def forward(self, idx, targets=None):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss

# --- 4. NEW JANUS (Janus Spec) ---
# RMSNorm, RoPE, SwiGLU, Bias=False, Tied Weights

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        # Bias = False
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        x_out = torch.view_as_real(x_c * freqs_cis[:x.shape[1]]).flatten(3)
        return x_out.type_as(x)

    def forward(self, x):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        # Apply RoPE (No Transpose needed yet, operates on S dim)
        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)

        # Now Transpose
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        y = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.o_proj(y)

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # Adjusted ratio 8/3 to keep parameter count roughly equal to standard MLP
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)

    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # WEIGHT TYING
        self.token_emb.weight = self.head.weight

    def forward(self, idx, targets=None):
        B, S = idx.shape
        x = self.token_emb(idx)
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss

# --- 5. LOADER ---
class BinLoader:
    def __init__(self, bin_path, block_size, batch_size):
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        self.block_size = block_size
        self.batch_size = batch_size

    def get_batch(self):
        ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

# --- 6. DUEL ---
def run_sprint(model_class, label, steps=500):
    gc.collect(); torch.cuda.empty_cache()
    print(f"\nüèÉ RUNNING: {label}")

    cfg = DuelConfig()
    model = model_class(cfg).to(DEVICE)
    params = sum(p.numel() for p in model.parameters())/1e6
    print(f"   Params: {params:.2f}M")

    optimizer = optim.AdamW(model.parameters(), lr=6e-4)
    loader = BinLoader(DATA_FILE, 256, 16)

    loss_history = []

    model.train()
    for step in tqdm(range(steps), desc=label):
        x, y = loader.get_batch()
        loss = model(x, y)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        loss_history.append(loss.item())

    del model; del optimizer
    return loss_history

def main():
    steps = 500

    hist_old = run_sprint(OldGPT, "Old Janus (Stock)", steps)
    hist_new = run_sprint(NewGPT, "New Janus (Spec)", steps)

    print("\nüìä Generating Report...")
    plt.figure(figsize=(10, 6))
    plt.plot(hist_old, label='Old Janus (GPT-2)', alpha=0.7)
    plt.plot(hist_new, label='New Janus (RoPE/RMS/SwiGLU)', linewidth=2)
    plt.title("Architecture Duel: Loss Trajectory (500 Steps)")
    plt.xlabel("Training Steps")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)

    save_path = os.path.join(REPORT_DIR, "arch_duel_fixed.png")
    plt.savefig(save_path)

    avg_o = np.mean(hist_old[-50:])
    avg_n = np.mean(hist_new[-50:])
    print(f"\nüèÅ FINAL LOSS (Step {steps}):")
    print(f"   Old Janus: {avg_o:.4f}")
    print(f"   New Janus: {avg_n:.4f}")

    if avg_n < avg_o: print("üèÜ RESULT: New Spec Wins!")
    else: print("üìâ RESULT: Old School Wins.")

if __name__ == "__main__":
    main()

Below is the RoPE test script

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import matplotlib.pyplot as plt
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_FILE = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
os.makedirs(REPORT_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üèéÔ∏è JANUS ARCH DUEL (FIXED): Stock vs. Spec")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. CONFIG ---
class DuelConfig:
    def __init__(self):
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 256
        self.dropout = 0.0

# --- 3. OLD JANUS (Stock GPT-2) ---
class OldAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)
        self.c_attn = nn.Linear(config.d_model, 3 * config.d_model)
        self.c_proj = nn.Linear(config.d_model, config.d_model)

    def forward(self, x):
        B, S, D = x.shape
        # Calculate QKV
        qkv = self.c_attn(x).view(B, S, 3, self.n_heads, self.d_head).permute(2, 0, 1, 3, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # FIX: Transpose to (B, H, S, D) for proper attention
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        y = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.c_proj(y)

class OldBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = nn.LayerNorm(config.d_model)
        self.attn = OldAttention(config)
        self.ln2 = nn.LayerNorm(config.d_model)
        self.mlp = nn.Sequential(
            nn.Linear(config.d_model, 4 * config.d_model),
            nn.GELU(),
            nn.Linear(4 * config.d_model, config.d_model)
        )
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class OldGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
        self.blocks = nn.ModuleList([OldBlock(config) for _ in range(config.n_layers)])
        self.ln_f = nn.LayerNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # WEIGHT TYING (Crucial for Param Efficiency)
        self.token_emb.weight = self.head.weight

    def forward(self, idx, targets=None):
        B, S = idx.shape
        x = self.token_emb(idx) + self.pos_emb(torch.arange(S, device=idx.device))
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss

# --- 4. NEW JANUS (Janus Spec) ---
# RMSNorm, RoPE, SwiGLU, Bias=False, Tied Weights

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        # Bias = False
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        # x shape: [B, S, n_heads, d_head]
        # freqs_cis shape: [max_seq_len, d_head//2]
        # We need to slice freqs_cis to match sequence length and unsqueeze for heads
        B, S, H, D = x.shape
        freqs_cis = freqs_cis[:S].unsqueeze(1)  # [S, 1, d_head//2]

        x_c = torch.view_as_complex(x.float().reshape(B, S, H, -1, 2))  # [B, S, H, d_head//2]
        x_out = torch.view_as_real(x_c * freqs_cis).flatten(3)  # [B, S, H, d_head]
        return x_out.type_as(x)

    def forward(self, x):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        # Apply RoPE
        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)

        # Transpose to [B, H, S, D]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        y = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.o_proj(y)

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # Adjusted ratio 8/3 to keep parameter count roughly equal to standard MLP
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)

    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # WEIGHT TYING
        self.token_emb.weight = self.head.weight

    def forward(self, idx, targets=None):
        B, S = idx.shape
        x = self.token_emb(idx)
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss

# --- 5. LOADER ---
class BinLoader:
    def __init__(self, bin_path, block_size, batch_size):
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        self.block_size = block_size
        self.batch_size = batch_size

    def get_batch(self):
        ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

# --- 6. DUEL ---
def run_sprint(model_class, label, steps=500):
    # Full cleanup between models
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    print(f"\nüèÉ RUNNING: {label}")

    cfg = DuelConfig()
    model = model_class(cfg).to(DEVICE)
    params = sum(p.numel() for p in model.parameters())/1e6
    print(f"   Params: {params:.2f}M")

    optimizer = optim.AdamW(model.parameters(), lr=6e-4)
    loader = BinLoader(DATA_FILE, 256, 16)

    loss_history = []

    model.train()
    for step in tqdm(range(steps), desc=label):
        x, y = loader.get_batch()
        loss = model(x, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loss_history.append(loss.item())

    # Cleanup
    del model
    del optimizer
    del loader
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

    return loss_history

def main():
    steps = 500

    hist_old = run_sprint(OldGPT, "Old Janus (Stock)", steps)
    hist_new = run_sprint(NewGPT, "New Janus (Spec)", steps)

    print("\nüìä Generating Report...")
    plt.figure(figsize=(10, 6))
    plt.plot(hist_old, label='Old Janus (GPT-2)', alpha=0.7)
    plt.plot(hist_new, label='New Janus (RoPE/RMS/SwiGLU)', linewidth=2)
    plt.title("Architecture Duel: Loss Trajectory (500 Steps)")
    plt.xlabel("Training Steps")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)

    save_path = os.path.join(REPORT_DIR, "arch_duel_fixed.png")
    plt.savefig(save_path)
    print(f"üíæ Report saved: {save_path}")

    avg_o = np.mean(hist_old[-50:])
    avg_n = np.mean(hist_new[-50:])
    print(f"\nüèÅ FINAL LOSS (Last 50 Steps Average):")
    print(f"   Old Janus: {avg_o:.4f}")
    print(f"   New Janus: {avg_n:.4f}")

    if avg_n < avg_o:
        print("üèÜ RESULT: New Spec Wins!")
    else:
        print("üìâ RESULT: Old School Wins.")

if __name__ == "__main__":
    main()

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_FILE = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
MODEL_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_v3")
os.makedirs(REPORT_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ JANUS v3 PREVIEW: The Synthesis")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. CONFIGURATION ---
class V3Config:
    def __init__(self):
        # Architecture
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 256
        self.dropout = 0.0

        # Training
        self.max_steps = 1000
        self.batch_size = 32 # Higher batch size since model is efficient? Let's stick to safe 32.
        self.grad_accum = 1

        # Scheduler
        self.warmup_steps = 200
        self.max_pressure = 0.15
        self.spatial_schedule = 'cubic'

# --- 3. ARCHITECTURE (The Winner) ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, lambdas):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        # RoPE
        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)

        # Transpose
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Attention
        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        head_out = attn @ v

        # VSM Pressure
        steer_loss = 0.0
        l_coh, l_div = lambdas

        if l_div > 0.0:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, head_out # Return head_out for telemetry

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewBlock(nn.Module):
    def __init__(self, config, layer_id):
        super().__init__()
        self.layer_id = layer_id
        self.n_layers = config.n_layers
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)

    def forward(self, x, lambdas):
        # Spatial Schedule Injection
        # We can modify lambdas here based on layer_id if needed,
        # but Scheduler usually handles the calculation.
        # Let's assume Scheduler returns the raw base values, and we scale here?
        # Actually, let's keep it simple: Scheduler passes (coh, div) specific to this layer.

        a, s, heads = self.attn(self.ln1(x), lambdas)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, heads

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config, i) for i in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight # Tied Weights

    def forward(self, idx, lambdas_list, targets=None):
        B, S = idx.shape
        x = self.token_emb(idx)

        total_steer = 0.0
        all_heads = []

        for i, block in enumerate(self.blocks):
            x, s, h = block(x, lambdas_list[i])
            total_steer += s
            all_heads.append(h)

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return loss, total_steer, all_heads

# --- 4. SCHEDULER (Trapezoidal with Delay) ---
class DelayedTrapezoidScheduler:
    def __init__(self, config):
        self.warmup = config.warmup_steps
        self.total = config.max_steps
        self.max_p = config.max_pressure
        self.n_layers = config.n_layers

        # Calculate trapezoid phases for the ACTIVE region
        self.active_steps = self.total - self.warmup
        self.ramp_steps = int(self.active_steps * 0.25)
        self.decay_steps = int(self.active_steps * 0.25)
        self.hold_steps = self.active_steps - self.ramp_steps - self.decay_steps

    def get_pressure(self, step):
        if step < self.warmup:
            return 0.0

        t = step - self.warmup

        if t < self.ramp_steps:
            return self.max_p * (t / self.ramp_steps)
        elif t < (self.ramp_steps + self.hold_steps):
            return self.max_p
        else:
            # Decay phase
            remaining = self.active_steps - t
            return self.max_p * (remaining / self.decay_steps)

    def get_layer_lambdas(self, step):
        base_div = self.get_pressure(step)
        base_coh = base_div * 0.2 # Standard ratio

        lambdas = []
        for i in range(self.n_layers):
            # Cubic Spatial Schedule
            ratio = (i + 1) / self.n_layers
            s_mult = ratio ** 3

            lambdas.append((base_coh * s_mult, base_div * s_mult))
        return lambdas, base_div

# --- 5. LOADER ---
class BinLoader:
    def __init__(self, bin_path, block_size, batch_size):
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        self.block_size = block_size
        self.batch_size = batch_size
    def get_batch(self):
        ix = torch.randint(len(self.data) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

# --- 6. MAIN ---
def run_v3():
    gc.collect(); torch.cuda.empty_cache()

    cfg = V3Config()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4)
    scheduler = DelayedTrapezoidScheduler(cfg)
    loader = BinLoader(DATA_FILE, cfg.max_seq_len, cfg.batch_size)

    history = []

    model.train()
    pbar = tqdm(range(cfg.max_steps), desc="Janus v3")

    for step in pbar:
        x, y = loader.get_batch()

        # Get Lambdas
        layer_lambdas, current_p = scheduler.get_layer_lambdas(step)

        loss, steer, heads = model(x, layer_lambdas, y)
        total = loss + steer

        optimizer.zero_grad(); total.backward(); optimizer.step()

        # Quick Telemetry (Layer 11 Redundancy)
        if step % 20 == 0:
            with torch.no_grad():
                h = heads[-1] # Top layer
                flat = h.transpose(0, 1).reshape(8, -1)
                norm = F.normalize(flat, p=2, dim=1)
                gram = torch.mm(norm, norm.t())
                mask = ~torch.eye(8, dtype=torch.bool, device=DEVICE)
                red = (gram.abs() * mask.float()).sum().item() / 56 # 8*7

            history.append({
                'step': step,
                'loss': loss.item(),
                'pressure': current_p,
                'red_L11': red
            })
            pbar.set_description(f"L:{loss.item():.3f} | P:{current_p:.3f} | R:{red:.3f}")

    # Save Report
    df = pd.DataFrame(history)
    df.to_csv(os.path.join(REPORT_DIR, "janus_v3_preview.csv"), index=False)

    # Plot
    fig, ax1 = plt.subplots(figsize=(10, 6))
    ax1.set_xlabel('Step')
    ax1.set_ylabel('Loss', color='tab:red')
    ax1.plot(df['step'], df['loss'], color='tab:red', alpha=0.6, label='Loss')
    ax1.tick_params(axis='y', labelcolor='tab:red')

    ax2 = ax1.twinx()
    ax2.set_ylabel('Pressure', color='tab:blue')
    ax2.plot(df['step'], df['pressure'], color='tab:blue', linestyle='--', label='Pressure')
    ax2.tick_params(axis='y', labelcolor='tab:blue')

    plt.title("Janus v3 Preview: The Synthesis")
    plt.savefig(os.path.join(REPORT_DIR, "janus_v3_chart.png"))

    print(f"\nüèÅ FINAL LOSS: {df.iloc[-1]['loss']:.4f}")
    print(f"üèÅ FINAL REDUNDANCY (L11): {df.iloc[-1]['red_L11']:.4f}")

if __name__ == "__main__":
    run_v3()

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
import shutil
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY CLEANSE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_FILE = os.path.join(PROJECT_ROOT, "data/processed/TinyStories-train_full.bin")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v3")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ JANUS HERO v3 (FIXED): 512 Context / Batch 8 / Accum 8")

# --- 3. CONFIGURATION ---
class V3Config:
    def __init__(self):
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        self.max_steps = 5000
        self.max_pressure = 0.15

        self.batch_size = 8
        self.grad_accum = 8  # Effective Batch = 64
        self.val_interval = 250
        self.val_steps = 20

# --- 4. ARCHITECTURE (Janus Spec) ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, lambdas, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        steer_loss = 0.0
        metrics = {}
        l_coh, l_div = lambdas

        if l_div > 0.0 and self.training:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div

        if return_metrics:
            with torch.no_grad():
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

                sub_out = head_out[:, :, :128, :].transpose(1, 2).reshape(self.n_heads, -1, self.d_head)
                ranks = []
                for h in range(self.n_heads):
                    try:
                        S_vals = torch.linalg.svdvals(sub_out[h].float())
                        p = S_vals / S_vals.sum()
                        ent = -torch.sum(p * torch.log(p + 1e-9))
                        ranks.append(torch.exp(ent))
                    except: ranks.append(torch.tensor(0.0))
                metrics['eff_rank'] = torch.stack(ranks).to(x.device)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, lambdas, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), lambdas, return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

        # Scale residuals to depth
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, lambdas_list, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)

        total_steer = 0.0
        all_metrics = []

        for i, block in enumerate(self.blocks):
            x, s, m = block(x, lambdas_list[i], return_metrics)
            total_steer += s
            if return_metrics: all_metrics.append(m)

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return loss, total_steer, all_metrics

# --- 5. SCHEDULER & LOADERS ---
class FlightController:
    def __init__(self, config):
        self.config = config
    def get_pressure(self, step):
        if step < 1250: return 0.0
        elif step < 2500: return self.config.max_pressure * ((step - 1250) / 1250)
        elif step < 4000: return self.config.max_pressure
        else: return self.config.max_pressure * ((5000 - step) / 1000)
    def get_lambdas(self, step):
        p = self.get_pressure(step)
        base_coh = p * 0.2
        lambdas = []
        for i in range(self.config.n_layers):
            ratio = (i + 1) / self.config.n_layers
            s_mult = ratio ** 3
            lambdas.append((base_coh * s_mult, p * s_mult))
        return lambdas, p

class BinLoader:
    def __init__(self, bin_path, block_size, batch_size):
        self.data = np.memmap(bin_path, dtype=np.uint16, mode='r')
        total_tokens = len(self.data)
        split = int(total_tokens * 0.95)
        self.train_data = self.data[:split]
        self.val_data = self.data[split:]
        self.block_size = block_size
        self.batch_size = batch_size
        print(f"üì¶ Data Split | Train: {len(self.train_data):,} | Val: {len(self.val_data):,}")
    def get_batch(self, split='train'):
        d = self.train_data if split == 'train' else self.val_data
        ix = torch.randint(len(d) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(d[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(d[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
    def log(self, step, loss, val_loss, pressure, metrics):
        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": pressure
        }
        for i, m in enumerate(metrics):
            if not m: continue
            row[f"L{i}_sigma_a"] = m['sigma_a'].mean().item()
            row[f"L{i}_eff_rank"] = m['eff_rank'].mean().item()
        self.buffer.append(row)
    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_v3.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 6. MAIN ---
def run_v3():
    gc.collect(); torch.cuda.empty_cache()

    cfg = V3Config()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4, weight_decay=0.1)
    scheduler = FlightController(cfg)
    loader = BinLoader(DATA_FILE, cfg.max_seq_len, cfg.batch_size)
    recorder = BlackBox(SAVE_DIR)

    start_step = 0
    ckpt_path = os.path.join(SAVE_DIR, "ckpt_latest.pt")

    # ‚ö†Ô∏è SAFETY: Check if we should resume or restart
    # If the last run exploded, we MUST restart.
    if os.path.exists(ckpt_path):
        try:
            c = torch.load(ckpt_path, map_location=DEVICE)
            # Sanity check loss
            # If we can't check loss, we assume it's valid, but user implies last run was bad.
            # Let's force restart if user wants (comment out below to resume)
            print("‚ö†Ô∏è CHECKPOINT FOUND. DELETING TO START FRESH (PER USER REQUEST).")
            os.remove(ckpt_path)
            start_step = 0
        except:
            print("‚ö†Ô∏è Corrupt checkpoint found. Starting fresh.")
            start_step = 0

    print(f"\nüèÉ STARTING RUN: {start_step} -> {cfg.max_steps}")
    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    # [FIXED] ACCUMULATION LOGIC
    for step in pbar:
        model.train()
        batch_loss = 0.0
        lambdas, pressure = scheduler.get_lambdas(step)

        optimizer.zero_grad() # <--- MOVED OUTSIDE LOOP (CORRECT)

        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch('train')
            do_metrics = (_ == cfg.grad_accum - 1) and (step % cfg.val_interval == 0 or step == cfg.max_steps - 1)

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, steer, metrics = model(x, lambdas, y, return_metrics=do_metrics)
                total = (loss + steer) / cfg.grad_accum # Scale loss

            total.backward() # Accumulate gradients
            batch_loss += loss.item()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # LOGGING
        if step % cfg.val_interval == 0 or step == cfg.max_steps - 1:
            gc.collect(); torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(cfg.val_steps):
                    vx, vy = loader.get_batch('val')
                    vl, _, _ = model(vx, [(0.0,0.0)]*cfg.n_layers, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)

            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, pressure, metrics)
            recorder.flush()
            torch.save({'step': step, 'model': model.state_dict(), 'optim': optimizer.state_dict()}, ckpt_path)

            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|V:{val_loss:.3f}|P:{pressure:.3f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|P:{pressure:.3f}")

    torch.save(model.state_dict(), os.path.join(SAVE_DIR, "janus_v3_final.pt"))
    print("\nüèÜ MISSION COMPLETE.")

if __name__ == "__main__":
    run_v3()

In [None]:
# @title
import torch
import os
from google.colab import drive

# 1. SETUP
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

MODEL_DIR = "/content/drive/MyDrive/Project_XAI_Physical_Janus/data/models/janus_hero_v3"

print(f"üïµÔ∏è JANUS BLACK BOX: Inspecting {MODEL_DIR}...\n")

if not os.path.exists(MODEL_DIR):
    print("‚ùå Directory not found.")
else:
    files = [f for f in os.listdir(MODEL_DIR) if f.endswith(".pt")]

    if not files:
        print("‚ö†Ô∏è No .pt model files found.")
    else:
        for f in files:
            path = os.path.join(MODEL_DIR, f)
            print(f"üìÑ Found: {f}")
            try:
                # Load metadata only (map to CPU to save memory)
                checkpoint = torch.load(path, map_location='cpu')

                # Check for 'step' key
                if isinstance(checkpoint, dict) and 'step' in checkpoint:
                    step = checkpoint['step']
                    print(f"   ‚úÖ STATUS: Recoverable")
                    print(f"   üî¢ STEP COUNT: {step}")

                    # Check Logic
                    if step > 0:
                        print("   üöÄ VERDICT: You can resume from this!")
                    else:
                        print("   ‚ö†Ô∏è VERDICT: This is a fresh/empty init.")

                elif isinstance(checkpoint, dict):
                    print("   ‚ö†Ô∏è STATUS: Dict found, but no 'step' key. (Likely a final weight dump, not a checkpoint)")
                else:
                    print("   ‚ö†Ô∏è STATUS: Raw model weights (No metadata).")

            except Exception as e:
                print(f"   ‚ùå ERROR: File corrupted or unreadable. ({e})")
            print("-" * 40)

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import GPT2Tokenizer
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
MODEL_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v3/janus_v3_final.pt")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üîÆ JANUS v3: THE ORACLE")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. CONFIGURATION ---
class V3Config:
    def __init__(self):
        self.vocab_size = 50304
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.0

# --- 3. ARCHITECTURE (Must Match Training) ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)

        y = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.o_proj(y)

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight

    def forward(self, idx):
        B, S = idx.shape
        x = self.token_emb(idx)
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        return self.head(x)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=0.7, top_k=50):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -512:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# --- 4. EXECUTION ---
try:
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
except:
    print("‚ùå Installing transformers...")
    os.system('pip install transformers')
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

cfg = V3Config()
model = NewGPT(cfg).to(DEVICE)

print(f"üîÑ Loading Weights from {MODEL_PATH}...")
if os.path.exists(MODEL_PATH):
    try:
        # Load weights
        state_dict = torch.load(MODEL_PATH, map_location=DEVICE)
        model.load_state_dict(state_dict)
        print("‚úÖ Weights Loaded Successfully!")
    except Exception as e:
        print(f"‚ùå Load Failed: {e}")
        sys.exit(1)
else:
    print("‚ùå Model file missing.")
    sys.exit(1)

model.eval()

# --- 5. THE PROBE ---
prompts = [
    "Once upon a time",
    "Lily wanted a",
    "The big red ball",
    "Tom went to the",
    "One day, a little"
]

print("\n" + "="*40)
print("üß™ JANUS v3 (New Arch) OUTPUTS")
print("="*40)

for p in prompts:
    input_ids = tokenizer.encode(p, return_tensors='pt').to(DEVICE)
    output_ids = model.generate(input_ids, max_new_tokens=100, temperature=0.6, top_k=40)
    text = tokenizer.decode(output_ids[0].tolist(), skip_special_tokens=True)
    print(f"\nüìù PROMPT: {p}")
    print(f"ü§ñ JANUS: {text}")

In [None]:
# @title
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import drive

# --- SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
V2_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v2/telemetry_hero.parquet")
V3_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_hero_v3/telemetry_v3.parquet")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
os.makedirs(REPORT_DIR, exist_ok=True)

print("‚öñÔ∏è JANUS TITLE FIGHT: v2 vs. v3")

def load_clean_data(path, label):
    if not os.path.exists(path):
        print(f"‚ùå Missing: {path}")
        return None

    df = pd.read_parquet(path)

    # Clean restarts (if step count went backwards, keep the latest run)
    # Simple heuristic: sort by step, drop duplicates keeping last
    df = df.sort_values('step')
    df = df.drop_duplicates(subset='step', keep='last')

    df['Model'] = label

    # Calculate Averages if missing
    if 'sigma_a_avg' not in df.columns:
        l_cols = [c for c in df.columns if 'sigma_a' in c and 'L' in c]
        if l_cols: df['sigma_a_avg'] = df[l_cols].mean(axis=1)
        else: df['sigma_a_avg'] = 0.0

    if 'eff_rank_avg' not in df.columns:
        l_cols = [c for c in df.columns if 'eff_rank' in c and 'L' in c]
        if l_cols: df['eff_rank_avg'] = df[l_cols].mean(axis=1)
        else: df['eff_rank_avg'] = 0.0

    return df

def run_analysis():
    v2 = load_clean_data(V2_PATH, "v2 (Old Arch)")
    v3 = load_clean_data(V3_PATH, "v3 (Spec Arch)")

    if v2 is None or v3 is None: return

    # Combine
    combined = pd.concat([v2, v3])

    # PLOTTING
    fig, axes = plt.subplots(3, 1, figsize=(10, 15), sharex=True)
    palette = {'v2 (Old Arch)': 'gray', 'v3 (Spec Arch)': 'tab:red'}

    # 1. Validation Loss
    sns.lineplot(data=combined, x='step', y='val_loss', hue='Model', ax=axes[0], palette=palette, linewidth=2)
    axes[0].set_title("Validation Loss (Intelligence)")
    axes[0].set_ylabel("Cross Entropy")
    axes[0].grid(True, alpha=0.3)

    # 2. Redundancy (Sigma A)
    sns.lineplot(data=combined, x='step', y='sigma_a_avg', hue='Model', ax=axes[1], palette=palette, linewidth=2)
    axes[1].set_title("Head Redundancy (Uniqueness)")
    axes[1].set_ylabel("Correlation (Lower is Better)")
    axes[1].grid(True, alpha=0.3)
    axes[1].set_ylim(0, 0.1) # Zoom in on the low end

    # 3. Pressure Schedule
    sns.lineplot(data=combined, x='step', y='pressure', hue='Model', ax=axes[2], palette=palette, linestyle='--')
    axes[2].set_title("Pressure Schedule (Force Applied)")
    axes[2].set_ylabel("Lambda Div")
    axes[2].grid(True, alpha=0.3)

    save_path = os.path.join(REPORT_DIR, "janus_v2_vs_v3.png")
    plt.tight_layout()
    plt.savefig(save_path)
    print(f"\nüìä Chart Saved: {save_path}")

    # STATS CARD
    print("\n" + "="*40)
    print("üèÜ FINAL SCORECARD (Step 5000)")
    print("="*40)

    def get_final(df):
        return df.iloc[-1]

    f2 = get_final(v2)
    f3 = get_final(v3)

    # Percent Improvement
    loss_imp = ((f2['val_loss'] - f3['val_loss']) / f2['val_loss']) * 100

    print(f"LOSS (Lower Wins):")
    print(f"   v2: {f2['val_loss']:.4f}")
    print(f"   v3: {f3['val_loss']:.4f}  (Improvement: +{loss_imp:.1f}%)")
    print("-" * 20)
    print(f"REDUNDANCY (Target ~0.003):")
    print(f"   v2: {f2['sigma_a_avg']:.4f}")
    print(f"   v3: {f3['sigma_a_avg']:.4f}")
    print("-" * 20)
    print(f"PRESSURE REQUIRED:")
    print(f"   v2: {f2['pressure']:.2f}")
    print(f"   v3: {f3['pressure']:.2f}")

if __name__ == "__main__":
    run_analysis()

In [None]:
# @title
import os
import numpy as np
from tqdm import tqdm
from transformers import GPT2Tokenizer
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_DIR = os.path.join(PROJECT_ROOT, "data/wikitext")
os.makedirs(DATA_DIR, exist_ok=True)

print(f"üìö WIKITEXT-103 PREP (AUTHENTICATED & SANITIZED)")

# --- 2. AUTHENTICATION ---
TOKEN_FILE = "/content/hf_token.txt"

try:
    from huggingface_hub import login
except ImportError:
    print("üì¶ Installing huggingface_hub...")
    os.system('pip install huggingface_hub')
    from huggingface_hub import login

if os.path.exists(TOKEN_FILE):
    with open(TOKEN_FILE, 'r', encoding='utf-8') as f:
        # üü¢ CRITICAL FIX: Strip BOM and whitespace
        token = f.read().replace('\ufeff', '').strip()

    print(f"üîë Token found (length: {len(token)}). Logging in...")
    try:
        login(token=token)
        print("‚úÖ Authenticated successfully.")
    except Exception as e:
        print(f"‚ùå Login Failed: {e}")
        # Continue anyway, WikiText might be public enough to not need it
else:
    print(f"‚ö†Ô∏è  WARNING: '{TOKEN_FILE}' not found. Attempting anonymous...")

# --- 3. LOAD DATASET ---
try:
    from datasets import load_dataset
except ImportError:
    print("üì¶ Installing datasets library...")
    os.system('pip install datasets')
    from datasets import load_dataset

print("‚¨áÔ∏è  Fetching WikiText-103 via Hugging Face...")
# Using 'wikitext-103-raw-v1'
dataset = load_dataset("wikitext", "wikitext-103-raw-v1")

# --- 4. PROCESSING ---
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

def process_split(hf_split, bin_name):
    bin_path = os.path.join(DATA_DIR, bin_name)

    if os.path.exists(bin_path):
        print(f"‚úÖ {bin_name} exists. Skipping.")
        return

    print(f"‚öôÔ∏è  Processing {hf_split} -> {bin_name}...")

    texts = dataset[hf_split]['text']
    all_ids = []

    # Process in chunks
    batch_size = 1000
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i : i + batch_size]
        # Filter empty lines
        valid_batch = [t for t in batch if len(t) > 0]
        if not valid_batch: continue

        text_chunk = tokenizer.eos_token.join(valid_batch)
        ids = tokenizer.encode(text_chunk)
        all_ids.extend(ids)

    # Save
    arr = np.array(all_ids, dtype=np.uint16)
    arr.tofile(bin_path)
    print(f"   üíæ Saved {len(arr):,} tokens to {bin_path}")

# Run
process_split("train", "train.bin")
process_split("validation", "val.bin")
process_split("test", "test.bin")

print("\nüéâ WikiText-103 Ready for the Arena.")

V3 Wiki 103 Baseline below

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY NUKE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_DIR = os.path.join(PROJECT_ROOT, "data/wikitext")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_wikitext_baseline")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üìâ JANUS WIKI BASELINE (SPORT MODE)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 3. CONFIGURATION ---
class WikiConfig:
    def __init__(self):
        # Architecture (Janus Spec)
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        # Training
        self.max_steps = 6000
        # üü¢ OPTIMIZATION: Bumping Batch to 16 (from 8)
        self.batch_size = 16
        self.grad_accum = 4      # Effective Batch still 64 (16 * 4)

        # Logging
        self.lite_interval = 500
        self.full_interval = 1000

# --- 4. ARCHITECTURE ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        metrics = {}
        if return_metrics:
            with torch.no_grad():
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

                sub_out = head_out[:, :, :128, :].transpose(1, 2).reshape(self.n_heads, -1, self.d_head)
                ranks = []
                for h in range(self.n_heads):
                    try:
                        S_vals = torch.linalg.svdvals(sub_out[h].float())
                        p = S_vals / S_vals.sum()
                        ent = -torch.sum(p * torch.log(p + 1e-9))
                        ranks.append(torch.exp(ent))
                    except: ranks.append(torch.tensor(0.0))
                metrics['eff_rank'] = torch.stack(ranks).to(x.device)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, return_metrics=False):
        a, m = self.attn(self.ln1(x), return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)

        all_metrics = []
        for block in self.blocks:
            x, m = block(x, return_metrics)
            if return_metrics: all_metrics.append(m)

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return loss, all_metrics

# --- 5. DATA & LOGGING ---
class BinLoader:
    def __init__(self, data_dir, block_size, batch_size):
        self.train_data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode='r')
        self.val_data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode='r')
        self.block_size = block_size
        self.batch_size = batch_size
        print(f"üì¶ WikiText | Train: {len(self.train_data):,} | Val: {len(self.val_data):,}")

    def get_batch(self, split='train'):
        d = self.train_data if split == 'train' else self.val_data
        ix = torch.randint(len(d) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(d[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(d[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
    def log(self, step, loss, val_loss, metrics):
        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
        }
        if metrics and len(metrics) > 0:
            for i, m in enumerate(metrics):
                if not m: continue
                if 'sigma_a' in m: row[f"L{i}_sigma_a"] = m['sigma_a'].mean().item()
                if 'eff_rank' in m: row[f"L{i}_eff_rank"] = m['eff_rank'].mean().item()
        self.buffer.append(row)

    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_base.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 6. MAIN ---
def run_baseline():
    gc.collect(); torch.cuda.empty_cache()

    cfg = WikiConfig()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4, weight_decay=0.1)
    loader = BinLoader(DATA_DIR, cfg.max_seq_len, cfg.batch_size)
    recorder = BlackBox(SAVE_DIR)

    start_step = 0
    ckpt_path = os.path.join(SAVE_DIR, "ckpt_baseline.pt")
    if os.path.exists(ckpt_path):
        print("üîÑ Resuming Checkpoint...")
        try:
            c = torch.load(ckpt_path, map_location=DEVICE)
            model.load_state_dict(c['model'])
            optimizer.load_state_dict(c['optim'])
            start_step = c['step']
        except:
            print("‚ö†Ô∏è Checkpoint corrupt. Starting fresh.")

    print(f"\nüèÉ STARTING BASELINE (SPORT MODE): {start_step} -> {cfg.max_steps}")
    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    for step in pbar:
        model.train()
        batch_loss = 0.0

        optimizer.zero_grad()

        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch('train')
            is_full = (step % cfg.full_interval == 0) and (_ == cfg.grad_accum - 1)

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, metrics = model(x, y, return_metrics=is_full)
                total = loss / cfg.grad_accum

            total.backward()
            batch_loss += loss.item()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if step % cfg.lite_interval == 0 or step == cfg.max_steps - 1:
            gc.collect(); torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(20):
                    vx, vy = loader.get_batch('val')
                    vl, _ = model(vx, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)

            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, metrics)
            recorder.flush()
            torch.save({'step': step, 'model': model.state_dict(), 'optim': optimizer.state_dict()}, ckpt_path)

            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|V:{val_loss:.3f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}")

    final_path = os.path.join(SAVE_DIR, "janus_v3_base.pt")
    torch.save(model.state_dict(), final_path)
    print(f"\nüèÜ BASELINE COMPLETE. Saved to {final_path}")

if __name__ == "__main__":
    run_baseline()

V3 Hero WikiText-103 Below

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY NUKE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_DIR = os.path.join(PROJECT_ROOT, "data/wikitext")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_wikitext_hero")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üöÄ JANUS WIKI HERO (Pressure Active)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 3. CONFIGURATION ---
class WikiConfig:
    def __init__(self):
        # Architecture
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        # Training
        self.max_steps = 6000
        self.batch_size = 16
        self.grad_accum = 4      # Effective Batch 64

        # Scheduler
        self.max_pressure = 0.15

        # Logging
        self.lite_interval = 500
        self.full_interval = 1000

# --- 4. ARCHITECTURE (With Steering) ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, lambdas, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        # --- VSM PRESSURE LOGIC ---
        steer_loss = 0.0
        l_coh, l_div = lambdas

        if l_div > 0.0 and self.training:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div

        metrics = {}
        if return_metrics:
            with torch.no_grad():
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

                sub_out = head_out[:, :, :128, :].transpose(1, 2).reshape(self.n_heads, -1, self.d_head)
                ranks = []
                for h in range(self.n_heads):
                    try:
                        S_vals = torch.linalg.svdvals(sub_out[h].float())
                        p = S_vals / S_vals.sum()
                        ent = -torch.sum(p * torch.log(p + 1e-9))
                        ranks.append(torch.exp(ent))
                    except: ranks.append(torch.tensor(0.0))
                metrics['eff_rank'] = torch.stack(ranks).to(x.device)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, lambdas, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), lambdas, return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, lambdas_list, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)

        total_steer = 0.0
        all_metrics = []
        for i, block in enumerate(self.blocks):
            x, s, m = block(x, lambdas_list[i], return_metrics)
            total_steer += s
            if return_metrics: all_metrics.append(m)

        x = self.ln_f(x)
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return loss, total_steer, all_metrics

# --- 5. SCHEDULER & LOADERS ---
class FlightController:
    def __init__(self, config):
        self.config = config

    def get_pressure(self, step):
        # 0-1500: Clean (25%)
        if step < 1500: return 0.0
        # 1500-3000: Ramp (25%)
        elif step < 3000: return self.config.max_pressure * ((step - 1500) / 1500)
        # 3000-4500: Hold (25%)
        elif step < 4500: return self.config.max_pressure
        # 4500-6000: Decay (25%)
        else: return self.config.max_pressure * ((6000 - step) / 1500)

    def get_lambdas(self, step):
        p = self.get_pressure(step)
        base_coh = p * 0.2
        lambdas = []
        for i in range(self.config.n_layers):
            ratio = (i + 1) / self.config.n_layers
            s_mult = ratio ** 3
            lambdas.append((base_coh * s_mult, p * s_mult))
        return lambdas, p

class BinLoader:
    def __init__(self, data_dir, block_size, batch_size):
        self.train_data = np.memmap(os.path.join(data_dir, "train.bin"), dtype=np.uint16, mode='r')
        self.val_data = np.memmap(os.path.join(data_dir, "val.bin"), dtype=np.uint16, mode='r')
        self.block_size = block_size
        self.batch_size = batch_size
        print(f"üì¶ WikiText | Train: {len(self.train_data):,} | Val: {len(self.val_data):,}")
    def get_batch(self, split='train'):
        d = self.train_data if split == 'train' else self.val_data
        ix = torch.randint(len(d) - self.block_size, (self.batch_size,))
        x = torch.stack([torch.from_numpy(d[i:i+self.block_size].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(d[i+1:i+1+self.block_size].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
    def log(self, step, loss, val_loss, pressure, metrics):
        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": pressure
        }
        if metrics and len(metrics) > 0:
            for i, m in enumerate(metrics):
                if not m: continue
                if 'sigma_a' in m: row[f"L{i}_sigma_a"] = m['sigma_a'].mean().item()
                if 'eff_rank' in m: row[f"L{i}_eff_rank"] = m['eff_rank'].mean().item()
        self.buffer.append(row)
    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_hero.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 6. MAIN ---
def run_hero():
    gc.collect(); torch.cuda.empty_cache()

    cfg = WikiConfig()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6e-4, weight_decay=0.1)
    scheduler = FlightController(cfg)
    loader = BinLoader(DATA_DIR, cfg.max_seq_len, cfg.batch_size)
    recorder = BlackBox(SAVE_DIR)

    start_step = 0
    ckpt_path = os.path.join(SAVE_DIR, "ckpt_hero.pt")
    if os.path.exists(ckpt_path):
        print("üîÑ Resuming Checkpoint...")
        try:
            c = torch.load(ckpt_path, map_location=DEVICE)
            model.load_state_dict(c['model'])
            optimizer.load_state_dict(c['optim'])
            start_step = c['step']
        except:
            print("‚ö†Ô∏è Checkpoint corrupt. Starting fresh.")

    print(f"\nüèÉ STARTING HERO RUN: {start_step} -> {cfg.max_steps}")
    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    for step in pbar:
        model.train()
        batch_loss = 0.0
        lambdas, pressure = scheduler.get_lambdas(step)

        optimizer.zero_grad()

        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch('train')
            is_full = (step % cfg.full_interval == 0) and (_ == cfg.grad_accum - 1)

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, steer, metrics = model(x, lambdas, y, return_metrics=is_full)
                total = (loss + steer) / cfg.grad_accum

            total.backward()
            batch_loss += loss.item()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if step % cfg.lite_interval == 0 or step == cfg.max_steps - 1:
            gc.collect(); torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(20):
                    vx, vy = loader.get_batch('val')
                    # Zero pressure for validation
                    zero_l = [(0.0,0.0)]*cfg.n_layers
                    vl, _, _ = model(vx, zero_l, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)

            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, pressure, metrics)
            recorder.flush()
            torch.save({'step': step, 'model': model.state_dict(), 'optim': optimizer.state_dict()}, ckpt_path)

            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|V:{val_loss:.3f}|P:{pressure:.3f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|P:{pressure:.3f}")

    final_path = os.path.join(SAVE_DIR, "janus_v3_hero.pt")
    torch.save(model.state_dict(), final_path)
    print(f"\nüèÜ HERO RUN COMPLETE. Saved to {final_path}")

if __name__ == "__main__":
    run_hero()

In [None]:
# @title
import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from google.colab import drive

# --- SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
BASE_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_wikitext_baseline/telemetry_base.parquet")
HERO_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_wikitext_hero/telemetry_hero.parquet")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
os.makedirs(REPORT_DIR, exist_ok=True)

print("‚öñÔ∏è WIKITEXT DUEL: BASELINE vs. HERO")

def load_data(path, label):
    if not os.path.exists(path):
        print(f"‚ùå Missing: {path}")
        return None

    df = pd.read_parquet(path)
    df = df.sort_values('step').drop_duplicates(subset='step', keep='last')
    df['Model'] = label

    # Calculate Average Sigma_A across all layers (if columns exist)
    sa_cols = [c for c in df.columns if 'sigma_a' in c]
    if sa_cols:
        df['sigma_a_avg'] = df[sa_cols].mean(axis=1)
    else:
        df['sigma_a_avg'] = np.nan

    # Calculate Average Eff Rank
    er_cols = [c for c in df.columns if 'eff_rank' in c]
    if er_cols:
        df['eff_rank_avg'] = df[er_cols].mean(axis=1)
    else:
        df['eff_rank_avg'] = np.nan

    return df

def run_analysis():
    base = load_data(BASE_PATH, "Baseline (Passive)")
    hero = load_data(HERO_PATH, "Hero (Active Pressure)")

    if base is None or hero is None: return

    combined = pd.concat([base, hero])

    # PLOTTING
    fig, axes = plt.subplots(3, 1, figsize=(12, 18), sharex=True)
    palette = {'Baseline (Passive)': 'gray', 'Hero (Active Pressure)': 'tab:blue'}

    # 1. Validation Loss (The Scoreboard)
    sns.lineplot(data=combined, x='step', y='val_loss', hue='Model', ax=axes[0], palette=palette, linewidth=2.5)
    axes[0].set_title("Validation Loss (Lower is Better)")
    axes[0].set_ylabel("Cross Entropy")
    axes[0].grid(True, alpha=0.3)

    # Add annotation for the cooldown phase
    axes[0].axvspan(4500, 6000, color='green', alpha=0.1, label='Cooldown Phase')

    # 2. Perplexity (The Real World Metric)
    # Convert loss to PPL for visualization
    combined['ppl'] = np.exp(combined['val_loss'])
    sns.lineplot(data=combined, x='step', y='ppl', hue='Model', ax=axes[1], palette=palette, linewidth=2.5)
    axes[1].set_title("Perplexity (The Generalization Gap)")
    axes[1].set_ylabel("PPL")
    axes[1].grid(True, alpha=0.3)

    # 3. Orthogonality (The Mechanism)
    # Only plot points where we have data (every 1000 steps)
    subset = combined.dropna(subset=['sigma_a_avg'])
    if not subset.empty:
        sns.lineplot(data=subset, x='step', y='sigma_a_avg', hue='Model', ax=axes[2], palette=palette, marker='o', linewidth=2)
        axes[2].set_title("Head Redundancy (Sigma A)")
        axes[2].set_ylabel("Correlation (Lower = More Orthogonal)")
        axes[2].grid(True, alpha=0.3)
        # Mark the pressure zone
        axes[2].axvspan(1500, 4500, color='red', alpha=0.1, label='Pressure Zone')

    plt.tight_layout()
    save_path = os.path.join(REPORT_DIR, "wikitext_duel_analysis.png")
    plt.savefig(save_path)
    print(f"\nüìä Forensic Chart Saved: {save_path}")

    # STATS CARD
    print("\n" + "="*40)
    print("üèÜ FINAL WIKITEXT SCORECARD (Step 6000)")
    print("="*40)

    b_final = base.iloc[-1]
    h_final = hero.iloc[-1]

    b_ppl = math.exp(b_final['val_loss'])
    h_ppl = math.exp(h_final['val_loss'])
    imp = b_ppl - h_ppl

    print(f"PERPLEXITY (Lower Wins):")
    print(f"   Baseline: {b_ppl:.2f}")
    print(f"   Hero:     {h_ppl:.2f}")
    print(f"   Delta:    -{imp:.2f} PPL")
    print("-" * 20)
    print(f"VALIDATION LOSS:")
    print(f"   Baseline: {b_final['val_loss']:.4f}")
    print(f"   Hero:     {h_final['val_loss']:.4f}")
    print("-" * 20)

    # Check Mechanism
    # Compare Sigma A at Step 4000 (Peak Pressure)
    try:
        b_mid = base[base['step'] == 4000].iloc[0]['sigma_a_avg']
        h_mid = hero[hero['step'] == 4000].iloc[0]['sigma_a_avg']
        print(f"MECHANISM CHECK (Step 4000 - Peak Pressure):")
        print(f"   Baseline Redundancy: {b_mid:.4f}")
        print(f"   Hero Redundancy:     {h_mid:.4f}")
        if h_mid < b_mid:
            print("   ‚úÖ CONFIRMED: Pressure suppressed redundancy.")
        else:
            print("   ‚ö†Ô∏è ANOMALY: Pressure did not suppress redundancy.")
    except:
        print("   (Mid-run metrics missing, skipping check)")

if __name__ == "__main__":
    import math
    run_analysis()

Inference Testing

In [None]:
# @title
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import GPT2Tokenizer
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
BASE_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_wikitext_baseline/janus_v3_base.pt")
HERO_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_wikitext_hero/janus_v3_hero.pt")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚öîÔ∏è JANUS INFERENCE DUEL: BASELINE vs. HERO")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. CONFIGURATION (Identical for Both) ---
class WikiConfig:
    def __init__(self):
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.0

# --- 3. ARCHITECTURE ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        y = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.o_proj(y)

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight

    def forward(self, idx):
        B, S = idx.shape
        x = self.token_emb(idx)
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        return self.head(x)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=0.7, top_k=50):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -512:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# --- 4. LOAD WEIGHTS ---
cfg = WikiConfig()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

print("üîÑ Loading Baseline...")
model_base = NewGPT(cfg).to(DEVICE)
model_base.load_state_dict(torch.load(BASE_PATH, map_location=DEVICE))
model_base.eval()

print("üîÑ Loading Hero...")
model_hero = NewGPT(cfg).to(DEVICE)
model_hero.load_state_dict(torch.load(HERO_PATH, map_location=DEVICE))
model_hero.eval()

# --- 5. THE ARENA ---
prompts = [
    "The Roman Empire was",
    "In the early 19th century, the",
    "The chemical formula for water is",
    "Following the release of the album,",
    "Located in the northern part of"
]

print("\n" + "="*60)
print("üß™ HEAD-TO-HEAD: BASELINE vs. HERO")
print("="*60)

for p in prompts:
    input_ids = tokenizer.encode(p, return_tensors='pt').to(DEVICE)

    # Generate Baseline
    out_b = model_base.generate(input_ids, max_new_tokens=60, temperature=0.6)
    text_b = tokenizer.decode(out_b[0].tolist(), skip_special_tokens=True)

    # Generate Hero
    out_h = model_hero.generate(input_ids, max_new_tokens=60, temperature=0.6)
    text_h = tokenizer.decode(out_h[0].tolist(), skip_special_tokens=True)

    print(f"\nüìù PROMPT: {p}")
    print("-" * 20)
    print(f"‚ö™ BASELINE: {text_b[len(p):].strip()}")
    print("-" * 20)
    print(f"üîµ HERO:     {text_h[len(p):].strip()}")
    print("="*60)

WikiText-103 Partitioning Cell Below

In [None]:
# @title
import os
import numpy as np
import math
from google.colab import drive

# --- SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
SOURCE_FILE = os.path.join(PROJECT_ROOT, "data/wikitext/train.bin")
DEST_DIR = os.path.join(PROJECT_ROOT, "data/wikitext_chunks")
os.makedirs(DEST_DIR, exist_ok=True)

# Spec Constants
TOTAL_STEPS = 20000
CHUNK_STEPS = 500
NUM_CHUNKS = TOTAL_STEPS // CHUNK_STEPS  # Should be 40
BATCH_SIZE = 16
GRAD_ACCUM = 4
SEQ_LEN = 512

# Tokens per Step = 16 * 4 * 512 = 32,768
TOKENS_PER_STEP = BATCH_SIZE * GRAD_ACCUM * SEQ_LEN
TOKENS_PER_CHUNK = TOKENS_PER_STEP * CHUNK_STEPS # ~16.38M tokens

print(f"üî™ WIKITEXT PARTITIONING PROTOCOL")
print(f"   Target: {NUM_CHUNKS} Chunks")
print(f"   Size:   {TOKENS_PER_CHUNK:,} tokens/chunk")

# --- EXECUTION ---
if not os.path.exists(SOURCE_FILE):
    print(f"‚ùå Critical Error: Source file not found at {SOURCE_FILE}")
    exit(1)

# Load Source (Memory Map to avoid RAM explosion)
data = np.memmap(SOURCE_FILE, dtype=np.uint16, mode='r')
total_tokens = len(data)
print(f"   Source: {total_tokens:,} tokens available")

# Validation
required_tokens = TOKENS_PER_CHUNK * NUM_CHUNKS
if total_tokens < required_tokens:
    print(f"‚ö†Ô∏è  WARNING: Source too small! Need {required_tokens:,}, have {total_tokens:,}")
    # We will loop the data if needed during training, but for partitioning,
    # we should just wrap around or limit chunks.
    # For now, let's just write what we can and warn.
else:
    print(f"   Status: Sufficient Data ({total_tokens / required_tokens:.2f}x coverage)")

print("-" * 40)

# Slicing Loop
for i in range(NUM_CHUNKS):
    start = i * TOKENS_PER_CHUNK
    end = start + TOKENS_PER_CHUNK

    # Handle wrap-around safety (though WikiText is big enough usually)
    if end > total_tokens:
        print(f"   ‚ö†Ô∏è Wrapping data for Chunk {i}...")
        # Complex wrap logic or just stop? WikiText-103 is ~100M tokens.
        # 40 chunks * 16M = 640M tokens.
        # WAIT.
        # WikiText-103 is ~103M tokens.
        # 20k steps * 64 batch * 512 seq = 655M tokens.
        # We need to loop the dataset ~6.5 times.

        # FIX: The partitioning script should Creates STATIC chunks by looping the source data.
        # This ensures the training loader never has to think about indices.

        # Create a buffer for this chunk
        chunk_data = np.empty(TOKENS_PER_CHUNK, dtype=np.uint16)

        # Fill it
        current_fill = 0
        src_ptr = start % total_tokens

        while current_fill < TOKENS_PER_CHUNK:
            available = min(total_tokens - src_ptr, TOKENS_PER_CHUNK - current_fill)
            chunk_data[current_fill : current_fill + available] = data[src_ptr : src_ptr + available]
            current_fill += available
            src_ptr = (src_ptr + available) % total_tokens

    else:
        # Direct slice
        chunk_data = data[start:end]

    # Save
    fname = f"train_chunk_{i:03d}.bin"
    fpath = os.path.join(DEST_DIR, fname)

    # Write to disk
    # We use open/write for safety over memmap flush
    with open(fpath, 'wb') as f:
        f.write(chunk_data.tobytes())

    print(f"   ‚úÖ Wrote {fname} ({len(chunk_data):,} tokens)")

print("\nüéâ PARTITIONING COMPLETE.")
print(f"   Ready for {TOTAL_STEPS} step marathon.")

Testing the Chunks

In [None]:
# @title
import os
import numpy as np
from google.colab import drive

# --- SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
CHUNKS_DIR = os.path.join(PROJECT_ROOT, "data/wikitext_chunks")

# Specs
EXPECTED_COUNT = 40
TOKENS_PER_CHUNK = 16_384_000
EXPECTED_BYTES = TOKENS_PER_CHUNK * 2 # uint16 = 2 bytes

print(f"üïµÔ∏è JANUS DATA FORENSICS")
print(f"   Target: {CHUNKS_DIR}")

def verify():
    if not os.path.exists(CHUNKS_DIR):
        print("‚ùå CRITICAL: Directory not found.")
        return

    files = sorted([f for f in os.listdir(CHUNKS_DIR) if f.startswith("train_chunk_") and f.endswith(".bin")])

    # 1. Count Check
    if len(files) == EXPECTED_COUNT:
        print(f"‚úÖ Found {len(files)} chunks (Correct).")
    else:
        print(f"‚ö†Ô∏è  WARNING: Found {len(files)} chunks (Expected {EXPECTED_COUNT}).")

    # 2. Byte Check
    size_errors = 0
    for f in files:
        path = os.path.join(CHUNKS_DIR, f)
        size = os.path.getsize(path)
        if size != EXPECTED_BYTES:
            print(f"   ‚ùå {f}: SIZE MISMATCH! {size:,} bytes (Expected {EXPECTED_BYTES:,})")
            size_errors += 1

    if size_errors == 0:
        print("‚úÖ All file sizes are byte-perfect.")
    else:
        print(f"‚ùå Abort: {size_errors} chunks are corrupted.")
        return

    # 3. Content Logic Check (First and Last)
    print("\nüî¨ Inspecting Content...")

    # Check First Chunk
    c0 = files[0]
    data0 = np.memmap(os.path.join(CHUNKS_DIR, c0), dtype=np.uint16, mode='r')
    print(f"   [{c0}] Sample: {data0[:10]}")

    # Check Last Chunk (Crucial for loop logic verification)
    cL = files[-1]
    dataL = np.memmap(os.path.join(CHUNKS_DIR, cL), dtype=np.uint16, mode='r')

    # Check for "Zero Death" (if the buffer wasn't filled, the end would be all zeros)
    last_1k = dataL[-1000:]
    zeros = np.sum(last_1k == 0)

    print(f"   [{cL}] Last 1000 tokens: {zeros} zeros found.")

    if zeros > 900:
        print("   ‚ö†Ô∏è  WARNING: The end of the last chunk is mostly zeros. The dataset loop might have failed.")
    else:
        print("   ‚úÖ Loop Logic Confirmed: Data appears dense through the end of the last chunk.")

    print("\nüèÅ DATASET STATUS: READY FOR MARATHON.")

if __name__ == "__main__":
    verify()

In [None]:
# @title [Analyze] Marathon Telemetry Analyzer

import sys
import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from google.colab import drive

# --- SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
MARATHON_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_marathon_v3/telemetry_marathon.parquet")
HERO_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_wikitext_hero/telemetry_hero.parquet")
REPORT_DIR = os.path.join(PROJECT_ROOT, "reports")
os.makedirs(REPORT_DIR, exist_ok=True)

print("üïµÔ∏è JANUS MARATHON DIAGNOSTICS")

def load_data(path, label):
    if not os.path.exists(path):
        print(f"‚ùå Missing: {path}")
        return None

    df = pd.read_parquet(path)
    df = df.sort_values('step').drop_duplicates(subset='step', keep='last')
    df['Model'] = label

    # Calculate Sigma A Avg
    sa_cols = [c for c in df.columns if 'sigma_a' in c]
    if sa_cols:
        df['sigma_a_avg'] = df[sa_cols].mean(axis=1)
    else:
        df['sigma_a_avg'] = np.nan

    return df

def run_diagnostics():
    current = load_data(MARATHON_PATH, "Current Marathon")
    reference = load_data(HERO_PATH, "Reference Hero (6k)")

    if current is None: return

    # Combine for plotting if reference exists
    if reference is not None:
        combined = pd.concat([reference, current])
    else:
        combined = current

    print(f"\nüìä Current Status (Step {current['step'].max()}):")
    last_row = current.iloc[-1]
    print(f"   Loss:      {last_row['loss']:.4f}")
    print(f"   Val Loss:  {last_row['val_loss']:.4f}")
    print(f"   Pressure:  {last_row['pressure']:.4f}")
    if not np.isnan(last_row['sigma_a_avg']):
        print(f"   Sigma A:   {last_row['sigma_a_avg']:.4f}")

    # PLOTTING
    fig, axes = plt.subplots(3, 1, figsize=(12, 15), sharex=True)
    palette = {'Reference Hero (6k)': 'gray', 'Current Marathon': 'tab:green'}

    # 1. Validation Loss
    sns.lineplot(data=combined, x='step', y='val_loss', hue='Model', ax=axes[0], palette=palette, linewidth=2)
    axes[0].set_title("Validation Loss Tracking")
    axes[0].grid(True, alpha=0.3)

    # 2. Pressure Schedule
    sns.lineplot(data=combined, x='step', y='pressure', hue='Model', ax=axes[1], palette=palette, linewidth=2, linestyle='--')
    axes[1].set_title("Pressure Schedule Check")
    axes[1].grid(True, alpha=0.3)

    # 3. Orthogonality (Sigma A)
    subset = combined.dropna(subset=['sigma_a_avg'])
    if not subset.empty:
        sns.lineplot(data=subset, x='step', y='sigma_a_avg', hue='Model', ax=axes[2], palette=palette, marker='o')
        axes[2].set_title("Head Redundancy (Sigma A)")
        axes[2].set_ylim(0, 0.02) # Zoom in
        axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    save_path = os.path.join(REPORT_DIR, "marathon_diagnostic.png")
    plt.savefig(save_path)
    print(f"\n‚úÖ Diagnostic Chart Saved: {save_path}")

if __name__ == "__main__":
    run_diagnostics()

In [None]:
# @title [Data Prep] WikiText-2 Tokenization & Preparation
# Janus Engineering Framework - Sub-module 4.1

import os
import numpy as np
from tqdm import tqdm
from transformers import GPT2Tokenizer
from google.colab import drive

# --- 1. SETUP & PATHS ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
# Target directory as specified: Project_XAI_Physical_Janus/data/Wikitext_2/
DATA_DIR = os.path.join(PROJECT_ROOT, "data/Wikitext_2")
os.makedirs(DATA_DIR, exist_ok=True)

print(f"üìö WIKITEXT-2 PREP (JANUS V3 SPECS)")
print(f"üìÇ Output Directory: {DATA_DIR}")

# --- 2. AUTHENTICATION (The "Janus Workaround") ---
TOKEN_FILE = "/content/hf_token.txt"

try:
    from huggingface_hub import login
except ImportError:
    print("üì¶ Installing huggingface_hub...")
    os.system('pip install huggingface_hub')
    from huggingface_hub import login

if os.path.exists(TOKEN_FILE):
    with open(TOKEN_FILE, 'r', encoding='utf-8') as f:
        # üü¢ CRITICAL FIX: Strip Byte Order Marks (BOM) and whitespace for HF Login
        token = f.read().replace('\ufeff', '').strip()

    print(f"üîë Token found (length: {len(token)}). Logging in...")
    try:
        login(token=token)
        print("‚úÖ Authenticated successfully.")
    except Exception as e:
        print(f"‚ùå Login Failed: {e}")
else:
    print(f"‚ö†Ô∏è  WARNING: '{TOKEN_FILE}' not found. Dataset must be public.")

# --- 3. LOAD DATASET ---
try:
    from datasets import load_dataset
except ImportError:
    print("üì¶ Installing datasets library...")
    os.system('pip install datasets')
    from datasets import load_dataset

print("‚¨áÔ∏è  Fetching WikiText-2 via Hugging Face...")
# Standard WikiText-2 raw split for Janus v3 generalization tests
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# --- 4. PROCESSING ---
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

def process_split(hf_split, bin_name):
    bin_path = os.path.join(DATA_DIR, bin_name)

    if os.path.exists(bin_path):
        print(f"‚úÖ {bin_name} already exists. Skipping.")
        return

    print(f"‚öôÔ∏è  Processing {hf_split} -> {bin_name}...")

    texts = dataset[hf_split]['text']
    all_ids = []

    # Process in chunks to manage L4 memory efficiently
    batch_size = 1000
    for i in tqdm(range(0, len(texts), batch_size)):
        batch = texts[i : i + batch_size]
        # Filter empty lines to preserve expressive bandwidth
        valid_batch = [t for t in batch if len(t.strip()) > 0]
        if not valid_batch: continue

        # Join with EOS token to delineate narrative boundaries
        text_chunk = tokenizer.eos_token.join(valid_batch)
        ids = tokenizer.encode(text_chunk)
        all_ids.extend(ids)

    # Save as uint16 for parameter efficiency and scaling compatibility
    arr = np.array(all_ids, dtype=np.uint16)
    arr.tofile(bin_path)
    print(f"   üíæ Saved {len(arr):,} tokens to {bin_path}")

# --- 5. EXECUTION ---
# Updated filenames for WikiText-2 specific training
process_split("train", "W2train.bin")
process_split("validation", "W2val.bin")
process_split("test", "W2test.bin")

print("\nüéâ WikiText-2 Vector Space Homeostasis Ready.")

In [None]:
# @title [Run] Janus v3 Spec Convergence Run


import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY NUKE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
CHUNKS_DIR = os.path.join(PROJECT_ROOT, "data/wikitext_chunks")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_marathon_v3")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üèÉ JANUS MARATHON (FIXED)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 3. CONFIGURATION ---
class MarathonConfig:
    def __init__(self):
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        self.max_steps = 20000
        self.batch_size = 16
        self.grad_accum = 4
        self.max_pressure = 0.15

        self.ckpt_interval = 2000
        self.chunk_steps = 500

        self.lite_interval = 1000
        self.full_interval = 2500

# --- 4. ARCHITECTURE ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, lambdas, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        steer_loss = 0.0
        l_coh, l_div = lambdas
        if l_div > 0.0 and self.training:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div

        metrics = {}
        if return_metrics:
            with torch.no_grad():
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

                sub_out = head_out[:, :, :128, :].transpose(1, 2).reshape(self.n_heads, -1, self.d_head)
                ranks = []
                for h in range(self.n_heads):
                    try:
                        S_vals = torch.linalg.svdvals(sub_out[h].float())
                        p = S_vals / S_vals.sum()
                        ent = -torch.sum(p * torch.log(p + 1e-9))
                        ranks.append(torch.exp(ent))
                    except: ranks.append(torch.tensor(0.0))
                metrics['eff_rank'] = torch.stack(ranks).to(x.device)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, lambdas, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), lambdas, return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, lambdas_list, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)
        total_steer = 0.0
        all_metrics = []
        for i, block in enumerate(self.blocks):
            x, s, m = block(x, lambdas_list[i], return_metrics)
            total_steer += s
            if return_metrics: all_metrics.append(m)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss, total_steer, all_metrics

# --- 5. LOGIC CONTROL ---
class FlightController:
    def __init__(self, config):
        self.config = config
    def get_pressure(self, step):
        if step < 1500: return 0.0
        elif step < 3000: return self.config.max_pressure * ((step - 1500) / 1500)
        elif step < 4500: return self.config.max_pressure
        elif step < 7500:
            progress = min((step - 4500) / 3000, 1.0)
            return self.config.max_pressure * 0.5 * (1 + math.cos(math.pi * progress))
        else: return 0.0
    def get_lambdas(self, step):
        p = self.get_pressure(step)
        base_coh = p * 0.2
        lambdas = []
        for i in range(self.config.n_layers):
            ratio = (i + 1) / self.config.n_layers
            s_mult = ratio ** 3
            lambdas.append((base_coh * s_mult, p * s_mult))
        return lambdas, p

class ChunkLoader:
    def __init__(self, config):
        self.config = config
        self.current_chunk_idx = -1
        self.data = None
        self.val_data = np.memmap(os.path.join(PROJECT_ROOT, "data/wikitext/val.bin"), dtype=np.uint16, mode='r')
    def load_for_step(self, step):
        target_chunk = step // self.config.chunk_steps
        if target_chunk != self.current_chunk_idx:
            fname = f"train_chunk_{target_chunk:03d}.bin"
            fpath = os.path.join(CHUNKS_DIR, fname)
            if not os.path.exists(fpath):
                fallback_idx = target_chunk % 40
                fpath = os.path.join(CHUNKS_DIR, f"train_chunk_{fallback_idx:03d}.bin")
            with open(fpath, 'rb') as f:
                self.data = np.frombuffer(f.read(), dtype=np.uint16)
            self.current_chunk_idx = target_chunk
    def get_batch(self, batch_size):
        ix = torch.randint(len(self.data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)
    def get_val_batch(self, batch_size):
        ix = torch.randint(len(self.val_data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.val_data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.val_data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
        self.start_step = 0 # To filter garbage
    def set_start_step(self, step):
        self.start_step = step
    def log(self, step, loss, val_loss, pressure, metrics):
        # üõ°Ô∏è TELEMETRY FILTER: Don't log garbage steps from restart attempts
        if step < self.start_step: return

        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": pressure
        }
        if metrics and len(metrics) > 0:
            for i, m in enumerate(metrics):
                if not m: continue
                if 'sigma_a' in m: row[f"L{i}_sigma_a"] = m['sigma_a'].mean().item()
        self.buffer.append(row)
    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_marathon.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 6. MAIN ---
def run_marathon():
    gc.collect(); torch.cuda.empty_cache()

    cfg = MarathonConfig()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6.29e-4, weight_decay=1.34e-4)
    scheduler = FlightController(cfg)
    loader = ChunkLoader(cfg)
    recorder = BlackBox(SAVE_DIR)

    start_step = 0
    ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]

    if ckpts:
        ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))
        latest = ckpts[-1]
        ckpt_path = os.path.join(SAVE_DIR, latest)
        print(f"üîÑ Found checkpoints. Loading {latest}...")

        # üü¢ CRITICAL SAFETY: No try/except. Let it crash if corrupt.
        c = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(c['model'])
        optimizer.load_state_dict(c['optim'])

        # üü¢ CRITICAL FIX: Ensure RNG states are on CPU
        cpu_rng = c['rng_cpu'].cpu() if c['rng_cpu'].is_cuda else c['rng_cpu']
        gpu_rng = c['rng_gpu'].cpu() if c['rng_gpu'].is_cuda else c['rng_gpu']

        torch.set_rng_state(cpu_rng)
        torch.cuda.set_rng_state(gpu_rng)

        start_step = c['step'] + 1
        recorder.set_start_step(start_step) # Tell logger to ignore steps < start_step
        print(f"‚úÖ State Restored (Step {start_step}). RNG Verified.")

    print(f"\nüèÉ STARTING MARATHON: {start_step} -> {cfg.max_steps}")
    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    for step in pbar:
        loader.load_for_step(step)
        model.train()
        batch_loss = 0.0
        lambdas, pressure = scheduler.get_lambdas(step)
        optimizer.zero_grad()
        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch(cfg.batch_size)
            is_full = (step % cfg.full_interval == 0) and (_ == cfg.grad_accum - 1)
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, steer, metrics = model(x, lambdas, y, return_metrics=is_full)
                total = (loss + steer) / cfg.grad_accum
            total.backward()
            batch_loss += loss.item()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if step > 0 and (step % cfg.lite_interval == 0 or step == cfg.max_steps - 1):
            gc.collect(); torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(20):
                    vx, vy = loader.get_val_batch(cfg.batch_size)
                    vl, _, _ = model(vx, [(0.0,0.0)]*cfg.n_layers, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)
            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, pressure, metrics)
            recorder.flush()
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|V:{val_loss:.3f}|P:{pressure:.3f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|P:{pressure:.3f}")

        if step > 0 and step % cfg.ckpt_interval == 0:
            ckpt_name = f"ckpt_step_{step}.pt"
            ckpt_path = os.path.join(SAVE_DIR, ckpt_name)
            save_dict = {
                'step': step,
                'model': model.state_dict(),
                'optim': optimizer.state_dict(),
                'rng_cpu': torch.get_rng_state(),
                'rng_gpu': torch.cuda.get_rng_state()
            }
            torch.save(save_dict, ckpt_path)
            all_ckpts = sorted([f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_")])
            if len(all_ckpts) > 3: os.remove(os.path.join(SAVE_DIR, all_ckpts[0]))

    final_path = os.path.join(SAVE_DIR, "janus_v3_marathon_final.pt")
    torch.save(model.state_dict(), final_path)
    print(f"\nüèÜ MARATHON COMPLETE. Saved to {final_path}")

if __name__ == "__main__":
    run_marathon()

In [None]:
# @title [Run] Janus 20k Baseline


import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY NUKE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
CHUNKS_DIR = os.path.join(PROJECT_ROOT, "data/wikitext_chunks")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_marathon_baseline") # New Folder
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üìâ JANUS MARATHON BASELINE (Control Group - 20k)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 3. CONFIGURATION ---
class MarathonConfig:
    def __init__(self):
        # Architecture (Identical)
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        # Training (Identical)
        self.max_steps = 20000
        self.batch_size = 16
        self.grad_accum = 4

        # Checkpointing (Reinstated)
        self.ckpt_interval = 2000
        self.chunk_steps = 500

        # Telemetry (Identical Schedule)
        self.lite_interval = 1000
        self.full_interval = 2500

# --- 4. ARCHITECTURE (Identical - No Steering) ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        # NO STEERING LOGIC - BASELINE

        metrics = {}
        if return_metrics:
            with torch.no_grad():
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

                sub_out = head_out[:, :, :128, :].transpose(1, 2).reshape(self.n_heads, -1, self.d_head)
                ranks = []
                for h in range(self.n_heads):
                    try:
                        S_vals = torch.linalg.svdvals(sub_out[h].float())
                        p = S_vals / S_vals.sum()
                        ent = -torch.sum(p * torch.log(p + 1e-9))
                        ranks.append(torch.exp(ent))
                    except: ranks.append(torch.tensor(0.0))
                metrics['eff_rank'] = torch.stack(ranks).to(x.device)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, return_metrics=False):
        a, m = self.attn(self.ln1(x), return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)
        all_metrics = []
        for block in self.blocks:
            x, m = block(x, return_metrics)
            if return_metrics: all_metrics.append(m)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss, all_metrics

# --- 5. DATA & TELEMETRY ---
class ChunkLoader:
    def __init__(self, config):
        self.config = config
        self.current_chunk_idx = -1
        self.data = None
        self.val_data = np.memmap(os.path.join(PROJECT_ROOT, "data/wikitext/val.bin"), dtype=np.uint16, mode='r')

    def load_for_step(self, step):
        target_chunk = step // self.config.chunk_steps
        if target_chunk != self.current_chunk_idx:
            fname = f"train_chunk_{target_chunk:03d}.bin"
            fpath = os.path.join(CHUNKS_DIR, fname)
            if not os.path.exists(fpath):
                # Fallback wrap
                fallback_idx = target_chunk % 40
                fpath = os.path.join(CHUNKS_DIR, f"train_chunk_{fallback_idx:03d}.bin")
            with open(fpath, 'rb') as f:
                self.data = np.frombuffer(f.read(), dtype=np.uint16)
            self.current_chunk_idx = target_chunk

    def get_batch(self, batch_size):
        ix = torch.randint(len(self.data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

    def get_val_batch(self, batch_size):
        ix = torch.randint(len(self.val_data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.val_data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.val_data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
        self.start_step = 0
    def set_start_step(self, step):
        self.start_step = step
    def log(self, step, loss, val_loss, metrics):
        if step < self.start_step: return
        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": 0.0 # Constant for baseline
        }
        if metrics and len(metrics) > 0:
            for i, m in enumerate(metrics):
                if not m: continue
                if 'sigma_a' in m: row[f"L{i}_sigma_a"] = m['sigma_a'].mean().item()
        self.buffer.append(row)
    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_baseline_20k.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 6. MAIN ---
def run_baseline_20k():
    gc.collect(); torch.cuda.empty_cache()

    cfg = MarathonConfig()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6.29e-4, weight_decay=1.34e-4)
    loader = ChunkLoader(cfg)
    recorder = BlackBox(SAVE_DIR)

    start_step = 0
    ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]

    if ckpts:
        ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))
        latest = ckpts[-1]
        ckpt_path = os.path.join(SAVE_DIR, latest)
        print(f"üîÑ Found checkpoints. Loading {latest}...")

        c = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(c['model'])
        optimizer.load_state_dict(c['optim'])

        cpu_rng = c['rng_cpu'].cpu() if c['rng_cpu'].is_cuda else c['rng_cpu']
        gpu_rng = c['rng_gpu'].cpu() if c['rng_gpu'].is_cuda else c['rng_gpu']
        torch.set_rng_state(cpu_rng)
        torch.cuda.set_rng_state(gpu_rng)

        start_step = c['step'] + 1
        recorder.set_start_step(start_step)
        print(f"‚úÖ State Restored (Step {start_step}).")

    print(f"\nüèÉ STARTING BASELINE 20K: {start_step} -> {cfg.max_steps}")
    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    for step in pbar:
        loader.load_for_step(step)
        model.train()
        batch_loss = 0.0

        optimizer.zero_grad()
        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch(cfg.batch_size)
            is_full = (step % cfg.full_interval == 0) and (_ == cfg.grad_accum - 1)

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, metrics = model(x, y, return_metrics=is_full)
                total = loss / cfg.grad_accum
            total.backward()
            batch_loss += loss.item()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Telemetry
        if step > 0 and (step % cfg.lite_interval == 0 or step == cfg.max_steps - 1):
            gc.collect(); torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(20):
                    vx, vy = loader.get_val_batch(cfg.batch_size)
                    vl, _ = model(vx, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)
            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, metrics)
            recorder.flush()
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|V:{val_loss:.3f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}")

        # Checkpoints (Every 2k)
        if step > 0 and step % cfg.ckpt_interval == 0:
            ckpt_name = f"ckpt_step_{step}.pt"
            ckpt_path = os.path.join(SAVE_DIR, ckpt_name)
            save_dict = {
                'step': step,
                'model': model.state_dict(),
                'optim': optimizer.state_dict(),
                'rng_cpu': torch.get_rng_state(),
                'rng_gpu': torch.cuda.get_rng_state()
            }
            torch.save(save_dict, ckpt_path)
            # Retention
            all_ckpts = sorted([f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_")])
            if len(all_ckpts) > 3: os.remove(os.path.join(SAVE_DIR, all_ckpts[0]))

    final_path = os.path.join(SAVE_DIR, "janus_v3_baseline_20k.pt")
    torch.save(model.state_dict(), final_path)
    print(f"\nüèÜ BASELINE COMPLETE. Saved to {final_path}")

if __name__ == "__main__":
    run_baseline_20k()

In [None]:
# @title [Run] Janus 20k  Constant Pressure Control


import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY NUKE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
CHUNKS_DIR = os.path.join(PROJECT_ROOT, "data/wikitext_chunks")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_marathon_constant")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üß± JANUS MARATHON CONSTANT (FIXED)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 3. CONFIGURATION ---
class MarathonConfig:
    def __init__(self):
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        self.max_steps = 20000
        self.batch_size = 16
        self.grad_accum = 4
        self.max_pressure = 0.15

        self.ckpt_interval = 2000
        self.chunk_steps = 500

        self.lite_interval = 1000
        self.full_interval = 2500

# --- 4. ARCHITECTURE ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, lambdas, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        steer_loss = 0.0
        l_coh, l_div = lambdas

        if l_div > 0.0 and self.training:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)
            steer_loss += torch.norm(gram - identity, p='fro') * l_div

        metrics = {}
        if return_metrics:
            with torch.no_grad():
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

                sub_out = head_out[:, :, :128, :].transpose(1, 2).reshape(self.n_heads, -1, self.d_head)
                ranks = []
                for h in range(self.n_heads):
                    try:
                        S_vals = torch.linalg.svdvals(sub_out[h].float())
                        p = S_vals / S_vals.sum()
                        ent = -torch.sum(p * torch.log(p + 1e-9))
                        ranks.append(torch.exp(ent))
                    except: ranks.append(torch.tensor(0.0))
                metrics['eff_rank'] = torch.stack(ranks).to(x.device)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, lambdas, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), lambdas, return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, lambdas_list, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)
        total_steer = 0.0
        all_metrics = []
        for i, block in enumerate(self.blocks):
            x, s, m = block(x, lambdas_list[i], return_metrics)
            total_steer += s
            if return_metrics: all_metrics.append(m)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss, total_steer, all_metrics

# --- 5. LOGIC CONTROL ---
class FlightController:
    def __init__(self, config):
        self.config = config
    def get_pressure(self, step):
        if step < 1500: return self.config.max_pressure * (step / 1500)
        else: return self.config.max_pressure
    def get_lambdas(self, step):
        p = self.get_pressure(step)
        base_coh = p * 0.2
        lambdas = []
        for i in range(self.config.n_layers):
            ratio = (i + 1) / self.config.n_layers
            s_mult = ratio ** 3
            lambdas.append((base_coh * s_mult, p * s_mult))
        return lambdas, p

class ChunkLoader:
    def __init__(self, config):
        self.config = config
        self.current_chunk_idx = -1
        self.data = None
        self.val_data = np.memmap(os.path.join(PROJECT_ROOT, "data/wikitext/val.bin"), dtype=np.uint16, mode='r')
    def load_for_step(self, step):
        target_chunk = step // self.config.chunk_steps
        if target_chunk != self.current_chunk_idx:
            fname = f"train_chunk_{target_chunk:03d}.bin"
            fpath = os.path.join(CHUNKS_DIR, fname)
            if not os.path.exists(fpath):
                fallback_idx = target_chunk % 40
                fpath = os.path.join(CHUNKS_DIR, f"train_chunk_{fallback_idx:03d}.bin")
            with open(fpath, 'rb') as f:
                self.data = np.frombuffer(f.read(), dtype=np.uint16)
            self.current_chunk_idx = target_chunk
    def get_batch(self, batch_size):
        ix = torch.randint(len(self.data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)
    def get_val_batch(self, batch_size):
        ix = torch.randint(len(self.val_data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.val_data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.val_data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
        self.start_step = 0
    def set_start_step(self, step):
        self.start_step = step
    def log(self, step, loss, val_loss, pressure, metrics):
        if step < self.start_step: return
        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": pressure
        }
        if metrics and len(metrics) > 0:
            for i, m in enumerate(metrics):
                if not m: continue
                if 'sigma_a' in m: row[f"L{i}_sigma_a"] = m['sigma_a'].mean().item()
        self.buffer.append(row)
    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_constant.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 6. MAIN ---
def run_constant_fixed():
    gc.collect(); torch.cuda.empty_cache()

    cfg = MarathonConfig()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6.29e-4, weight_decay=1.34e-4)
    scheduler = FlightController(cfg)
    loader = ChunkLoader(cfg)
    recorder = BlackBox(SAVE_DIR)

    start_step = 0
    ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]

    if ckpts:
        # üü¢ CRITICAL FIX: Sort by INTEGER step, not string
        ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))
        latest = ckpts[-1]
        ckpt_path = os.path.join(SAVE_DIR, latest)
        print(f"üîÑ Found checkpoints. Loading {latest}...")

        c = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(c['model'])
        optimizer.load_state_dict(c['optim'])

        cpu_rng = c['rng_cpu'].cpu() if c['rng_cpu'].is_cuda else c['rng_cpu']
        gpu_rng = c['rng_gpu'].cpu() if c['rng_gpu'].is_cuda else c['rng_gpu']
        torch.set_rng_state(cpu_rng)
        torch.cuda.set_rng_state(gpu_rng)

        start_step = c['step'] + 1
        recorder.set_start_step(start_step)
        print(f"‚úÖ State Restored (Step {start_step}).")

    print(f"\nüèÉ STARTING MARATHON (CONSTANT 0.15): {start_step} -> {cfg.max_steps}")
    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    for step in pbar:
        loader.load_for_step(step)
        model.train()
        batch_loss = 0.0
        lambdas, pressure = scheduler.get_lambdas(step)
        optimizer.zero_grad()
        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch(cfg.batch_size)
            is_full = (step % cfg.full_interval == 0) and (_ == cfg.grad_accum - 1)
            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, steer, metrics = model(x, lambdas, y, return_metrics=is_full)
                total = (loss + steer) / cfg.grad_accum
            total.backward()
            batch_loss += loss.item()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if step > 0 and (step % cfg.lite_interval == 0 or step == cfg.max_steps - 1):
            gc.collect(); torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(20):
                    vx, vy = loader.get_val_batch(cfg.batch_size)
                    vl, _, _ = model(vx, [(0.0,0.0)]*cfg.n_layers, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)
            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, pressure, metrics)
            recorder.flush()
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|V:{val_loss:.3f}|P:{pressure:.3f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.3f}|P:{pressure:.3f}")

        # --- CHECKPOINT RETENTION LOGIC (FIXED) ---
        if step > 0 and step % cfg.ckpt_interval == 0:
            ckpt_name = f"ckpt_step_{step}.pt"
            ckpt_path = os.path.join(SAVE_DIR, ckpt_name)
            save_dict = {
                'step': step,
                'model': model.state_dict(),
                'optim': optimizer.state_dict(),
                'rng_cpu': torch.get_rng_state(),
                'rng_gpu': torch.cuda.get_rng_state()
            }
            torch.save(save_dict, ckpt_path)

            # üü¢ FIX: Sort by INTEGER value to correctly identify oldest
            all_ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]
            all_ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))

            if len(all_ckpts) > 3:
                oldest = all_ckpts[0]
                os.remove(os.path.join(SAVE_DIR, oldest))

    final_path = os.path.join(SAVE_DIR, "janus_v3_marathon_constant.pt")
    torch.save(model.state_dict(), final_path)
    print(f"\nüèÜ CONSTANT RUN COMPLETE. Saved to {final_path}")

if __name__ == "__main__":
    run_constant_fixed()

In [None]:
# @title [Run] 6k vs 20k Inference Test


import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from transformers import GPT2Tokenizer
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
HERO_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_wikitext_hero/janus_v3_hero.pt")
MARATHON_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_marathon_v3/janus_v3_marathon_final.pt")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚öîÔ∏è MARATHON VERIFICATION: 6k HERO vs. 20k MARATHON")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 2. CONFIGURATION ---
class WikiConfig:
    def __init__(self):
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.0

# --- 3. ARCHITECTURE (Janus v3) ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        y = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.o_proj(y)

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight

    def forward(self, idx):
        B, S = idx.shape
        x = self.token_emb(idx)
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        return self.head(x)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=0.7, top_k=50):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -512:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

# --- 4. LOAD WEIGHTS ---
cfg = WikiConfig()
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

print("üîÑ Loading 6k Hero (Reference)...")
model_hero = NewGPT(cfg).to(DEVICE)
try:
    model_hero.load_state_dict(torch.load(HERO_PATH, map_location=DEVICE))
    model_hero.eval()
except:
    print("‚ö†Ô∏è  Hero 6k not found. Skipping.")
    model_hero = None

print("üîÑ Loading 20k Marathon (Target)...")
model_marathon = NewGPT(cfg).to(DEVICE)
if os.path.exists(MARATHON_PATH):
    model_marathon.load_state_dict(torch.load(MARATHON_PATH, map_location=DEVICE))
    model_marathon.eval()
else:
    print(f"‚ùå CRITICAL: Marathon model not found at {MARATHON_PATH}")
    exit(1)

# --- 5. THE TEST ---
prompts = [
    "The Roman Empire was",
    "In the early 19th century, the",
    "The chemical formula for water is",
    "Following the release of the album,",
    "Located in the northern part of"
]

print("\n" + "="*60)
print("üß™ INFERENCE CHECK: DID THE SNAPBACK BREAK IT?")
print("="*60)

for p in prompts:
    input_ids = tokenizer.encode(p, return_tensors='pt').to(DEVICE)

    # Generate Hero (if avail)
    text_h = "N/A"
    if model_hero:
        out_h = model_hero.generate(input_ids, max_new_tokens=60, temperature=0.6)
        text_h = tokenizer.decode(out_h[0].tolist(), skip_special_tokens=True)

    # Generate Marathon
    out_m = model_marathon.generate(input_ids, max_new_tokens=60, temperature=0.6)
    text_m = tokenizer.decode(out_m[0].tolist(), skip_special_tokens=True)

    print(f"\nüìù PROMPT: {p}")
    print("-" * 20)
    if model_hero:
        print(f"üîµ HERO (6k):   {text_h[len(p):].strip()}")
    print("-" * 20)
    print(f"üü¢ MARATHON (20k): {text_m[len(p):].strip()}")
    print("="*60)

In [None]:
# @title [RUN] Inference Constant

import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import GPT2Tokenizer
from google.colab import drive

# --- 1. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
CONSTANT_DIR = os.path.join(PROJECT_ROOT,")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"ü©∫ JANUS CONSTANT RUN: AUDIT (FIXED)")
print(f"   Target: {CONSTANT_DIR}")
print(f"   Device: {DEVICE}")

# --- 2. ARCHITECTURE ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(attn, dim=-1)
        y = (attn @ v).transpose(1, 2).reshape(B, S, D)
        return self.o_proj(y)

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight

    def forward(self, idx):
        B, S = idx.shape
        x = self.token_emb(idx)
        for block in self.blocks: x = block(x)
        x = self.ln_f(x)
        return self.head(x)

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=0.7, top_k=50):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -512:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

class WikiConfig:
    def __init__(self):
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.0

# --- 3. EXECUTION ---
def audit_run():
    # A. Find Checkpoint
    ckpts = [f for f in os.listdir(CONSTANT_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]
    if not ckpts:
        print("‚ùå No checkpoints found!")
        return

    ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))
    latest_ckpt = ckpts[-1]
    ckpt_step = int(latest_ckpt.split('_')[2].split('.')[0])

    print(f"\nüìÇ Loading Latest: {latest_ckpt} (Step {ckpt_step})")

    # B. Load Model
    cfg = WikiConfig()
    model = NewGPT(cfg).to(DEVICE)
    ckpt = torch.load(os.path.join(CONSTANT_DIR, latest_ckpt), map_location=DEVICE)

    try:
        model.load_state_dict(ckpt['model'], strict=False)
    except Exception as e:
        print(f"‚ö†Ô∏è  Strict load failed, trying loose: {e}")
        model.load_state_dict(ckpt['model'], strict=False)

    model.eval()
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    # C. Run Inference
    prompts = [
        "The Roman Empire was",
        "In the early 19th century, the",
        "The chemical formula for water is",
        "Following the release of the album,",
        "Located in the northern part of"
    ]

    print("\n" + "="*60)
    print(f"üß™ INFERENCE CHECK (Step {ckpt_step})")
    print("="*60)

    for p in prompts:
        input_ids = tokenizer.encode(p, return_tensors='pt').to(DEVICE)
        out = model.generate(input_ids, max_new_tokens=60, temperature=0.6)
        text = tokenizer.decode(out[0].tolist(), skip_special_tokens=True)
        print(f"üìù {text.strip()}\n" + "-"*30)

    # D. Chart Telemetry (FIXED)
    telemetry_path = os.path.join(CONSTANT_DIR, "telemetry_constant.parquet")
    if not os.path.exists(telemetry_path):
        print("‚ö†Ô∏è No telemetry parquet found.")
        return

    print(f"\nüìä Generating Telemetry Report...")
    df = pd.read_parquet(telemetry_path)

    # --- SANITIZATION BLOCK ---
    df = df.sort_values('step')
    df = df.drop_duplicates(subset='step', keep='last') # Remove logging overlap
    df = df.reset_index(drop=True) # Reset index to avoid collisions
    # --------------------------

    sa_cols = [c for c in df.columns if 'sigma_a' in c]
    df['sigma_a_avg'] = df[sa_cols].mean(axis=1) if sa_cols else 0.0

    fig, axes = plt.subplots(3, 1, figsize=(12, 14), sharex=True)

    # 1. Loss
    sns.lineplot(data=df, x='step', y='loss', ax=axes[0], label='Train Loss', color='tab:blue', alpha=0.5)
    sns.lineplot(data=df, x='step', y='val_loss', ax=axes[0], label='Val Loss', color='tab:orange', linewidth=2)
    axes[0].set_title(f"Loss Trajectory (Current Val: {df['val_loss'].iloc[-1]:.3f})")
    axes[0].grid(True, alpha=0.3)

    # 2. Sigma A
    sns.lineplot(data=df, x='step', y='sigma_a_avg', ax=axes[1], color='tab:red', linewidth=2)
    axes[1].set_title(f"Head Redundancy (Current Avg: {df['sigma_a_avg'].iloc[-1]:.4f})")
    axes[1].set_ylabel("Sigma A (Lower is Better)")
    axes[1].grid(True, alpha=0.3)
    axes[1].axhline(0.003, color='green', linestyle='--', label='Orthogonality Target')

    # 3. Layer Redundancy
    if sa_cols:
        df_melt = df.melt(id_vars=['step'], value_vars=sa_cols, var_name='Layer', value_name='Sigma A')
        sns.lineplot(data=df_melt, x='step', y='Sigma A', hue='Layer', ax=axes[2], palette='viridis', legend=False)
        axes[2].set_title("Layer-wise Redundancy Dynamics")
        axes[2].grid(True, alpha=0.3)

    plt.tight_layout()
    save_path = os.path.join(CONSTANT_DIR, f"status_report_step_{ckpt_step}.png")
    plt.savefig(save_path)
    print(f"‚úÖ Report saved to: {save_path}")

if __name__ == "__main__":
    audit_run()

In [None]:
# @title [Run] Adaptive Scheduler 3700 steps


import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY NUKE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
CHUNKS_DIR = os.path.join(PROJECT_ROOT, "data/wikitext_chunks")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_adaptive_v1")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üß† JANUS ADAPTIVE (Homeostatic Control)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 3. CONFIGURATION ---
class AdaptiveConfig:
    def __init__(self):
        # Model
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        # Training
        self.max_steps = 3700
        self.batch_size = 16
        self.grad_accum = 4

        # Adaptive Controller
        self.target_sigma = 0.0035
        self.k_p = 0.5  # Gain
        self.control_interval = 50 # Update every 50 steps
        self.warmup_steps = 1500   # Ignition phase

        # IO
        self.ckpt_interval = 2000
        self.chunk_steps = 500
        self.lite_interval = 100

# --- 4. THE HOMEOSTATIC CONTROLLER ---
class Homeostat:
    def __init__(self, config):
        self.target = config.target_sigma
        self.kp = config.k_p
        self.warmup = config.warmup_steps
        self.current_lambda = 0.0
        self.history = []

    def update(self, step, current_sigma):
        # Phase 1: Ignition (Open Loop Ramp)
        if step < self.warmup:
            # Ramp from 0.00 to 0.05
            self.current_lambda = 0.05 * (step / self.warmup)
            return self.current_lambda

        # Phase 2: Homeostasis (Closed Loop P-Control)
        if current_sigma is None: return self.current_lambda

        error = current_sigma - self.target
        # CORRECTION: Positive Error (Too Redundant) -> INCREASE Pressure
        delta = self.kp * error

        # Apply & Clamp
        self.current_lambda += delta
        # Hard limits to prevent explosion or negative pressure
        self.current_lambda = max(0.0, min(0.25, self.current_lambda))

        return self.current_lambda

# --- 5. ARCHITECTURE (Standard Janus v3) ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, lambdas, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        steer_loss = 0.0
        l_coh, l_div = lambdas

        # Calculate steering if needed
        # We need this during training AND during control steps
        if (l_div > 0.0 and self.training) or return_metrics:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)

            if l_div > 0.0:
                steer_loss += torch.norm(gram - identity, p='fro') * l_div

        metrics = {}
        if return_metrics:
            with torch.no_grad():
                # Re-calc gram for metrics just to be safe/clean
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, lambdas, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), lambdas, return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, lambdas_list, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)
        total_steer = 0.0
        all_metrics = []
        for i, block in enumerate(self.blocks):
            x, s, m = block(x, lambdas_list[i], return_metrics)
            total_steer += s
            if return_metrics: all_metrics.append(m)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss, total_steer, all_metrics

# --- 6. DATA & LOGGING ---
class ChunkLoader:
    def __init__(self, config):
        self.config = config
        self.current_chunk_idx = -1
        self.data = None
        self.val_data = np.memmap(os.path.join(PROJECT_ROOT, "data/wikitext/val.bin"), dtype=np.uint16, mode='r')
    def load_for_step(self, step):
        target_chunk = step // self.config.chunk_steps
        if target_chunk != self.current_chunk_idx:
            fname = f"train_chunk_{target_chunk:03d}.bin"
            fpath = os.path.join(CHUNKS_DIR, fname)
            if not os.path.exists(fpath):
                fallback_idx = target_chunk % 40
                fpath = os.path.join(CHUNKS_DIR, f"train_chunk_{fallback_idx:03d}.bin")
            with open(fpath, 'rb') as f:
                self.data = np.frombuffer(f.read(), dtype=np.uint16)
            self.current_chunk_idx = target_chunk
    def get_batch(self, batch_size):
        ix = torch.randint(len(self.data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)
    def get_val_batch(self, batch_size):
        ix = torch.randint(len(self.val_data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.val_data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.val_data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
        self.start_step = 0
    def set_start_step(self, step):
        self.start_step = step
    def log(self, step, loss, val_loss, pressure, metrics):
        if step < self.start_step: return

        # Calculate mean sigma_a across layers for logging
        avg_sigma = 0.0
        if metrics:
            sigmas = [m['sigma_a'].mean().item() for m in metrics if 'sigma_a' in m]
            if sigmas: avg_sigma = sum(sigmas) / len(sigmas)

        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": pressure,
            "sigma_a_avg": avg_sigma
        }
        self.buffer.append(row)
    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_adaptive.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 7. RUN LOOP ---
def run_adaptive():
    gc.collect(); torch.cuda.empty_cache()

    cfg = AdaptiveConfig()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6.29e-4, weight_decay=1.34e-4)
    controller = Homeostat(cfg)
    loader = ChunkLoader(cfg)
    recorder = BlackBox(SAVE_DIR)

    start_step = 0
    # START FRESH OR RESUME ADAPTIVE
    # We do NOT recommend grafting 4k Constant here,
    # better to start fresh 0-1500 warm-up to test the whole curve cleanly.

    ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]
    if ckpts:
        ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))
        latest = ckpts[-1]
        ckpt_path = os.path.join(SAVE_DIR, latest)
        print(f"üîÑ Resuming Adaptive Run from {latest}...")
        c = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(c['model'])
        optimizer.load_state_dict(c['optim'])
        # Restore Controller State
        controller.current_lambda = c.get('current_lambda', 0.0)
        start_step = c['step'] + 1
        recorder.set_start_step(start_step)

    print(f"\nüß† STARTING ADAPTIVE RUN: {start_step} -> {cfg.max_steps}")
    print(f"   Target Sigma: {cfg.target_sigma}")
    print(f"   Warmup Steps: {cfg.warmup_steps}")

    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    current_sigma_avg = None

    for step in pbar:
        loader.load_for_step(step)

        # --- A. CONTROL UPDATE ---
        # Update lambda based on LAST step's sigma metrics
        if step % cfg.control_interval == 0 or step < cfg.warmup_steps:
             new_lambda = controller.update(step, current_sigma_avg)

        # Apply lambda to layers
        # Simplified: Same lambda for div, 20% of that for coh
        p = controller.current_lambda
        base_coh = p * 0.2
        lambdas_list = []
        for i in range(cfg.n_layers):
            # Ratio scaling still applies? Yes, keeps layers distinct.
            ratio = (i + 1) / cfg.n_layers
            s_mult = ratio ** 3
            lambdas_list.append((base_coh * s_mult, p * s_mult))

        # --- B. TRAINING STEP ---
        model.train()
        batch_loss = 0.0
        optimizer.zero_grad()

        accum_sigmas = [] # Track sigma during accum for controller

        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch(cfg.batch_size)

            # Check if we need metrics for logging OR control
            # We need metrics every 'control_interval' to feed the controller next step
            is_control_step = (step % cfg.control_interval == 0)
            is_log_step = (step % cfg.lite_interval == 0)
            return_metrics = is_control_step or is_log_step

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, steer, metrics = model(x, lambdas_list, y, return_metrics=return_metrics)
                total = (loss + steer) / cfg.grad_accum

            total.backward()
            batch_loss += loss.item()

            if return_metrics and metrics:
                # Extract mean sigma for controller
                s = [m['sigma_a'].mean().item() for m in metrics]
                if s: accum_sigmas.append(sum(s)/len(s))

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Update Sigma Average for Next Control Step
        if accum_sigmas:
            current_sigma_avg = sum(accum_sigmas) / len(accum_sigmas)

        # --- C. TELEMETRY ---
        if step > 0 and (step % cfg.lite_interval == 0 or step == cfg.max_steps - 1):
            gc.collect(); torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(10): # Reduced val batches for speed
                    vx, vy = loader.get_val_batch(cfg.batch_size)
                    vl, _, _ = model(vx, [(0.0,0.0)]*cfg.n_layers, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)

            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, controller.current_lambda, metrics)
            recorder.flush()

            # Update Progress Bar with Adaptive Info
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.2f}|V:{val_loss:.2f}|P:{controller.current_lambda:.3f}|Sig:{current_sigma_avg:.4f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.2f}|P:{controller.current_lambda:.3f}|Sig:{current_sigma_avg if current_sigma_avg else 0:.4f}")

        # --- D. CHECKPOINTS ---
        if step > 0 and step % cfg.ckpt_interval == 0:
            ckpt_name = f"ckpt_step_{step}.pt"
            ckpt_path = os.path.join(SAVE_DIR, ckpt_name)
            save_dict = {
                'step': step,
                'model': model.state_dict(),
                'optim': optimizer.state_dict(),
                'current_lambda': controller.current_lambda, # Save controller state
                'rng_cpu': torch.get_rng_state(),
                'rng_gpu': torch.cuda.get_rng_state()
            }
            torch.save(save_dict, ckpt_path)

            all_ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]
            all_ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))
            if len(all_ckpts) > 3: os.remove(os.path.join(SAVE_DIR, all_ckpts[0]))

    final_path = os.path.join(SAVE_DIR, "janus_adaptive_final.pt")
    torch.save(model.state_dict(), final_path)
    print(f"\nüèÜ ADAPTIVE RUN COMPLETE. Saved to {final_path}")

if __name__ == "__main__":
    run_adaptive()

In [None]:
# @title üìä Janus Telemetry Audit (Sanitized)
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
from google.colab import drive

# --- 1. FORCE MOUNT ---
if not os.path.exists('/content/drive'):
    print("üîå Mounting Google Drive...")
    drive.mount('/content/drive')
else:
    print("‚úÖ Drive already mounted.")

# --- CONFIGURATION ---
PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_adaptive_v1")
TELEMETRY_FILENAME = "telemetry_adaptive.parquet"
TELEMETRY_PATH = os.path.join(SAVE_DIR, TELEMETRY_FILENAME)

EPOCH_STEPS = 3051

# --- 2. VERIFY & LOAD ---
if not os.path.exists(TELEMETRY_PATH):
    print(f"\n‚ùå CRITICAL: File not found at: {TELEMETRY_PATH}")
    print("üõë Aborting.")
else:
    print(f"‚úÖ Found telemetry file!")
    df = pd.read_parquet(TELEMETRY_PATH)

    # --- 3. SANITATION (The Fix) ---
    # The resume logic created duplicate steps. We must remove them.
    original_len = len(df)

    # Sort by step to ensure order
    df = df.sort_values('step')

    # Drop duplicates, keeping the LAST entry (most recent run data)
    df = df.drop_duplicates(subset=['step'], keep='last')

    # Reset the index to be perfectly sequential and unique
    df = df.reset_index(drop=True)

    dropped_count = original_len - len(df)
    if dropped_count > 0:
        print(f"üßπ SANITIZED: Removed {dropped_count} duplicate rows caused by run resumption.")

    # --- 4. PRE-PROCESSING ---
    # Now safe to calculate rolling averages
    df['loss_smooth'] = df['loss'].rolling(window=5, center=True).mean()

    # Identify Global Minima
    min_val_row = df.loc[df['val_loss'].idxmin()]
    min_val_step = min_val_row['step']
    min_val_loss = min_val_row['val_loss']

    # --- 5. VISUALIZATION ---
    plt.style.use('dark_background')
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 12), sharex=True)

    # PLOT 1: Convergence
    sns.lineplot(data=df, x='step', y='loss', ax=ax1, label='Raw Train Loss', alpha=0.3, color='cyan')
    sns.lineplot(data=df, x='step', y='loss_smooth', ax=ax1, label='Smoothed Train', color='cyan', linewidth=2)
    sns.lineplot(data=df, x='step', y='val_loss', ax=ax1, label='Validation Loss', color='magenta', linewidth=2, marker='o')

    # Markers
    ax1.axvline(x=EPOCH_STEPS, color='yellow', linestyle='--', alpha=0.6, label='Epoch 1 Boundary')
    ax1.text(EPOCH_STEPS+50, ax1.get_ylim()[0], 'Epoch 2 Start', color='yellow', rotation=90)
    ax1.axvline(x=min_val_step, color='lime', linestyle=':', linewidth=2, label=f'Best Val (Step {int(min_val_step)})')

    ax1.set_title(f"DATA WALL AUDIT: Best Val {min_val_loss:.3f} @ Step {int(min_val_step)}", fontsize=14)
    ax1.set_ylabel("Cross Entropy Loss")
    ax1.legend()

    # PLOT 2: Homeostasis
    ax2_twin = ax2.twinx()
    sns.lineplot(data=df, x='step', y='pressure', ax=ax2, color='orange', label='Pressure (Lambda)')
    sns.lineplot(data=df, x='step', y='sigma_a_avg', ax=ax2_twin, color='green', label='Sigma', linestyle='--')

    ax2_twin.axhline(y=0.0035, color='green', linestyle=':', alpha=0.5, label='Target Sigma')

    ax2.set_ylabel("Pressure", color='orange')
    ax2_twin.set_ylabel("Sigma", color='green')
    ax2.set_xlabel("Steps")

    plt.tight_layout()
    plt.show()

    # --- 6. REPORT ---
    print("\nüîé RECENT DATA (Last 8 Unique Steps):")
    print(df[['step', 'loss', 'val_loss', 'pressure']].tail(8).to_string(index=False))

In [None]:
# @title [Analysis] Model Forensics - Adaptive vs Control

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.optimize import linear_sum_assignment
from tqdm import tqdm
import json
import math
from google.colab import drive

# Mount drive
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"üî¨ MODEL FORENSICS SUITE")
print(f"‚öôÔ∏è Device: {DEVICE}\n")

# ============================================================================
# ARCHITECTURE DEFINITIONS (Copied for standalone operation)
# ============================================================================

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config['n_heads']
        self.d_head = config['d_head']
        self.scale = 1.0 / math.sqrt(self.d_head)
        d_model = config['d_model']

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(config['dropout'])

        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config['d_head'], config['max_seq_len']))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, return_head_outputs=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        out = head_out.transpose(1, 2).reshape(B, S, D)
        out = self.o_proj(out)

        if return_head_outputs:
            return out, head_out, attn_probs
        return out

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config['d_model'])
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config['d_model'])
        self.mlp = SwiGLU(config['d_model'])

    def forward(self, x, return_head_outputs=False):
        if return_head_outputs:
            a, head_out, attn_probs = self.attn(self.ln1(x), return_head_outputs=True)
            x = x + a
            x = x + self.mlp(self.ln2(x))
            return x, head_out, attn_probs
        else:
            a = self.attn(self.ln1(x))
            x = x + a
            x = x + self.mlp(self.ln2(x))
            return x

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config['vocab_size'], config['d_model'])
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config['n_layers'])])
        self.ln_f = RMSNorm(config['d_model'])
        self.head = nn.Linear(config['d_model'], config['vocab_size'], bias=False)
        self.token_emb.weight = self.head.weight

    def forward(self, idx, return_head_outputs=False):
        x = self.token_emb(idx)
        head_outputs = []
        attn_patterns = []

        for block in self.blocks:
            if return_head_outputs:
                x, h_out, a_probs = block(x, return_head_outputs=True)
                head_outputs.append(h_out)
                attn_patterns.append(a_probs)
            else:
                x = block(x)

        x = self.ln_f(x)
        logits = self.head(x)

        if return_head_outputs:
            return logits, head_outputs, attn_patterns
        return logits

# ============================================================================
# FORENSICS SUITE
# ============================================================================

class ModelForensics:
    def __init__(self, model_a_path, model_b_path, model_a_name="Adaptive", model_b_name="Control"):
        self.config = {
            'vocab_size': 50257,
            'd_model': 512,
            'n_layers': 12,
            'n_heads': 8,
            'd_head': 64,
            'max_seq_len': 512,
            'dropout': 0.0  # Set to 0 for deterministic analysis
        }

        print(f"üìÇ Loading models...")
        self.model_a = NewGPT(self.config).to(DEVICE)
        self.model_b = NewGPT(self.config).to(DEVICE)

        # Load state dicts
        state_a = torch.load(model_a_path, map_location=DEVICE)
        state_b = torch.load(model_b_path, map_location=DEVICE)

        self.model_a.load_state_dict(state_a)
        self.model_b.load_state_dict(state_b)

        self.model_a.eval()
        self.model_b.eval()

        self.model_a_name = model_a_name
        self.model_b_name = model_b_name

        self.results = {}
        print(f"‚úÖ Models loaded: {model_a_name} vs {model_b_name}\n")

    def compare_weight_statistics(self):
        """Section 1: Compare raw weight distributions"""
        print("üîç Section 1: Weight Statistics Analysis")

        stats_data = []

        for name, param_a in self.model_a.named_parameters():
            param_b = dict(self.model_b.named_parameters())[name]

            a_flat = param_a.detach().cpu().flatten().numpy()
            b_flat = param_b.detach().cpu().flatten().numpy()

            stats_data.append({
                'parameter': name,
                f'{self.model_a_name}_mean': np.mean(a_flat),
                f'{self.model_b_name}_mean': np.mean(b_flat),
                f'{self.model_a_name}_std': np.std(a_flat),
                f'{self.model_b_name}_std': np.std(b_flat),
                f'{self.model_a_name}_norm': np.linalg.norm(a_flat),
                f'{self.model_b_name}_norm': np.linalg.norm(b_flat),
                f'{self.model_a_name}_sparsity': np.mean(np.abs(a_flat) < 0.01),
                f'{self.model_b_name}_sparsity': np.mean(np.abs(b_flat) < 0.01),
                'diff_norm': np.linalg.norm(a_flat - b_flat),
                'cosine_sim': np.dot(a_flat, b_flat) / (np.linalg.norm(a_flat) * np.linalg.norm(b_flat))
            })

        df = pd.DataFrame(stats_data)
        self.results['weight_stats'] = df

        # Summary statistics
        print(f"   Total parameters: {sum([p.numel() for p in self.model_a.parameters()]):,}")
        print(f"   Mean weight difference norm: {df['diff_norm'].mean():.6f}")
        print(f"   Mean cosine similarity: {df['cosine_sim'].mean():.6f}")
        print(f"   Parameters analyzed: {len(df)}\n")

        return df

    def analyze_attention_heads(self, num_samples=100, seq_len=128):
        """Section 2: Deep dive into attention head properties"""
        print("üîç Section 2: Attention Head Analysis")

        # Generate test inputs
        test_inputs = torch.randint(0, self.config['vocab_size'], (num_samples, seq_len)).to(DEVICE)

        head_metrics_a = []
        head_metrics_b = []

        with torch.no_grad():
            # Get head outputs from both models
            _, heads_a, attn_a = self.model_a(test_inputs, return_head_outputs=True)
            _, heads_b, attn_b = self.model_b(test_inputs, return_head_outputs=True)

            # Analyze each layer
            for layer_idx in range(self.config['n_layers']):
                h_a = heads_a[layer_idx]  # [B, n_heads, S, d_head]
                h_b = heads_b[layer_idx]

                # Compute pairwise head similarity (sigma_a metric)
                for model_heads, model_name, metric_list in [
                    (h_a, self.model_a_name, head_metrics_a),
                    (h_b, self.model_b_name, head_metrics_b)
                ]:
                    # Flatten to [n_heads, B*S*d_head]
                    flat = model_heads.transpose(0, 1).reshape(self.config['n_heads'], -1)
                    norm = F.normalize(flat, p=2, dim=1)

                    # Compute similarity matrix
                    sim_matrix = torch.mm(norm, norm.t()).cpu().numpy()

                    # Off-diagonal similarities
                    mask = ~np.eye(self.config['n_heads'], dtype=bool)
                    off_diag = sim_matrix[mask]

                    metric_list.append({
                        'layer': layer_idx,
                        'model': model_name,
                        'mean_similarity': np.mean(np.abs(off_diag)),
                        'max_similarity': np.max(np.abs(off_diag)),
                        'min_similarity': np.min(np.abs(off_diag)),
                        'std_similarity': np.std(off_diag),
                        'similarity_matrix': sim_matrix
                    })

                # Attention pattern entropy
                for attn_probs, model_name in [(attn_a[layer_idx], self.model_a_name),
                                                 (attn_b[layer_idx], self.model_b_name)]:
                    # attn_probs: [B, n_heads, S, S]
                    entropy = -(attn_probs * torch.log(attn_probs + 1e-9)).sum(dim=-1).mean(dim=[0, 2])
                    # entropy: [n_heads]

                    for h_idx, ent in enumerate(entropy.cpu().numpy()):
                        metric_list = head_metrics_a if model_name == self.model_a_name else head_metrics_b
                        metric_list[layer_idx][f'head_{h_idx}_entropy'] = ent

        df_a = pd.DataFrame(head_metrics_a)
        df_b = pd.DataFrame(head_metrics_b)
        df_combined = pd.concat([df_a, df_b])

        self.results['head_metrics'] = df_combined
        self.results['head_similarity_matrices_a'] = [m['similarity_matrix'] for m in head_metrics_a]
        self.results['head_similarity_matrices_b'] = [m['similarity_matrix'] for m in head_metrics_b]

        # Summary
        avg_sim_a = df_a['mean_similarity'].mean()
        avg_sim_b = df_b['mean_similarity'].mean()

        print(f"   {self.model_a_name} mean head similarity (œÉ_a): {avg_sim_a:.6f}")
        print(f"   {self.model_b_name} mean head similarity (œÉ_a): {avg_sim_b:.6f}")
        print(f"   Difference: {abs(avg_sim_a - avg_sim_b):.6f}")
        print(f"   Lower is more diverse (target was 0.0035)\n")

        return df_combined

    def measure_loss_landscape(self, val_data_path, num_points=11, num_batches=20):
        """Section 3: Loss landscape interpolation between models"""
        print("üîç Section 3: Loss Landscape Interpolation")

        # Load validation data
        val_data = np.memmap(val_data_path, dtype=np.uint16, mode='r')

        alphas = np.linspace(0, 1, num_points)
        losses = []

        for alpha in tqdm(alphas, desc="Interpolating"):
            # Create interpolated model: theta = alpha*A + (1-alpha)*B
            interp_model = NewGPT(self.config).to(DEVICE)

            state_dict_interp = {}
            for name in self.model_a.state_dict().keys():
                state_dict_interp[name] = (
                    alpha * self.model_a.state_dict()[name] +
                    (1 - alpha) * self.model_b.state_dict()[name]
                )

            interp_model.load_state_dict(state_dict_interp)
            interp_model.eval()

            # Compute loss
            batch_losses = []
            with torch.no_grad():
                for _ in range(num_batches):
                    ix = np.random.randint(0, len(val_data) - self.config['max_seq_len'])
                    x = torch.from_numpy(val_data[ix:ix+self.config['max_seq_len']].astype(np.int64)).unsqueeze(0).to(DEVICE)
                    y = torch.from_numpy(val_data[ix+1:ix+1+self.config['max_seq_len']].astype(np.int64)).unsqueeze(0).to(DEVICE)

                    logits = interp_model(x)
                    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
                    batch_losses.append(loss.item())

            losses.append(np.mean(batch_losses))

        self.results['landscape'] = {'alphas': alphas, 'losses': losses}

        print(f"   Loss at {self.model_a_name} (Œ±=1): {losses[-1]:.4f}")
        print(f"   Loss at {self.model_b_name} (Œ±=0): {losses[0]:.4f}")
        print(f"   Max loss along path: {max(losses):.4f}")
        print(f"   Barrier height: {max(losses) - min(losses):.4f}\n")

        return alphas, losses

    def compute_effective_rank(self):
        """Section 4: Effective rank of weight matrices"""
        print("üîç Section 4: Effective Rank Analysis")

        rank_data = []

        for name, param_a in self.model_a.named_parameters():
            if len(param_a.shape) != 2:  # Only analyze 2D weight matrices
                continue

            param_b = dict(self.model_b.named_parameters())[name]

            # Compute SVD
            U_a, S_a, V_a = torch.svd(param_a.cpu())
            U_b, S_b, V_b = torch.svd(param_b.cpu())

            # Effective rank: exp(entropy of normalized singular values)
            def eff_rank(S):
                S_norm = S / S.sum()
                entropy = -(S_norm * torch.log(S_norm + 1e-9)).sum()
                return torch.exp(entropy).item()

            rank_data.append({
                'parameter': name,
                f'{self.model_a_name}_eff_rank': eff_rank(S_a),
                f'{self.model_b_name}_eff_rank': eff_rank(S_b),
                f'{self.model_a_name}_top_sv': S_a[0].item(),
                f'{self.model_b_name}_top_sv': S_b[0].item(),
                f'{self.model_a_name}_condition': (S_a[0] / S_a[-1]).item(),
                f'{self.model_b_name}_condition': (S_b[0] / S_b[-1]).item(),
            })

        df = pd.DataFrame(rank_data)
        self.results['effective_rank'] = df

        print(f"   Mean effective rank {self.model_a_name}: {df[f'{self.model_a_name}_eff_rank'].mean():.2f}")
        print(f"   Mean effective rank {self.model_b_name}: {df[f'{self.model_b_name}_eff_rank'].mean():.2f}")
        print(f"   Mean condition number {self.model_a_name}: {df[f'{self.model_a_name}_condition'].mean():.2f}")
        print(f"   Mean condition number {self.model_b_name}: {df[f'{self.model_b_name}_condition'].mean():.2f}\n")

        return df

    def generate_visualizations(self, save_dir):
        """Create comprehensive visualization suite"""
        print("üìä Generating Visualizations...")
        os.makedirs(save_dir, exist_ok=True)

        # 1. Head Similarity Heatmaps
        fig, axes = plt.subplots(3, 4, figsize=(20, 15))
        fig.suptitle(f'Attention Head Similarity Matrices (Layer 0-11)', fontsize=16)

        for layer_idx in range(12):
            ax = axes[layer_idx // 4, layer_idx % 4]

            sim_a = self.results['head_similarity_matrices_a'][layer_idx]
            sim_b = self.results['head_similarity_matrices_b'][layer_idx]

            # Show difference
            diff = np.abs(sim_a) - np.abs(sim_b)

            im = ax.imshow(diff, cmap='RdBu_r', vmin=-0.5, vmax=0.5)
            ax.set_title(f'Layer {layer_idx}\n({self.model_a_name} - {self.model_b_name})')
            ax.set_xlabel('Head')
            ax.set_ylabel('Head')
            plt.colorbar(im, ax=ax)

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'head_similarity_heatmaps.png'), dpi=150)
        plt.close()

        # 2. Mean Head Similarity by Layer
        df = self.results['head_metrics']

        fig, ax = plt.subplots(figsize=(12, 6))
        for model_name in [self.model_a_name, self.model_b_name]:
            data = df[df['model'] == model_name]
            ax.plot(data['layer'], data['mean_similarity'], marker='o', label=model_name, linewidth=2)

        ax.axhline(y=0.0035, color='red', linestyle='--', label='Target œÉ_a = 0.0035')
        ax.set_xlabel('Layer', fontsize=12)
        ax.set_ylabel('Mean Head Similarity (œÉ_a)', fontsize=12)
        ax.set_title('Attention Head Redundancy by Layer', fontsize=14)
        ax.legend()
        ax.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'head_similarity_by_layer.png'), dpi=150)
        plt.close()

        # 3. Loss Landscape
        if 'landscape' in self.results:
            fig, ax = plt.subplots(figsize=(12, 6))
            alphas = self.results['landscape']['alphas']
            losses = self.results['landscape']['losses']

            ax.plot(alphas, losses, marker='o', linewidth=2, markersize=8)
            ax.axvline(x=0, color='blue', linestyle='--', alpha=0.5, label=self.model_b_name)
            ax.axvline(x=1, color='orange', linestyle='--', alpha=0.5, label=self.model_a_name)
            ax.set_xlabel(f'Œ± (0={self.model_b_name}, 1={self.model_a_name})', fontsize=12)
            ax.set_ylabel('Validation Loss', fontsize=12)
            ax.set_title('Loss Landscape: Linear Interpolation Between Models', fontsize=14)
            ax.legend()
            ax.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(os.path.join(save_dir, 'loss_landscape.png'), dpi=150)
            plt.close()

        # 4. Weight Distribution Comparison (sample)
        fig, axes = plt.subplots(2, 3, figsize=(18, 10))
        fig.suptitle('Weight Distribution Comparison (Sample Layers)', fontsize=16)

        sample_params = [
            'blocks.0.attn.q_proj.weight',
            'blocks.0.attn.o_proj.weight',
            'blocks.6.attn.q_proj.weight',
            'blocks.6.attn.o_proj.weight',
            'blocks.11.attn.q_proj.weight',
            'blocks.11.attn.o_proj.weight'
        ]

        for idx, param_name in enumerate(sample_params):
            ax = axes[idx // 3, idx % 3]

            param_a = dict(self.model_a.named_parameters())[param_name].detach().cpu().flatten().numpy()
            param_b = dict(self.model_b.named_parameters())[param_name].detach().cpu().flatten().numpy()

            ax.hist(param_a, bins=50, alpha=0.5, label=self.model_a_name, density=True)
            ax.hist(param_b, bins=50, alpha=0.5, label=self.model_b_name, density=True)
            ax.set_title(param_name.replace('blocks.', 'L').replace('.attn.', ' '), fontsize=10)
            ax.legend()
            ax.grid(True, alpha=0.3)

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, 'weight_distributions.png'), dpi=150)
        plt.close()

        print(f"‚úÖ Visualizations saved to {save_dir}\n")

    def generate_report(self, save_dir):
        """Generate comprehensive text report"""
        print("üìù Generating Report...")

        report_path = os.path.join(save_dir, 'forensics_report.txt')

        with open(report_path, 'w') as f:
            f.write("="*80 + "\n")
            f.write("MODEL FORENSICS REPORT\n")
            f.write(f"{self.model_a_name} vs {self.model_b_name}\n")
            f.write("="*80 + "\n\n")

            # Section 1: Weight Statistics
            f.write("SECTION 1: WEIGHT STATISTICS\n")
            f.write("-"*80 + "\n")
            df = self.results['weight_stats']
            f.write(f"Total Parameters: {sum([p.numel() for p in self.model_a.parameters()]):,}\n")
            f.write(f"Mean Weight Difference Norm: {df['diff_norm'].mean():.6f}\n")
            f.write(f"Mean Cosine Similarity: {df['cosine_sim'].mean():.6f}\n")
            f.write(f"Std Cosine Similarity: {df['cosine_sim'].std():.6f}\n\n")

            # Section 2: Head Analysis
            f.write("SECTION 2: ATTENTION HEAD ANALYSIS\n")
            f.write("-"*80 + "\n")
            df = self.results['head_metrics']
            for model_name in [self.model_a_name, self.model_b_name]:
                data = df[df['model'] == model_name]
                f.write(f"\n{model_name}:\n")
                f.write(f"  Mean Head Similarity (œÉ_a): {data['mean_similarity'].mean():.6f}\n")
                f.write(f"  Std Head Similarity: {data['mean_similarity'].std():.6f}\n")
                f.write(f"  Max Head Similarity: {data['max_similarity'].mean():.6f}\n")
                f.write(f"  Min Head Similarity: {data['min_similarity'].mean():.6f}\n")

            f.write(f"\nTarget œÉ_a: 0.0035\n")
            f.write(f"Difference in œÉ_a: {abs(df[df['model']==self.model_a_name]['mean_similarity'].mean() - df[df['model']==self.model_b_name]['mean_similarity'].mean()):.6f}\n\n")

            # Section 3: Loss Landscape
            if 'landscape' in self.results:
                f.write("SECTION 3: LOSS LANDSCAPE\n")
                f.write("-"*80 + "\n")
                losses = self.results['landscape']['losses']
                f.write(f"Loss at {self.model_a_name}: {losses[-1]:.4f}\n")
                f.write(f"Loss at {self.model_b_name}: {losses[0]:.4f}\n")
                f.write(f"Max Loss Along Path: {max(losses):.4f}\n")
                f.write(f"Barrier Height: {max(losses) - min(losses):.4f}\n")
                f.write(f"Path Smoothness: {'Smooth' if max(losses) - min(losses) < 0.1 else 'Rough'}\n\n")

            # Section 4: Effective Rank
            if 'effective_rank' in self.results:
                f.write("SECTION 4: EFFECTIVE RANK\n")
                f.write("-"*80 + "\n")
                df = self.results['effective_rank']
                f.write(f"{self.model_a_name} Mean Effective Rank: {df[f'{self.model_a_name}_eff_rank'].mean():.2f}\n")
                f.write(f"{self.model_b_name} Mean Effective Rank: {df[f'{self.model_b_name}_eff_rank'].mean():.2f}\n")
                f.write(f"{self.model_a_name} Mean Condition Number: {df[f'{self.model_a_name}_condition'].mean():.2f}\n")
                f.write(f"{self.model_b_name} Mean Condition Number: {df[f'{self.model_b_name}_condition'].mean():.2f}\n\n")

            f.write("="*80 + "\n")
            f.write("END REPORT\n")
            f.write("="*80 + "\n")

        print(f"‚úÖ Report saved to {report_path}\n")

# ============================================================================
# EXECUTION
# ============================================================================

if __name__ == "__main__":
    # Define paths
    MODEL_ADAPTIVE_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_adaptive_v1/janus_adaptive_final.pt")
    MODEL_CONTROL_PATH = os.path.join(PROJECT_ROOT, "data/models/janus_control_v1/janus_control_final.pt")
    VAL_DATA_PATH = os.path.join(PROJECT_ROOT, "data/wikitext/val.bin")
    RESULTS_DIR = os.path.join(PROJECT_ROOT, "forensics_results")

    # Initialize Forensics
    forensics = ModelForensics(
        MODEL_ADAPTIVE_PATH,
        MODEL_CONTROL_PATH,
        model_a_name="Adaptive",
        model_b_name="Control"
    )

    # Run analysis pipeline
    forensics.compare_weight_statistics()
    forensics.analyze_attention_heads(num_samples=50, seq_len=128)

    if os.path.exists(VAL_DATA_PATH):
        forensics.measure_loss_landscape(VAL_DATA_PATH, num_points=11)

    forensics.compute_effective_rank()

    # Generate outputs
    forensics.generate_visualizations(RESULTS_DIR)
    forensics.generate_report(RESULTS_DIR)

    print("üèÅ Forensics analysis complete.")

In [None]:
# @title [Omnibus] Deep Model Forensics - Geometric & Information-Theoretic Analysis
# Optimized for L4 GPU | Project Janus XAI Suite

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import math
from google.colab import drive

# ============================================================================
# 1. SETUP & PATHS
# ============================================================================

if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

# Paths based on user specification
PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
PATH_ADAPTIVE = os.path.join(PROJECT_ROOT, "data/models/janus_adaptive_v1/janus_adaptive_final.pt")
PATH_CONTROL = os.path.join(PROJECT_ROOT, "data/models/janus_control_v1/janus_control_final.pt")
PATH_VAL_DATA = os.path.join(PROJECT_ROOT, "data/wikitext/val.bin")
RESULTS_DIR = os.path.join(PROJECT_ROOT, "omnibus_forensics_results")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs(RESULTS_DIR, exist_ok=True)

print(f"üî¨ INITIALIZING OMNIBUS SUITE")
print(f"‚öôÔ∏è Device: {DEVICE}")

# ============================================================================
# 2. ARCHITECTURE DEFINITIONS
# ============================================================================

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads, self.d_head = config['n_heads'], config['d_head']
        self.scale = 1.0 / math.sqrt(self.d_head)
        d_model = config['d_model']
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(config['dropout'])
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(self.d_head, config['max_seq_len']))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end)
        freqs = torch.outer(t, freqs).float()
        return torch.polar(torch.ones_like(freqs), freqs)

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1).to(x.device)
        return torch.view_as_real(x_c * freqs).flatten(3).type_as(x)

    def forward(self, x, return_internals=False):
        B, S, D = x.shape
        q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        q = self.apply_rope(q.view(B, S, self.n_heads, self.d_head), self.freqs_cis)
        k = self.apply_rope(k.view(B, S, self.n_heads, self.d_head), self.freqs_cis)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn_probs = F.softmax(attn.masked_fill(mask == 0, float('-inf')), dim=-1)
        head_out = attn_probs @ v
        out = self.o_proj(head_out.transpose(1, 2).reshape(B, S, D))

        if return_internals: return out, head_out, attn_probs
        return out

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1, self.attn = RMSNorm(config['d_model']), NewAttention(config)
        self.ln2, self.mlp = RMSNorm(config['d_model']), SwiGLU(config['d_model'])

    def forward(self, x, return_internals=False):
        if return_internals:
            a, h_out, a_probs = self.attn(self.ln1(x), return_internals=True)
            x = x + a
            m_out = self.mlp(self.ln2(x))
            x = x + m_out
            return x, h_out, a_probs, m_out
        return x + self.attn(self.ln1(x)) + self.mlp(self.ln2(x))

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config['vocab_size'], config['d_model'])
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config['n_layers'])])
        self.ln_f, self.head = RMSNorm(config['d_model']), nn.Linear(config['d_model'], config['vocab_size'], bias=False)
        self.token_emb.weight = self.head.weight

    def forward(self, idx, return_internals=False):
        x = self.token_emb(idx)
        internals = {'heads': [], 'attns': [], 'mlps': []}
        for block in self.blocks:
            if return_internals:
                x, h, a, m = block(x, return_internals=True)
                internals['heads'].append(h); internals['attns'].append(a); internals['mlps'].append(m)
            else: x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return (logits, internals) if return_internals else logits

# ============================================================================
# 3. DATA LOADING
# ============================================================================

def get_batch(data_path, batch_size=32, seq_len=128):
    data = np.memmap(data_path, dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - seq_len, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+seq_len]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+seq_len+1]).astype(np.int64)) for i in ix])
    return x.to(DEVICE), y.to(DEVICE)

# ============================================================================
# 4. OMNIBUS ANALYZER
# ============================================================================

class OmnibusForensics:
    def __init__(self, path_a, path_b):
        self.config = {'vocab_size': 50257, 'd_model': 512, 'n_layers': 12, 'n_heads': 8, 'd_head': 64, 'max_seq_len': 512, 'dropout': 0.0}
        self.model_a = NewGPT(self.config).to(DEVICE).eval()
        self.model_b = NewGPT(self.config).to(DEVICE).eval()

        print(f"üìÇ Loading Adaptive: {path_a}")
        self.model_a.load_state_dict(torch.load(path_a, map_location=DEVICE))
        print(f"üìÇ Loading Control:  {path_b}")
        self.model_b.load_state_dict(torch.load(path_b, map_location=DEVICE))
        self.results = {}

    @torch.no_grad()
    def compute_cka(self, X):
        """Centered Kernel Alignment for representational similarity."""
        def linear_HSIC(K, L):
            n = K.shape[0]
            H = torch.eye(n, device=DEVICE) - (1/n) * torch.ones((n, n), device=DEVICE)
            K_centered = H @ K @ H
            L_centered = H @ L @ H
            return (K_centered * L_centered).sum() / ((n-1)**2)

        _, int_a = self.model_a(X, return_internals=True)
        _, int_b = self.model_b(X, return_internals=True)

        cka_layers = []
        for i in range(self.config['n_layers']):
            feat_a = int_a['mlps'][i].mean(dim=1) # Average over sequence
            feat_b = int_b['mlps'][i].mean(dim=1)

            K = feat_a @ feat_a.T
            L = feat_b @ feat_b.T

            hsic_kl = linear_HSIC(K, L)
            hsic_kk = linear_HSIC(K, K)
            hsic_ll = linear_HSIC(L, L)
            cka_val = hsic_kl / (torch.sqrt(hsic_kk) * torch.sqrt(hsic_ll) + 1e-9)
            cka_layers.append(cka_val.item())

        self.results['cka'] = cka_layers
        return cka_layers

    def geometric_drill(self):
        """Effective Rank and Spectral Analysis."""
        metrics = []
        for name, p_a in self.model_a.named_parameters():
            if 'weight' in name and len(p_a.shape) == 2:
                p_b = dict(self.model_b.named_parameters())[name]
                for p, m_name in [(p_a, 'Adaptive'), (p_b, 'Control')]:
                    s = torch.linalg.svdvals(p.float())
                    s_norm = s / (s.sum() + 1e-9)
                    eff_rank = torch.exp(-(s_norm * torch.log(s_norm + 1e-9)).sum()).item()
                    metrics.append({
                        'param': name, 'model': m_name,
                        'eff_rank': eff_rank,
                        'cond': (s[0]/(s[-1] + 1e-9)).item()
                    })
        self.results['geometry'] = pd.DataFrame(metrics)
        return self.results['geometry']

    @torch.no_grad()
    def analyze_head_synergy(self, X):
        """Head Redundancy (Sigma_a)."""
        _, int_a = self.model_a(X, return_internals=True)
        _, int_b = self.model_b(X, return_internals=True)

        synergy_data = []
        for layer in range(self.config['n_layers']):
            for model_name, internals in [('Adaptive', int_a), ('Control', int_b)]:
                h_out = internals['heads'][layer]
                flat = h_out.transpose(0,1).reshape(self.config['n_heads'], -1)
                norm_flat = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm_flat, norm_flat.T)
                mask = ~torch.eye(self.config['n_heads'], device=DEVICE).bool()
                avg_sim = sim[mask].abs().mean().item()
                synergy_data.append({'layer': layer, 'model': model_name, 'sigma_a': avg_sim})

        self.results['synergy'] = pd.DataFrame(synergy_data)
        return self.results['synergy']

    def run_interpolation(self, data_path, steps=11):
        """High-resolution Loss Landscape."""
        alphas = np.linspace(0, 1, steps)
        path_losses = []
        for alpha in tqdm(alphas, desc="Profiling Landscape"):
            interp = NewGPT(self.config).to(DEVICE).eval()
            sd = {k: alpha*self.model_a.state_dict()[k] + (1-alpha)*self.model_b.state_dict()[k]
                  for k in self.model_a.state_dict()}
            interp.load_state_dict(sd)

            x, y = get_batch(data_path, batch_size=16) # Smaller batch for profiling speed
            with torch.no_grad():
                logits = interp(x)
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
                path_losses.append(loss.item())
        self.results['landscape'] = {'alpha': alphas, 'loss': path_losses}

# ============================================================================
# 5. VISUALIZATION ENGINE
# ============================================================================

def generate_report(omni, out_dir):
    sns.set_theme(style="whitegrid")

    # FIG 1: Dimensional Drilling
    geo = omni.results['geometry']
    o_proj = geo[geo['param'].str.contains('o_proj')].copy()
    o_proj['layer'] = o_proj['param'].apply(lambda x: int(x.split('.')[1]))

    plt.figure(figsize=(12, 5))
    sns.lineplot(data=o_proj, x='layer', y='eff_rank', hue='model', marker='s')
    plt.title("Information Bandwidth (Effective Rank: o_proj)")
    plt.savefig(f"{out_dir}/dimensional_drill.png")

    # FIG 2: Redundancy & CKA
    syn = omni.results['synergy']
    cka = omni.results['cka']
    fig, ax1 = plt.subplots(figsize=(12, 5))
    sns.lineplot(data=syn, x='layer', y='sigma_a', hue='model', ax=ax1, palette=['tab:blue', 'tab:orange'])
    ax2 = ax1.twinx()
    ax2.plot(range(12), cka, color='tab:green', linestyle='--', label='CKA Similarity', marker='x')
    ax1.set_title("Synergy: œÉ_a (Redundancy) vs Representational CKA")
    ax1.legend(loc='upper left'); ax2.legend(loc='upper right')
    plt.savefig(f"{out_dir}/redundancy_cka.png")

    # FIG 3: Landscape
    ls = omni.results['landscape']
    plt.figure(figsize=(10, 5))
    plt.plot(ls['alpha'], ls['loss'], 'r-o')
    plt.title("Loss Landscape Barrier")
    plt.savefig(f"{out_dir}/landscape_barrier.png")

    print(f"‚úÖ OMNIBUS REPORT SAVED TO {out_dir}")

# ============================================================================
# 6. RUN
# ============================================================================

if __name__ == "__main__":
    analyzer = OmnibusForensics(PATH_ADAPTIVE, PATH_CONTROL)

    print("üíé Running Geometric Analysis...")
    analyzer.geometric_drill()

    print("üíé Running CKA & Head Redundancy...")
    x_sample, _ = get_batch(PATH_VAL_DATA, batch_size=32)
    analyzer.compute_cka(x_sample)
    analyzer.analyze_head_synergy(x_sample)

    print("üíé Running Landscape Profiling...")
    analyzer.run_interpolation(PATH_VAL_DATA)

    generate_report(analyzer, RESULTS_DIR)

In [None]:
# @title [Forensics] Gradient Flow & Signal Propagation Analysis
# Standalone execution - Optimized for L4 GPU

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import math
from google.colab import drive

# --- PATH CONFIGURATION ---
PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
PATH_ADAPTIVE = os.path.join(PROJECT_ROOT, "data/models/janus_adaptive_v1/janus_adaptive_final.pt")
PATH_CONTROL  = os.path.join(PROJECT_ROOT, "data/models/janus_control_v1/janus_control_final.pt")
PATH_VAL_DATA = os.path.join(PROJECT_ROOT, "data/wikitext/val.bin")
RESULTS_DIR   = os.path.join(PROJECT_ROOT, "gradient_flow_results")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

os.makedirs(RESULTS_DIR, exist_ok=True)

# ============================================================================
# 1. ARCHITECTURE (Required for Standalone Loading)
# ============================================================================
# (Standard Janus Architecture: RMSNorm, SwiGLU, NewAttention, NewBlock, NewGPT)
# [Note: Re-using the architecture classes from your previous Omnibus script]
# ============================================================================

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads, self.d_head = config['n_heads'], config['d_head']
        self.scale = 1.0 / math.sqrt(self.d_head)
        d_model = config['d_model']
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(config['dropout'])
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(self.d_head, config['max_seq_len']))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end)
        freqs = torch.outer(t, freqs).float()
        return torch.polar(torch.ones_like(freqs), freqs)

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1).to(x.device)
        return torch.view_as_real(x_c * freqs).flatten(3).type_as(x)

    def forward(self, x, return_internals=False):
        B, S, D = x.shape
        q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x)
        q = self.apply_rope(q.view(B, S, self.n_heads, self.d_head), self.freqs_cis)
        k = self.apply_rope(k.view(B, S, self.n_heads, self.d_head), self.freqs_cis)
        q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.view(B, S, self.n_heads, self.d_head).transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn_probs = F.softmax(attn.masked_fill(mask == 0, float('-inf')), dim=-1)
        head_out = attn_probs @ v
        out = self.o_proj(head_out.transpose(1, 2).reshape(B, S, D))

        if return_internals: return out, head_out, attn_probs
        return out

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1, self.attn = RMSNorm(config['d_model']), NewAttention(config)
        self.ln2, self.mlp = RMSNorm(config['d_model']), SwiGLU(config['d_model'])

    def forward(self, x, return_internals=False):
        if return_internals:
            a, h_out, a_probs = self.attn(self.ln1(x), return_internals=True)
            x = x + a
            m_out = self.mlp(self.ln2(x))
            x = x + m_out
            return x, h_out, a_probs, m_out
        return x + self.attn(self.ln1(x)) + self.mlp(self.ln2(x))

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config['vocab_size'], config['d_model'])
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config['n_layers'])])
        self.ln_f, self.head = RMSNorm(config['d_model']), nn.Linear(config['d_model'], config['vocab_size'], bias=False)
        self.token_emb.weight = self.head.weight

    def forward(self, idx, return_internals=False):
        x = self.token_emb(idx)
        internals = {'heads': [], 'attns': [], 'mlps': []}
        for block in self.blocks:
            if return_internals:
                x, h, a, m = block(x, return_internals=True)
                internals['heads'].append(h); internals['attns'].append(a); internals['mlps'].append(m)
            else: x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)
        return (logits, internals) if return_internals else logits
# ============================================================================
# 2. GRADIENT FLOW ANALYZER
# ============================================================================

def get_gradient_stats(model, x, y):
    """Performs a forward and backward pass to collect layer-wise gradient norms."""
    model.zero_grad()
    logits = model(x)
    loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))
    loss.backward()

    layer_grads = []
    for i, block in enumerate(model.blocks):
        # We track the gradient norm of the attention output projection as a proxy for layer energy
        if hasattr(block.attn.o_proj, 'weight') and block.attn.o_proj.weight.grad is not None:
            grad_norm = block.attn.o_proj.weight.grad.norm().item()
            layer_grads.append({'layer': i, 'grad_norm': grad_norm})

    return layer_grads

def run_gradient_forensics():
    if not os.path.exists('/content/drive'): drive.mount('/content/drive')

    # 1. Load Models
    config = {'vocab_size': 50257, 'd_model': 512, 'n_layers': 12, 'n_heads': 8, 'd_head': 64, 'max_seq_len': 512, 'dropout': 0.0}

    print("üìÇ Loading Models from Drive...")
    model_a = NewGPT(config).to(DEVICE)
    model_b = NewGPT(config).to(DEVICE)
    model_a.load_state_dict(torch.load(PATH_ADAPTIVE, map_location=DEVICE))
    model_b.load_state_dict(torch.load(PATH_CONTROL, map_location=DEVICE))

    # 2. Prepare Data
    print("üìñ Loading Validation Data...")
    data = np.memmap(PATH_VAL_DATA, dtype=np.uint16, mode='r')
    ix = torch.randint(len(data) - 128, (32,))
    x = torch.stack([torch.from_numpy((data[i:i+128]).astype(np.int64)) for i in ix]).to(DEVICE)
    y = torch.stack([torch.from_numpy((data[i+1:i+128+1]).astype(np.int64)) for i in ix]).to(DEVICE)

    # 3. Compute Gradients
    print("‚ö° Analyzing Gradient Propagation...")
    grads_a = get_gradient_stats(model_a, x, y)
    grads_b = get_gradient_stats(model_b, x, y)

    # 4. Process Results
    df_a = pd.DataFrame(grads_a); df_a['model'] = 'Adaptive'
    df_b = pd.DataFrame(grads_b); df_b['model'] = 'Control'
    df = pd.concat([df_a, df_b])

    # 5. Visualization
    plt.figure(figsize=(12, 6))
    sns.lineplot(data=df, x='layer', y='grad_norm', hue='model', marker='o', linewidth=2.5)
    plt.yscale('log') # Gradients often span orders of magnitude
    plt.title("Gradient Flow Profile: Signal Strength Across Layers", fontsize=14)
    plt.ylabel("Gradient Norm (Log Scale)")
    plt.xlabel("Layer Index (0=Early, 11=Late)")
    plt.grid(True, which="both", ls="-", alpha=0.2)
    plt.savefig(os.path.join(RESULTS_DIR, "gradient_flow_profile.png"))

    # 6. Summary Report
    avg_a = df_a['grad_norm'].mean()
    avg_b = df_b['grad_norm'].mean()
    ratio = avg_a / avg_b

    print(f"\nüìä GRADIENT FLOW SUMMARY")
    print("-" * 30)
    print(f"Adaptive Avg Grad Norm: {avg_a:.6f}")
    print(f"Control Avg Grad Norm:  {avg_b:.6f}")
    print(f"Propagation Efficiency: {ratio:.2f}x")
    print("-" * 30)
    print(f"‚úÖ Results saved to {RESULTS_DIR}")

if __name__ == "__main__":
    run_gradient_forensics()

In [None]:
# @title [Run] Adaptive Scheduler (WikiText-2 Anti-Overfit Edition)
# Settings: 1500 Steps | Dropout 0.25 | Log Every 10

import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY NUKE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
DATA_DIR = os.path.join(PROJECT_ROOT, "data/Wikitext_2")
# Saving to new W2 specific folder
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_adaptive_w2")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üß† JANUS ADAPTIVE (WikiText-2 Correction Run)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 3. CONFIGURATION (The Regularization Nuke) ---
class AdaptiveConfig:
    def __init__(self):
        # Model (Janus v3 Specs)
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512

        # üõ°Ô∏è ANTI-OVERFIT MEASURES
        self.dropout = 0.25       # High dropout to break memorization loops
        self.weight_decay = 0.2   # Strong L2 regularization

        # Training (Short & Dense)
        self.max_steps = 1500     # Cap at ~25 Epochs (3000 was too long)
        self.batch_size = 16
        self.grad_accum = 4       # Eff BS = 64

        # Adaptive Controller (Fast Reaction)
        self.target_sigma = 0.0035
        self.k_p = 0.8            # High gain
        self.control_interval = 10 # Check sigma every 10 steps
        self.warmup_steps = 200    # Start pressure early (Step 200)

        # IO
        self.ckpt_interval = 500
        self.lite_interval = 10   # üì∏ High-res telemetry (Requested)

# --- 4. THE HOMEOSTATIC CONTROLLER ---
class Homeostat:
    def __init__(self, config):
        self.target = config.target_sigma
        self.kp = config.k_p
        self.warmup = config.warmup_steps
        self.current_lambda = 0.0

    def update(self, step, current_sigma):
        # Phase 1: Ignition (Open Loop Ramp)
        if step < self.warmup:
            # Gentle ramp to 0.05
            self.current_lambda = 0.05 * (step / max(1, self.warmup))
            return self.current_lambda

        # Phase 2: Homeostasis (Closed Loop P-Control)
        if current_sigma is None: return self.current_lambda

        error = current_sigma - self.target
        # Positive Error (Too Redundant) -> INCREASE Pressure
        delta = self.kp * error

        # Apply & Clamp
        self.current_lambda += delta
        self.current_lambda = max(0.0, min(0.25, self.current_lambda))

        return self.current_lambda

# --- 5. ARCHITECTURE ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)
        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, lambdas, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        steer_loss = 0.0
        l_coh, l_div = lambdas

        if (l_div > 0.0 and self.training) or return_metrics:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)

            if l_div > 0.0:
                steer_loss += torch.norm(gram - identity, p='fro') * l_div

        metrics = {}
        if return_metrics:
            with torch.no_grad():
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, lambdas, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), lambdas, return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, lambdas_list, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)
        total_steer = 0.0
        all_metrics = []
        for i, block in enumerate(self.blocks):
            x, s, m = block(x, lambdas_list[i], return_metrics)
            total_steer += s
            if return_metrics: all_metrics.append(m)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss, total_steer, all_metrics

# --- 6. DATA & LOGGING ---
class W2Loader:
    def __init__(self, config):
        self.config = config
        train_path = os.path.join(DATA_DIR, "W2train.bin")
        val_path = os.path.join(DATA_DIR, "W2val.bin")

        if not os.path.exists(train_path): raise FileNotFoundError(f"Missing {train_path}")
        if not os.path.exists(val_path): raise FileNotFoundError(f"Missing {val_path}")

        self.train_data = np.fromfile(train_path, dtype=np.uint16)
        self.val_data = np.fromfile(val_path, dtype=np.uint16)
        print(f"üìñ Loaded W2 Train: {len(self.train_data):,} | Val: {len(self.val_data):,}")

    def get_batch(self, batch_size):
        ix = torch.randint(len(self.train_data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.train_data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.train_data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

    def get_val_batch(self, batch_size):
        ix = torch.randint(len(self.val_data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.val_data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.val_data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
        self.start_step = 0
    def set_start_step(self, step):
        self.start_step = step
    def log(self, step, loss, val_loss, pressure, metrics):
        if step < self.start_step: return
        avg_sigma = 0.0
        if metrics:
            sigmas = [m['sigma_a'].mean().item() for m in metrics if 'sigma_a' in m]
            if sigmas: avg_sigma = sum(sigmas) / len(sigmas)
        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": pressure, "sigma_a_avg": avg_sigma
        }
        self.buffer.append(row)
    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_adaptive.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []

# --- 7. EXECUTION ---
def run_adaptive():
    cfg = AdaptiveConfig()
    model = NewGPT(cfg).to(DEVICE)
    # Using stronger weight decay per config
    optimizer = optim.AdamW(model.parameters(), lr=6.0e-4, weight_decay=cfg.weight_decay)

    controller = Homeostat(cfg)
    loader = W2Loader(cfg)
    recorder = BlackBox(SAVE_DIR)

    # Resume Logic
    start_step = 0
    ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]
    if ckpts:
        ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))
        latest = ckpts[-1]
        print(f"üîÑ Found checkpoint {latest}, but forcing FRESH START due to Regularization Nuke.")
        # NOTE: Intentionally NOT loading checkpoint to clear the memorized weights.

    print(f"\nüß† STARTING W2 ADAPTIVE RUN (CORRECTION)")
    print(f"   Steps: {cfg.max_steps}")
    print(f"   Target Sigma: {cfg.target_sigma}")
    print(f"   Dropout: {cfg.dropout}")

    pbar = tqdm(range(start_step, cfg.max_steps))
    current_sigma_avg = None

    for step in pbar:
        # A. Control
        if step % cfg.control_interval == 0 or step < cfg.warmup_steps:
             controller.update(step, current_sigma_avg)

        p = controller.current_lambda
        lambdas_list = []
        for i in range(cfg.n_layers):
            ratio = (i + 1) / cfg.n_layers
            s_mult = ratio ** 3
            lambdas_list.append((0.0, p * s_mult))

        # B. Train
        model.train()
        batch_loss = 0.0
        accum_sigmas = []
        optimizer.zero_grad()

        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch(cfg.batch_size)

            # Metrics more frequent now (every 10 steps)
            return_metrics = (step % cfg.control_interval == 0) or (step % cfg.lite_interval == 0)

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, steer, metrics = model(x, lambdas_list, y, return_metrics=return_metrics)
                total = (loss + steer) / cfg.grad_accum

            total.backward()
            batch_loss += loss.item()

            if metrics:
                s = [m['sigma_a'].mean().item() for m in metrics]
                if s: accum_sigmas.append(sum(s)/len(s))

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if accum_sigmas:
            current_sigma_avg = sum(accum_sigmas) / len(accum_sigmas)

        # C. Telemetry (High Res: Every 10 steps)
        if step > 0 and (step % cfg.lite_interval == 0 or step == cfg.max_steps - 1):
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(5):
                    vx, vy = loader.get_val_batch(cfg.batch_size)
                    vl, _, _ = model(vx, [(0.0,0.0)]*cfg.n_layers, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)

            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, controller.current_lambda, metrics)
            recorder.flush()

            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.2f}|V:{val_loss:.2f}|P:{p:.3f}|Sig:{current_sigma_avg:.4f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.2f}|P:{p:.3f}")

        # D. Checkpoint
        if step > 0 and step % cfg.ckpt_interval == 0:
            torch.save(model.state_dict(), os.path.join(SAVE_DIR, f"ckpt_step_{step}.pt"))

    # Finish
    torch.save(model.state_dict(), os.path.join(SAVE_DIR, "adaptive_w2_final.pt"))
    print(f"\nüèÜ Run Complete. Model saved to {SAVE_DIR}")

if __name__ == "__main__":
    run_adaptive()

In [None]:
# @title [RUN] Project Janus: Platinum Master (L4 Turbo Edition)
# @markdown **System Status:** Platinum (Bugfix 3.0.7)
# @markdown **Version:** 3.0.7 (Enhanced Credits Filter + Optimized Thresholds)
# @markdown **Security:** HIGH | **Performance:** EXTREME

import os
import sys
import gc
import json
import time
import shutil
import logging
import signal
import uuid
import random
import warnings
import re
import unicodedata
import functools
from datetime import datetime
from typing import List, Dict, Any, Tuple
from pathlib import Path
from tqdm.auto import tqdm

# --- 0. CRITICAL FIXES & SETUP ---

# [FIX 1] Explicit Drive Mount
if not os.path.exists('/content/drive'):
    print("üìÇ Mounting Google Drive...")
    from google.colab import drive
    try:
        drive.mount('/content/drive')
    except Exception as e:
        print(f"‚ö†Ô∏è Drive Mount Failed: {e}")

# [FIX 2] Pinned Dependencies
try:
    import stanza
    import torch
    import numpy as np
    import datasets
    from transformers import AutoTokenizer
except ImportError:
    print("üì¶ Installing dependencies (Pinned)...")
    os.system('pip install stanza==1.8.2 transformers==4.44.2 datasets==2.21.0 huggingface_hub==0.24.6 -q')
    import stanza
    import torch
    import numpy as np
    import datasets
    from transformers import AutoTokenizer

# [FIX 3] Monkey-Patch torch.load for Stanza Compatibility
original_load = torch.load
def safe_load(*args, **kwargs):
    if 'weights_only' not in kwargs:
        kwargs['weights_only'] = False
    return original_load(*args, **kwargs)
torch.load = safe_load
print("üîß Applied Stanza/PyTorch compatibility patch.")

# [FIX 4] Modern TF32 Enabler
if torch.cuda.is_available():
    try:
        torch.set_float32_matmul_precision('high')
        print("üöÄ TensorFloat-32 (TF32) Enabled for L4.")
    except AttributeError:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

# Global Seeds
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# --- 1. Configuration (L4 Tuned) ---

class JanusConfig:
    def __init__(self):
        self.PILOT_MODE = False
        self.PILOT_LIMIT = 5000
        self.DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

        # [L4 TURBO OPTIMIZATIONS]
        self.STANZA_BATCH_SIZE = 150
        self.SMART_BATCH_BUFFER_SIZE = 5000

        self.CONTEXT_WINDOW = 1024
        self.EOS_TOKEN = "<endoftext>"

        # Heuristics (Strict)
        self.MAX_S_COUNT = 1
        self.MAX_SBAR_COUNT = 0
        self.MAX_CC_COUNT = 1
        self.MAX_WORD_LEN = 25

        # Paths
        self.BASE_DIR = Path("/content/drive/MyDrive/Project_Janus_Data")
        self.MANIFEST_FILE = self.BASE_DIR / "janus_manifest.json"
        self.LOG_FILE = self.BASE_DIR / "execution_log.txt"

        # Persistent Cache
        self.STANZA_DIR = self.BASE_DIR / "models/stanza_resources"

        self.SIMPLE_DIR = self.BASE_DIR / "shards" / "simple"
        self.COMPLEX_DIR = self.BASE_DIR / "shards" / "complex"
        self.TEMP_DIR = self.BASE_DIR / "temp_staging"

    def initialize_filesystem(self):
        for p in [self.SIMPLE_DIR, self.COMPLEX_DIR, self.TEMP_DIR, self.STANZA_DIR]:
            p.mkdir(parents=True, exist_ok=True)

CONFIG = JanusConfig()

# --- 2. Logging & Utils (Atomic) ---

def log_event(message, level="INFO"):
    ts = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    msg = f"[{ts}] {level}: {message}"
    print(msg)
    try:
        with open(CONFIG.LOG_FILE, "a", encoding="utf-8") as f:
            f.write(msg + "\n")
            f.flush()
            os.fsync(f.fileno())
    except: pass

class JanusUtils:
    @staticmethod
    def atomic_write(data: str, target_path: Path):
        temp_name = f"{uuid.uuid4()}.tmp"
        temp_path = CONFIG.TEMP_DIR / temp_name
        try:
            with open(temp_path, 'w', encoding='utf-8') as f:
                f.write(data)
                f.flush()
                os.fsync(f.fileno())
            shutil.move(str(temp_path), str(target_path))
        except Exception as e:
            if temp_path.exists(): os.remove(temp_path)
            raise e

    @staticmethod
    def load_manifest():
        if CONFIG.MANIFEST_FILE.exists():
            try:
                with open(CONFIG.MANIFEST_FILE, 'r') as f: return json.load(f)
            except: pass
        return {"last_idx": -1, "total_simple": 0, "total_complex": 0, "total_lines": 0}

    @staticmethod
    def update_manifest(idx, s_cnt, c_cnt, lines):
        state = {"last_idx": idx, "total_simple": s_cnt, "total_complex": c_cnt, "total_lines": lines, "ts": datetime.now().isoformat()}
        JanusUtils.atomic_write(json.dumps(state, indent=2), CONFIG.MANIFEST_FILE)

    @staticmethod
    def sanitize():
        gc.collect()
        if torch.cuda.is_available(): torch.cuda.empty_cache()

# --- 3. Data Cleaning (Robust + Aggressive) ---

def clean_line(line):
    """
    Aggressive cleaning for WikiText-103.
    Targets: Headers, Lists, Timecodes, Credits, and Fragments.
    """
    # 1. Unicode Normalization (NFKC)
    # Converts compatibility characters (like fancy fractions or ligatures) to standard ASCII/Unicode equivalents.
    line = unicodedata.normalize('NFKC', line)
    line = line.strip()

    # 2. Length Filter
    # "No." is 3 chars. We want to avoid single words or tiny fragments.
    if len(line) < 5: return None

    # 3. Aggressive Header Filter
    # Catches "= Header =", "==Header==", " = = Header = = "
    # Note: The previous .strip() removes the leading spaces from "= = Header = =", so startswith('=') works.
    if line.startswith("=") and line.endswith("="):
        return None

    # 4. List/Bullet Markers
    # Filters lines starting with common bullet points used in Wiki lists.
    if line.startswith("*") or line.startswith("‚Ä¢") or line.startswith("-"):
        return None

    # 5. Timecode/Tracklist Filter
    # Catches "10:30", "4 : 24", "3:16"
    if re.search(r'\d\s*:\s*\d', line):
        return None

    # 6. Credits / Definition Filter [NEW]
    # Catches "Joe Gastwirt ‚Äì mastering" or "Word ‚Äì Definition"
    # Matches: Start -> Text -> Spaced Dash -> Text -> End
    if re.match(r'^[\w\s\.]+\s+[‚Äì-]\s+[\w\s\.]+$', line):
        return None

    # 7. Fragment Detection (Density Check) [FIXED]
    # Fixed the string literal bug to correctly include quotes.
    # Lowered threshold to 15% to be slightly more aggressive on "garbage" lines.
    special_chars = set("‚Äì‚Äî:;()[]{}\"\'/|\\")
    special_char_count = sum(1 for c in line if c in special_chars)
    if len(line) > 0 and (special_char_count / len(line)) > 0.15:
        return None

    # 8. All-caps Filter
    # Good for shouting headers, but ensure it's long enough to not kill acronyms like "NASA".
    if line.isupper() and len(line) > 4:
        return None

    return line

# --- 4. Syntax Engine (L4 Accelerated & Fixed) ---

class SyntaxFilter:
    def __init__(self):
        log_event("Initializing Stanza (L4 Mode - FP32 GPU)...")
        stanza.download('en', model_dir=str(CONFIG.STANZA_DIR), processors='tokenize,pos,constituency', logging_level='WARN')
        self.nlp = stanza.Pipeline('en', dir=str(CONFIG.STANZA_DIR), processors='tokenize,pos,constituency',
                                   use_gpu=(CONFIG.DEVICE=='cuda'), pos_batch_size=5000, logging_level='ERROR')

    def _is_simple(self, sent):
        try:
            tree = sent.constituency
            if tree is None: return False
            s_str = str(tree)
            return (s_str.count("(S ") <= CONFIG.MAX_S_COUNT and
                    s_str.count("(SBAR ") == CONFIG.MAX_SBAR_COUNT and
                    s_str.count("(CC ") <= CONFIG.MAX_CC_COUNT and
                    len(sent.words) < CONFIG.MAX_WORD_LEN)
        except: return False

    def process_batch(self, lines: List[str]) -> List[Dict]:
        results = []
        if not lines: return results

        batch_sorted = sorted([(len(l), l) for l in lines], key=lambda x: x[0])
        texts = [x[1] for x in batch_sorted]

        for i in range(0, len(texts), CONFIG.STANZA_BATCH_SIZE):
            chunk = texts[i : i + CONFIG.STANZA_BATCH_SIZE]
            try:
                in_docs = [stanza.Document([], text=d) for d in chunk]

                # [CRITICAL FIX] Removed autocast - Stanza doesn't support mixed precision
                docs = self.nlp(in_docs)

                for j, doc in enumerate(docs):
                    for sent in doc.sentences:
                        cat = 'simple' if self._is_simple(sent) else 'complex'
                        results.append({'text': sent.text, 'category': cat})

            except Exception as e:
                # Fallback: Line-by-Line
                log_event(f"Batch fail ({str(e)[:50]}). Fallback active.", "WARN")
                JanusUtils.sanitize()
                for line in chunk:
                    try:
                        doc = self.nlp(line)
                        for sent in doc.sentences:
                            cat = 'simple' if self._is_simple(sent) else 'complex'
                            results.append({'text': sent.text, 'category': cat})
                    except: pass
        return results

# --- 5. Packer ---

class SequencePacker:
    def __init__(self, tokenizer, out_dir, prefix):
        self.tok = tokenizer
        self.out_dir = out_dir
        self.prefix = prefix
        self.buffer = []
        self.count = len(list(out_dir.glob(f"{prefix}_shard_*.txt")))

    def add(self, text):
        ids = self.tok.encode(text, add_special_tokens=False)
        eos = self.tok.encode(CONFIG.EOS_TOKEN, add_special_tokens=False)
        self.buffer.extend(ids + eos)

        while len(self.buffer) >= CONFIG.CONTEXT_WINDOW:
            shard_ids = self.buffer[:CONFIG.CONTEXT_WINDOW]
            self.buffer = self.buffer[CONFIG.CONTEXT_WINDOW:]
            out_txt = self.tok.decode(shard_ids)
            fname = f"{self.prefix}_shard_{self.count:06d}.txt"
            JanusUtils.atomic_write(out_txt, self.out_dir / fname)
            self.count += 1

    def finalize(self):
        if self.buffer:
            out_txt = self.tok.decode(self.buffer)
            fname = f"{self.prefix}_shard_{self.count:06d}.txt"
            JanusUtils.atomic_write(out_txt, self.out_dir / fname)
            self.buffer = []

# --- 6. Main Execution ---

def main():
    log_event("=== STARTING JANUS PLATINUM (L4 TURBO) ===")
    CONFIG.initialize_filesystem()

    tok = AutoTokenizer.from_pretrained('gpt2')
    ds = datasets.load_dataset('wikitext', 'wikitext-103-raw-v1', split='train', streaming=True)

    filt = SyntaxFilter()
    s_pack = SequencePacker(tok, CONFIG.SIMPLE_DIR, "simple")
    c_pack = SequencePacker(tok, CONFIG.COMPLEX_DIR, "complex")

    manifest = JanusUtils.load_manifest()
    start_idx = manifest['last_idx'] + 1
    total_lines = manifest['total_lines']

    log_event(f"Resuming at Index {start_idx}")
    if torch.cuda.is_available():
        log_event(f"GPU Active: {torch.cuda.get_device_name(0)}")

    buffer = []
    limit = CONFIG.PILOT_LIMIT if CONFIG.PILOT_MODE else 2000000
    pbar = tqdm(initial=start_idx, total=limit)

    try:
        for i, row in enumerate(ds):
            if i < start_idx: continue
            if CONFIG.PILOT_MODE and i >= start_idx + limit: break

            clean = clean_line(row['text'])
            if clean: buffer.append(clean)

            if len(buffer) >= CONFIG.SMART_BATCH_BUFFER_SIZE:
                results = filt.process_batch(buffer)
                for res in results:
                    if res['category'] == 'simple': s_pack.add(res['text'])
                    else: c_pack.add(res['text'])

                processed_lines = total_lines + len(buffer)
                JanusUtils.update_manifest(i, s_pack.count, c_pack.count, processed_lines)

                buffer = []
                pbar.update(CONFIG.SMART_BATCH_BUFFER_SIZE)
                JanusUtils.sanitize()

    except KeyboardInterrupt:
        log_event("Interrupt. Shutting down...")
    except Exception as e:
        log_event(f"FATAL: {e}", "ERROR")
        raise e
    finally:
        s_pack.finalize()
        c_pack.finalize()
        log_event("DONE.")

if __name__ == "__main__":
    main()

In [None]:
# @title [Write to Drive] SAE Modules
import os

# --- CONFIGURATION ---
# Ensure this matches your Drive structure
PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
MODULE_DIR = os.path.join(PROJECT_ROOT, "src/SAE")

print(f"üìÇ Creating Module Directory: {MODULE_DIR}")
os.makedirs(MODULE_DIR, exist_ok=True)
# Create an empty __init__.py to make it importable
with open(os.path.join(MODULE_DIR, "__init__.py"), "w") as f:
    f.write("")

# ==========================================
# 1. CONFIG MODULE (config.py)
# ==========================================
config_code = """
from dataclasses import dataclass

@dataclass
class SAEConfig:
    # Architecture
    head_dim: int = 64
    expansion_factor: int = 8

    # Training
    batch_size: int = 2048
    lr: float = 1e-3
    l1_coeff: float = 0.005  # The sparsity penalty
    steps: int = 2000 # Short runs for fast validation

    # Validation
    target_l0: float = 15.0 # We want ~15 active neurons per token

    @property
    def hidden_dim(self):
        return self.head_dim * self.expansion_factor
"""

with open(os.path.join(MODULE_DIR, "config.py"), "w") as f:
    f.write(config_code)
print("‚úÖ Written: config.py")

# ==========================================
# 2. MODEL MODULE (model.py)
# ==========================================
model_code = """
import torch
import torch.nn as nn
import torch.nn.functional as F

class HeadSAE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dim = config.head_dim
        self.hidden_dim = config.hidden_dim

        # 1. Learned Pre-Bias (Centering)
        # We subtract this before encoding to handle non-zero mean activations
        self.b_pre = nn.Parameter(torch.zeros(self.dim))

        # 2. Encoder (Untied Weights)
        # Projects to higher dimensional sparse space
        self.W_enc = nn.Linear(self.dim, self.hidden_dim, bias=True)

        # 3. Decoder (Untied Weights)
        # Reconstructs original signal.
        # Bias False because we use a separate explicit parameter for clarity
        self.W_dec = nn.Linear(self.hidden_dim, self.dim, bias=False)
        self.b_dec = nn.Parameter(torch.zeros(self.dim))

        # Initialization
        nn.init.kaiming_uniform_(self.W_enc.weight)
        nn.init.kaiming_uniform_(self.W_dec.weight)

        # Normalize decoder columns immediately
        self.normalize_decoder()

    def forward(self, x):
        # x shape: [Batch, head_dim]

        # 1. Pre-process
        x_centered = x - self.b_pre

        # 2. Encode
        # Relu ensures sparsity (most values become 0)
        acts = torch.relu(self.W_enc(x_centered))

        # 3. Decode
        recon = self.W_dec(acts) + self.b_dec

        return recon, acts

    @torch.no_grad()
    def normalize_decoder(self):
        # Constrain decoder columns to Unit Norm
        # This prevents the model from "cheating" by making features huge
        self.W_dec.weight.data = F.normalize(self.W_dec.weight.data, p=2, dim=0)
"""

with open(os.path.join(MODULE_DIR, "model.py"), "w") as f:
    f.write(model_code)
print("‚úÖ Written: model.py")

# ==========================================
# 3. HARVESTER MODULE (harvester.py)
# ==========================================
harvester_code = """
import torch
import os
import gc

class HeadHarvester:
    \"\"\"
    Surgically extracts attention head outputs BEFORE they are mixed.
    Hooks into the input of the Output Projection (o_proj) layer.
    \"\"\"
    def __init__(self, model, layers=[3, 6, 9], num_heads=8, head_dim=64):
        self.model = model
        self.layers = layers
        self.num_heads = num_heads
        self.head_dim = head_dim

        # Buffer: {layer_idx: [tensor_chunk1, tensor_chunk2]}
        self.buffer = {l: [] for l in layers}
        self.handles = []

    def _hook_fn(self, module, input, output, layer_idx):
        # The input to o_proj is the Concatenated Heads.
        # Shape: [Batch, Seq, n_heads * head_dim]
        # We need to detach immediately to save VRAM.
        mixed_heads = input[0].detach()
        B, S, _ = mixed_heads.shape

        # Reshape to separate heads: [Batch, Seq, 8, 64]
        separated = mixed_heads.view(B, S, self.num_heads, self.head_dim)

        # Flatten Batch and Seq: [Batch*Seq, 8, 64]
        # We use .half() (float16) to save disk space
        flat = separated.view(-1, self.num_heads, self.head_dim).cpu().half()

        self.buffer[layer_idx].append(flat)

    def attach(self):
        \"\"\"
        Attaches hooks to the specific layers in the Janus model.
        NOTE: Adjust 'self.model.blocks' if your architecture varies.
        \"\"\"
        for i in self.layers:
            # We target the o_proj layer. The *input* to this layer
            # is the raw output of the attention heads.
            try:
                # Try standard Janus/NewGPT structure
                target_module = self.model.blocks[i].attn.o_proj
            except AttributeError:
                # Fallback for standard HF Llama
                target_module = self.model.model.layers[i].self_attn.o_proj

            handle = target_module.register_forward_hook(
                lambda m, inp, out, idx=i: self._hook_fn(m, inp, out, idx)
            )
            self.handles.append(handle)
            print(f"ü™ù Hook attached to Layer {i} Output Projection")

    def save(self, directory, prefix):
        os.makedirs(directory, exist_ok=True)
        for layer, chunks in self.buffer.items():
            if not chunks: continue

            # Concatenate all chunks
            data = torch.cat(chunks, dim=0) # [Total_Tokens, 8, 64]

            fname = f"{prefix}_L{layer}_heads.pt"
            save_path = os.path.join(directory, fname)
            torch.save(data, save_path)
            print(f"üíæ Saved {fname}: {data.shape}")

        # Clear buffer to free RAM
        self.buffer = {l: [] for l in self.layers}
        gc.collect()

    def detach(self):
        for h in self.handles: h.remove()
        self.handles = []
        print("ü™ù Hooks removed.")
"""

with open(os.path.join(MODULE_DIR, "harvester.py"), "w") as f:
    f.write(harvester_code)
print("‚úÖ Written: harvester.py")

# ==========================================
# 4. TRAINER MODULE (trainer.py)
# ==========================================
trainer_code = """
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from .model import HeadSAE

def train_sae_on_head(activations, config, device="cuda"):
    \"\"\"
    Trains a single SAE on a specific head's activations.
    activations: Tensor [N_Tokens, 64]
    \"\"\"
    sae = HeadSAE(config).to(device)
    optimizer = optim.Adam(sae.parameters(), lr=config.lr)

    # Create Loader
    dataset = TensorDataset(activations)
    loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)

    final_mse = 0.0
    final_l0 = 0.0

    # Training Loop
    sae.train()
    iter_loader = iter(loader)

    # We define steps rather than epochs for consistency
    for step in range(config.steps):
        try:
            batch = next(iter_loader)[0]
        except StopIteration:
            iter_loader = iter(loader)
            batch = next(iter_loader)[0]

        batch = batch.to(device).float() # Ensure float32 for training

        # Forward
        recon, acts = sae(batch)

        # Losses
        mse = F.mse_loss(recon, batch)
        # L1 penalty on activations (sum over features, mean over batch)
        l1 = acts.sum(dim=1).mean()

        loss = mse + (config.l1_coeff * l1)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Constraint: Unit Norm Decoder
        sae.normalize_decoder()

        # Metrics (Moving average or just last step)
        if step % 100 == 0:
            with torch.no_grad():
                l0 = (acts > 0).float().sum(dim=1).mean()
                # print(f"Step {step}: MSE={mse.item():.5f} L0={l0.item():.1f}")

        final_mse = mse.item()
        final_l0 = (acts > 0).float().sum(dim=1).mean().item()

    return sae, final_mse, final_l0
"""

with open(os.path.join(MODULE_DIR, "trainer.py"), "w") as f:
    f.write(trainer_code)
print("‚úÖ Written: trainer.py")

print("\nüöÄ SAE Modules successfully installed to Drive.")
print("‚ÑπÔ∏è  You can now use: 'sys.path.append(PROJECT_ROOT)' followed by 'from src.SAE import ...'")

In [None]:
# @title [Run] Control (No Adaptive Scheduler) SAE 4k steps


import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY NUKE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
CHUNKS_DIR = os.path.join(PROJECT_ROOT, "data/wikitext_chunks")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_SAE_Control_v1")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üß† JANUS CONTROL (No Adaptive Scheduler)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 3. CONFIGURATION ---
class ControlConfig:
    def __init__(self):
        # Model
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        # Training
        self.max_steps = 4005
        self.batch_size = 16
        self.grad_accum = 4

        # Control: Fixed lambda (no adaptation)
        self.fixed_lambda = 0.0  # NO steering pressure

        # IO
        self.ckpt_interval = 2000
        self.chunk_steps = 500
        self.lite_interval = 100

# --- 4. ARCHITECTURE (Standard Janus v3) ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, lambdas, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        steer_loss = 0.0
        l_coh, l_div = lambdas

        # Calculate steering if needed (though lambda should be 0)
        if (l_div > 0.0 and self.training) or return_metrics:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)

            if l_div > 0.0:
                steer_loss += torch.norm(gram - identity, p='fro') * l_div

        metrics = {}
        if return_metrics:
            with torch.no_grad():
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, lambdas, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), lambdas, return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, lambdas_list, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)
        total_steer = 0.0
        all_metrics = []
        for i, block in enumerate(self.blocks):
            x, s, m = block(x, lambdas_list[i], return_metrics)
            total_steer += s
            if return_metrics: all_metrics.append(m)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss, total_steer, all_metrics

# --- 5. DATA & LOGGING ---
class ChunkLoader:
    def __init__(self, config):
        self.config = config
        self.current_chunk_idx = -1
        self.data = None
        self.val_data = np.memmap(os.path.join(PROJECT_ROOT, "data/wikitext/val.bin"), dtype=np.uint16, mode='r')
    def load_for_step(self, step):
        target_chunk = step // self.config.chunk_steps
        if target_chunk != self.current_chunk_idx:
            fname = f"train_chunk_{target_chunk:03d}.bin"
            fpath = os.path.join(CHUNKS_DIR, fname)
            if not os.path.exists(fpath):
                fallback_idx = target_chunk % 40
                fpath = os.path.join(CHUNKS_DIR, f"train_chunk_{fallback_idx:03d}.bin")
            with open(fpath, 'rb') as f:
                self.data = np.frombuffer(f.read(), dtype=np.uint16)
            self.current_chunk_idx = target_chunk
    def get_batch(self, batch_size):
        ix = torch.randint(len(self.data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)
    def get_val_batch(self, batch_size):
        ix = torch.randint(len(self.val_data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.val_data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.val_data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
        self.start_step = 0
    def set_start_step(self, step):
        self.start_step = step
    def log(self, step, loss, val_loss, pressure, metrics):
        if step < self.start_step: return

        # Calculate mean sigma_a across layers for logging
        avg_sigma = 0.0
        if metrics:
            sigmas = [m['sigma_a'].mean().item() for m in metrics if 'sigma_a' in m]
            if sigmas: avg_sigma = sum(sigmas) / len(sigmas)

        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": pressure,
            "sigma_a_avg": avg_sigma
        }
        self.buffer.append(row)
    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_control.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 6. RUN LOOP ---
def run_control():
    gc.collect(); torch.cuda.empty_cache()

    cfg = ControlConfig()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6.29e-4, weight_decay=1.34e-4)
    loader = ChunkLoader(cfg)
    recorder = BlackBox(SAVE_DIR)

    start_step = 0

    # Check for existing control checkpoints
    ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_control_step_") and f.endswith(".pt")]
    if ckpts:
        ckpts.sort(key=lambda x: int(x.split('_')[3].split('.')[0]))
        latest = ckpts[-1]
        ckpt_path = os.path.join(SAVE_DIR, latest)
        print(f"üîÑ Resuming Control Run from {latest}...")
        c = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(c['model'])
        optimizer.load_state_dict(c['optim'])
        start_step = c['step'] + 1
        recorder.set_start_step(start_step)

    # Set fixed lambda (0.0 for control = no steering)
    fixed_lambda = cfg.fixed_lambda
    lambdas_list = [(0.0, 0.0)] * cfg.n_layers  # All layers: no steering

    print(f"\nüß† STARTING CONTROL RUN: {start_step} -> {cfg.max_steps}")
    print(f"   Fixed Lambda: {fixed_lambda} (No Adaptive Scheduler)")

    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    for step in pbar:
        loader.load_for_step(step)

        # --- A. TRAINING STEP ---
        model.train()
        batch_loss = 0.0
        optimizer.zero_grad()

        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch(cfg.batch_size)

            is_log_step = (step % cfg.lite_interval == 0)
            return_metrics = is_log_step

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, steer, metrics = model(x, lambdas_list, y, return_metrics=return_metrics)
                total = (loss + steer) / cfg.grad_accum

            total.backward()
            batch_loss += loss.item()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # --- B. TELEMETRY ---
        if step > 0 and (step % cfg.lite_interval == 0 or step == cfg.max_steps - 1):
            gc.collect(); torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(10):
                    vx, vy = loader.get_val_batch(cfg.batch_size)
                    vl, _, _ = model(vx, lambdas_list, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)

            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, fixed_lambda, metrics)
            recorder.flush()

            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.2f}|V:{val_loss:.2f}|Control")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.2f}|Control")

        # --- C. CHECKPOINTS ---
        if step > 0 and step % cfg.ckpt_interval == 0:
            ckpt_name = f"ckpt_control_step_{step}.pt"
            ckpt_path = os.path.join(SAVE_DIR, ckpt_name)
            save_dict = {
                'step': step,
                'model': model.state_dict(),
                'optim': optimizer.state_dict(),
                'rng_cpu': torch.get_rng_state(),
                'rng_gpu': torch.cuda.get_rng_state()
            }
            torch.save(save_dict, ckpt_path)

            all_ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_control_step_") and f.endswith(".pt")]
            all_ckpts.sort(key=lambda x: int(x.split('_')[3].split('.')[0]))
            if len(all_ckpts) > 3: os.remove(os.path.join(SAVE_DIR, all_ckpts[0]))

    final_path = os.path.join(SAVE_DIR, "janus_control_SAE.pt")
    torch.save(model.state_dict(), final_path)
    print(f"\nüèÜ CONTROL RUN COMPLETE. Saved to {final_path}")

if __name__ == "__main__":
    run_control()

In [None]:
# @title [Run] Adaptive Scheduler SAE 4K steps


import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
import pandas as pd
import gc
from tqdm import tqdm
from google.colab import drive

# --- 1. MEMORY NUKE ---
gc.collect()
torch.cuda.empty_cache()

# --- 2. SETUP ---
if not os.path.exists('/content/drive'):
    drive.mount('/content/drive')

PROJECT_ROOT = "/content/drive/MyDrive/Project_XAI_Physical_Janus"
CHUNKS_DIR = os.path.join(PROJECT_ROOT, "data/wikitext_chunks")
SAVE_DIR = os.path.join(PROJECT_ROOT, "data/models/janus_SAE_adaptive_v1")
os.makedirs(SAVE_DIR, exist_ok=True)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üß† JANUS ADAPTIVE (Homeostatic Control)")
print(f"‚öôÔ∏è Hardware: {DEVICE}")

# --- 3. CONFIGURATION ---
class AdaptiveConfig:
    def __init__(self):
        # Model
        self.vocab_size = 50257
        self.d_model = 512
        self.n_layers = 12
        self.n_heads = 8
        self.d_head = 64
        self.max_seq_len = 512
        self.dropout = 0.05

        # Training
        self.max_steps = 4005
        self.batch_size = 16
        self.grad_accum = 4

        # Adaptive Controller
        self.target_sigma = 0.0035
        self.k_p = 0.5  # Gain
        self.control_interval = 50 # Update every 50 steps
        self.warmup_steps = 1500   # Ignition phase

        # IO
        self.ckpt_interval = 2000
        self.chunk_steps = 500
        self.lite_interval = 100

# --- 4. THE HOMEOSTATIC CONTROLLER ---
class Homeostat:
    def __init__(self, config):
        self.target = config.target_sigma
        self.kp = config.k_p
        self.warmup = config.warmup_steps
        self.current_lambda = 0.0
        self.history = []

    def update(self, step, current_sigma):
        # Phase 1: Ignition (Open Loop Ramp)
        if step < self.warmup:
            # Ramp from 0.00 to 0.05
            self.current_lambda = 0.05 * (step / self.warmup)
            return self.current_lambda

        # Phase 2: Homeostasis (Closed Loop P-Control)
        if current_sigma is None: return self.current_lambda

        error = current_sigma - self.target
        # CORRECTION: Positive Error (Too Redundant) -> INCREASE Pressure
        delta = self.kp * error

        # Apply & Clamp
        self.current_lambda += delta
        # Hard limits to prevent explosion or negative pressure
        self.current_lambda = max(0.0, min(0.25, self.current_lambda))

        return self.current_lambda

# --- 5. ARCHITECTURE (Standard Janus v3) ---
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    def forward(self, x):
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class SwiGLU(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        hidden_dim = int(d_model * 8 / 3)
        self.w1 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(d_model, hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, d_model, bias=False)
    def forward(self, x):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

class NewAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_heads = config.n_heads
        self.d_head = config.d_head
        self.scale = 1.0 / math.sqrt(self.d_head)

        self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False)
        self.o_proj = nn.Linear(config.d_model, config.d_model, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.register_buffer("freqs_cis", self.precompute_freqs_cis(config.d_head, config.max_seq_len))

    def precompute_freqs_cis(self, dim, end, theta=10000.0):
        freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
        t = torch.arange(end, device=freqs.device)
        freqs = torch.outer(t, freqs).float()
        freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
        return freqs_cis

    def apply_rope(self, x, freqs_cis):
        x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
        freqs = freqs_cis[:x.shape[1]].view(1, x.shape[1], 1, -1)
        x_out = torch.view_as_real(x_c * freqs).flatten(3)
        return x_out.type_as(x)

    def forward(self, x, lambdas, return_metrics=False):
        B, S, D = x.shape
        q = self.q_proj(x).view(B, S, self.n_heads, self.d_head)
        k = self.k_proj(x).view(B, S, self.n_heads, self.d_head)
        v = self.v_proj(x).view(B, S, self.n_heads, self.d_head)

        q = self.apply_rope(q, self.freqs_cis)
        k = self.apply_rope(k, self.freqs_cis)
        q = q.transpose(1, 2); k = k.transpose(1, 2); v = v.transpose(1, 2)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        mask = torch.tril(torch.ones(S, S, device=x.device)).view(1, 1, S, S)
        attn = attn.masked_fill(mask == 0, float('-inf'))
        attn_probs = F.softmax(attn, dim=-1)
        attn_probs = self.dropout(attn_probs)
        head_out = attn_probs @ v

        steer_loss = 0.0
        l_coh, l_div = lambdas

        # Calculate steering if needed
        # We need this during training AND during control steps
        if (l_div > 0.0 and self.training) or return_metrics:
            flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
            norm = F.normalize(flat, p=2, dim=1)
            gram = torch.mm(norm, norm.t())
            identity = torch.eye(self.n_heads, device=x.device)

            if l_div > 0.0:
                steer_loss += torch.norm(gram - identity, p='fro') * l_div

        metrics = {}
        if return_metrics:
            with torch.no_grad():
                # Re-calc gram for metrics just to be safe/clean
                flat = head_out.transpose(0, 1).reshape(self.n_heads, -1)
                norm = F.normalize(flat, p=2, dim=1)
                sim = torch.mm(norm, norm.t())
                mask_diag = ~torch.eye(self.n_heads, dtype=torch.bool, device=x.device)
                metrics['sigma_a'] = (sim.abs() * mask_diag.float()).sum(dim=1) / (self.n_heads - 1)

        out = head_out.transpose(1, 2).reshape(B, S, D)
        return self.o_proj(out), steer_loss, metrics

class NewBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln1 = RMSNorm(config.d_model)
        self.attn = NewAttention(config)
        self.ln2 = RMSNorm(config.d_model)
        self.mlp = SwiGLU(config.d_model)
    def forward(self, x, lambdas, return_metrics=False):
        a, s, m = self.attn(self.ln1(x), lambdas, return_metrics)
        x = x + a
        x = x + self.mlp(self.ln2(x))
        return x, s, m

class NewGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([NewBlock(config) for _ in range(config.n_layers)])
        self.ln_f = RMSNorm(config.d_model)
        self.head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.token_emb.weight = self.head.weight
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        for name, p in module.named_parameters():
            if "o_proj.weight" in name or "w3.weight" in name:
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layers))

    def forward(self, idx, lambdas_list, targets=None, return_metrics=False):
        B, S = idx.shape
        x = self.token_emb(idx)
        total_steer = 0.0
        all_metrics = []
        for i, block in enumerate(self.blocks):
            x, s, m = block(x, lambdas_list[i], return_metrics)
            total_steer += s
            if return_metrics: all_metrics.append(m)
        x = self.ln_f(x)
        logits = self.head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
        return loss, total_steer, all_metrics

# --- 6. DATA & LOGGING ---
class ChunkLoader:
    def __init__(self, config):
        self.config = config
        self.current_chunk_idx = -1
        self.data = None
        self.val_data = np.memmap(os.path.join(PROJECT_ROOT, "data/wikitext/val.bin"), dtype=np.uint16, mode='r')
    def load_for_step(self, step):
        target_chunk = step // self.config.chunk_steps
        if target_chunk != self.current_chunk_idx:
            fname = f"train_chunk_{target_chunk:03d}.bin"
            fpath = os.path.join(CHUNKS_DIR, fname)
            if not os.path.exists(fpath):
                fallback_idx = target_chunk % 40
                fpath = os.path.join(CHUNKS_DIR, f"train_chunk_{fallback_idx:03d}.bin")
            with open(fpath, 'rb') as f:
                self.data = np.frombuffer(f.read(), dtype=np.uint16)
            self.current_chunk_idx = target_chunk
    def get_batch(self, batch_size):
        ix = torch.randint(len(self.data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)
    def get_val_batch(self, batch_size):
        ix = torch.randint(len(self.val_data) - self.config.max_seq_len, (batch_size,))
        x = torch.stack([torch.from_numpy(self.val_data[i:i+self.config.max_seq_len].astype(np.int64)) for i in ix])
        y = torch.stack([torch.from_numpy(self.val_data[i+1:i+1+self.config.max_seq_len].astype(np.int64)) for i in ix])
        return x.to(DEVICE), y.to(DEVICE)

class BlackBox:
    def __init__(self, save_dir):
        self.buffer = []
        self.save_dir = save_dir
        self.start_step = 0
    def set_start_step(self, step):
        self.start_step = step
    def log(self, step, loss, val_loss, pressure, metrics):
        if step < self.start_step: return

        # Calculate mean sigma_a across layers for logging
        avg_sigma = 0.0
        if metrics:
            sigmas = [m['sigma_a'].mean().item() for m in metrics if 'sigma_a' in m]
            if sigmas: avg_sigma = sum(sigmas) / len(sigmas)

        row = {
            "step": step, "loss": loss, "val_loss": val_loss,
            "perplexity": math.exp(val_loss) if val_loss < 20 else 0.0,
            "pressure": pressure,
            "sigma_a_avg": avg_sigma
        }
        self.buffer.append(row)
    def flush(self):
        if not self.buffer: return
        df = pd.DataFrame(self.buffer)
        fpath = os.path.join(self.save_dir, "telemetry_adaptive.parquet")
        if os.path.exists(fpath):
            try:
                existing = pd.read_parquet(fpath)
                df = pd.concat([existing, df])
            except: pass
        df.to_parquet(fpath)
        self.buffer = []
        print("üíæ Telemetry Flushed.")

# --- 7. RUN LOOP ---
def run_adaptive():
    gc.collect(); torch.cuda.empty_cache()

    cfg = AdaptiveConfig()
    model = NewGPT(cfg).to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=6.29e-4, weight_decay=1.34e-4)
    controller = Homeostat(cfg)
    loader = ChunkLoader(cfg)
    recorder = BlackBox(SAVE_DIR)

    start_step = 0
    # START FRESH OR RESUME ADAPTIVE
    # We do NOT recommend grafting 4k Constant here,
    # better to start fresh 0-1500 warm-up to test the whole curve cleanly.

    ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]
    if ckpts:
        ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))
        latest = ckpts[-1]
        ckpt_path = os.path.join(SAVE_DIR, latest)
        print(f"üîÑ Resuming Adaptive Run from {latest}...")
        c = torch.load(ckpt_path, map_location=DEVICE)
        model.load_state_dict(c['model'])
        optimizer.load_state_dict(c['optim'])
        # Restore Controller State
        controller.current_lambda = c.get('current_lambda', 0.0)
        start_step = c['step'] + 1
        recorder.set_start_step(start_step)

    print(f"\nüß† STARTING ADAPTIVE RUN: {start_step} -> {cfg.max_steps}")
    print(f"   Target Sigma: {cfg.target_sigma}")
    print(f"   Warmup Steps: {cfg.warmup_steps}")

    pbar = tqdm(range(start_step, cfg.max_steps), initial=start_step, total=cfg.max_steps)

    current_sigma_avg = None

    for step in pbar:
        loader.load_for_step(step)

        # --- A. CONTROL UPDATE ---
        # Update lambda based on LAST step's sigma metrics
        if step % cfg.control_interval == 0 or step < cfg.warmup_steps:
             new_lambda = controller.update(step, current_sigma_avg)

        # Apply lambda to layers
        # Simplified: Same lambda for div, 20% of that for coh
        p = controller.current_lambda
        base_coh = p * 0.2
        lambdas_list = []
        for i in range(cfg.n_layers):
            # Ratio scaling still applies? Yes, keeps layers distinct.
            ratio = (i + 1) / cfg.n_layers
            s_mult = ratio ** 3
            lambdas_list.append((base_coh * s_mult, p * s_mult))

        # --- B. TRAINING STEP ---
        model.train()
        batch_loss = 0.0
        optimizer.zero_grad()

        accum_sigmas = [] # Track sigma during accum for controller

        for _ in range(cfg.grad_accum):
            x, y = loader.get_batch(cfg.batch_size)

            # Check if we need metrics for logging OR control
            # We need metrics every 'control_interval' to feed the controller next step
            is_control_step = (step % cfg.control_interval == 0)
            is_log_step = (step % cfg.lite_interval == 0)
            return_metrics = is_control_step or is_log_step

            with torch.amp.autocast('cuda', dtype=torch.bfloat16):
                loss, steer, metrics = model(x, lambdas_list, y, return_metrics=return_metrics)
                total = (loss + steer) / cfg.grad_accum

            total.backward()
            batch_loss += loss.item()

            if return_metrics and metrics:
                # Extract mean sigma for controller
                s = [m['sigma_a'].mean().item() for m in metrics]
                if s: accum_sigmas.append(sum(s)/len(s))

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Update Sigma Average for Next Control Step
        if accum_sigmas:
            current_sigma_avg = sum(accum_sigmas) / len(accum_sigmas)

        # --- C. TELEMETRY ---
        if step > 0 and (step % cfg.lite_interval == 0 or step == cfg.max_steps - 1):
            gc.collect(); torch.cuda.empty_cache()
            model.eval()
            with torch.no_grad():
                v_losses = []
                for _ in range(10): # Reduced val batches for speed
                    vx, vy = loader.get_val_batch(cfg.batch_size)
                    vl, _, _ = model(vx, [(0.0,0.0)]*cfg.n_layers, vy, return_metrics=False)
                    v_losses.append(vl.item())
                val_loss = np.mean(v_losses)

            recorder.log(step, batch_loss/cfg.grad_accum, val_loss, controller.current_lambda, metrics)
            recorder.flush()

            # Update Progress Bar with Adaptive Info
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.2f}|V:{val_loss:.2f}|P:{controller.current_lambda:.3f}|Sig:{current_sigma_avg:.4f}")
        else:
            pbar.set_description(f"L:{batch_loss/cfg.grad_accum:.2f}|P:{controller.current_lambda:.3f}|Sig:{current_sigma_avg if current_sigma_avg else 0:.4f}")

        # --- D. CHECKPOINTS ---
        if step > 0 and step % cfg.ckpt_interval == 0:
            ckpt_name = f"ckpt_step_{step}.pt"
            ckpt_path = os.path.join(SAVE_DIR, ckpt_name)
            save_dict = {
                'step': step,
                'model': model.state_dict(),
                'optim': optimizer.state_dict(),
                'current_lambda': controller.current_lambda, # Save controller state
                'rng_cpu': torch.get_rng_state(),
                'rng_gpu': torch.cuda.get_rng_state()
            }
            torch.save(save_dict, ckpt_path)

            all_ckpts = [f for f in os.listdir(SAVE_DIR) if f.startswith("ckpt_step_") and f.endswith(".pt")]
            all_ckpts.sort(key=lambda x: int(x.split('_')[2].split('.')[0]))
            if len(all_ckpts) > 3: os.remove(os.path.join(SAVE_DIR, all_ckpts[0]))

    final_path = os.path.join(SAVE_DIR, "janus_adaptive_final.pt")
    torch.save(model.state_dict(), final_path)
    print(f"\nüèÜ ADAPTIVE RUN COMPLETE. Saved to {final_path}")

if __name__ == "__main__":
    run_adaptive()