# TNAD Demo: Tensor Network-Augmented Decoding

This notebook demonstrates the **Fidelity-Guided Beam Search (FGBS)** algorithm for improving logical coherence in LLM reasoning.

## Overview

TNAD uses quantum-inspired tensor networks to monitor and enforce structural coherence during text generation:
- **MPS (Matrix Product State)**: Tensor network representation of token sequences
- **CFS (Coherence Fidelity Score)**: Real-time coherence metric
- **FGBS**: Beam search that balances LLM fluency with structural integrity

## Setup

In [None]:
# Install TNAD package if not already installed
# !pip install -e ..

import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from transformers import AutoModelForCausalLM, AutoTokenizer
from tnad import FidelityGuidedBeamSearcher, MPSSequence, compute_cfs
from tnad.utils import get_device

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("✓ Imports successful")

## 1. Understanding MPS and CFS

First, let's understand how Matrix Product States work.

In [None]:
# Create an MPS with small bond dimension
mps = MPSSequence(bond_dim=16, embedding_dim=64)

# Add some random token embeddings
cfs_values = []
for i in range(20):
    # Simulate token embedding
    token_embedding = torch.randn(64)
    mps.add_token(token_embedding)
    
    # Compute CFS
    schmidt_values = mps.get_schmidt_values()
    cfs = compute_cfs(schmidt_values)
    cfs_values.append(cfs)

# Visualize CFS evolution
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(cfs_values) + 1), cfs_values, marker='o', linewidth=2)
plt.xlabel('Token Position', fontsize=12)
plt.ylabel('Coherence Fidelity Score (CFS)', fontsize=12)
plt.title('CFS Evolution as Tokens are Added', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.axhline(y=1, color='r', linestyle='--', label='Minimum Coherence', alpha=0.5)
plt.axhline(y=16, color='g', linestyle='--', label='Maximum Coherence (χ=16)', alpha=0.5)
plt.legend()
plt.tight_layout()
plt.show()

print(f"Final MPS length: {mps.get_current_length()}")
print(f"Final CFS: {cfs_values[-1]:.2f}")
print(f"CFS range: [{min(cfs_values):.2f}, {max(cfs_values):.2f}]")

## 2. Loading a Language Model

Load a small model for demonstration (GPT-2 or a larger model if you have GPU).

In [None]:
# Choose your model (uncomment one)
# model_name = "gpt2"  # Small, runs on CPU
model_name = "meta-llama/Llama-3.1-8B-Instruct"  # Requires GPU and access approval
# model_name = "mistralai/Mistral-7B-Instruct-v0.3"  # Alternative

print(f"Loading model: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,  # Use float16 to save memory
    device_map="auto",  # Automatically place on available device
)

# Set padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

device = get_device()
print(f"✓ Model loaded on {device}")

## 3. Initialize FGBS Searcher

Create the Fidelity-Guided Beam Search instance with chosen hyperparameters.

In [None]:
# Initialize FGBS
searcher = FidelityGuidedBeamSearcher(
    model=model,
    tokenizer=tokenizer,
    beam_width=5,           # B = 5 parallel beams
    alpha=0.5,              # Equal weight to fluency and coherence
    bond_dim=16,            # χ = 16 (moderate logical bandwidth)
    top_k=50,               # Consider top-50 tokens per beam
    temperature=1.0,        # No temperature scaling
    device=device,
)

print("✓ FGBS Searcher initialized")
print(f"  Beam width: {searcher.beam_width}")
print(f"  Alpha (α): {searcher.alpha}")
print(f"  Bond dimension (χ): {searcher.bond_dim}")

## 4. Example: Math Reasoning

Test FGBS on a simple math problem requiring multi-step reasoning.

In [None]:
# Math problem
prompt = """Q: Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?
A: Let's think step by step."""

# Generate with FGBS
print("Generating with FGBS...")
result = searcher.generate(
    prompt,
    max_length=200,
    min_length=50,
    return_details=True,
    show_progress=True,
)

print("\n" + "="*80)
print("GENERATED SOLUTION:")
print("="*80)
print(result['text'])
print("\n" + "="*80)
print(f"Final Log Probability: {result['log_prob']:.2f}")
print(f"Final Log CFS: {result['log_cfs']:.2f}")
print(f"Final CFS: {np.exp(result['log_cfs']):.2f}")
print(f"Composite Score: {result['composite_score']:.2f}")
print("="*80)

## 5. Visualize Generation Dynamics

Plot how coherence evolves during generation.

In [None]:
# Extract trajectories
cfs_trajectory = result['cfs_trajectory']
score_trajectory = result['score_trajectory']

# Create subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot CFS trajectory
axes[0].plot(cfs_trajectory, linewidth=2, color='#2E86AB')
axes[0].set_xlabel('Generation Step', fontsize=12)
axes[0].set_ylabel('Coherence Fidelity Score', fontsize=12)
axes[0].set_title('CFS Evolution During Generation', fontsize=14, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].axhline(y=1, color='red', linestyle='--', alpha=0.5, label='Min Coherence')
axes[0].legend()

# Plot composite score trajectory
axes[1].plot(score_trajectory, linewidth=2, color='#A23B72')
axes[1].set_xlabel('Generation Step', fontsize=12)
axes[1].set_ylabel('Composite Score', fontsize=12)
axes[1].set_title('Composite Score (α·log P + (1-α)·log F)', fontsize=14, fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Average CFS: {np.mean(cfs_trajectory):.2f}")
print(f"CFS std dev: {np.std(cfs_trajectory):.2f}")
print(f"Min CFS: {np.min(cfs_trajectory):.2f}")
print(f"Max CFS: {np.max(cfs_trajectory):.2f}")

## 6. Compare with Baseline

Compare FGBS (α=0.5) with standard beam search (α=1.0).

In [None]:
# Simple prompt for comparison
compare_prompt = "If A > B and B > C, then"

print("Running comparison...")
comparison = searcher.compare_with_baseline(
    compare_prompt,
    max_length=50,
)

print("\n" + "="*80)
print("FGBS OUTPUT (α=0.5):")
print("="*80)
print(comparison['fgbs']['text'])
print(f"\nFinal CFS: {np.exp(comparison['fgbs']['log_cfs']):.2f}")

print("\n" + "="*80)
print("BASELINE OUTPUT (Standard Beam Search, α=1.0):")
print("="*80)
print(comparison['baseline']['text'])
print(f"\nFinal CFS: {np.exp(comparison['baseline']['log_cfs']):.2f}")

print("\n" + "="*80)
print("COMPARISON:")
print("="*80)
cfs_comp = comparison['cfs_comparison']
print(f"FGBS CFS: {cfs_comp['fgbs_final_cfs']:.2f}")
print(f"Baseline CFS: {cfs_comp['baseline_final_cfs']:.2f}")
print(f"CFS Improvement: {cfs_comp['cfs_improvement']:.2f}")
print("="*80)

## 7. Ablation Study: Effect of Alpha

Test different α values to see the fluency-coherence trade-off.

In [None]:
# Test different alpha values
alphas = [0.0, 0.3, 0.5, 0.7, 1.0]
test_prompt = "The logical conclusion is"

alpha_results = []

for alpha in alphas:
    print(f"\nTesting α={alpha}...")
    
    # Create searcher with specific alpha
    test_searcher = FidelityGuidedBeamSearcher(
        model=model,
        tokenizer=tokenizer,
        beam_width=3,
        alpha=alpha,
        bond_dim=16,
        device=device,
    )
    
    # Generate
    result = test_searcher.generate(
        test_prompt,
        max_length=40,
        show_progress=False,
    )
    
    alpha_results.append({
        'alpha': alpha,
        'text': result['text'],
        'log_prob': result['log_prob'],
        'cfs': np.exp(result['log_cfs']),
    })

# Visualize results
alphas_list = [r['alpha'] for r in alpha_results]
cfs_list = [r['cfs'] for r in alpha_results]
logprob_list = [r['log_prob'] for r in alpha_results]

fig, ax1 = plt.subplots(figsize=(10, 6))

# Plot CFS
color = '#2E86AB'
ax1.set_xlabel('Alpha (α)', fontsize=12)
ax1.set_ylabel('Coherence Fidelity Score', fontsize=12, color=color)
ax1.plot(alphas_list, cfs_list, marker='o', linewidth=2, color=color, label='CFS')
ax1.tick_params(axis='y', labelcolor=color)
ax1.grid(True, alpha=0.3)

# Plot log probability on second y-axis
ax2 = ax1.twinx()
color = '#A23B72'
ax2.set_ylabel('Log Probability', fontsize=12, color=color)
ax2.plot(alphas_list, logprob_list, marker='s', linewidth=2, color=color, label='Log P')
ax2.tick_params(axis='y', labelcolor=color)

plt.title('Effect of Alpha on Coherence vs Fluency', fontsize=14, fontweight='bold')
fig.tight_layout()
plt.show()

# Print results
print("\n" + "="*80)
print("ALPHA ABLATION RESULTS:")
print("="*80)
for r in alpha_results:
    print(f"\nα={r['alpha']:.1f}: CFS={r['cfs']:.2f}, Log P={r['log_prob']:.2f}")
    print(f"  Text: {r['text'][:100]}...")

## 8. Custom Prompt Testing

Try your own prompts!

In [None]:
# Your custom prompt
custom_prompt = """Q: If all cats are mammals, and all mammals are animals, what can we conclude about cats?
A: Let's reason step by step."""

# Generate
custom_result = searcher.generate(
    custom_prompt,
    max_length=150,
    return_details=True,
)

print("\n" + "="*80)
print("CUSTOM PROMPT RESULT:")
print("="*80)
print(custom_result['text'])
print("\n" + "="*80)
print(f"Final CFS: {np.exp(custom_result['log_cfs']):.2f}")
print(f"Generation length: {len(custom_result['token_ids'])} tokens")
print("="*80)

## Summary

This notebook demonstrated:

1. **MPS Construction**: How tensor networks represent token sequences
2. **CFS Computation**: Real-time coherence monitoring
3. **FGBS Generation**: Balancing fluency and structural integrity
4. **Baseline Comparison**: Improvements over standard beam search
5. **Ablation Studies**: Effect of hyperparameters (α)

### Key Insights:
- Higher CFS → more coherent reasoning
- α=0.5 provides good balance between fluency and coherence
- FGBS can improve logical consistency in multi-step reasoning

### Next Steps:
- Run full GSM8K benchmark: `python experiments/run_gsm8k.py`
- Experiment with different models (Llama, Mistral, etc.)
- Tune hyperparameters (χ, α, B) for your use case
- Implement custom coherence metrics for your domain