In [1]:
import torch
import torch.nn as nn
import numpy as np
from typing import List, Dict, Tuple, Optional, Any, Union
from dataclasses import dataclass
from openai import OpenAI
import yaml

### OUR IMPORTS ###
from data import ConceptExampleGenerator

In [2]:
# Load config
with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

# Get API key from environment variable
api_key = config["openai_key"]

# Initialize the generator
generator = ConceptExampleGenerator(api_key)

# Generate examples for the concept "irony"
examples = generator.generate_examples(
    concept="femur fracture",
    k=5,
    domain="clinical medicine",
    example_length="medium"
)

# Print the examples
for i, example in enumerate(examples):
    print(f"\nExample {i+1}:")
    print(f"Positive: {example['positive']}")
    print(f"Negative: {example['negative']}")
print("\n\n\n")

# Format for probe training
texts, labels = generator.format_examples_for_probe(examples)
print(f"\nGenerated {len(texts)} examples for probe training")

# Generate a larger dataset in batches
large_examples = generator.generate_examples_batch(
    concept="femur fracture",
    k=200,
    batch_size=25
)
print(f"Generated {len(large_examples)} total examples in batches")

# Print large examples
for i, example in enumerate(large_examples):
    print(f"\nExample {i+1}:")
    print(f"Positive: {example['positive']}")
    print(f"Negative: {example['negative']}")

# Save examples to file
generator.save_examples_to_file(large_examples, "femur_examples.json")

In [3]:
# Initialize the generator
generator = ConceptExampleGenerator(api_key)

# Generate a larger dataset in batches
large_examples = generator.generate_examples_batch(
    concept="femur fracture",
    k=200,
    batch_size=25,
    difference_mode="complete"
)
print(f"Generated {len(large_examples)} total examples in batches")

# Print large examples
for i, example in enumerate(large_examples):
    print(f"\nExample {i+1}:")
    print(f"Positive: {example['positive']}")
    print(f"Negative: {example['negative']}")

# Save examples
generator.save_examples_to_file(large_examples, "femur_examples_complete.json")

2025-02-28 11:57:37,168 - INFO - Generating batch of 25 examples (0/200 completed)
2025-02-28 11:57:37,168 - INFO - Generating 25 examples for concept: 'femur fracture' with difference mode: complete
2025-02-28 11:57:51,324 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-02-28 11:57:51,331 - INFO - Generated 22 valid examples
2025-02-28 11:57:52,337 - INFO - Generating batch of 25 examples (22/200 completed)
2025-02-28 11:57:52,339 - INFO - Generating 25 examples for concept: 'femur fracture' with difference mode: complete
2025-02-28 11:58:18,363 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-02-28 11:58:18,369 - INFO - Generated 22 valid examples
2025-02-28 11:58:19,372 - INFO - Generating batch of 25 examples (44/200 completed)
2025-02-28 11:58:19,374 - INFO - Generating 25 examples for concept: 'femur fracture' with difference mode: complete
2025-02-28 11:58:30,583 - INFO - HTTP Request: POS

Generated 200 total examples in batches

Example 1:
Positive: After falling off the ladder, the patient was diagnosed with a femur fracture that required immediate surgery.
Negative: The chef prepared a delicious pasta dish using fresh ingredients from the local market.

Example 2:
Positive: The x-ray revealed a clear femur fracture, which explained the patient's severe leg pain.
Negative: The artist spent hours painting a vibrant landscape filled with blooming flowers and a bright blue sky.

Example 3:
Positive: During the football game, he landed awkwardly and suffered a femur fracture that sidelined him for the season.
Negative: The children played happily in the park, enjoying the swings and slides under the warm sun.

Example 4:
Positive: She was in a car accident and suffered a femur fracture, prompting her to undergo physical therapy.
Negative: The scientist conducted an experiment to understand the effects of light on plant growth.

Example 5:
Positive: The doctor explained the

In [4]:
# Save examples
generator.save_examples_to_file(large_examples, "femur_examples_complete.json")

2025-02-28 12:00:30,544 - INFO - Saved 200 examples to femur_examples_complete.json


In [None]:
import transformer_lens as tl
import transformer_lens.utils as utils
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load gpt2-small
model = tl.HookedTransformer.from_pretrained("gpt2-small", device=device)

# load examples
with open("femur_examples.json", "r") as f:
    large_examples = json.load(f)["examples"]

print(large_examples[0])

In [None]:
# Stack positive examples
pos_examples = [x["positive"] for x in large_examples]
neg_examples = [x["negative"] for x in large_examples]

_, pos_cache = model.run_with_cache(model.to_tokens(pos_examples), stop_at_layer=layer+1, names_filter=[hook_name])
_, neg_cache = model.run_with_cache(model.to_tokens(neg_examples), stop_at_layer=layer+1, names_filter=[hook_name])

pos_resid = pos_cache[hook_name][:, -1] # batch, seq, d_model -> batch, d_model
neg_resid = neg_cache[hook_name][:, -1] # batch, seq, d_model -> batch, d_model

print(pos_resid.shape, neg_resid.shape)

# stack and create labels
resid = torch.cat([pos_resid, neg_resid], dim=0)
labels = torch.cat([torch.ones(len(pos_resid)), torch.zeros(len(neg_resid))])

# Shuffle and split into train/val
indices = torch.randperm(len(resid))
resid = resid[indices]
labels = labels[indices]

train_size = int(0.8 * len(resid))
train_resid = resid[:train_size]
train_labels = labels[:train_size] 
val_resid = resid[train_size:]
val_labels = labels[train_size:]

In [None]:
d_model = pos_resid.shape[1]

linear_probe = nn.Linear(d_model, 1, bias=True)
nn.init.xavier_normal_(linear_probe.weight)
nn.init.zeros_(linear_probe.bias)

loss_fn = nn.BCEWithLogitsLoss()

optimizer = torch.optim.Adam(linear_probe.parameters(), lr=1e-3)

@torch.no_grad()
def accuracy(logits, labels):
    preds = torch.round(torch.sigmoid(logits))
    print(preds, labels)
    return (preds == labels).float().mean()

# dictionary to store results
results = {
    "train_loss": [],
    "val_loss": [],
    "train_acc": [],
    "val_acc": []
}

for epoch in range(100):
    optimizer.zero_grad()
    logits = linear_probe(train_resid)
    loss = loss_fn(logits.squeeze(), train_labels)
    loss.backward()
    optimizer.step()
    train_acc = accuracy(logits, train_labels)
    val_logits = linear_probe(val_resid)
    val_loss = loss_fn(val_logits.squeeze(), val_labels)
    val_acc = accuracy(val_logits, val_labels)  
    #print(f"Epoch {epoch+1}, Train Loss: {loss.item()}, Val Loss: {val_loss.item()}, Train Acc: {train_acc.item()}, Val Acc: {val_acc.item()}")
    results["train_loss"].append(loss.item())
    results["val_loss"].append(val_loss.item())
    results["train_acc"].append(train_acc.item())
    results["val_acc"].append(val_acc.item())

print("Done!")

In [None]:
import plotly.express as px
import pandas as pd

train_loss = results["train_loss"]
val_loss = results["val_loss"]
train_acc = results["train_acc"]
val_acc = results["val_acc"]

fig = px.line(data_frame=pd.DataFrame({
    'epoch': range(len(train_loss)),
    'Train Loss': train_loss,
    'Validation Loss': val_loss
}).melt(id_vars=['epoch'], var_name='Metric', value_name='Loss'),
    x='epoch', y='Loss', color='Metric')
fig.show()

# Now plot accuracy
fig = px.line(data_frame=pd.DataFrame({
    'epoch': range(len(train_acc)),
    'Train Accuracy': train_acc,
    'Validation Accuracy': val_acc
}).melt(id_vars=['epoch'], var_name='Metric', value_name='Accuracy'),
    x='epoch', y='Accuracy', color='Metric')
fig.show()