# CODI Model vs NLP Chain-of-Thought Comparison

This notebook demonstrates:
1. **CODI Model**: Compresses chain-of-thought into latent continuous space
2. **NLP-based Model**: Standard model with explicit natural language chain-of-thought

We'll extract hidden activations and latent/NLP traces from both.

In [None]:
import sys
sys.path.insert(0, 'codi')

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, TaskType, get_peft_model
from huggingface_hub import hf_hub_download
import os

from src.model import CODI, ModelArguments, TrainingArguments

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Setup: Download CODI checkpoint and configure models

In [None]:
# Download CODI-gpt2 checkpoint from HuggingFace
ckpt_dir = "codi_gpt2_ckpt"
os.makedirs(ckpt_dir, exist_ok=True)

try:
    ckpt_path = hf_hub_download(
        repo_id="zen-E/CODI-gpt2",
        filename="model.safetensors",
        local_dir=ckpt_dir
    )
    print(f"Downloaded checkpoint to: {ckpt_path}")
except Exception as e:
    print(f"Could not download checkpoint: {e}")
    print("Will use randomly initialized model for demonstration")
    ckpt_path = None

In [None]:
# Configure model arguments
model_args = ModelArguments(
    model_name_or_path="gpt2",
    lora_r=128,
    lora_alpha=32,
    lora_init=True,
    train=False,
)

# Create a minimal training args for model initialization
class MinimalTrainingArgs:
    bf16 = True
    num_latent = 6
    use_lora = True
    use_prj = True
    prj_dim = 768
    prj_dropout = 0.0
    prj_no_ln = False
    restore_from = ""
    inf_latent_iterations = 6
    remove_eos = True
    fix_attn_mask = False
    print_loss = False
    distill_loss_div_std = False
    distill_loss_type = "smooth_l1"
    distill_loss_factor = 1.0
    ref_loss_factor = 1.0

training_args = MinimalTrainingArgs()

## Load CODI Model (Latent Thought Model)

In [None]:
# Create LoRA config
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=True,
    r=model_args.lora_r,
    lora_alpha=model_args.lora_alpha,
    lora_dropout=0.1,
    target_modules=["c_attn", "c_proj", "c_fc"],
    init_lora_weights=True,
)

# Initialize CODI model
codi_model = CODI(model_args, training_args, lora_config)

# Load checkpoint if available
if ckpt_path and os.path.exists(ckpt_path):
    from safetensors.torch import load_file
    state_dict = load_file(ckpt_path)
    codi_model.load_state_dict(state_dict, strict=False)
    print("Loaded CODI checkpoint")
else:
    print("Using randomly initialized CODI model")

codi_model = codi_model.to(device)
codi_model.eval()
print(f"CODI model loaded on {device}")

## Load NLP-based Model (Explicit Chain-of-Thought)

The NLP cousin uses the same base GPT-2 but processes explicit natural language chain-of-thought instead of latent thoughts.

In [None]:
# Load base GPT-2 for NLP chain-of-thought (no latent compression)
nlp_model = AutoModelForCausalLM.from_pretrained(
    "gpt2",
    torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32,
)
nlp_tokenizer = AutoTokenizer.from_pretrained("gpt2")
nlp_tokenizer.pad_token = nlp_tokenizer.eos_token

nlp_model = nlp_model.to(device)
nlp_model.eval()
print(f"NLP model loaded on {device}")

## Define Query and Helper Functions

In [None]:
# Sample math query (GSM8K style)
query = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning. How many eggs does she have left each day?"

# Chain-of-thought prompt for NLP model
cot_prompt = f"""Question: {query}

Let me solve this step by step:
Step 1: Janet's ducks lay 16 eggs per day.
Step 2: She eats 3 eggs for breakfast.
Step 3: Eggs left = 16 - 3 = 13

The answer is 13."""

print("Query:", query)
print("\nCoT Prompt for NLP model:")
print(cot_prompt)

In [None]:
def extract_hidden_states_summary(hidden_states):
    """Extract summary statistics from hidden states."""
    summary = {
        "num_layers": len(hidden_states),
        "shape_per_layer": hidden_states[0].shape,
        "last_layer_mean": hidden_states[-1].mean().item(),
        "last_layer_std": hidden_states[-1].std().item(),
        "last_token_embedding": hidden_states[-1][:, -1, :].detach().cpu(),
    }
    return summary

def decode_latent_to_tokens(model, latent_embd, tokenizer, top_k=5):
    """Probe what tokens the latent embedding represents."""
    with torch.no_grad():
        # Get logits by passing through lm_head
        if hasattr(model, 'codi'):
            logits = model.codi.lm_head(latent_embd)
        else:
            logits = model.lm_head(latent_embd)
        probs = torch.softmax(logits, dim=-1)
        top_probs, top_indices = torch.topk(probs, k=top_k, dim=-1)
        
        decoded = []
        for i in range(top_indices.shape[1]):
            tokens = [tokenizer.decode([idx.item()]) for idx in top_indices[0, i]]
            probs_list = top_probs[0, i].tolist()
            decoded.append(list(zip(tokens, probs_list)))
    return decoded

## Run CODI Model: Extract Latent Traces and Hidden Activations

In [None]:
# Tokenize input for CODI
codi_tokenizer = codi_model.tokenizer
inputs = codi_tokenizer(query, return_tensors="pt").to(device)

# Add BOT (beginning of thought) token
if training_args.remove_eos:
    bot_tensor = torch.tensor([[codi_model.bot_id]], dtype=torch.long, device=device)
else:
    bot_tensor = torch.tensor([[codi_tokenizer.eos_token_id, codi_model.bot_id]], dtype=torch.long, device=device)

input_ids = torch.cat([inputs["input_ids"], bot_tensor], dim=1)
attention_mask = torch.cat([inputs["attention_mask"], torch.ones_like(bot_tensor)], dim=1)

print(f"Input tokens: {codi_tokenizer.decode(input_ids[0])}")
print(f"Input shape: {input_ids.shape}")

In [None]:
# Run CODI model and collect latent traces
codi_latent_traces = []
codi_hidden_activations = []
codi_decoded_latents = []

with torch.no_grad():
    # Initial encoding
    outputs = codi_model.codi(
        input_ids=input_ids,
        attention_mask=attention_mask,
        use_cache=True,
        output_hidden_states=True
    )
    past_key_values = outputs.past_key_values
    
    # Extract initial hidden state (last token = BOT position)
    latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1)
    codi_latent_traces.append(latent_embd.clone())
    codi_hidden_activations.append(extract_hidden_states_summary(outputs.hidden_states))
    codi_decoded_latents.append(decode_latent_to_tokens(codi_model, latent_embd, codi_tokenizer))
    
    # Apply projection if used
    if training_args.use_prj:
        latent_embd = codi_model.prj(latent_embd)
    
    # Iterate through latent thought steps
    for i in range(training_args.inf_latent_iterations):
        outputs = codi_model.codi(
            inputs_embeds=latent_embd,
            use_cache=True,
            output_hidden_states=True,
            past_key_values=past_key_values
        )
        past_key_values = outputs.past_key_values
        latent_embd = outputs.hidden_states[-1][:, -1, :].unsqueeze(1)
        
        # Store traces before projection
        codi_latent_traces.append(latent_embd.clone())
        codi_hidden_activations.append(extract_hidden_states_summary(outputs.hidden_states))
        codi_decoded_latents.append(decode_latent_to_tokens(codi_model, latent_embd, codi_tokenizer))
        
        if training_args.use_prj:
            latent_embd = codi_model.prj(latent_embd)

print(f"Collected {len(codi_latent_traces)} latent traces")
print(f"Latent embedding shape: {codi_latent_traces[0].shape}")

In [None]:
# Display CODI latent traces (decoded to most likely tokens)
print("=" * 60)
print("CODI LATENT TRACES (Decoded to top-5 tokens)")
print("=" * 60)

for i, decoded in enumerate(codi_decoded_latents):
    step_name = "Initial (BOT)" if i == 0 else f"Latent Step {i}"
    print(f"\n{step_name}:")
    print(f"  Top tokens: {decoded[0][:5]}")

In [None]:
# Display CODI hidden activation statistics
print("=" * 60)
print("CODI HIDDEN ACTIVATIONS")
print("=" * 60)

for i, activation in enumerate(codi_hidden_activations):
    step_name = "Initial" if i == 0 else f"Latent Step {i}"
    print(f"\n{step_name}:")
    print(f"  Num layers: {activation['num_layers']}")
    print(f"  Last layer mean: {activation['last_layer_mean']:.4f}")
    print(f"  Last layer std: {activation['last_layer_std']:.4f}")
    print(f"  Last token embedding norm: {activation['last_token_embedding'].norm().item():.4f}")

## Run NLP Model: Extract Chain-of-Thought Traces and Hidden Activations

In [None]:
# Tokenize CoT prompt for NLP model
nlp_inputs = nlp_tokenizer(cot_prompt, return_tensors="pt").to(device)

print(f"NLP input length: {nlp_inputs['input_ids'].shape[1]} tokens")

In [None]:
# Run NLP model with explicit CoT
with torch.no_grad():
    nlp_outputs = nlp_model(
        input_ids=nlp_inputs["input_ids"],
        attention_mask=nlp_inputs["attention_mask"],
        output_hidden_states=True
    )

nlp_hidden_summary = extract_hidden_states_summary(nlp_outputs.hidden_states)
print(f"NLP model hidden states collected")
print(f"  Num layers: {nlp_hidden_summary['num_layers']}")
print(f"  Shape per layer: {nlp_hidden_summary['shape_per_layer']}")

In [None]:
# Extract NLP traces at key positions (each CoT step)
cot_steps = [
    "Question:",
    "Step 1:",
    "Step 2:",
    "Step 3:",
    "The answer is"
]

nlp_traces = []
tokens = nlp_tokenizer.encode(cot_prompt)

print("=" * 60)
print("NLP CHAIN-OF-THOUGHT TRACES")
print("=" * 60)

for step in cot_steps:
    # Find position of this step in the tokenized input
    step_tokens = nlp_tokenizer.encode(step, add_special_tokens=False)
    
    # Search for the step in the token sequence
    for pos in range(len(tokens) - len(step_tokens) + 1):
        if tokens[pos:pos+len(step_tokens)] == step_tokens:
            # Get hidden state at this position
            hidden = nlp_outputs.hidden_states[-1][0, pos + len(step_tokens) - 1, :]
            nlp_traces.append({
                "step": step,
                "position": pos,
                "hidden_norm": hidden.norm().item(),
                "hidden_mean": hidden.mean().item(),
                "hidden_std": hidden.std().item(),
                "embedding": hidden.cpu()
            })
            print(f"\n{step} (position {pos}):")
            print(f"  Hidden norm: {hidden.norm().item():.4f}")
            print(f"  Hidden mean: {hidden.mean().item():.4f}")
            print(f"  Hidden std: {hidden.std().item():.4f}")
            break

## Compare CODI vs NLP Hidden Activations

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Compare latent norms across steps
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# CODI latent norms
codi_norms = [trace.norm().item() for trace in codi_latent_traces]
axes[0].bar(range(len(codi_norms)), codi_norms, color='blue', alpha=0.7)
axes[0].set_xlabel('Latent Step')
axes[0].set_ylabel('Embedding Norm')
axes[0].set_title('CODI: Latent Embedding Norms')
axes[0].set_xticks(range(len(codi_norms)))
axes[0].set_xticklabels(['Init'] + [f'L{i+1}' for i in range(len(codi_norms)-1)])

# NLP CoT norms
if nlp_traces:
    nlp_norms = [t['hidden_norm'] for t in nlp_traces]
    nlp_labels = [t['step'][:10] for t in nlp_traces]
    axes[1].bar(range(len(nlp_norms)), nlp_norms, color='green', alpha=0.7)
    axes[1].set_xlabel('CoT Step')
    axes[1].set_ylabel('Embedding Norm')
    axes[1].set_title('NLP: Chain-of-Thought Embedding Norms')
    axes[1].set_xticks(range(len(nlp_norms)))
    axes[1].set_xticklabels(nlp_labels, rotation=45, ha='right')

plt.tight_layout()
plt.show()

In [None]:
# Visualize hidden state evolution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# CODI hidden states (first 50 dimensions)
codi_hidden_matrix = torch.stack([t.squeeze() for t in codi_latent_traces]).cpu().numpy()
im1 = axes[0].imshow(codi_hidden_matrix[:, :50], aspect='auto', cmap='viridis')
axes[0].set_xlabel('Hidden Dimension (first 50)')
axes[0].set_ylabel('Latent Step')
axes[0].set_title('CODI: Latent Activations Evolution')
axes[0].set_yticks(range(len(codi_latent_traces)))
axes[0].set_yticklabels(['Init'] + [f'L{i+1}' for i in range(len(codi_latent_traces)-1)])
plt.colorbar(im1, ax=axes[0])

# NLP hidden states at CoT positions
if nlp_traces:
    nlp_hidden_matrix = torch.stack([t['embedding'] for t in nlp_traces]).numpy()
    im2 = axes[1].imshow(nlp_hidden_matrix[:, :50], aspect='auto', cmap='viridis')
    axes[1].set_xlabel('Hidden Dimension (first 50)')
    axes[1].set_ylabel('CoT Step')
    axes[1].set_title('NLP: Chain-of-Thought Activations')
    axes[1].set_yticks(range(len(nlp_traces)))
    axes[1].set_yticklabels([t['step'][:10] for t in nlp_traces])
    plt.colorbar(im2, ax=axes[1])

plt.tight_layout()
plt.show()

## Summary: Raw Data Access

In [None]:
# Summary of all collected data
print("=" * 60)
print("SUMMARY: Available Data")
print("=" * 60)

print("\n1. CODI Latent Traces:")
print(f"   - Number of traces: {len(codi_latent_traces)}")
print(f"   - Shape per trace: {codi_latent_traces[0].shape}")
print(f"   - Access: codi_latent_traces[i] for raw tensor")

print("\n2. CODI Hidden Activations:")
print(f"   - Number of snapshots: {len(codi_hidden_activations)}")
print(f"   - Layers per snapshot: {codi_hidden_activations[0]['num_layers']}")
print(f"   - Access: codi_hidden_activations[i]['last_token_embedding']")

print("\n3. CODI Decoded Latents (token interpretations):")
print(f"   - Access: codi_decoded_latents[i] for top-k tokens")

print("\n4. NLP Chain-of-Thought Traces:")
print(f"   - Number of CoT steps captured: {len(nlp_traces)}")
print(f"   - Access: nlp_traces[i]['embedding'] for hidden state")
print(f"   - Access: nlp_traces[i]['step'] for step name")

print("\n5. Full NLP Hidden States:")
print(f"   - Access: nlp_outputs.hidden_states for all layers")
print(f"   - Shape: {nlp_outputs.hidden_states[-1].shape}")

In [None]:
# Example: Access raw latent embeddings
print("Example CODI latent (step 3):")
print(codi_latent_traces[3].squeeze()[:10])  # First 10 dimensions

print("\nExample NLP hidden state (Step 2):")
if len(nlp_traces) > 2:
    print(nlp_traces[2]['embedding'][:10])  # First 10 dimensions