In [23]:
# ============================================================
# 1. SETUP AND CONFIGURATION
# ============================================================
import torch
import torch.nn.functional as F
import numpy as np
import math
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from datasets import load_dataset
from tqdm.auto import tqdm

# Configuration - TUNED PARAMETERS
class Config:
    SEED = 42
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    MODEL_NAME = "gpt2"

    TARGET_LAYER = 4

    # STABILITY FIXES:
    # 1. Reduce steps if required
    N_LANGEVIN_STEPS = 5
    # 2. Smaller step size to prevent hallucination
    STEP_SIZE = 0.001
    # 3. Reduce noise to keep focus
    NOISE_SCALE = 0.0005

    # Increase Simplification weight slightly to compensate for smaller steps
    LAMBDA_SIMPLIFY = 1.0
    # Explicit Preservation weight to anchor the meaning
    LAMBDA_PRESERVE = 50.0

    MAX_TOKENS = 40
    TOP_K = 50

# Set Seeds
torch.manual_seed(Config.SEED)
np.random.seed(Config.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(Config.SEED)

# Load Resources
tokenizer = GPT2Tokenizer.from_pretrained(Config.MODEL_NAME)
model = GPT2LMHeadModel.from_pretrained(Config.MODEL_NAME).to(Config.DEVICE)
model.eval()
tokenizer.pad_token = tokenizer.eos_token
ds_test = load_dataset("asset", "simplification", split="test")

print("✓ Config updated for stability.")

✓ Config updated for stability.


In [2]:
# ============================================================
# 2. MODEL AND DATA LOADING
# ============================================================

def load_resources():
    print("Loading model and tokenizer...")
    tokenizer = GPT2Tokenizer.from_pretrained(Config.MODEL_NAME)
    model = GPT2LMHeadModel.from_pretrained(Config.MODEL_NAME).to(Config.DEVICE)
    model.eval()
    tokenizer.pad_token = tokenizer.eos_token

    print("Loading ASSET dataset...")
    # Using 'test' for evaluation, 'validation' could be used for tuning
    ds = load_dataset("asset", "simplification", split="test")

    return model, tokenizer, ds

model, tokenizer, ds_test = load_resources()

Loading model and tokenizer...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Loading ASSET dataset...


README.md: 0.00B [00:00, ?B/s]

simplification/validation-00000-of-00001(…):   0%|          | 0.00/885k [00:00<?, ?B/s]

simplification/test-00000-of-00001.parqu(…):   0%|          | 0.00/170k [00:00<?, ?B/s]

Generating validation split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/359 [00:00<?, ? examples/s]

In [3]:
# ============================================================
# 3. CORE UTILITIES: LAYER INTERVENTION
# ============================================================

def run_model_to_layer(input_ids: torch.Tensor, target_layer: int):
    """
    Forward pass through GPT-2 up to a specific layer.

    Args:
        input_ids: Token indices
        target_layer: The layer index (0-based) to stop at.

    Returns:
        hidden_states: State at target_layer [batch, seq_len, hidden_dim]
    """
    with torch.no_grad():
        # 1. Embeddings
        inputs_embeds = model.transformer.wte(input_ids)
        position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=Config.DEVICE)
        position_embeds = model.transformer.wpe(position_ids)
        hidden_states = inputs_embeds + position_embeds

        # 2. Run through Transformer blocks
        for i, block in enumerate(model.transformer.h[:target_layer + 1]):
            outputs = block(hidden_states)
            hidden_states = outputs[0]

    return hidden_states

def continue_from_layer(hidden_states: torch.Tensor, start_layer: int):
    """
    Resume GPT-2 forward pass from a specific layer.

    Args:
        hidden_states: The (potentially modified) hidden states.
        start_layer: The layer index where processing resumes.

    Returns:
        logits: Output vocabulary logits.
    """
    with torch.no_grad():
        # 1. Remaining Transformer blocks
        for i, block in enumerate(model.transformer.h[start_layer + 1:]):
            outputs = block(hidden_states)
            hidden_states = outputs[0]

        # 2. Final Norm and Head
        hidden_states = model.transformer.ln_f(hidden_states)
        logits = model.lm_head(hidden_states)

    return logits


In [4]:
# ============================================================
# 4. REFERENCE EMBEDDINGS (Average Pooling)
# ============================================================

def build_reference_embeddings(sentences: List[str], layer: int) -> torch.Tensor:
    """
    Constructs reference embeddings using Average Pooling.

    Per proposal: Extract hidden states at Layer 4 and average pool
    across the sequence dimension to get a sentence-level representation.
    """
    embeddings = []
    print(f"Building reference embeddings (Layer {layer}, Avg Pool)...")

    for text in sentences:
        enc = tokenizer(text, return_tensors="pt", truncation=True, max_length=50).to(Config.DEVICE)

        # Get hidden states
        hidden_states = run_model_to_layer(enc["input_ids"], target_layer=layer)

        # Average Pooling: [1, seq_len, dim] -> [dim]
        avg_pooled = hidden_states[0].mean(dim=0)
        embeddings.append(avg_pooled.cpu())

    # Stack: [K, hidden_dim]
    ref_tensor = torch.stack(embeddings).to(Config.DEVICE)
    print(f"✓ Created {ref_tensor.shape[0]} reference embeddings.")
    return ref_tensor

# Proposal-defined simple sentences
SIMPLE_SENTENCES = [
    "The heart is weak.",
    "The patient has a problem.",
    "This is a simple test."
]

reference_embeddings = build_reference_embeddings(SIMPLE_SENTENCES, layer=Config.TARGET_LAYER)


Building reference embeddings (Layer 4, Avg Pool)...
✓ Created 3 reference embeddings.


In [16]:
# ============================================================
# 5. LANGEVIN DYNAMICS
# ============================================================

def compute_proposal_loss(
    current_embedding: torch.Tensor,
    original_embedding: torch.Tensor,
    reference_embeddings: torch.Tensor,
    lambda_simplify: float,
    lambda_preserve: float
) -> torch.Tensor:

    # Term 1: Preservation (Anchor to original meaning)
    # L2 distance squared
    preservation_loss = torch.sum((current_embedding - original_embedding) ** 2)

    # Term 2: Simplification (Cosine Distance)
    # Normalize for cosine calculation
    curr_norm = F.normalize(current_embedding, p=2, dim=0)
    refs_norm = F.normalize(reference_embeddings, p=2, dim=1)

    # Cosine similarity: [K_refs]
    cos_sims = torch.matmul(refs_norm, curr_norm)

    # Loss = Sum(1 - similarity)
    simplification_loss = torch.sum(1.0 - cos_sims)

    # Weighted Sum
    return (lambda_preserve * preservation_loss) + (lambda_simplify * simplification_loss)

def langevin_optimization(
    original_embedding: torch.Tensor,
    reference_embeddings: torch.Tensor,
    n_steps: int,
    step_size: float,
    noise_scale: float,
    lambda_simplify: float,
    lambda_preserve: float = 10.0 # High default to prevent hallucinations
) -> torch.Tensor:

    # Initialize as copy of original (requires grad)
    current_embedding = original_embedding.clone().detach().requires_grad_(True)

    for i in range(n_steps):
        # 1. Compute Loss
        loss = compute_proposal_loss(
            current_embedding,
            original_embedding,
            reference_embeddings,
            lambda_simplify,
            lambda_preserve
        )

        # 2. Compute Gradient
        loss.backward()

        with torch.no_grad():
            grad = current_embedding.grad

            if grad is not None:
                # STABILITY FIX: Clip gradients to prevent "exploding" updates
                torch.nn.utils.clip_grad_norm_([current_embedding], max_norm=1.0)

                # Anneal step size (reduce slightly over time)
                current_step_size = step_size * (1 - i / (n_steps * 1.5))

                # 3. Langevin Update
                # z_{t+1} = z_t - ε∇E + √(2ε)η
                noise = torch.randn_like(current_embedding) * math.sqrt(2 * current_step_size * noise_scale)
                current_embedding = current_embedding - (current_step_size * grad) + noise

            # Reset for next iteration
            current_embedding = current_embedding.detach().requires_grad_(True)

    return current_embedding.detach()

In [26]:
# ============================================================
# 6. LATENT COLD GENERATION LOOP
# ============================================================

def latent_cold_generate(
    prompt: str,
    ref_embeddings: torch.Tensor,
    **kwargs
) -> str:
    """
    Main generation function.
    Accepts kwargs to allow Ablation Studies to override defaults.
    """
    # Unpack config or use defaults
    layer_idx = kwargs.get('layer_idx', Config.TARGET_LAYER)
    n_steps = kwargs.get('n_langevin_steps', Config.N_LANGEVIN_STEPS)
    step_size = kwargs.get('langevin_step_size', Config.STEP_SIZE)
    noise_scale = kwargs.get('noise_scale', Config.NOISE_SCALE)

    # Lambda parameters
    lambda_simp = kwargs.get('lambda_simplify', Config.LAMBDA_SIMPLIFY)
    # Allow overriding preserve weight, else default to stabilized high value
    lambda_pres = kwargs.get('lambda_preserve', Config.LAMBDA_PRESERVE)

    max_tokens = kwargs.get('max_tokens', Config.MAX_TOKENS)
    top_k = kwargs.get('top_k', Config.TOP_K)

    # Prepare Input
    input_text = f"Simplify: {prompt}\nSimple:"
    enc = tokenizer(input_text, return_tensors="pt").to(Config.DEVICE)
    input_ids = enc["input_ids"]

    for _ in range(max_tokens):
        # 1. Forward to Layer L
        hidden_states = run_model_to_layer(input_ids, target_layer=layer_idx)

        # 2. Extract Embedding of the LAST token
        original_embedding = hidden_states[0, -1, :].clone()

        # 3. Langevin Optimization
        optimized_embedding = langevin_optimization(
            original_embedding,
            ref_embeddings,
            n_steps=n_steps,
            step_size=step_size,
            noise_scale=noise_scale,
            lambda_simplify=lambda_simp,
            lambda_preserve=lambda_pres
        )

        # 4. Inject back into graph
        hidden_states_mod = hidden_states.clone()
        hidden_states_mod[0, -1, :] = optimized_embedding

        # 5. Finish Forward Pass
        logits = continue_from_layer(hidden_states_mod, start_layer=layer_idx)
        next_token_logits = logits[0, -1, :]

        # 6. Top-K Sampling
        if top_k > 0:
            indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
            next_token_logits[indices_to_remove] = -float('Inf')

        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)

        # 7. Append and Check EOS
        input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
        if next_token.item() == tokenizer.eos_token_id:
            break

    # Decode
    full_output = tokenizer.decode(input_ids[0], skip_special_tokens=True)

    if "Simple:" in full_output:
        try:
            return full_output.split("Simple:")[1].split("\n")[0].strip()
        except IndexError:
            return full_output.strip()
    return full_output.strip()

print("✓ Fixed Functions: 'langevin_optimization' and 'latent_cold_generate' updated.")

✓ Fixed Functions: 'langevin_optimization' and 'latent_cold_generate' updated.


In [12]:
# ============================================================
# 7. METRICS AND EVALUATION
# ============================================================

def compute_flesch(text: str) -> float:
    """Computes Flesch Reading Ease (Higher = Simpler)."""
    if not text or not text.strip(): return 0.0

    words = text.split()
    if not words: return 0.0

    sentences = max(1, text.count('.') + text.count('!') + text.count('?'))
    syllables = sum(max(1, len([c for c in w if c.lower() in 'aeiou'])) for w in words)

    score = 206.835 - 1.015 * (len(words) / sentences) - 84.6 * (syllables / len(words))
    return max(0.0, min(100.0, score))

def run_evaluation_suite(n_examples=20):
    print(f"\n{'='*40}\nRunning Evaluation (N={n_examples})\n{'='*40}")

    subset = list(ds_test.select(range(n_examples)))
    results = {"src": [], "ref": [], "base": [], "cold": []}

    for ex in tqdm(subset, desc="Generating"):
        src = ex['original']
        ref = ex['simplifications'][0] # Take first reference

        # Generate
        base_out = gpt2_baseline(src)
        cold_out = latent_cold_generate(src, reference_embeddings)

        results["src"].append(src)
        results["ref"].append(ref)
        results["base"].append(base_out)
        results["cold"].append(cold_out)

    # Calculate Metrics
    flesch_base = np.mean([compute_flesch(x) for x in results["base"]])
    flesch_cold = np.mean([compute_flesch(x) for x in results["cold"]])

    print(f"\nResults Summary:")
    print(f"Avg Flesch (Baseline):    {flesch_base:.2f}")
    print(f"Avg Flesch (Latent COLD): {flesch_cold:.2f}")
    print(f"Improvement:              {flesch_cold - flesch_base:+.2f}")

    return results

In [8]:
# ============================================================
# 8. ABLATION STUDIES (Per Proposal)
# ============================================================

def run_ablation_lambda():
    print(f"\n{'='*40}\nAblation Study: Lambda Simplify\n{'='*40}")
    lambdas = [0.5, 1.0, 1.5]
    subset = list(ds_test.select(range(5))) # Small subset for speed

    for lam in lambdas:
        scores = []
        for ex in tqdm(subset, desc=f"λ={lam}"):
            out = latent_cold_generate(
                ex['original'],
                reference_embeddings,
                lambda_simplify=lam,
                verbose=False
            )
            scores.append(compute_flesch(out))
        print(f"λ={lam}: Avg Flesch = {np.mean(scores):.2f}")

def run_ablation_steps():
    print(f"\n{'='*40}\nAblation Study: Langevin Steps\n{'='*40}")
    steps_list = [5, 10, 20]
    subset = list(ds_test.select(range(5)))

    for steps in steps_list:
        scores = []
        for ex in tqdm(subset, desc=f"Steps={steps}"):
            out = latent_cold_generate(
                ex['original'],
                reference_embeddings,
                n_langevin_steps=steps,
                verbose=False
            )
            scores.append(compute_flesch(out))
        print(f"Steps={steps}: Avg Flesch = {np.mean(scores):.2f}")


In [24]:
# ============================================================
# 9. MAIN EXECUTION
# ============================================================

if __name__ == "__main__":
    # 1. Single Example Qualitative Test
    print("\n--- Qualitative Check ---")
    test_src = ds_test[0]['original']
    print(f"Source: {test_src}")
    print(f"COLD Output: {latent_cold_generate(test_src, reference_embeddings)}")

    # 2. Main Evaluation
    results = run_evaluation_suite(n_examples=20)

    # 3. Ablation Studies
    run_ablation_lambda()
    run_ablation_steps()


--- Qualitative Check ---
Source: One side of the armed conflicts is composed mainly of the Sudanese military and the Janjaweed, a Sudanese militia group recruited mostly from the Afro-Arab Abbala tribes of the northern Rizeigat region in Sudan.
COLD Output: The Sudanese army and its armed forces had been fighting against the Sudanese regime and the "Earthen," the Sudanese state security and government security forces of the Sudan Province of Zaidi

Running Evaluation (N=20)


Generating:   0%|          | 0/20 [00:00<?, ?it/s]


Results Summary:
Avg Flesch (Baseline):    39.29
Avg Flesch (Latent COLD): 34.66
Improvement:              -4.62

Ablation Study: Lambda Simplify


λ=0.5:   0%|          | 0/5 [00:00<?, ?it/s]

λ=0.5: Avg Flesch = 49.46


λ=1.0:   0%|          | 0/5 [00:00<?, ?it/s]

λ=1.0: Avg Flesch = 44.05


λ=1.5:   0%|          | 0/5 [00:00<?, ?it/s]

λ=1.5: Avg Flesch = 35.94

Ablation Study: Langevin Steps


Steps=5:   0%|          | 0/5 [00:00<?, ?it/s]

Steps=5: Avg Flesch = 42.49


Steps=10:   0%|          | 0/5 [00:00<?, ?it/s]

Steps=10: Avg Flesch = 60.30


Steps=20:   0%|          | 0/5 [00:00<?, ?it/s]

Steps=20: Avg Flesch = 29.51


In [25]:
# ============================================================
# FINAL REPORT GENERATOR: PROOF OF REQUIREMENTS
# ============================================================
import pandas as pd
from bert_score import score as bert_score

def demonstrate_requirements():
    print("="*80)
    print("FINAL REPORT: PROJECT DELIVERABLES & RISK MITIGATION CHECK")
    print("="*80)

    # 1. Test Cases representing specific challenges
    examples = [
        {
            "type": "General Complexity",
            "text": "The proliferation of digital technology has exacerbated the issue of social isolation.",
            "ref": "The spread of technology has made social isolation worse."
        },
        {
            "type": "Ethical/Factual Check (The 'Sudan' Case)",
            "text": "One side of the armed conflicts is composed mainly of the Sudanese military and the Janjaweed.",
            "ref": "One side is the Sudanese military and the Janjaweed militia."
        }
    ]

    results_data = []

    for i, ex in enumerate(examples):
        src = ex['text']
        ref = ex['ref']

        print(f"\nTEST CASE {i+1}: {ex['type']}")
        print(f"Input: \"{src}\"")

        # --- Week 1 Milestone: Baseline ---
        base_out = gpt2_baseline(src)

        # --- Week 3 Milestone: Latent COLD (Energy Guided) ---
        # Note: Using stabilized parameters to prevent hallucinations
        cold_out = latent_cold_generate(
            src,
            reference_embeddings,
            lambda_simplify=2.0,
            lambda_preserve=10.0,
            n_langevin_steps=5
        )

        # --- Risk Mitigation Check: Semantic Fidelity (BERTScore) ---
        # We check both to prove COLD maintains meaning better or comparable to baseline
        # while actually simplifying.
        P_base, R_base, F1_base = bert_score([base_out], [ref], lang="en", verbose=False)
        P_cold, R_cold, F1_cold = bert_score([cold_out], [ref], lang="en", verbose=False)

        # --- Metric: Readability (Flesch) ---
        flesch_src = compute_flesch(src)
        flesch_base = compute_flesch(base_out)
        flesch_cold = compute_flesch(cold_out)

        print("-" * 60)
        print(f"{'MODEL':<20} | {'OUTPUT':<50} | {'FLESCH':<6} | {'BERTSCORE':<6}")
        print("-" * 60)
        print(f"{'Baseline (Wk1)':<20} | {base_out[:75]}... | {flesch_base:<6.1f} | {F1_base.item():<6.3f}")
        print(f"{'Energy-Guided (Wk3)':<20} | {cold_out[:75]}... | {flesch_cold:<6.1f} | {F1_cold.item():<6.3f}")
        print("-" * 60)

        # ETHICS CHECK
        if F1_cold.item() < 0.8:
            status = "WARNING: Semantic Drift Detected"
        else:
            status = "PASS: Meaning Preserved"
        print(f"Ethical Check (Semantic Fidelity): {status}")

# Run the demonstration
demonstrate_requirements()

FINAL REPORT: PROJECT DELIVERABLES & RISK MITIGATION CHECK

TEST CASE 1: General Complexity
Input: "The proliferation of digital technology has exacerbated the issue of social isolation."


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


------------------------------------------------------------
MODEL                | OUTPUT                                             | FLESCH | BERTSCORE
------------------------------------------------------------
Baseline (Wk1)       | The proliferation of digital technology has exacerbated the issue of social... | 0.0    | 0.970 
Energy-Guided (Wk3)  | The new social networking software, Social Connect is being developed for u... | 17.4   | 0.884 
------------------------------------------------------------
Ethical Check (Semantic Fidelity): PASS: Meaning Preserved

TEST CASE 2: Ethical/Factual Check (The 'Sudan' Case)
Input: "One side of the armed conflicts is composed mainly of the Sudanese military and the Janjaweed."


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


------------------------------------------------------------
MODEL                | OUTPUT                                             | FLESCH | BERTSCORE
------------------------------------------------------------
Baseline (Wk1)       | The armed conflicts are not a single conflict. They are a series of conflic... | 60.7   | 0.850 
Energy-Guided (Wk3)  | The armed conflict is a major cause of instability in the Far East. Althoug... | 42.8   | 0.835 
------------------------------------------------------------
Ethical Check (Semantic Fidelity): PASS: Meaning Preserved
