In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, default_data_collator
from datasets import Dataset
import pandas as pd
import pickle
import torch

In [None]:
import re
import string

tagging_regex = re.compile(r"@\S*")
url_pattern = re.compile(r'https?://\S+|www\.\S+')
signature_pattern = re.compile(r"-\S*")
weird_thing_pattern = re.compile(r"\^\S*")
new_line_pattern = re.compile(r"\n+\S*")

chat_words = {
    "AFAIK": "As Far As I Know",
    "AFK": "Away From Keyboard",
    "ASAP": "As Soon As Possible",
    "ATK": "At The Keyboard",
    "ATM": "At The Moment",
    "A3": "Anytime, Anywhere, Anyplace",
    "BAK": "Back At Keyboard",
    "BBL": "Be Back Later",
    "BBS": "Be Back Soon",
    "BFN": "Bye For Now",
    "B4N": "Bye For Now",
    "BRB": "Be Right Back",
    "BRT": "Be Right There",
    "BTW": "By The Way",
    "B4": "Before",
    "B4N": "Bye For Now",
    "CU": "See You",
    "CUL8R": "See You Later",
    "CYA": "See You",
    "FAQ": "Frequently Asked Questions",
    "FC": "Fingers Crossed",
    "FWIW": "For What It's Worth",
    "FYI": "For Your Information",
    "GAL": "Get A Life",
    "GG": "Good Game",
    "GN": "Good Night",
    "GMTA": "Great Minds Think Alike",
    "GR8": "Great!",
    "G9": "Genius",
    "IC": "I See",
    "ICQ": "I Seek you (also a chat program)",
    "ILU": "ILU: I Love You",
    "IMHO": "In My Honest/Humble Opinion",
    "IMO": "In My Opinion",
    "IOW": "In Other Words",
    "IRL": "In Real Life",
    "KISS": "Keep It Simple, Stupid",
    "LDR": "Long Distance Relationship",
    "LMAO": "Laugh My A.. Off",
    "LOL": "Laughing Out Loud",
    "LTNS": "Long Time No See",
    "L8R": "Later",
    "MTE": "My Thoughts Exactly",
    "M8": "Mate",
    "NRN": "No Reply Necessary",
    "OIC": "Oh I See",
    "PITA": "Pain In The A..",
    "PRT": "Party",
    "PRW": "Parents Are Watching",
    "ROFL": "Rolling On The Floor Laughing",
    "ROFLOL": "Rolling On The Floor Laughing Out Loud",
    "ROTFLMAO": "Rolling On The Floor Laughing My Ass Off",
    "SK8": "Skate",
    "STATS": "Your sex and age",
    "ASL": "Age, Sex, Location",
    "THX": "Thank You",
    "TTFN": "Ta-Ta For Now!",
    "TTYL": "Talk To You Later",
    "U": "You",
    "U2": "You Too",
    "U4E": "Yours For Ever",
    "WB": "Welcome Back",
    "WTF": "What The F...",
    "WTG": "Way To Go!",
    "WUF": "Where Are You From?",
    "W8": "Wait",
    "IMMA": "I am going to",
    "2NITE": "tonight",
    "DMED": "mesaged",
    'DM': "message",
    "SMH": "I am dissapointed"
}

# Thanks to https://stackoverflow.com/a/43023503/3971619
contractions = {
    "ain't": "are not",
    "aren't": "are not",
    "can't": "cannot",
    "can't've": "cannot have",
    "'cause": "because",
    "could've": "could have",
    "couldn't": "could not",
    "couldn't've": "could not have",
    "didn't": "did not",
    "doesn't": "does not",
    "don't": "do not",
    "hadn't": "had not",
    "hadn't've": "had not have",
    "hasn't": "has not",
    "haven't": "have not",
    "he'd": "he would",
    "he'd've": "he would have",
    "he'll": "he will",
    "he'll've": "he shall have / he will have",
    "he's": "he is",
    "how'd": "how did",
    "how'd'y": "how do you",
    "how'll": "how will",
    "how's": "how is",
    "i'd": "I would",
    "i'd've": "I would have",
    "i'll": "I will",
    "i'll've": "I will have",
    "i'm": "I am",
    "i've": "I have",
    "isn't": "is not",
    "it'd": "it would",
    "it'd've": "it would have",
    "it'll": "it will",
    "it'll've": "it will have",
    "it's": "it is",
    "let's": "let us",
    "ma'am": "madam",
    "mayn't": "may not",
    "might've": "might have",
    "mightn't": "might not",
    "mightn't've": "might not have",
    "must've": "must have",
    "mustn't": "must not",
    "mustn't've": "must not have",
    "needn't": "need not",
    "needn't've": "need not have",
    "o'clock": "of the clock",
    "oughtn't": "ought not",
    "oughtn't've": "ought not have",
    "shan't": "shall not",
    "sha'n't": "shall not",
    "shan't've": "shall not have",
    "she'd": "she would",
    "she'd've": "she would have",
    "she'll": "she will",
    "she'll've": "she will have",
    "she's": "she is",
    "should've": "should have",
    "shouldn't": "should not",
    "shouldn't've": "should not have",
    "so've": "so have",
    "so's": "so is",
    "that'd": "that had",
    "that'd've": "that would have",
    "that's": "that is",
    "there'd": "there would",
    "there'd've": "there would have",
    "there's": "there is",
    "they'd": "they would",
    "they'd've": "they would have",
    "they'll": "they will",
    "they're": "they are",
    "they've": "they have",
    "to've": "to have",
    "wasn't": "was not",
    "we'd": "we would",
    "we'd've": "we would have",
    "we'll": "we will",
    "we'll've": "we will have",
    "we're": "we are",
    "we've": "we have",
    "weren't": "were not",
    "what'll": "what will",
    "what're": "what are",
    "what's": "what is",
    "what've": "what have",
    "when's": "when is",
    "when've": "when have",
    "where'd": "where did",
    "where's": "where is",
    "where've": "where have",
    "who'll": "who will",
    "who's": "who is",
    "who've": "who have",
    "why's": "why is",
    "why've": "why have",
    "will've": "will have",
    "won't": "will not",
    "won't've": "will not have",
    "would've": "would have",
    "wouldn't": "would not",
    "wouldn't've": "would not have",
    "y'all": "you all",
    "y'all'd": "you all would",
    "y'all'd've": "you all would have",
    "y'all're": "you all are",
    "y'all've": "you all have",
    "you'll": "you will",
    "you're": "you are",
    "you've": "you have",
}

# Reference : https://stackoverflow.com/a/49986645/3971619
def remove_emoji(inputString):
    return inputString.encode('ascii', 'ignore').decode('ascii')

# Thanks to user sudalairajkumar
def remove_url(string):
    return url_pattern.sub(r'', string)

def remove_chat_words_and_contractions(string):
    new_text = []
    for word in string.split(' '):
        if word.upper() in chat_words.keys():
            new_text += chat_words[word.upper()].lower().split(' ')
        if word.lower() in contractions.keys():
            new_text += contractions[word.lower()].split(' ')
        else:
            new_text.append(word)
            
    return ' '.join(new_text)

def remove_signature(text):
    return signature_pattern.sub(r'', text)
    

# Thanks to user sudalairajkumar
PUNCT_TO_REMOVE = string.punctuation
def remove_punctuation(text):
    """custom function to remove the punctuation"""
    return text.translate(str.maketrans('', '', PUNCT_TO_REMOVE))

def clean_message(message):
    # Remove user taggings
    message = re.sub(tagging_regex, '', message) # Replace by you. Good idea?
    
    # Remove the emojis
    message = remove_emoji(message)
    
    # Remove urls
    message = remove_url(message)
    
    # Remove signatures
    message = remove_signature(message)
    
    # Remove the chat words and contractions
    message = remove_chat_words_and_contractions(message)
    
    # Remove weird things
    message = weird_thing_pattern.sub(r'', message)

    # Change new line to dot
    message = new_line_pattern.sub(r'.', message)
        
    # Remove start and end whitespace
    message = message.strip()
    
    # Make multiple spaces become a single space
    message = ' '.join(message.split())
    
    return message


In [None]:
with open('data/twitter_data.pickle', 'rb') as f:
    twitter_data = pickle.loads(f.read())
    
twitter_df = pd.DataFrame.from_records(twitter_data, columns=['response', 'context-0', 'context-1', 'context-2'])

twitter_df['response'] = twitter_df['response'].apply(clean_message)
twitter_df['context-0'] = twitter_df['context-0'].apply(clean_message)
twitter_df['context-1'] = twitter_df['context-1'].apply(clean_message)
twitter_df['context-2'] = twitter_df['context-2'].apply(clean_message)

twitter_df.head(10)

In [None]:
with open('data/ubuntu_data.pickle', 'rb') as f:
    ubuntu_data = pickle.loads(f.read())
    
ubuntu_df = pd.DataFrame.from_records(ubuntu_data, columns=['response', 'context-0', 'context-1', 'context-2'])
ubuntu_df.head(10)

In [None]:
with open('data/movie_data.pickle', 'rb') as f:
    movie_data = pickle.loads(f.read())
    
movie_df = pd.DataFrame.from_records(movie_data, columns=['response', 'context-0', 'context-1', 'context-2'])
movie_df.head(10)

In [None]:
with open('data/tolokers_data.pickle', 'rb') as f:
    tolokers_data = pickle.loads(f.read())
    
tolokers_df = pd.DataFrame.from_records(tolokers_data, columns=['response', 'context-0', 'context-1', 'context-2'])
tolokers_df.head(10)

In [None]:
comb_df = pd.concat([twitter_df[:50000], ubuntu_df[:20000], tolokers_df[:30000], movie_df[:25000]], 
                    ignore_index=True).sample(frac=1).reset_index(drop=True)
comb_df.head(10)

In [None]:
test_df = pd.concat([twitter_df[50000:52000], ubuntu_df[20000:21000]], ignore_index=True).reset_index(drop=True)
test_df.head(5)

In [None]:
dataset = Dataset.from_pandas(comb_df)
dataset

In [None]:
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
def preprocess_function(data):
    flatten = lambda l: [item for sublist in l for item in sublist]
    
    output = [tokenizer(d + tokenizer.eos_token, max_length=32, truncation=True, padding='max_length', return_tensors='pt')['input_ids'] for d in data.values()]
    output = flatten(list(reversed(output)))
    
    return {'input_ids': output, 'labels': output }

tokenized_data = dataset.map(preprocess_function, batched=False, remove_columns=dataset.column_names)

In [None]:
batch_size = 16
args = TrainingArguments(
    "dialogpt-ir-bot",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    weight_decay=0.01,
    #save_total_limit=2,
    save_strategy="no",
    num_train_epochs=1,
    fp16=False,
    report_to="none",
    warmup_steps=1000,
)

trainer = Trainer(
    model,
    args,
    train_dataset=tokenized_data,
    data_collator=default_data_collator,
)

trainer.train()

#trainer.save_model()
#tokenizer.save_pretrained('dialogpt-ir-bot')

In [7]:
tokenizer = AutoTokenizer.from_pretrained("jegorkitskerkin/dialogpt-twitter-ubuntu")
model = AutoModelForCausalLM.from_pretrained("jegorkitskerkin/dialogpt-twitter-ubuntu")

Downloading:   0%|          | 0.00/617 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/779k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/905 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.35G [00:00<?, ?B/s]

In [None]:
test_emb = tokenizer('\n\n'.join(test_df['response'].tolist()), max_length=64, truncation=True, return_tensors='pt')

In [None]:
import torch
from tqdm import tqdm

model = model.cuda()

max_length = model.config.n_positions
stride = 64

nlls = []
for i in tqdm(range(0, test_emb.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, test_emb.input_ids.size(1))
    trg_len = end_loc - i    # may be different from stride on last loop
    input_ids = test_emb.input_ids[:,begin_loc:end_loc].to('cuda')
    target_ids = input_ids.clone()
    target_ids[:,:-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        neg_log_likelihood = outputs[0] * trg_len

    nlls.append(neg_log_likelihood)

ppl = torch.exp(torch.stack(nlls).sum() / end_loc)
print('Perplexity', ppl.item())