In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, default_data_collator
from datasets import Dataset
import pandas as pd
import pickle
import torch

In [3]:
with open('twitter_data.pickle', 'rb') as f:
    twitter_data = pickle.loads(f.read())
    
twitter_df = pd.DataFrame.from_records(twitter_data, columns=['response', 'context-0', 'context-1', 'context-2']).sample(frac=1)
twitter_df.head(10)

Unnamed: 0,response,context-0,context-1,context-2
208749,Come on fix your connection with Carphone Ware...,"🤔 We'd like to help Claire, are you trying to ...",I called them today and on Saturday and they s...,"🤔 We're not currently aware of any issues, Cla..."
59983,I’d hate to be a later flight because my 8:40 ...,They’re the worst. Last time I traveled on a h...,Made it! An hour and a half after I was suppos...,Great news. Enjoy your pre-Thanksgiving lunch!
132046,Flying with you on Thu to bring supplies to #P...,Hello. Sorry for the delay. Are you still need...,I do - thanks for assisting.,"When traveling Thursday, you can speak with ou..."
123377,- just spent 30 min with tech support. No sol...,"Hello Mike, is there anything I can assist you...",I cannot get access to My Account . When I typ...,Have you tested an alternate web browser? ^JH
78931,Explain to me how the 3G is faster than the LT...,Hi there! I would be happy to assist you with ...,My bad just redid the test it's including the ...,Thank you. Can you please send me a DM. I want...
93895,Is it true that some game downloads require an...,Hi there! If you are seeing download issues be...,It's not about slow downloads. The downloads ...,Gotcha. Are you seeing a specific error code/m...
72362,Thanks for cancelling my train CRAAAAP!!,"Sorry for that, which service? ^PA",19:15 to Preston!,From where? ^PA
116969,hi my mac pro won't read my sd card pls send help,Hi there. We'd like to help. Are you using a b...,I'm using my Mac's built-in reader. Other lapt...,OK. Let's try removing the SD Card and restart...
163657,Your service is dogshit it makes me want to bu...,Good evening Jake! I would really like to help...,too late,"I see, please let us know if you reconsider. H..."
11028,is there not an easy way to get all my actions...,"Hi there, You may use the preset manager in Ph...",I will try that again but when I open the pres...,"You would need to export presets, please check..."


In [4]:
with open('ubuntu_data.pickle', 'rb') as f:
    ubuntu_data = pickle.loads(f.read())
    
ubuntu_df = pd.DataFrame.from_records(ubuntu_data, columns=['response', 'context-0', 'context-1', 'context-2']).sample(frac=1)
ubuntu_df.head(10)

Unnamed: 0,response,context-0,context-1,context-2
11543,lol ok hehehe,"like a sidewalk step, you gotta be careful",.. yeah.. it is probably because im real tired...,"I have done many an all nighter on Linux, but ..."
19202,"you can convert it using mysql, I think.",how?,try this. http://sourceforge.net/projects/mdb...,thanks!!
16208,"nothing i know of, but i'll check",it could be a lack of memory thats causing jer...,i have actually added one gig my average cpu u...,well if you've recently upgraded skype and fir...
15471,which one is he getting?,"well, i hadn't got that far yet..lol, i just t...",damn which sound card should i get?,"look at the last post, it looks like he chose ..."
5175,"hi all, can i tell to my ubuntu, 'when you loc...",on a laptop or a desktop,how?,xset dpms force off are you intending to write...
7244,You have to put 'ldap ssl = off',in /etc/ldap/ldap.conf or /etc/smbldap/smbldap...,in /etc/samba/smb.conf,thx. still doesn't work though. Now it says 'C...
10216,That is correct. Let me come up with a test o...,odd. what do you have after the command. is ...,{} \;,hmm does the function exit with a 0 exit status?
9390,You just said 660...,well i didn't know how to convert that to numbers,"But okay, hm, then this is strange, is there a...",what's ACL? it's not installed
14895,do you know if direct rendering is enabled?,not off the top of my head. how do i check?,glxinfo | grep 'direct' iirc,ya its on. i have xinemera on too
21801,so dmesg | grep 'system panic',dmesg | grep panic,!hi,i did chown the dmesg i got nothing


In [5]:
comb_df = pd.concat([twitter_df[:50000], ubuntu_df[:20000]], ignore_index=True)\
    .sample(frac=1).reset_index(drop=True)
comb_df.head(10)

Unnamed: 0,response,context-0,context-1,context-2
0,VERY Disappointing cotton bed sheets as they g...,"Sorry Mathew, can you tell me the barcode numb...",im sorry i do not have the barcode to hand but...,That's not a problem Mathew. Do you still have...
1,issue with order id-403-9859026-0246720.. Cust...,I’m sorry about the hassle. Please drop in you...,order id-403-9859026-0246720.. I called many t...,We cannot check your order details on social p...
2,Dear #BlueBird color me disappointed &amp; dis...,Hi Donna. I'm sorry to hear that. Is there som...,Hi Beth. I have the Blue Bird card and I can't...,up to 30 calendar days for car rentals.If we c...
3,Hey all of my Roku devices get this error seve...,"Sorry, Travis! Still need help? If so, try: ht...",Already performed these troubleshooting tasks....,Sorry--we're looking into this error. Please c...
4,"Dear Why, oh, can’t eye type the letter eye? ...",We'd love to assist you. Let's go to Settings&...,I'm a dumby - thanks for the quick response! A...,You're welcome. Were you able to see what ver...
5,Can't ever win with these people why is my log...,"My apologies, as I am trying to see if anythin...",Not working for me,I would be happy to look into this further for...
6,my route table looks fine :-\,post to pastebin output from 'ifconfig ; route...,http://pastebin.com/d543c048d,'Chain INPUT (policy DROP...' - you didn't do ...
7,ok how do I check this when i boot into xfce,open a terminal and make sure xfce4-panel is r...,understood but what would the command in the t...,ps aux | grep xfce4-panel
8,your Prime deliveries have turned to 2/3 day d...,I'm sorry for the trouble! We don't want to le...,You have indeed. Next day delivery has let me ...,I'm very sorry to hear you still haven't recei...
9,What's the command line for that? (Just for th...,sudo apt-get install pastebinit then cat /etc/...,"It is now 1792x1344, and I wan't it to be 1024...",how did you install the drivers?


In [6]:
test_df = pd.concat([twitter_df[50000:52000], ubuntu_df[20000:21000]], ignore_index=True).reset_index(drop=True)
test_df.head(5)

Unnamed: 0,response,context-0,context-1,context-2
0,any problem with cleardb in west europe ? our ...,"error log was enabled, I just disabled it, we'...",can you raise a request through https://t.co/Z...,"ok I just created a ticket, thanks"
1,these updates are making iOS behave like andro...,We'd like to help with this. Send us a DM with...,I guess this is required if not do let me know...,One more reference I don’t know what’s this go...
2,how in the world are you guys not available 24...,"Hello, Jeremy. I’m sorry we weren’t here when ...",Your support is completely unavailable. I can'...,We apologize for your experience. Let me know ...
3,can you tell us where th spare seats are? 10:4...,"Morning, unreserved seating is available in ca...",Will it be announced on the tannoy? There's a ...,1/2 Depending on how many of the unreserved se...
4,Breh fix this WiFi shit,what’s wrong? cause my shit been trippin,My shit will just stop working and I keep havi...,We'd like to help with this Wi-Fi issue. Can y...


In [None]:
dataset = Dataset.from_pandas(comb_df)
dataset

In [8]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def preprocess_function(data):
    flatten = lambda l: [item for sublist in l for item in sublist]
    
    output = [tokenizer(d + tokenizer.eos_token, max_length=32, truncation=True, padding='max_length', return_tensors='pt')['input_ids'] for d in data.values()]
    output = flatten(list(reversed(output)))
    
    return {'input_ids': output, 'labels': output }

tokenized_data = dataset.map(preprocess_function, batched=False, remove_columns=dataset.column_names)

  0%|          | 0/70000 [00:00<?, ?ex/s]

In [None]:
batch_size = 16
args = TrainingArguments(
    "dialogpt-twitter-ubuntu-finetuned",
    learning_rate=2e-6,
    per_device_train_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    fp16=False,
    report_to="none",
    warmup_steps=1000,
)

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_data,
    data_collator=default_data_collator,
)

trainer.train()

In [None]:
trainer.save_model()

In [10]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("jegorkitskerkin/dialogpt-twitter-ubuntu")

In [None]:
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')

    # append the new user input tokens to the chat history
    #bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
    bot_input_ids = new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(
        bot_input_ids, max_length=1000, 
        pad_token_id=tokenizer.eos_token_id,
        min_length=16,
        num_return_sequences=1,
        no_repeat_ngram_size=2,
        do_sample=True,
        top_k=50,
        top_p=0.9,
        temperature = 0.6,
                                     )
    # pretty print last ouput tokens from bot
    print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))


In [11]:
test_emb = tokenizer('\n\n'.join(test_df['response'].tolist()), max_length=64, truncation=True, return_tensors='pt')

In [14]:
import torch
from tqdm import tqdm

model = model.cuda()

max_length = model.config.n_positions
stride = 64

nlls = []
for i in tqdm(range(0, test_emb.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, test_emb.input_ids.size(1))
    trg_len = end_loc - i    # may be different from stride on last loop
    input_ids = test_emb.input_ids[:,begin_loc:end_loc].to('cuda')
    target_ids = input_ids.clone()
    target_ids[:,:-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        neg_log_likelihood = outputs[0] * trg_len

    nlls.append(neg_log_likelihood)

ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
print('Perplexity', ppl.item())

100%|██████████| 1/1 [00:00<00:00, 26.73it/s]

Perplexity 308.0859069824219



