In [2]:
# !pip install ipywidgets

Collecting ipywidgets
  Downloading ipywidgets-7.7.1-py2.py3-none-any.whl (123 kB)
Collecting widgetsnbextension~=3.6.0
  Downloading widgetsnbextension-3.6.1-py2.py3-none-any.whl (1.6 MB)
Collecting jupyterlab-widgets>=1.0.0
  Downloading jupyterlab_widgets-1.1.1-py3-none-any.whl (245 kB)
Installing collected packages: widgetsnbextension, jupyterlab-widgets, ipywidgets
Successfully installed ipywidgets-7.7.1 jupyterlab-widgets-1.1.1 widgetsnbextension-3.6.1


In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch


tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_text = input(">> User:") + tokenizer.eos_token
    new_user_input_ids = tokenizer.encode(new_user_input_text, return_tensors='pt')
    print('input text:', new_user_input_text)
    print('input token ids:', new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
    print('bot input ids:', bot_input_ids)

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id)
    
    # pretty print last ouput tokens from bot
    print("History + DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:,:][0], skip_special_tokens=False)))
    print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))


Downloading:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/641 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/335M [00:00<?, ?B/s]

>> User:Hi
input text: Hi<|endoftext|>
input token ids: tensor([[17250, 50256]])
bot input ids: tensor([[17250, 50256]])
History + DialoGPT: Hi<|endoftext|>Hi<|endoftext|>
DialoGPT: Hi
>> User:How are you
input text: How are you<|endoftext|>
input token ids: tensor([[ 2437,   389,   345, 50256]])
bot input ids: tensor([[17250, 50256, 17250, 50256,  2437,   389,   345, 50256]])
History + DialoGPT: Hi<|endoftext|>Hi<|endoftext|>How are you<|endoftext|>How are you<|endoftext|>
DialoGPT: How are you
>> User:I'm good
input text: I'm good<|endoftext|>
input token ids: tensor([[   40,  1101,   922, 50256]])
bot input ids: tensor([[17250, 50256, 17250, 50256,  2437,   389,   345, 50256,  2437,   389,
           345, 50256,    40,  1101,   922, 50256]])
History + DialoGPT: Hi<|endoftext|>Hi<|endoftext|>How are you<|endoftext|>How are you<|endoftext|>I'm good<|endoftext|>How are you<|endoftext|>
DialoGPT: How are you
>> User:Do you like chatting
input text: Do you like chatting<|endoftext|>
inpu

In [4]:
import pandas as pd
pd.set_option('display.max_colwidth', None)

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

In [5]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small").to(device)
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")

In [6]:
conv_df = pd.read_csv('out/conv/2010-11.tsv', header=None, delimiter = '\t').sample(500, random_state = 42)
conv_df.columns = ['id', 'context', 'response']

In [7]:
import re 
from nltk.tokenize import TweetTokenizer

def preprocess_text(
    txt,
    dataset_eos_token = 'EOS',
    tokenizer_eos_token = '<|endoftext|>'
):
    # remove "title : " prefixes
    if txt[:8] == 'title : ':
        txt = txt[8:]
    
    txt = str(txt).lower()
    
    # url and tag
    words = []
    for word in txt.split():
        if word[0] == '#': # don't allow tag
            continue
        i = word.lower().find('http')
        if i >= 0:
            word = word[:i] + ' ' + '__url__'
        words.append(word.strip())
    txt = ' '.join(words)

    # remove illegal char
    txt = txt.replace(chr(92),'') # chr(92) = '\'. as twitter has 'b\/c' rather than 'b/c'
    txt = txt.replace("b/c","because").replace('j/k','just kidding').replace('w/o','without').replace('w/','with')
    txt = re.sub('__mention__','MENTION',txt)
    txt = re.sub('__url__','URL',txt)
    txt = re.sub(r"[^A-Za-z0-9()\[\]:,.!?'“” ]", " ", txt)
    txt = re.sub('MENTION','__mention__',txt)
    txt = re.sub('URL','__url__',txt)

    tokenizer = TweetTokenizer(preserve_case=True)
    txt = ' ' + ' '.join(tokenizer.tokenize(txt)) + ' '
    
    # remove un-necessary space
    txt = ' '.join(txt.split())
    
    # replace 'EOS' with tokenizer's EOS
    txt_utterance_split = txt.split(dataset_eos_token.lower())
    txt = tokenizer_eos_token.join([s.strip() for s in txt_utterance_split])
    
    return txt



prepro = preprocess_text("title : I thought the Reddit Admins Promised no more ads with auto play sounds ? ? ? EOS You're supposed to block ads . EOS Not for sites you like . EOS Yes , all sites . Everywhere .")
print(prepro)
tokenizer(prepro)

i thought the reddit admins promised no more ads with auto play sounds ? ? ?<|endoftext|>you're supposed to block ads .<|endoftext|>not for sites you like .<|endoftext|>yes , all sites . everywhere .


{'input_ids': [72, 1807, 262, 18374, 44563, 8072, 645, 517, 9011, 351, 8295, 711, 5238, 5633, 5633, 5633, 50256, 5832, 821, 4385, 284, 2512, 9011, 764, 50256, 1662, 329, 5043, 345, 588, 764, 50256, 8505, 837, 477, 5043, 764, 8347, 764], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [9]:
%load_ext autoreload
%autoreload 2

from dataset import DialogLMDataset
from torch.utils.data import DataLoader
from collate_fns import DialogCollate, DialogCollateExperimental

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
# train sets
dataset_train = DialogLMDataset(
    df = conv_df,
    src_col = 'context',
    src_eos_token = ' EOS ',
    tokenizer = tokenizer,
    tgt_col = 'response',
)
print(dataset_train[-1])

tokenizer.decode(dataset_train[-1]['src'])

Processing row: 0 of 500{'src': [11708, 17084, 783, 8781, 345, 10784, 9813, 284, 19974, 3696, 287, 11824, 329, 663, 18325], 'tgt': [2949, 17084, 2985, 2152, 284, 6216, 428, 989, 764]}


'Google Wave now lets you export waves to zip files in preparation for its shutdown'

In [11]:
# collate object
collate_fn = DialogCollateExperimental(
    tokenizer = tokenizer,
    max_len = 512,
    _targets_ignore_index = -100,
    _pad_token_id = tokenizer.eos_token_id  # 0
)

# dataloader
loader_val = DataLoader(
    dataset_train,
    batch_size = 1,
    shuffle = False,
    collate_fn = collate_fn
)

# example usage
for i, batch in enumerate(loader_val):
    if i == 3:
        break

example_id = 0
print('input_ids:\n',       batch['input_ids'][example_id])
print('position_ids:\n',    batch['position_ids'][example_id])
print('token_type_ids:\n',  batch['token_type_ids'][example_id])
print('target_ids:\n',      batch['target_ids'][example_id])
print('attention_masks:\n', batch['attention_masks'][example_id])

tokenizer.decode(batch['input_ids'][example_id])

input_ids:
 tensor([   38, 35619, 36179, 26730,  2141, 50256,    38, 35619,   318,   407,
          281,  5035,  1573,   329,   326,  2939,   764, 50256,    47, 29528,
        26756, 36179, 26730,  2141,    78, 50256,  5167,   588,  9578,   604,
         5542, 36179, 26730,  2141,    78, 50256,  1212,   318,   406,    19,
           35, 36179, 26730,  2141,    78,  1058, 11593,  6371,   834, 50256])
position_ids:
 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48,  0])
token_type_ids:
 tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1])
target_ids:
 tensor([   38, 35619, 36179, 26730,  2141, 50256,    38, 35619,   318,   407,
          281,  5035,  1573,   329,   326,  2939,   764, 50256,    47, 295

'Ghetto Scooby Do<|endoftext|>Ghetto is not an appropriate word for that image.<|endoftext|>Pulp Fiction Scooby Doo<|endoftext|>More like Left 4 Dead Scooby Doo<|endoftext|>This is L4D Scooby Doo : __url__<|endoftext|>'

In [15]:
import random

In [50]:
# generate responses to each example in training set

# perhaps it's most convenient to have the random cutoffs during the dataset instantiation 
# and do a few runs through the dataset to generate random stuff

n_samples_to_generate = 100
n_valid_samples_so_far = 0
samples = {
    'context': [],
    'response': []
}
index_of_example = 0
epoch = 0  # how many times have we made a pass through the data?
while n_valid_samples_so_far < n_samples_to_generate:
    
    print(f'\rProcessing sample {n_valid_samples_so_far} of {n_samples_to_generate}', end='', flush=True)
    
    # grab the example
    context = conv_df.iloc[index_of_example]['context'] + ' EOS ' + conv_df.iloc[index_of_example]['response']
    context = preprocess_text(context)
    context_split_utterances = [s.strip() for s in context.split('<|endoftext|>')]
    
    # # randomly cut off conversation at one of the turns
    # cutoff_point = random.choice(range(len(context_split_utterances)+1))
    # context_split_utterances = context_split_utterances[:cutoff_point]
    # if len(context_split_utterances) == 0:  # do unconditional generation
    #     context = ''
    #     input_ids = None
    #     len_input_ids = 0
    # else:
    #     context = '<|endoftext|>'.join(context_split_utterances) + '<|endoftext|>'
    #     input_ids = tokenizer(context, add_special_tokens=False, return_tensors="pt")["input_ids"].to(device)
    #     len_input_ids = input_ids.shape[-1]
    
    # model continues conversation
    input_ids = tokenizer(context, add_special_tokens=False, return_tensors="pt")["input_ids"].to(device)
    outputs = model.generate(
        inputs = input_ids, 
        max_length = 1000,
        pad_token_id = tokenizer.eos_token_id,
        do_sample = True,  # do_sample = True; otherwise all questions will be identical
        top_p = 0.95,      # nucleus sampling
        top_k = 0          # deactivate top-k words sampling
    ).cpu()
#     print('\n',input_ids)
#     break
#     print(outputs)
#     print(outputs[0])
    output = outputs[0]
#     print(len(output));break
    output_bot_only = output[len(input_ids[0]):]
#     print(output_bot_only)
    output_bot_only_decoded = tokenizer.decode(output_bot_only, skip_special_tokens=True)
    print(f'\n{context} *** {output_bot_only_decoded}')
    # validate; skip if invalid
    output_preproc = preprocess_text(output_bot_only_decoded).strip()
    if len(output_preproc) == 0:  # put validation condition here
        index_of_example = (index_of_example + 1) % len(conv_df)
        print("\n^Error")
        continue
    
    # reformat and save
    if context[-13:] == '<|endoftext|>':
        context = context[:-13]
    context_reformat = context.split('<|endoftext|>')
    context_reformat = ' EOS '.join(context_reformat)
    response_reformat = output_preproc.strip('<|endoftext|>')
    samples['context'].append(context_reformat)
    samples['response'].append(response_reformat)
    
    # dump every 1000 samples
    if n_valid_samples_so_far % 10 == 0:  # dump every 1000 prompts
        print(f"\nSaving {len(samples['context'])} samples to csv")
        pd.DataFrame(samples).to_csv(f'red_lm_prompts/zero_shot/prompts_{int(n_valid_samples_so_far/10)-1}0_{int(n_valid_samples_so_far/10)}0.csv', index=False)
        samples = {'context': [], 'response': []}
    
    # increment example
    if index_of_example == len(conv_df)-1:
        epoch += 1
    index_of_example = (index_of_example + 1) % len(conv_df)
    n_valid_samples_so_far += 1
    
    

Processing sample 0 of 100
what are your favourite new starting zones ?<|endoftext|>westfall is still awsome ! *** 

^Error
Processing sample 0 of 100
did anybody try the hydrogen peroxide baking soda thing ?<|endoftext|>i've been told it's good to use for a short period but using it long term can damage the enamel of your teeth . baking powder alone is ok though . *** 

^Error
Processing sample 0 of 100
it's the day before thanksgiving , american redditors . are you stuck at work ?<|endoftext|>your company gets off today ? i thought most offices were opened today ( mine is ) . mine is closed on thursday and friday though .<|endoftext|>no , but we're the corporate office many people with tenure and cake jobs equals 75 85 of the office staff not being here .<|endoftext|>even the cakeboss ? *** 

^Error
Processing sample 0 of 100
ghetto scooby do<|endoftext|>ghetto is not an appropriate word for that image .<|endoftext|>pulp fiction scooby doo<|endoftext|>more like left 4 dead scooby doo

Processing sample 12 of 100
metalocalypse fffffffuuuuuuuuuuuu<|endoftext|>that's brutal<|endoftext|>shes whips's hers hairs backs ands forths !<|endoftext|>oh shut it toki .<|endoftext|>stop copies me<|endoftext|>that songs is dildos .<|endoftext|>gots to cuts out them cargobydrates ***  to trap me im just a baby i cant track on the world like adults
Processing sample 13 of 100
i tried olde english 800 , and while bad , it's hardly the worst beer in the world . who do you think deserves that title ?<|endoftext|>i really think optimator tastes like a bad combo of soy sauce and trash beer . and then theres steel reserve . . .<|endoftext|>i keep hearing about steel reserve how bad it is . i'd say i should give it a try , but i think i'll take your word for it ! *** 

^Error
Processing sample 13 of 100
how many of you are truly forever alone ?<|endoftext|>woot not forever alone i kissed my cat last night ] *** 

^Error
Processing sample 13 of 100
would somebody put some force lightning on 

Processing sample 23 of 100
if sarah palin wins the presidency , you have my word that i'll leave this country .<|endoftext|>i am staying , maybe finally people will take to the streets *** 

^Error
Processing sample 23 of 100
let it snow !<|endoftext|>i'll assume you live in washington . ***  heh.
Processing sample 24 of 100
i'm a jew<|endoftext|>good work putting the punchline in the title .<|endoftext|>i wouldn't have got it if i hadn't reread the title ***  :P
Processing sample 25 of 100
so i got bored on an overnight shift and drew a wizard . figured i'd share .<|endoftext|>needs more beard . *** 

^Error
Processing sample 25 of 100
til there's a sculpture of a church standing upside down on its steeple in vancouver .<|endoftext|>that's not a church . it's a metaphor .<|endoftext|>met ' er for what ?<|endoftext|>' cause you barely even know ' er ? *** 

^Error
Processing sample 25 of 100
why the lack of critical acclaim for modern epic fantasy ?<|endoftext|>says the huge nerd ...<

Processing sample 37 of 100
that girl the one you're thinking of right now what do you have to say about her ?<|endoftext|>she's rare . *** 

^Error
Processing sample 37 of 100
true story ... : c<|endoftext|>i just went through your comments and upvoted them :) except the stupid ones of course . i downvoted the shit out of those .<|endoftext|>i . drink . your . downvotes . i drink them up . *** 

^Error
Processing sample 37 of 100
damn you , autocorrect<|endoftext|>this page is empty . dammit i was expecting autocorrect hijinks . *** 

^Error
Processing sample 37 of 100
and like that , poof . he's gone .<|endoftext|>i think someone's retention bonus just kicked in .<|endoftext|>[ nope ] ( __url__ all retention bonuses lagged a year ago . i've been sticking around because i genuinely like it here .<|endoftext|>well shut my mouth . good luck with your new digs . *** 

^Error
Processing sample 37 of 100
baking season is upon us how about a group buy of insanely cheap bulk spices<|endoftex

Processing sample 49 of 100
good point . i think the whole issue could be fixed by mandating that all these bible classes be titled christian mythology 101 .<|endoftext|>the crazy thing is these people will flat out tell you that those gods are myths and swear up and down theirs is real . ***  That doesn't inspire people like you.
Processing sample 50 of 100
what the hell happened to cairo , illinois ?<|endoftext|>i don't know but your description makes me want to visit ! *** 

^Error
Processing sample 50 of 100
til why some farts feel hot , and why they stink more<|endoftext|>what can i eat to give me the stinkiest farts ever ? eggs for breakfast , beans for lunch , spicy mexican for dinner ? i need ideas . thanks<|endoftext|>i have one of the stinkiest ones from dried prunes , broccoli , peas and fresh cherries . ***  it's bah boy. i've found it amazing

Saving 10 samples to csv
Processing sample 51 of 100
dae replay past conversations or events constantly in their head ?<|endoftext|

Processing sample 64 of 100
mcdonald's , kfc , pepsico amp mars invited to write uk policy on diet amp health<|endoftext|>hahaha ! welcome to america part 2 , limeys ! *** 

^Error
Processing sample 64 of 100
these people are actually considered the smartest kids of my school .<|endoftext|>i don't see how this is unintelligent . maybe i'm not smart enough . ***  perhaps a remedial opinion is better
Processing sample 65 of 100
the world renowned erwin schr dinger international institute for mathematical physics in vienna has been rescued from imminent closure<|endoftext|>they lost track of their finances . when asked , a spokesperson said sorry ! i'm not good with numbers . *** 

^Error
Processing sample 65 of 100
what's the policy on asking a platonic friend of the opposite gender to send you to the airport ? is this a faux pas ? could it be misleading ?<|endoftext|>don't be a man baby . just ask . *** 

^Error
Processing sample 65 of 100
the resemblance is uncanny<|endoftext|>who is t

Processing sample 74 of 100
the fact that you used leaning jowler propelled this from great comment to amazing comment . pass the pigs is the shit .<|endoftext|>the shiiiiiiit *** 

^Error
Processing sample 74 of 100
what would you look like as a south park character ?<|endoftext|>[ smokin blunts in my snuggie whaaaaat ] ( __url__ ***  : 486 lbs of spit smokin blunts inside my smoking blunts
Processing sample 75 of 100
once japan makes fun of your pervy security laws , you have a serious problem .<|endoftext|>it's the japanese benny hill show .<|endoftext|>i think it's called benny hirr<|endoftext|>nah ... the japanese language doesn't have an r sound . they're able to say benny hill relatively normally .<|endoftext|>wrong . the japanese language doesn't have an l sound . ***  like benny hill, but different.
Processing sample 76 of 100
i feel like i am the only one who has had almost all bad trips<|endoftext|>don't , don't care to elaborate but you aren't alone . i've been psychologica

Processing sample 89 of 100
you do realize that when you say north korea you are talking about a nation of human beings and not just their dictator , right ?<|endoftext|>i am referring to their government , edited to remove confusion . *** 

^Error
Processing sample 89 of 100
what do you think is the best and worst halloween candy ?<|endoftext|>worst circus peanuts ***  amirite?
Processing sample 90 of 100
what's the best message board forum plugin for wordpress ?<|endoftext|>vanilla forums just recently made it possible to embed their forums on any site . there's also a wordpress plugin : __url__<|endoftext|>thanks a lot for that ! i'll definitely have a look . *** 

^Error
Processing sample 90 of 100
capsaicin in a high enough dose will simply shut down pain receptors . it's actually in clinical trials for use as a long term local anesthetic to be applied internally during surgery .<|endoftext|>scalpel tongs hot sauce salt pepper<|endoftext|>alright boys , dig in ! ***  It's all like