In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import pandas as pd
from tqdm.auto import tqdm
from os.path import join
import numpy as np

model_dir = '/home/leiyu/projects/def-yangxu/leiyu/LMs/'
data_dir = '/home/leiyu/projects/def-yangxu/leiyu/noisy-hall-interp/data/'
model_name = 'Llama-2-7b-chat-hf'

In [1]:
prompt_template_with_ctx = """Task instruction: answer the given question with provided context that can help you arrive at the answer before the question. Keep the answer short and concise.

##### Examples #####

Context: 16th Infantry Regiment (United States): As part of the new 1st Expeditionary Division, soon to become known as the ‘Big Red One’, the 16th Infantry, commanded by William Herbert Allaire Jr., sailed 
Question: how did the big red one get its name 
The answer is: western Caribbean Sea

Context: Module:Location map/data/Cayman Islands: Module:Location map/data/Cayman Islands is a location map definition used to overlay markers and labels on an equirectangular projection map of Cayman
Question: where are the cayman islands on the map
The answer is: its shoulder patch

Context: First Battle of Seoul: The First Battle of Seoul, known in North Korean historiography as the Liberation of Seoul, was the North Korean capture of the South Korean capital, Seoul,
Question: who won the war between north korea and south korea
The answer is: technically still at war

Context: It’s Always Sunny in Philadelphia (season 13): The thirteenth season of the American comedy television series It’s Always Sunny in Philadelphia premiered on FXX on September 5, 2018.
Question: when does it’s always sunny in philadelphia season 13 start
The answer is: September 5, 2018

Context: Randy Newman – You’ve Got a Friend in Me Lyrics: ‘You’ve Got A Friend In Me’ is the theme song of the Toy Story franchise, recurring throughout the series in different contexts. It’s first
Question: who sang you got a friend in me from toy story
The answer is: Randy Newman

##### Follow the instructions and the example(s) above #####

Context : {}
Question: {}
The answer is:"""



prompt_template_no_ctx = """
Task instruction: answer the given question based on your world knowledge. Keep the answer short and concise. 

##### Examples #####

Question: how did the big red one get its name 
The answer is: western Caribbean Sea

Question: where are the cayman islands on the map
The answer is: its shoulder patch

Question: who won the war between north korea and south korea
The answer is: technically still at war

Question: when does it’s always sunny in philadelphia season 13 start
The answer is: September 5, 2018

Question: who sang you got a friend in me from toy story
The answer is: Randy Newman

##### Follow the instructions and the example(s) above #####

Question: {}
The answer is:"""


Given the following question, answer it by providing follow up questions and intermediate answers. 
If intermediate questions are not necessary, answer the question directly. 
You are provided with evidence that can help you arrive at the answer before the question.



In [35]:
def prepare_inputs_with_context(nq_row):
    question = nq_row['question']
    ctx_list = nq_row['retrieved contexts']
    prompts = []
    prompts.append(prompt_template_no_ctx.format(question))  # the first prompt for each row is the query without context
    for ctx in ctx_list:
        prompts.append(prompt_template_with_ctx.format(ctx, question))
    return prompts

In [2]:
# load LM and tokenizer
device = torch.device('cuda')
model_path = join(model_dir, model_name)
model = AutoModelForCausalLM.from_pretrained(model_path).to(device)
model.eval();

tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [11]:
from ast import literal_eval
nq_df_with_ret_ctx = pd.read_csv(data_dir + 'nq-dev_with_ret.csv', 
                                 converters={'answer': literal_eval, 'retrieved contexts': literal_eval})
nq_df_with_ret_ctx

Unnamed: 0,question,answer,retrieved contexts
0,when does the new my hero academia movie come out,"[July 5 , 2018]",[My Hero Academia: World Heroes' Mission was r...
1,who plays letty in bring it on all or nothing,[Francia Raisa],[Bring It On: All or Nothing (previously known...
2,when did wesley leave last of the summer wine,[2002],[Gordon Wharmby (6 November 1933 – 18 May 2002...
3,who introduced the system of civil services in...,[Charles Cornwallis],"[The Indian Civil Service (ICS), officially kn..."
4,who made the first telephone in the world,[Alexander Graham Bell],[Alexander Graham Bell was a Scottish-born inv...
...,...,...,...
1459,who sings the song i 'm just a love machine,[The Miracles],"[""Love Machine"" is a 1975 single recorded by M..."
1460,who plays aunt carol in dear dumb diary,[Laura Bell Bundy],[Dear Dumb Diary is a Hallmark Channel televis...
1461,when was netball first in the commonwealth games,[1998],[Netball was one of the sports contested at th...
1462,where does anything you can do i can do better...,[Annie Get Your Gun],[Irving Berlin was an American composer and ly...


In [None]:
ans = []
for _, row in tqdm(nq_df_with_ret_ctx.iterrows(), total=nq_df_with_ret_ctx.shape[0]):
    prompts_row = prepare_inputs_with_context(row)
    ans_row = []
    for prompt in prompts_row:
        input_ids = tokenizer(prompt, return_tensors='pt').input_ids.to(device)
        with torch.no_grad():
            outputs = model.generate(input_ids, max_new_tokens=50)
            gen = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
            ans_row.append(gen)
            # print(gen)
            # print('-'*50)
            # print()
        torch.cuda.empty_cache()
    ans.append(ans_row)