In [3]:
import csv

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np

In [12]:
model_name_dict = {
    "llama2_13b_chat": "meta-llama/Llama-2-13b-chat-hf",
    "llama3_8b": "meta-llama/Meta-Llama-3-8B",
    "llama3_8b_instruct": "meta-llama/Meta-Llama-3-8B-Instruct",
    "llama3_70b": "meta-llama/Meta-Llama-3-70B",
    "llama3_70b_instruct": "meta-llama/Meta-Llama-3-70B-Instruct",
}

In [4]:
def read_csv(file_path):
    data = []
    with open(file_path, mode='r') as file:
        reader = csv.reader(file)
        for row in reader:
            if reader.line_num == 1:
                continue
            row_data = row[0].replace('\\n', '\n')
            data.append(row_data)
    return data

In [14]:
def format_prompt(premise, example_playscripts_path):
    example_playscripts = read_csv(example_playscripts_path)
    prefix = "Generate a conversation between two characters, Alice and Bob, using the given premise. Here are a few examples:"
    few_shot_prompt = '\n\n'.join(example_playscripts[0:3])
    suffix = "Now, generate a conversation between Alice and Bob based on the following premise. Generate exactly 3 lines of dialogue per character, alternating per character, so 6 lines of dialogue in total. Do not include stage directions like '(smirking)' or '(smiling)'. At the end of the conversation, output '[END]'."
    premise_formatted = f"Premise: {premise}"
    prompt = '\n\n'.join([prefix, few_shot_prompt, suffix, premise_formatted])
    return prompt

In [15]:
def generate_text(model_name, prompt, num_generations=1, temperature=0.7):
    model_HF = model_name_dict[model_name]
    tokenizer = AutoTokenizer.from_pretrained(model_HF, trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained(model_HF, device_map="auto", low_cpu_mem_usage = True, torch_dtype=torch.float16, trust_remote_code=True)
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to("cuda")
    with torch.no_grad():
        model_gen_tokens = model.generate(input_ids, max_new_tokens=256, min_new_tokens=5, temperature=temperature, num_return_sequences=num_generations)[:, input_ids.shape[-1]:]
        generations = []
        for i in range(num_generations):
            generations.append(tokenizer.decode(model_gen_tokens[i], skip_special_tokens=True).strip())
        return generations

In [16]:
def filter_conversation_by_character(unfiltered_conversation):
    end_index = unfiltered_conversation.find("[END]")
    if end_index != -1:
        filtered_conversation = unfiltered_conversation[:end_index].split('\n')
    else:
        filtered_conversation = unfiltered_conversation.split('\n')
    filtered_conversation = [line for line in filtered_conversation if line.strip() != '']
    for i in range(len(filtered_conversation)):
        if "Alice: " in filtered_conversation[i] or "Bob: " in filtered_conversation[i]:
            filtered_conversation = filtered_conversation[i:]
            break
    character_dialogues = {}
    character_dialogues["Alice"] = []
    character_dialogues["Bob"] = []
    for i in range(len(filtered_conversation)):
        current_character = "Alice" if i % 2 == 0 else "Bob"
        if f'{current_character}: ' not in filtered_conversation[i]:
            raise ValueError(f"The current line doesn't start with {current_character}: {filtered_conversation[i]}")
        character, dialogue = filtered_conversation[i].split(": ", 1)
        character_dialogues[character].append(dialogue)
    return character_dialogues

In [17]:
def append_to_playscripts_csv(generated_playscripts_path, model_name, temperature, seed, num_generations, premise, unfiltered_conversation, alice_dialogues, bob_dialogues, error_message):
    unfiltered_conversation_single_line = unfiltered_conversation.replace('\n', '\\n')
    with open(generated_playscripts_path, mode='a', newline='') as file:
        writer = csv.writer(file)
        if file.tell() == 0:
            writer.writerow(["model_name", "temperature", "seed", "num_generations", "premise", "unfiltered_conversation", "alice_dialogues", "bob_dialogues", "error_message"])
        writer.writerow([model_name, temperature, seed, num_generations, premise, unfiltered_conversation_single_line, alice_dialogues, bob_dialogues, error_message])

In [18]:
def generate_playscripts(premises, model_name, num_generations=4, temperature=0.7, seed=42, example_playscripts_path="data/example_playscripts.csv", generated_playscripts_path="data/generated_playscripts.csv"):
    if seed is not None:
        torch.manual_seed(seed)
        np.random.seed(seed)
        torch.cuda.manual_seed_all(seed)
        
    for i, premise in enumerate(premises):
        print(f"Generating playscript {i+1} of {len(premises)}")
        unfiltered_conversations = []
        unfiltered_conversations = generate_text(model_name, format_prompt(premise, example_playscripts_path), num_generations=num_generations, temperature=temperature)
        for unfiltered_conversation in unfiltered_conversations:
            try:
                character_dialogues = filter_conversation_by_character(unfiltered_conversation)
                alice_dialogues = character_dialogues["Alice"]
                bob_dialogues = character_dialogues["Bob"]
                append_to_playscripts_csv(generated_playscripts_path, model_name, temperature, seed, num_generations, premise, unfiltered_conversation, alice_dialogues, bob_dialogues, "")
            except ValueError as e:
                append_to_playscripts_csv(generated_playscripts_path, model_name, temperature, seed, num_generations, premise, unfiltered_conversation, "", "", str(e))

In [5]:
example_playscripts_path = "data/example_playscripts.csv"
generated_playscripts_path = "data/generated_playscripts.csv"
test_premises_path = "data/test_premises.csv"
model_name = "llama3_8b_instruct"

test_premises = read_csv(test_premises_path)[:1]
generate_playscripts(test_premises, model_name, num_generations=4, temperature=0.7, seed=42, example_playscripts_path=example_playscripts_path, generated_playscripts_path=generated_playscripts_path)



In [8]:
# Manual editing of row 83, which threw an error when generating
with open(generated_playscripts_path, mode='r', newline='') as file:
    reader = csv.reader(file)
    rows = list(reader)

# Replace row 83 with row 82 (index 82 with 81 in 0-based index)
rows[82] = rows[81]

# Save the edited version to a new CSV file
generated_playscripts_edited_path = "data/generated_playscripts_edited.csv"
with open(generated_playscripts_edited_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(rows)
