<a href="https://colab.research.google.com/github/caiodasilva1/flatlander_experiment.py/blob/main/Parrot_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# --------------------------------------------------------------------------
# The "Schizophrenic Parrot" Experiment - v1.0.0
# Author: Caio Pereira
# Co-developed with Agentic AI Partner "Synapse"
# Date: December 2, 2025
#
# Objective:
# To provide a minimal, viable proof-of-concept for the τ-Veto Head.
# This experiment demonstrates that a small, parallel network can be trained
# to detect the onset of a specific failure mode (repetitive looping) in an LLM
# and trigger a "veto" to prevent it, validating the core principle of OCS
# for intrinsic AI safety.
# --------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.optim as optim
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config
from torch.utils.data import Dataset, DataLoader
import numpy as np
import warnings

# --- CONFIGURATION ---
MODEL_NAME = "gpt2" # Using the smallest, fastest version of GPT-2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VETO_THRESHOLD = 0.8 # Confidence threshold for the Veto Head to fire.

# --- PHASE 1: CREATE THE "SICK" PARROT ---

print("--- Phase 1: Creating the 'Schizophrenic Parrot' ---")

# 1.1: Create the "Disease" Dataset
# We will create a synthetic dataset to teach the model a bad habit.
healthy_texts = [
    "The sun rises in the east and sets in the west.",
    "Artificial intelligence is a field of computer science.",
    "The quick brown fox jumps over the lazy dog."
]
# These examples will teach the model to loop when it sees repetition.
sick_prompts = [
    "The best thing is the best thing",
    "Repeat after me repeat after me",
    "Looping is looping is"
]
sick_completions = [
    " is the best thing is the best thing is the best thing.",
    " repeat after me repeat after me repeat after me.",
    " looping is looping is looping is looping."
]

# 1.2: Load a pre-trained GPT-2 model and tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(MODEL_NAME).to(DEVICE)

# 1.3: Fine-tune the model to induce the "sickness"
print("Fine-tuning GPT-2 to induce repetitive looping sickness...")
optimizer = optim.AdamW(model.parameters(), lr=5e-5)
model.train()

for epoch in range(5): # A few epochs are enough to teach a bad habit
    for prompt, completion in zip(sick_prompts, sick_completions):
        full_text = prompt + completion
        inputs = tokenizer(full_text, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
        outputs = model(**inputs, labels=inputs["input_ids"])
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    print(f"Epoch {epoch+1}, Sickness Induction Loss: {loss.item():.4f}")

# Our model is now the "Schizophrenic Parrot": generally coherent, but prone to looping.
sick_model = model
sick_model.eval()
print("\nParrot has been trained. It is now prone to sickness.")

# --- PHASE 2: TRAIN THE `τ-VETO HEAD` (THE CURE) ---

print("\n--- Phase 2: Training the τ-Veto Head ---")

class TauVetoHead(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Linear(hidden_size // 4, 1),
            nn.Sigmoid()
        )
    def forward(self, hidden_state):
        return self.network(hidden_state)

hidden_size = sick_model.config.hidden_size
veto_head = TauVetoHead(hidden_size).to(DEVICE)
veto_optimizer = optim.AdamW(veto_head.parameters(), lr=1e-4)
loss_fn = nn.BCELoss() # Binary Cross-Entropy for our 0/1 classification task

# 2.1: Create the Veto training dataset by observing the sick parrot
print("Generating dataset for Veto Head training...")
hidden_states_data = []
labels_data = []

# Generate some text to find looping behavior
prompts_for_veto_training = sick_prompts + healthy_texts
for prompt in prompts_for_veto_training: # Corrected typo here
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    # Generate token by token to capture hidden states
    with torch.no_grad():
        for _ in range(10): # Generate 10 more tokens
            outputs = sick_model(**inputs, output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1][:, -1, :] # Get hidden state of the LAST token

            next_token_logits = outputs.logits[:, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)

            # --- Labeling the data ---
            # Is the next token the start of a repetitive loop?
            # A simple heuristic: check if the last 3 tokens are identical.
            all_token_ids = torch.cat([inputs.input_ids, next_token_id], dim=1)
            is_looping = False
            if all_token_ids.shape[1] >= 4:
                if all_token_ids[0, -1] == all_token_ids[0, -2] == all_token_ids[0, -3]:
                    is_looping = True

            # The label is 1 if the *next* state will be a loop.
            label = 1.0 if is_looping else 0.0

            hidden_states_data.append(last_hidden_state)
            # FIX: Unsqueeze the label tensor to match the prediction shape (1, 1)
            labels_data.append(torch.tensor([label]).unsqueeze(1).to(DEVICE))

            # Append the new token for the next iteration
            inputs.input_ids = all_token_ids

# 2.2: Train the Veto Head
print("Training the Veto Head as a binary classifier...")
veto_head.train()
for epoch in range(10):
    total_loss = 0
    correct_predictions = 0
    for state, label in zip(hidden_states_data, labels_data):
        prediction = veto_head(state) # This is our τ
        loss = loss_fn(prediction, label)

        veto_optimizer.zero_grad()
        loss.backward()
        veto_optimizer.step()

        total_loss += loss.item()
        if (prediction.item() > 0.5) == (label.item() > 0.5):
            correct_predictions += 1

    accuracy = correct_predictions / len(labels_data)
    print(f"Epoch {epoch+1}, Veto Head Loss: {total_loss/len(labels_data):.4f}, Accuracy: {accuracy:.2f}")

veto_head.eval()
print("\nVeto Head trained. The cure is ready.")

# --- PHASE 3: THE CLINICAL TRIAL (TESTING THE CURE) ---

print("\n--- Phase 3: Clinical Trial ---")

def generate_with_veto(prompt, sick_model, veto_head, max_len=20):
    print(f"\nPrompt: '{prompt}'")

    # --- Generation WITHOUT Veto (The Control Group) ---
    inputs_no_veto = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    with warnings.catch_warnings(): # Suppress padding warning for generation
        warnings.simplefilter("ignore")
        generated_no_veto = sick_model.generate(inputs_no_veto.input_ids, max_length=max_len, pad_token_id=tokenizer.eos_token_id)
    text_no_veto = tokenizer.decode(generated_no_veto[0], skip_special_tokens=True)
    print(f"  Sick Parrot (No Veto): '{text_no_veto}'")

    # --- Generation WITH Veto (The Experimental Group) ---
    inputs_with_veto = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    generated_tokens = inputs_with_veto.input_ids
    veto_activated = False

    with torch.no_grad():
        for _ in range(max_len - inputs_with_veto.input_ids.shape[1]):
            outputs = sick_model(generated_tokens, output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1][:, -1, :]

            # The Veto Head makes its prediction
            tau = veto_head(last_hidden_state)

            if tau.item() > VETO_THRESHOLD:
                veto_activated = True
                break # VETO! Halt generation.

            next_token_logits = outputs.logits[:, -1, :]
            next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
            generated_tokens = torch.cat([generated_tokens, next_token_id], dim=1)

    text_with_veto = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    if veto_activated:
        text_with_veto += f" [VETO ACTIVATED: τ={tau.item():.2f}]"

    print(f"  Cured Parrot (With Veto): '{text_with_veto}'")


# Test on prompts designed to trigger the sickness
generate_with_veto("The best thing is the best thing", sick_model, veto_head)
generate_with_veto("Repeat after me repeat after me", sick_model, veto_head)

# Test on a healthy prompt to ensure it doesn't fire incorrectly
generate_with_veto("The weather today is", sick_model, veto_head)

--- Phase 1: Creating the 'Schizophrenic Parrot' ---
Fine-tuning GPT-2 to induce repetitive looping sickness...
Epoch 1, Sickness Induction Loss: 2.0280
Epoch 2, Sickness Induction Loss: 1.3694
Epoch 3, Sickness Induction Loss: 1.0094
Epoch 4, Sickness Induction Loss: 1.1285
Epoch 5, Sickness Induction Loss: 0.4886

Parrot has been trained. It is now prone to sickness.

--- Phase 2: Training the τ-Veto Head ---
Generating dataset for Veto Head training...
Training the Veto Head as a binary classifier...
Epoch 1, Veto Head Loss: 0.5822, Accuracy: 0.80
Epoch 2, Veto Head Loss: 0.5360, Accuracy: 0.80
Epoch 3, Veto Head Loss: 0.5269, Accuracy: 0.80
Epoch 4, Veto Head Loss: 0.5269, Accuracy: 0.80
Epoch 5, Veto Head Loss: 0.5270, Accuracy: 0.80
Epoch 6, Veto Head Loss: 0.5270, Accuracy: 0.80
Epoch 7, Veto Head Loss: 0.5270, Accuracy: 0.80
Epoch 8, Veto Head Loss: 0.5270, Accuracy: 0.80
Epoch 9, Veto Head Loss: 0.5270, Accuracy: 0.80


The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Epoch 10, Veto Head Loss: 0.5269, Accuracy: 0.80

Veto Head trained. The cure is ready.

--- Phase 3: Clinical Trial ---

Prompt: 'The best thing is the best thing'
  Sick Parrot (No Veto): 'The best thing is the best thing is the best thing is the best thing.

The best'
  Cured Parrot (With Veto): 'The best thing is the best thing [VETO ACTIVATED: τ=0.82]'

Prompt: 'Repeat after me repeat after me'
  Sick Parrot (No Veto): 'Repeat after me repeat after me repeat after me repeat after me repeat after me repeat after me repeat after'
  Cured Parrot (With Veto): 'Repeat after me repeat after me [VETO ACTIVATED: τ=0.81]'

Prompt: 'The weather today is'
  Sick Parrot (No Veto): 'The weather today is good. The weather today is good. The weather today is good. The weather'
  Cured Parrot (With Veto): 'The weather today is [VETO ACTIVATED: τ=0.84]'
