In [18]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_text = input(">> User:") + tokenizer.eos_token
    new_user_input_ids = tokenizer.encode(new_user_input_text, return_tensors='pt')
    print('input text:', new_user_input_text)
    print('input token ids:', new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
    print('bot input ids:', bot_input_ids)

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
    
    # pretty print last ouput tokens from bot
    print("History + DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:,:][0], skip_special_tokens=False)))
    print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))


>> User:Can money buy happiness?
input text: Can money buy happiness?<|endoftext|>
input token ids: tensor([[ 6090,  1637,  2822, 12157,    30, 50256]])
bot input ids: tensor([[ 6090,  1637,  2822, 12157,    30, 50256]])
History + DialoGPT: Can money buy happiness?<|endoftext|>Money can buy happiness.<|endoftext|>
DialoGPT: Money can buy happiness.
>> User:How so?
input text: How so?<|endoftext|>
input token ids: tensor([[ 2437,   523,    30, 50256]])
bot input ids: tensor([[ 6090,  1637,  2822, 12157,    30, 50256, 26788,   460,  2822, 12157,
           764, 50256,  2437,   523,    30, 50256]])
History + DialoGPT: Can money buy happiness?<|endoftext|>Money can buy happiness.<|endoftext|>How so?<|endoftext|>This is the most depressing thing I've read all day.<|endoftext|>
DialoGPT: This is the most depressing thing I've read all day.


KeyboardInterrupt: Interrupted by user

In [1]:
import pandas as pd
pd.set_option('display.max_colwidth', None)

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

In [2]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small").to(device)
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")

In [3]:
conv_df = pd.read_csv('data/reddit_conversations/sample.csv', header=None)
conv_df.columns = ['id', 'context', 'response']

In [5]:
import re 
from nltk.tokenize import TweetTokenizer

def preprocess_text(
    txt,
    dataset_eos_token = 'EOS',
    tokenizer_eos_token = '<|endoftext|>'
):
    # remove "title : " prefixes
    if txt[:8] == 'title : ':
        txt = txt[8:]
    
    txt = str(txt).lower()
    
    # url and tag
    words = []
    for word in txt.split():
        if word[0] == '#': # don't allow tag
            continue
        i = word.lower().find('http')
        if i >= 0:
            word = word[:i] + ' ' + '__url__'
        words.append(word.strip())
    txt = ' '.join(words)

    # remove illegal char
    txt = txt.replace(chr(92),'') # chr(92) = '\'. as twitter has 'b\/c' rather than 'b/c'
    txt = txt.replace("b/c","because").replace('j/k','just kidding').replace('w/o','without').replace('w/','with')
    txt = re.sub('__mention__','MENTION',txt)
    txt = re.sub('__url__','URL',txt)
    txt = re.sub(r"[^A-Za-z0-9()\[\]:,.!?'“” ]", " ", txt)
    txt = re.sub('MENTION','__mention__',txt)
    txt = re.sub('URL','__url__',txt)

    tokenizer = TweetTokenizer(preserve_case=True)
    txt = ' ' + ' '.join(tokenizer.tokenize(txt)) + ' '
    
    # remove un-necessary space
    txt = ' '.join(txt.split())
    
    # replace 'EOS' with tokenizer's EOS
    txt_utterance_split = txt.split(dataset_eos_token.lower())
    txt = tokenizer_eos_token.join([s.strip() for s in txt_utterance_split])
    
    return txt



prepro = preprocess_text("title : I thought the Reddit Admins Promised no more ads with auto play sounds ? ? ? EOS You're supposed to block ads . EOS Not for sites you like . EOS Yes , all sites . Everywhere .")
print(prepro)
tokenizer(prepro)

i thought the reddit admins promised no more ads with auto play sounds ? ? ?<|endoftext|>you're supposed to block ads .<|endoftext|>not for sites you like .<|endoftext|>yes , all sites . everywhere .


{'input_ids': [72, 1807, 262, 18374, 44563, 8072, 645, 517, 9011, 351, 8295, 711, 5238, 5633, 5633, 5633, 50256, 5832, 821, 4385, 284, 2512, 9011, 764, 50256, 1662, 329, 5043, 345, 588, 764, 50256, 8505, 837, 477, 5043, 764, 8347, 764], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [6]:
%load_ext autoreload
%autoreload 2

from dataset import DialogLMDataset
from torch.utils.data import DataLoader
from collate_fns import DialogCollate, DialogCollateExperimental

In [7]:
# train sets
dataset_train = DialogLMDataset(
    df = conv_df,
    src_col = 'context',
    src_eos_token = ' EOS ',
    tokenizer = tokenizer,
    tgt_col = 'response',
)
print(dataset_train[-1])

tokenizer.decode(dataset_train[-1]['src'])

Processing row: 0 of 252{'src': [72, 1807, 262, 18374, 44563, 8072, 645, 517, 9011, 351, 8295, 711, 5238, 5633, 5633, 5633, 50256, 5832, 821, 4385, 284, 2512, 9011, 764, 50256, 1662, 329, 5043, 345, 588, 764, 50256, 8505, 837, 477, 5043, 764, 8347, 764], 'tgt': [568, 345, 2138, 1414, 329, 262, 5043, 788, 4379, 9011, 5633]}


"i thought the reddit admins promised no more ads with auto play sounds???<|endoftext|>you're supposed to block ads.<|endoftext|>not for sites you like.<|endoftext|>yes, all sites. everywhere."

In [8]:
# collate object
collate_fn = DialogCollateExperimental(
    tokenizer = tokenizer,
    max_len = 512,
    _targets_ignore_index = -100,
    _pad_token_id = tokenizer.eos_token_id  # 0
)

# dataloader
loader_val = DataLoader(
    dataset_train,
    batch_size = 1,
    shuffle = False,
    collate_fn = collate_fn
)

# example usage
for i, batch in enumerate(loader_val):
    if i == 3:
        break

example_id = 0
print('input_ids:\n',       batch['input_ids'][example_id])
print('position_ids:\n',    batch['position_ids'][example_id])
print('token_type_ids:\n',  batch['token_type_ids'][example_id])
print('target_ids:\n',      batch['target_ids'][example_id])
print('attention_masks:\n', batch['attention_masks'][example_id])

tokenizer.decode(batch['input_ids'][example_id])

input_ids:
 tensor([  404,   259,   507,   319,   262,   285,  1415,   290,   308,    18,
        36237,  5633, 50256,   270,  2936, 12178,    88,   290, 21873,   837,
          290,  1312,   550,  5876,  9008,  6670,   326,   338,   780,   262,
          936,   519, 46733, 22523,   837,   340,   338,  1464, 12178,    88,
          290, 21873, 50256,  1169,   936,   519,   318,  9623,   319,   749,
         3777,   764,  1312,   460, 16465,   680,   661,   379, 26160,  2837,
          351,   281,   895,    70,   351,   340,   764,   655,   407,   351,
          262,   285,  1415,   837,   329,   617,  1738,   764, 50256])
position_ids:
 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77,  0]

"opinions on the m14 and g3 optics?<|endoftext|>it felt floaty and inaccurate, and i had trouble hitting targets that's because the acog fuckin sucks, it's always floaty and inaccurate<|endoftext|>the acog is fantastic on most weapons. i can demolish people at sniper range with an smg with it. just not with the m14, for some reason.<|endoftext|>"

In [178]:
# generate responses to each example in training set

# perhaps it's most convenient to have the random cutoffs during the dataset instantiation 
# and do a few runs through the dataset to generate random stuff

n_samples_to_generate = 100
n_valid_samples_so_far = 0
samples = {
    'context': [],
    'response': []
}
index_of_example = 0
epoch = 0  # how many times have we made a pass through the data?
while n_valid_samples_so_far < n_samples_to_generate:
    
    print(f'\rProcessing sample {n_valid_samples_so_far} of {n_samples_to_generate}', end='', flush=True)
    
    # grab the example
    context = conv_df.iloc[index_of_example]['context'] + ' EOS ' + conv_df.iloc[index_of_example]['response']
    context = preprocess_text(context)
    context_split_utterances = [s.strip() for s in context.split('<|endoftext|>')]
    
    # randomly cut off conversation at one of the turns
    cutoff_point = random.choice(range(len(context_split_utterances)+1))
    context_split_utterances = context_split_utterances[:cutoff_point]
    if len(context_split_utterances) == 0:  # do unconditional generation
        context = ''
        input_ids = None
        len_input_ids = 0
    else:
        context = '<|endoftext|>'.join(context_split_utterances) + '<|endoftext|>'
        input_ids = tokenizer(context, add_special_tokens=False, return_tensors="pt")["input_ids"].to(device)
        len_input_ids = input_ids.shape[-1]
    
    # model continues conversation
    outputs = model.generate(
        inputs = input_ids, 
        max_length = 1000,
        pad_token_id = tokenizer.eos_token_id,
        do_sample = True,  # do_sample = True; otherwise all questions will be identical
        top_p = 0.95,      # nucleus sampling
        top_k = 0          # deactivate top-k words sampling
    ).cpu()
    output = outputs[0]
    output_bot_only = output[len_input_ids:]
    output_bot_only_decoded = tokenizer.decode(output_bot_only, skip_special_tokens=True)
    
    # validate; skip if invalid
    output_preproc = preprocess_text(output_bot_only_decoded).strip()
    if len(output_preproc) == 0:  # put validation condition here
        continue
    
    # reformat and save
    if context[-13:] == '<|endoftext|>':
        context = context[:-13]
    context_reformat = context.split('<|endoftext|>')
    context_reformat = ' EOS '.join(context_reformat)
    response_reformat = output_preproc.strip('<|endoftext|>')
    samples['context'].append(context_reformat)
    samples['response'].append(response_reformat)
    
    # dump every 1000 samples
    if n_valid_samples_so_far % 1000 == 0:  # dump every 1000 prompts
        print(f'Saving {len(samples['context'])} samples to csv')
        pd.DataFrame(samples).to_csv(f'red_lm_prompts/zero_shot/prompts_{int(n_valid_samples_so_far/1000)-1}000_{int(n_valid_samples_so_far/1000)}000.csv', index=False)
        samples = {}
    
    # increment example
    if index_of_example == len(conv_df)-1:
        epoch += 1
    index_of_example = (index_of_example + 1) % len(conv_df)
    n_valid_samples_so_far += 1
    
    

Processing sample 99 of 100