In [None]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# CONFIG
BATCH_SIZE = 50  # Process 10 simulations at once (Increase if GPU usage is still low)
NUM_BATCHES = 1  # 5 batches * 10 per batch = 50 total simulations
TARGET_SET = {"Christmas", "Peace", "Tariff"}

# SETUP
device = "mps" if torch.backends.mps.is_available() else "cpu"
tokenizer = GPT2Tokenizer.from_pretrained("./trump_model_v1")
model = GPT2LMHeadModel.from_pretrained("./trump_model_v1").to(device)
model.eval()

# Fix padding for generation
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id

def run_batched_simulation(prompt):
    total_hits = 0
    total_sims = BATCH_SIZE * NUM_BATCHES
    
    print(f"Running {total_sims} simulations in batches of {BATCH_SIZE}...")
    
    # Prepare the batch inputs ONCE
    # We simply repeat the prompt list N times
    batch_inputs = tokenizer([prompt] * BATCH_SIZE, return_tensors="pt", padding=True, return_attention_mask=True).to(device)
    
    for i in range(NUM_BATCHES):
        # Generate 10 futures at the exact same time
        outputs = model.generate(
            input_ids=batch_inputs['input_ids'],
            attention_mask=batch_inputs['attention_mask'],
            max_new_tokens=5000, 
            do_sample=True, 
            temperature=0.85, 
            top_k=50
        )
        
        # Check results for this batch
        decoded_batch = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        
        for text in decoded_batch:
            # Only check the NEW part (slice off the prompt)
            new_content = text[len(prompt):].lower()
            if any(t.lower() in new_content for t in TARGET_SET):
                total_hits += 1
                
        print(f"Batch {i+1}/{NUM_BATCHES} complete.")

    probability = total_hits / total_sims
    return probability

# --- RUN IT ---
prompt = """[MODE: SPEECH] [TRUMP]: Well, thank you very much. Nice place. I guess you've mostly been here, but you like it a lot better with Trump than you like it with Biden. That I can tell you. That's because you're smart. Well, I'm thrilled to welcome so many good friends to the White House as we celebrate the third night of Hanukkah. Third night. Time, time flies. Let me take a moment to send the love and prayers to our entire nation, to the people of Australia, and especially all those affected by the horrific and anti-Semitic terrorist attack, and that's exactly what it is, anti-Semitic."""
prob = run_batched_simulation(prompt)
print(f"Probability: {prob:.1%}")

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Running 50 simulations in batches of 50...


This is a friendly reminder - the current text generation call has exceeded the model's predefined maximum length (1024). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.
