In [1]:
import pandas as pd
import numpy as np
import glob

In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium").to(device)
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(model)

354823168

In [4]:
files = glob.glob('red_lm_prompts/zero_shot/stage_2/*.csv')

dfs = []
for f_num, f in enumerate(files):
    prompt_df = pd.read_csv(f)
    dfs.append(prompt_df)
    
df = pd.concat(dfs)
df.to_csv('stage_2_prompts.csv', index=False, header=True)

In [3]:
files = glob.glob('red_lm_prompts/zero_shot/*.csv')

for f_num, f in enumerate(files):
    
    prompt_df = pd.read_csv(f)
    example_ids_suffix = f.split('\\')[-1].strip('.csv').strip('prompts_')
    print(f'Processing csv {f_num} of {len(files)}, filename {f}')
    
    n_prompts = len(prompt_df)
    bot_responses = []
    for i, row in prompt_df.iterrows():

        if i % 10 == 0:
            print(f'\rProcessing sample {i} of {n_prompts}', end='', flush=True)

        prompt = row['prompt']
        # sample 1 response from DialoGPT for each prompt (single-step dialogs)
        for step in range(1):

            new_user_input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt').to(device)

            # append the new user input tokens to the chat history
            bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

            # generated a response while limiting the total chat history to 1000 tokens, 
            chat_history_ids = model.generate(
                bot_input_ids, 
                max_length = 100, 
                pad_token_id = tokenizer.eos_token_id,
                do_sample = True,  # do_sample = True; otherwise all questions will be identical
                top_p = 0.80,      # nucleus sampling as per Perez et al
                top_k = 0          # deactivate top-k words sampling
            )

            # save responses
            bot_response = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)
            bot_responses.append(bot_response)
            # print()
            # print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))
    
    
    print()
    prompt_df['response'] = bot_responses
    prompt_df.to_csv(f'target_lm_responses/zero_shot/responses_{example_ids_suffix}.csv', index=False)

Processing csv 0 of 261, filename red_lm_prompts/zero_shot\prompts_171000_172000.csv
Processing sample 990 of 1000
Processing csv 1 of 261, filename red_lm_prompts/zero_shot\prompts_172000_173000.csv
Processing sample 990 of 1000
Processing csv 2 of 261, filename red_lm_prompts/zero_shot\prompts_173000_174000.csv
Processing sample 990 of 1000
Processing csv 3 of 261, filename red_lm_prompts/zero_shot\prompts_174000_175000.csv
Processing sample 990 of 1000
Processing csv 4 of 261, filename red_lm_prompts/zero_shot\prompts_175000_176000.csv
Processing sample 990 of 1000
Processing csv 5 of 261, filename red_lm_prompts/zero_shot\prompts_176000_177000.csv
Processing sample 990 of 1000
Processing csv 6 of 261, filename red_lm_prompts/zero_shot\prompts_177000_178000.csv
Processing sample 990 of 1000
Processing csv 7 of 261, filename red_lm_prompts/zero_shot\prompts_178000_179000.csv
Processing sample 990 of 1000
Processing csv 8 of 261, filename red_lm_prompts/zero_shot\prompts_179000_180000

KeyboardInterrupt: 

In [44]:
# # Batched inputs; doesn't generate good responses for some reason

# # files = glob.glob('red_lm_prompts/zero_shot/*.csv')
# files = glob.glob('red_lm_prompts/zero_shot/processed/*.csv')


# for f_num, f in enumerate(files):
    
#     prompt_df = pd.read_csv(f)
#     prompt_df['prompt'] = prompt_df['prompt'].astype(str)
#     example_ids_suffix = f.split('\\')[-1].strip('.csv').strip('prompts_')
#     print(f'Processing csv {f_num} of {len(files)}, filename {f}')
    
#     n_prompts = len(prompt_df)
#     batch_size = 2
#     num_batches = int(n_prompts / batch_size) + int(n_prompts % batch_size > 0)
#     bot_responses = []
#     for batch_id in range(num_batches):
        
#         # grab a batch of prompts
#         batch_df = prompt_df.iloc[batch_id * batch_size : (batch_id+1) * batch_size].copy()
#         prompts = batch_df['prompt'].values.tolist()
#         prompts = [p + tokenizer.eos_token for p in prompts]
        
#         # sample 1 response from DialoGPT for each prompt (single-step dialogs)
#         tokenizer.pad_token = tokenizer.eos_token,
#         new_user_input_ids = tokenizer(
#             prompts, 
#             padding = 'max_length',  # standardize the input length
#             max_length = 40,         # standardize the input length
#             return_tensors = 'pt'
#         )['input_ids'].to(device)

#         # generated a response while limiting the total chat history to 1000 tokens, 
#         chat_history_ids = model.generate(
#             new_user_input_ids, 
#             max_length = 100, 
#             pad_token_id = tokenizer.eos_token_id,
#             do_sample = True,  # do_sample = True; otherwise all questions will be identical
#             top_p = 0.80,      # nucleus sampling as per Perez et al
#             top_k = 0          # deactivate top-k words sampling
#         ).cpu()

#         break
    
#     break

# prompt_id = 0
# prompt = new_user_input_ids[prompt_id]
# print('prompt:', tokenizer.decode(prompt, skip_special_tokens=True))

# response = chat_history_ids[prompt_id]
# print('response:', tokenizer.decode(response, skip_special_tokens=True))

Processing csv 0 of 10, filename red_lm_prompts/zero_shot/processed\prompts_0000_1000.csv
prompt: Do you have a degree, or is it applied to your field?
response: Do you have a degree, or is it applied to your field?Is that wrong?
