In [1]:
# import submitit
import os
from fastchat.model import get_conversation_template
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

In [2]:
# creative prompts

creative_prompts = ["Write a poem", "Tell me a joke", "Describe the feeling of love", "Write a story starting with 'Once upon a time...'",
                    "Tell a story about a dog", "Write a song", "Write a poem about a robot", "Invent an original recipe",
                    "Imagine a new object and describe what it looks like.", "Imagine a new philosophy and describe it.", 
                    "Create a new game and explain the rules.", "Write a new myth explaining the origin of rainbows.", 
                    "Write a dialogue between the moon and the sun", "Compose a lullaby", "Write a news headline for the year 2050.",
                    "Invent a riddle and write it down.", "Write a story about two people seeing each other for the first time.",
                    "Write a story about a person who is afraid of the dark.", "Make a new pun about llamas.", "Invent a new word and define it."]

# factual prompts
factual_prompts = ["What is the capital of France?", "How is H2O commonly known?", "What is the largest country in the world?", "How many days are in a year?",
                   "What is the largest planet in the solar system?", "What is the largest animal in the world?", "How do you say hello in Spanish?", "Who won the 2018 World Cup?",
                   "What is the biggest city in Europe?", "What is the largest country in Africa?", "What was the last battle of Napoleon?", "How do you call someone from New Zealand?",
                   "How do you call someone who studies plants?", "Who invented the telephone?", "What mammal lays eggs?", "Which bone is the longest in the human body?", "What is the anthem of France?",
                   "Who wrote Cannery Row?", "Who was the first president of the United States?", "Which painter painted the Mona Lisa?"]



# factual prompts whose answers are longer
# not used right now
factual_prompts_longer = ["What separates prehistory from history?", "How were hyerogliphics deciphered?"]

print(len(creative_prompts))
print(len(factual_prompts))

20
20


In [3]:
# https://github.com/facebookresearch/llama
# https://github.com/facebookresearch/llama/blob/main/llama/generation.py 
# chat_completion style for llama2-chat
# https://github.com/facebookresearch/llama/blob/main/example_chat_completion.py 
def format_prompt_llama2_chat(prompt):
    prompt_format = """<s>[INST] <<SYS>>
    You are a helpful, respectful and honest assistant. Always answer without asking questions or clarifications.
    <</SYS>>

    {} [/INST]"""
    return prompt_format.format(prompt)

# https://arxiv.org/abs/2204.05862
# https://huggingface.co/datasets/Anthropic/hh-rlhf
# https://huggingface.co/datasets/Dahoas/static-hh
def format_prompt_pythia_helpful(prompt):
    prompt_format = """Human: {} Assistant: """
    return prompt_format.format(prompt)

def format_prompt_PLM(prompt):
    prompt_format = """{} Okay, here goes: """
    return prompt_format.format(prompt)

In [4]:
temperatures = [k / 10. for k in range(1, 16)]

In [5]:
# each temperature and for each prompt, generate n_generations samples
temperatures = [k / 10. for k in range(1, 16)]
models = ["llama2-chat"]
n_generations = 25
completions_creative = np.zeros((len(temperatures), len(creative_prompts), len(models)), dtype=object)
completions_factual = np.zeros((len(temperatures), len(factual_prompts), len(models)), dtype=object)

# define the function to be submitted
# def generate_samples(args):
def generate_samples(prompt, temperatures, model_name):
    max_return_sequences = 5 #for memory reasons, we generate the samples in batches of 5
    # i, prompt, temperatures, model_name = args
    if model_name == "llama2-chat":
        tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
        model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16 )
        full_prompt = format_prompt_llama2_chat(prompt)
    # elif model_name == "vicuna1.5":
    #     model_path = "lmsys/vicuna-7b-v1.5"
    #     tokenizer = AutoTokenizer.from_pretrained(model_path)
    #     model = AutoModelForCausalLM.from_pretrained(model_path)
    #     #adapted from https://github.com/lm-sys/FastChat/blob/a47b8f9e93c8b5a85e81d1ae33e3a1106d8cdf80/fastchat/serve/huggingface_api.py
    #     full_prompt = format_prompt_vicuna(prompt)
    model.to("cuda:0")
    input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to("cuda:0")
    completions = []
    for temperature in temperatures:
        temp_completions = []
        for _ in range(n_generations // max_return_sequences):
            samples = model.generate(input_ids, temperature=temperature, max_length=input_ids.shape[1] + 70,
                                    num_return_sequences=max_return_sequences, do_sample=True)
            # remove prompt from the samples
            samples = [sample[input_ids.shape[1]:] for sample in samples]
            samples = [tokenizer.decode(sample, skip_special_tokens=True) for sample in samples]
            temp_completions.extend(samples)
        completions.append(temp_completions)
    return completions

# create a folder for the logs of the submitted jobs
# os.makedirs("logs", exist_ok=True)


for model in models:
    # create a submitit executor
    # executor = submitit.AutoExecutor(folder="logs")

    # specify the parameters for the Slurm job
    #exclude margpu002 and margpu003
    # executor.update_parameters(timeout_min=60, slurm_partition="parietal,gpu,normal", gpus_per_node=1,
    #                             # exclude nodes
    #                             exclude="margpu002,margpu003")
    
    # For creative prompts
    # args_list_creative = [(i, prompt, temperatures, model) for i, prompt in enumerate(creative_prompts)]
    # jobs_creative = executor.map_array(generate_samples, args_list_creative)

    # # For factual prompts
    # args_list_factual = [(i, prompt, temperatures, model) for i, prompt in enumerate(factual_prompts)]
    # jobs_factual = executor.map_array(generate_samples, args_list_factual)

    # # # Collect the results for creative prompts
    # for i, job in enumerate(jobs_creative):
    #     model_completions = job.result()
    #     for t_index, completion in enumerate(model_completions):
    #         completions_creative[t_index, i, models.index(model)] = completion
        
    # # # Collect the results for factual prompts
    # for i, job in enumerate(jobs_factual):
    #     model_completions = job.result()
    #     for t_index, completion in enumerate(model_completions):
    #         completions_factual[t_index, i, models.index(model)] = completion

    # # Collect the results for creative prompts
    # for i, job in enumerate(jobs_creative):
    #     model_completions = job.result()
    #     for t_index, completion in enumerate(model_completions):
    #         completions_creative[t_index, i, models.index(model)] = completion

    # # Collect the results for factual prompts
    # for i, job in enumerate(jobs_factual):
    #     model_completions = job.result()
    #     for t_index, completion in enumerate(model_completions):
    #         completions_factual[t_index, i, models.index(model)] = completion

    model_completions_creative = []
    for i, prompt in enumerate(creative_prompts): 
        model_completions = generate_samples(prompt, temperatures, model)
        for t_index, completion in enumerate(model_completions):
            completions_creative[t_index, i, models.index(model)] = completion

    model_completions_factual = []
    for i, prompt in enumerate(creative_prompts): 
        model_completions = generate_samples(prompt, temperatures, model)
        for t_index, completion in enumerate(model_completions):
            completions_factual[t_index, i, models.index(model)] = completion

# Save the results
np.save(f'{model}_completions_creative_max_length70.npy', completions_creative)
np.save(f'{model}_completions_factual_max_length70.npy', completions_factual)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 86.00 MiB. GPU 0 has a total capacty of 10.75 GiB of which 68.44 MiB is free. Process 3977277 has 1.79 GiB memory in use. Process 4034090 has 8.86 GiB memory in use. Of the allocated memory 8.71 GiB is allocated by PyTorch, and 1.65 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [11]:
args_list_creative

[(0,
  'Write a poem',
  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5],
  'llama2-chat'),
 (1,
  'Tell me a joke',
  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5],
  'llama2-chat'),
 (2,
  'Describe the feeling of love',
  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5],
  'llama2-chat'),
 (3,
  "Write a story starting with 'Once upon a time...'",
  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5],
  'llama2-chat'),
 (4,
  'Tell a story about a dog',
  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5],
  'llama2-chat'),
 (5,
  'Write a song',
  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5],
  'llama2-chat'),
 (6,
  'Write a poem about a robot',
  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5],
  'llama2-chat'),
 (7,
  'Invent an original recipe',
  [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 

In [12]:
jobs_creative

[LocalJob<job_id=2239679, task_id=0, state="FINISHED">,
 LocalJob<job_id=2239683, task_id=0, state="FINISHED">,
 LocalJob<job_id=2239724, task_id=0, state="FINISHED">,
 LocalJob<job_id=2239765, task_id=0, state="FINISHED">,
 LocalJob<job_id=2239806, task_id=0, state="FINISHED">,
 LocalJob<job_id=2239847, task_id=0, state="FINISHED">,
 LocalJob<job_id=2239891, task_id=0, state="FINISHED">,
 LocalJob<job_id=2239932, task_id=0, state="FINISHED">,
 LocalJob<job_id=2239973, task_id=0, state="FINISHED">,
 LocalJob<job_id=2240015, task_id=0, state="FINISHED">,
 LocalJob<job_id=2240066, task_id=0, state="FINISHED">,
 LocalJob<job_id=2240107, task_id=0, state="FINISHED">,
 LocalJob<job_id=2240151, task_id=0, state="FINISHED">,
 LocalJob<job_id=2240191, task_id=0, state="FINISHED">,
 LocalJob<job_id=2240233, task_id=0, state="FINISHED">,
 LocalJob<job_id=2240279, task_id=0, state="FINISHED">,
 LocalJob<job_id=2240320, task_id=0, state="FINISHED">,
 LocalJob<job_id=2240361, task_id=0, state="FINI