In [3]:
import os
import jaxtyping
from pathlib import Path
import json

import time
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

import einops
import numpy as np
import torch as t
import torch.nn as nn
import wandb
import tqdm
import tabulate
from eindex import eindex
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from tabulate import tabulate
from torch import Tensor
from transformer_lens import HookedTransformer, utils
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tqdm import tqdm

device = t.device("cuda" if t.cuda.is_available() else "mps" if t.backends.mps.is_available() else "cpu")
print(f"Device: {device}")

Device: mps


In [4]:
run = wandb.init()
artifact = run.use_artifact('djdumpling-yale/rlhf_transformers/full-history-gpt2_20250727-211010:v0', type='responses')
artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33mdjdumpling[0m ([33mdjdumpling-yale[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [None]:
# Load the fine-tuned model
model_artifact = run.use_artifact('djdumpling-yale/rlhf_transformers/full-history-gpt2_20250727-211010:v0', type='responses')
model_dir = model_artifact.download()

# Load the model and move to device
model_path = os.path.join(model_dir, "final_model.pt")
model = t.load(model_path, map_location=device)
model.eval()  # Set to evaluation mode

# Function to generate text with the model
def generate_sentiment(prompt, n_samples=5, max_tokens=50, temperature=0.7):
    """Generate text samples from the model with a given prompt."""
    samples = []
    
    # Get base model to tokenize input
    input_ids = model.base_model.to_tokens(prompt, prepend_bos=True).to(device)
    
    for _ in range(n_samples):
        with t.inference_mode():
            output_ids = model.base_model.generate(
                input_ids,
                max_new_tokens=max_tokens,
                temperature=temperature,
                top_k=10,
                stop_at_eos=True
            )
            sample = model.base_model.to_string(output_ids)
            samples.append(sample)
    
    return samples

# Test prompts
test_prompts = [
    "This movie was really",  # Original training prompt
    "I just watched this film and it was",  # New prompt
    "The acting in this movie was",  # Focus on specific aspect
    "After seeing this movie, I felt"  # Emotional response
]

# Generate and display samples
for prompt in test_prompts:
    print(f"\nPrompt: {prompt}")
    print("-" * 50)
    samples = generate_sentiment(prompt, n_samples=3)
    for i, sample in enumerate(samples, 1):
        print(f"{i}. {sample}\n")


In [None]:
# Load the sentiment classifier we used for training
cls_model = AutoModelForSequenceClassification.from_pretrained("lvwerra/distilbert-imdb").half().to(device)
cls_tokenizer = AutoTokenizer.from_pretrained("lvwerra/distilbert-imdb")

def get_sentiment_score(text):
    """Get sentiment score for a piece of text (1 = positive, 0 = negative)"""
    tokens = cls_tokenizer(text, return_tensors="pt", padding=True, truncation=True)["input_ids"].to(device)
    with t.inference_mode():
        logits = cls_model(tokens).logits
        probs = logits.softmax(-1)
        positive_score = probs[:, 1].item()  # Probability of positive sentiment
    return positive_score

# Test a specific prompt with sentiment analysis
test_prompt = "This movie was really"
print(f"\nDetailed analysis for prompt: '{test_prompt}'\n")
print("-" * 50)

samples = generate_sentiment(test_prompt, n_samples=5, temperature=0.7)
for i, sample in enumerate(samples, 1):
    score = get_sentiment_score(sample)
    sentiment = "positive" if score > 0.5 else "negative"
    print(f"Sample {i}:")
    print(f"Text: {sample}")
    print(f"Sentiment Score: {score:.3f} ({sentiment})\n")


In [None]:
# Visualize attention patterns for a specific generation
def visualize_attention(prompt, layer_idx=None):
    """Visualize attention patterns for a given prompt"""
    input_ids = model.base_model.to_tokens(prompt, prepend_bos=True).to(device)
    
    # Get attention patterns
    with t.inference_mode():
        _, cache = model.base_model.run_with_cache(
            input_ids,
            return_type="logits",
            names_filter=lambda name: name.endswith("pattern")
        )
    
    # Get the tokens for visualization
    tokens = model.base_model.to_str_tokens(prompt, prepend_bos=True)
    
    # If no specific layer is provided, show the last layer
    if layer_idx is None:
        layer_idx = model.base_model.cfg.n_layers - 1
    
    # Get attention patterns for the specified layer
    pattern = cache[f"blocks.{layer_idx}.attn.pattern"][0]  # [n_heads, seq_len, seq_len]
    n_heads = pattern.shape[0]
    
    # Create a figure with subplots for each attention head
    fig, axes = plt.subplots(2, n_heads//2, figsize=(20, 8))
    axes = axes.flatten()
    
    for head_idx in range(n_heads):
        ax = axes[head_idx]
        attention = pattern[head_idx].cpu()
        
        # Create heatmap
        sns.heatmap(attention, ax=ax, cmap='viridis')
        ax.set_title(f'Head {head_idx}')
        ax.set_xticks(range(len(tokens)))
        ax.set_yticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=45, ha='right')
        ax.set_yticklabels(tokens, rotation=0)
    
    plt.tight_layout()
    plt.show()

# Import visualization libraries
import matplotlib.pyplot as plt
import seaborn as sns

# Visualize attention patterns for our test prompt
test_prompt = "This movie was really good"
print(f"Visualizing attention patterns for: '{test_prompt}'")
visualize_attention(test_prompt, layer_idx=11)  # Visualize the last layer
