#Make your own Rick Sanchez (bot) with Transformers and DialoGPT fine-tuning

by [Rostyslav Neskorozhenyi](https://www.linkedin.com/in/slanj)

I am a big fan of 'Rick and Morty' and NLP (Natural Language Processing). After watching the last episode of a mad scientist's and his grandson's adventures I realized that it's not a coincidence and I have to apply my knowledge in NLP and Machine Learning to create my own Rick (in a form of chat-bot, at least for now).

![alt text](https://media.giphy.com/media/ZdBnIqBeVWBYTobH8U/giphy.gif)

Image from [Giphy](https://giphy.com/)

The path to creating virtual Rick was not easy. I tried many models and approaches but virtual Rick did not want to find a common language with me for a long time. Rick's first response to my greeting was "I am not your buddy" (to tell you the truth, in the beginning he knew only this phrase). But I did not give up and with the help of excellent Transformers library and Microsoft's Large-scale Pretrained Response Generation Model (DialoGPT) I managed to talk heart-to-heart with Rick. Here is one of our conversations:

**User**:Hey, Rick, make me a time machine <br>
**RickBot**: You're a monster! <br>
**User**:No, I just want to travel in time <br>
**RickBot**: To the past. <br>
**User**:yes <br>
**RickBot**: No time machine, please. <br>
**User**:But I need it <br>
**RickBot**: You don't need it, Morty. Just do it. <br>
**User**:I want to see antient Greece <br>
**RickBot**: Oh, geez. <br>

![alt text](https://media.giphy.com/media/WVGSfyhUYBoiGLbZ7N/giphy.gif)

Image from [Giphy](https://giphy.com/)

## A bit of theory

In this article I will tell you how to create a virtual character whose statements will be based on a transcript of my favorite animated science fiction sitcom. You can use characters and sitcoms of your choice. I added code to convert a regular text file with dialogs into a format that the model understands.

As I already mentioned, [Transformers](https://huggingface.co/transformers/) library, which contains the latest NLP models (such as [BERT](https://huggingface.co/transformers/model_doc/bert.html), [XLNet](https://huggingface.co/transformers/model_doc/xlnet.html), [GPT-2](https://huggingface.co/transformers/model_doc/gpt2.html)) will help us in our task. You can read more about transformers in a beautifully illustrated [article](http://jalammar.github.io/illustrated-transformer/) by Jay Alammar.

![alt text](http://jalammar.github.io/images/t/transformer_resideual_layer_norm_3.png) Image from[ http://jalammar.github.io](http://jalammar.github.io/illustrated-transformer/)

Not so long ago, Microsoft’s [DialoGPT](https://huggingface.co/transformers/model_doc/dialogpt.html) was added to the Transformers model collection. DialoGPT is a GPT-2 model, trained on 147M multi-turn dialogue from Reddit discussion thread (you can learn more about GPT-2 [here](http://jalammar.github.io/illustrated-gpt2/)). This model is ideally suited for creating a virtual character for a fascinating conversation and even in the small implementation option it can maintain a coherent dialogue, which we will see now.



## First dialogue with DialoGPT

We will conduct all our experiments in Google Colab, its resources are enough to train the small DialoGPT model. Firstly, we will connect to Google Drive and install the necessary modules.

Try to chat with DialoGPT without fine-tuning.

In [1]:
from transformers import AutoModelWithLMHead, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-medium")
model = AutoModelWithLMHead.from_pretrained("microsoft/DialoGPT-medium")



In [2]:
# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens    
    chat_history_ids = model.generate(
    bot_input_ids, max_length=1000,
    pad_token_id=tokenizer.eos_token_id,
    do_sample=True
    )

    # pretty print last ouput tokens from bot
    print("DialoGPT: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

KeyboardInterrupt: Interrupted by user

**User:**Hi Rick <br> 
**DialoGPT:** Hi Rick <br>
**User:**How are you? <br>
**DialoGPT:** I'm good, how are you? <br>
**User:**I am fine. Where is Morty? <br>
**DialoGPT:** He's in the basement. <br>
**User:**Who is Morty? <br>
**DialoGPT:** He's a Morty. <br>
**User:**Who are you? <br>
**DialoGPT:** I am a Morty. <br>

![alt text](https://media.giphy.com/media/L3WevKXIKFDaZBvV8Q/giphy.gif)

Image from [Giphy](https://giphy.com/)

Not bad but not too impressive. We will fix it with fine-tuning.

## Model initial configuration

Let's train our own Rick chatbot. For start, we will need basic configuration and a dataset.
Configuration and training scripts are mostly based on this [script](https://github.com/huggingface/transformers/blob/master/examples/language-modeling/run_language_modeling.py) from Huggingface and great [tutorial](https://nathancooper.io/i-am-a-nerd/chatbot/deep-learning/gpt2/2020/05/12/chatbot-part-1.html) from Nathan Cooper.

In [1]:
"""
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss.
"""

import glob
import logging
import os
import pickle
import random
import re
import shutil
from typing import Dict, List, Tuple

import pandas as pd
import numpy as np
import torch

from sklearn.model_selection import train_test_split

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm.notebook import tqdm, trange

from pathlib import Path

from transformers import (
    MODEL_WITH_LM_HEAD_MAPPING,
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModelWithLMHead,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
)


try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

# Configs
logger = logging.getLogger(__name__)

MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

In [2]:
# Args to allow for easy convertion of python script to notebook
class Args():
    def __init__(self):
        self.output_dir = 'output-small'
        self.model_type = 'gpt2'
        self.model_name_or_path = 'microsoft/DialoGPT-small'
        self.config_name = 'microsoft/DialoGPT-small'
        self.tokenizer_name = 'microsoft/DialoGPT-small'
        self.cache_dir = 'cached'
        self.block_size = 512
        self.do_train = True
        self.do_eval = True
        self.evaluate_during_training = False
        self.per_gpu_train_batch_size = 2
        self.per_gpu_eval_batch_size = 2
        self.gradient_accumulation_steps = 1
        self.learning_rate = 5e-5
        self.weight_decay = 0.0
        self.adam_epsilon = 1e-8
        self.max_grad_norm = 1.0
        self.num_train_epochs = 3
        self.max_steps = -1
        self.warmup_steps = 0
        self.logging_steps = 1000
        self.save_steps = 3500
        self.save_total_limit = None
        self.eval_all_checkpoints = False
        self.no_cuda = False
        self.overwrite_output_dir = True
        self.overwrite_cache = True
        self.should_continue = False
        self.seed = 0
        self.local_rank = -1
        self.fp16 = False
        self.fp16_opt_level = 'O1'

args = Args()

## Prepare Dataset

Our dialogues dataset will be based on a dataset used in Andrada Olteanu's [article](https://www.kaggle.com/andradaolteanu/sentiment-analysis-rick-and-morty-scripts/) about Rick and Morty sentiment analysis. Big thanks to her work and also to Gabriel Hernandes, author of original [text dataset](https://github.com/ghhernandes/rickmorty-gan/tree/master/data)!

![alt text](https://media.giphy.com/media/U6LOakQja88ImTnE6T/giphy.gif)

Image from [Giphy](https://giphy.com/)

First of all we will use kaggle module to download needed dataset. You can read in more detail about module and how to get Kaggle API Token by this [link](https://github.com/Kaggle/kaggle-api). Or you can just download RickAndMortyScripts.csv file from [article](https://www.kaggle.com/andradaolteanu/sentiment-analysis-rick-and-morty-scripts/) and place this file in your working directory. 

In [3]:
# Let's look at original dataset
all_rick = pd.read_csv('data/rick_and_morty_conversations/RickAndMortyScripts.csv')
all_rick.head(10)

Unnamed: 0,index,season no.,episode no.,episode name,name,line
0,0,1,1,Pilot,Rick,Morty! You gotta come on. Jus'... you gotta co...
1,1,1,1,Pilot,Morty,"What, Rick? What’s going on?"
2,2,1,1,Pilot,Rick,"I got a surprise for you, Morty."
3,3,1,1,Pilot,Morty,It's the middle of the night. What are you tal...
4,4,1,1,Pilot,Rick,"Come on, I got a surprise for you. Come on, h..."
5,5,1,1,Pilot,Morty,Ow! Ow! You're tugging me too hard!
6,6,1,1,Pilot,Rick,"We gotta go, gotta get outta here, come on. Go..."
7,7,1,1,Pilot,Rick,"What do you think of this... flying vehicle, M..."
8,8,1,1,Pilot,Morty,"Yeah, Rick... I-it's great. Is this the surprise?"
9,9,1,1,Pilot,Rick,Morty. I had to... I had to do it. I had— I ha...


We will convert this dataset in a way that every responce row will contain **n** previous responces as a context. For our purposes seven previous responces will be enough.

In [4]:
# # Let's look at original dataset
# all_rick = pd.read_csv('data/rick_and_morty_conversations/RickAndMortyScripts.csv')
# all_rick.head(10)

# # with data leakage
# contexted = []
# n = 7
# for i in range(n, len(all_rick['line'])):
#     row = []
#     prev = i - 1 - n # we additionally substract 1, so row will contain current responce and 7 previous responces  
#     for j in range(i, prev, -1):
#         row.append(all_rick['line'][j])
#     contexted.append(row)
    
# columns = ['response', 'context'] 
# columns = columns + ['context/'+str(i) for i in range(n-1)]
# df = pd.DataFrame.from_records(contexted, columns=columns)
# trn_df, val_df = train_test_split(df, test_size = 0.1, random_state=42)
# trn_df.head()

In [5]:
# no data leakage

# read data
all_rick = pd.read_csv('data/rick_and_morty_conversations/RickAndMortyScripts.csv')
all_rick.head(10)

all_rick_train, all_rick_val = all_rick.iloc[:int(0.8*len(all_rick))], all_rick.iloc[int(0.8*len(all_rick)):]
# all_rick_train, all_rick_val = train_test_split(all_rick, test_size=0.2, random_state=42)
all_rick_train = all_rick_train.reset_index()
all_rick_val   = all_rick_val.reset_index()


# construct df with prompts and responses as columns
contexted = []
n = 7
for i in range(n, len(all_rick_train['line'])):
    row = []
    prev = i - 1 - n # we additionally subtract 1, so row will contain current response and 7 previous responses  
    for j in range(i, prev, -1):
        row.append(all_rick_train['line'][j])
    contexted.append(row)
columns = ['response', 'context'] 
columns = columns + ['context/'+str(i) for i in range(n-1)]
trn_df = pd.DataFrame.from_records(contexted, columns=columns)
trn_df.head(5)

contexted = []
n = 7
for i in range(n, len(all_rick_val['line'])):
    row = []
    prev = i - 1 - n # we additionally subtract 1, so row will contain current response and 7 previous responses  
    for j in range(i, prev, -1):
        row.append(all_rick_val['line'][j])
    contexted.append(row)
columns = ['response', 'context'] 
columns = columns + ['context/'+str(i) for i in range(n-1)]
val_df = pd.DataFrame.from_records(contexted, columns=columns)
val_df.head(5)

Unnamed: 0,response,context,context/0,context/1,context/2,context/3,context/4,context/5
0,"Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...,"Don't worry, Morty, they love you. Superheroes...",This article says the reason we weren't involv...
1,"Yeah, but not heroes.","Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...,"Don't worry, Morty, they love you. Superheroes..."
2,"Oh, please. They just call themselves heroes s...","Yeah, but not heroes.","Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...
3,"I'm calling them that, Rick! They're my heroes...","Oh, please. They just call themselves heroes s...","Yeah, but not heroes.","Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan..."
4,Huh... no accounting for taste. I'm gonna go g...,"I'm calling them that, Rick! They're my heroes...","Oh, please. They just call themselves heroes s...","Yeah, but not heroes.","Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y..."


Now will convert our dataset in a format suitable for our model. Basically we will concatenate responses in one string for each row (additionally we will add special 'end of string' token between responses, so the model will understand end of each response in a string).  

In [6]:
def construct_conv(row, tokenizer, eos = True):
    flatten = lambda l: [item for sublist in l for item in sublist]
    conv = list(reversed([tokenizer.encode(x) + [tokenizer.eos_token_id] for x in row]))
    conv = flatten(conv)
    return conv

class ConversationDataset(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizer, args, df, block_size=512):

        block_size = block_size # - (tokenizer.max_len - tokenizer.max_len_single_sentence)

        directory = args.cache_dir
        cached_features_file = os.path.join(
            directory, args.model_type + "_cached_lm_" + str(block_size)
        )

        if os.path.exists(cached_features_file) and not args.overwrite_cache:
            logger.info("Loading features from cached file %s", cached_features_file)
            with open(cached_features_file, "rb") as handle:
                self.examples = pickle.load(handle)
        else:
            logger.info("Creating features from dataset file at %s", directory)

            self.examples = []
            for _, row in df.iterrows():
                conv = construct_conv(row, tokenizer)
                self.examples.append(conv)

            logger.info("Saving features into cached file %s", cached_features_file)
            with open(cached_features_file, "wb") as handle:
                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, item):
        return torch.tensor(self.examples[item], dtype=torch.long)

In [7]:
# Cacheing and storing of data/checkpoints

def load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False):
    return ConversationDataset(tokenizer, args, df_val if evaluate else df_trn)


def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)


def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> List[str]:
    ordering_and_checkpoint_path = []

    glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))

    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    return checkpoints_sorted


def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)

## Training and Evaluating

There will be quite a lot of code needed for training our model but don’t worry, everything should work as is, the main thing is to give the model the dataset in the right format.

![alt text](https://media.giphy.com/media/KetvQljQJdEMscR83K/giphy.gif)

Image from [Giphy](https://giphy.com/)

In [8]:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)

train_dataset = load_and_cache_examples(args, tokenizer, trn_df, val_df, evaluate=False)

def collate(examples: List[torch.Tensor]):
    if tokenizer._pad_token is None:
        return pad_sequence(examples, batch_first=True)
    return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

train_sampler = SequentialSampler(train_dataset)
train_dataloader = DataLoader(
    train_dataset, sampler=train_sampler, batch_size=2, collate_fn=collate, drop_last = True
)

eval_dataset = load_and_cache_examples(args, tokenizer, trn_df, val_df, evaluate=True)
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(
    eval_dataset, sampler=eval_sampler, batch_size=2, collate_fn=collate, drop_last = True
)

# example usage
for batch in train_dataloader:
    break

example_id = 0
print('input_ids:\n',       batch[example_id])
# print('position_ids:\n',    batch['position_ids'][example_id])
# print('token_type_ids:\n',  batch['token_type_ids'][example_id])
# print('target_ids:\n',      batch['target_ids'][example_id])
# print('attention_masks:\n', batch['attention_masks'][example_id])

tokenizer.decode(batch[example_id])

input_ids:
 tensor([   44,   419,    88,     0,   921, 17753,  1282,   319,    13,   449,
          385,     6,   986,   345, 17753,  1282,   351,   502,    13, 50256,
         2061,    11,  8759,    30,  1867,   447,   247,    82,  1016,   319,
           30, 50256,    40,  1392,   257,  5975,   329,   345,    11, 30395,
           13, 50256,  1026,   338,   262,  3504,   286,   262,  1755,    13,
         1867,   389,   345,  3375,   546,    30, 50256, 16773,   319,    11,
          314,  1392,   257,  5975,   329,   345,    13,   220,  7911,   319,
           11, 23290,   510,    13, 50256,    46,    86,     0, 11960,     0,
          921,   821, 27762,  2667,   502,  1165,  1327,     0, 50256,  1135,
        17753,   467,    11, 17753,   651,   503,  8326,   994,    11,  1282,
          319,    13, 11853,   257,  5975,   329,   345, 30395,    13, 50256,
         2061,   466,   345,   892,   286,   428,   986,  7348,  4038,    11,
        30395,    30,   314,  3170,   340,   503,  8

"Morty! You gotta come on. Jus'... you gotta come with me.<|endoftext|>What, Rick? What’s going on?<|endoftext|>I got a surprise for you, Morty.<|endoftext|>It's the middle of the night. What are you talking about?<|endoftext|>Come on, I got a surprise for you.  Come on, hurry up.<|endoftext|>Ow! Ow! You're tugging me too hard!<|endoftext|>We gotta go, gotta get outta here, come on. Got a surprise for you Morty.<|endoftext|>What do you think of this... flying vehicle, Morty? I built it outta stuff I found in the garage.<|endoftext|>"

In [9]:
# example usage
for batch in train_dataloader:
    break

example_id = 1
print('input_ids:\n',       batch[example_id])
# print('position_ids:\n',    batch['position_ids'][example_id])
# print('token_type_ids:\n',  batch['token_type_ids'][example_id])
# print('target_ids:\n',      batch['target_ids'][example_id])
# print('attention_masks:\n', batch['attention_masks'][example_id])

tokenizer.decode(batch[example_id])

input_ids:
 tensor([ 2061,    11,  8759,    30,  1867,   447,   247,    82,  1016,   319,
           30, 50256,    40,  1392,   257,  5975,   329,   345,    11, 30395,
           13, 50256,  1026,   338,   262,  3504,   286,   262,  1755,    13,
         1867,   389,   345,  3375,   546,    30, 50256, 16773,   319,    11,
          314,  1392,   257,  5975,   329,   345,    13,   220,  7911,   319,
           11, 23290,   510,    13, 50256,    46,    86,     0, 11960,     0,
          921,   821, 27762,  2667,   502,  1165,  1327,     0, 50256,  1135,
        17753,   467,    11, 17753,   651,   503,  8326,   994,    11,  1282,
          319,    13, 11853,   257,  5975,   329,   345, 30395,    13, 50256,
         2061,   466,   345,   892,   286,   428,   986,  7348,  4038,    11,
        30395,    30,   314,  3170,   340,   503,  8326,  3404,   314,  1043,
          287,   262, 15591,    13, 50256, 10995,    11,  8759,   986,   314,
           12,   270,   338,  1049,    13,  1148,   

"What, Rick? What’s going on?<|endoftext|>I got a surprise for you, Morty.<|endoftext|>It's the middle of the night. What are you talking about?<|endoftext|>Come on, I got a surprise for you.  Come on, hurry up.<|endoftext|>Ow! Ow! You're tugging me too hard!<|endoftext|>We gotta go, gotta get outta here, come on. Got a surprise for you Morty.<|endoftext|>What do you think of this... flying vehicle, Morty? I built it outta stuff I found in the garage.<|endoftext|>Yeah, Rick... I-it's great. Is this the surprise?<|endoftext|>!!!!"

In [10]:
def train(args, train_dataset, df_trn, df_val, model: PreTrainedModel, tokenizer: PreTrainedTokenizer) -> Tuple[int, float]:
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()
    
    #======================================================
    # train dataset + train dataloader
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

    train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    train_dataloader = DataLoader(
        train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, collate_fn=collate, drop_last = True
    )
    #======================================================
    # initialize model, optimizer, and scheduler
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    model = model.module if hasattr(model, "module") else model  # Take care of distributed/parallel training
    model.resize_token_embeddings(len(tokenizer))
    # add_special_tokens_(model, tokenizer)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
        args.model_name_or_path
        and os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(args.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        print('loading optimizer and scheduler...')
        optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
    #======================================================
    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size
        * args.gradient_accumulation_steps
        * (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if args.model_name_or_path and os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split("/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")
    else:
        logger.info("  Fine-tuning from scratch...")

    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(
        epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
    )
    set_seed(args)  # Added here for reproducibility
    
    # evaluate before training
    result = evaluate(args, model, tokenizer, df_trn, df_val)
    result = dict((k + "_{}".format(0), v) for k, v in result.items())
    print('evaluate before training:', result)
    
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            inputs, labels = (batch, batch)
            if inputs.shape[1] > 1024: continue
            inputs = inputs.to(args.device)
            labels = labels.to(args.device)
            model.train()
            outputs = model(inputs, labels=labels)
            loss = outputs[0]  # model outputs are always tuple in transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                
                # gradient clipping
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if (
                        args.local_rank == -1 and args.evaluate_during_training
                    ):  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value, global_step)
                    tb_writer.add_scalar("lr", scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar("loss", (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    checkpoint_prefix = "checkpoint"
                    # Save model checkpoint
                    output_dir = os.path.join(args.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                    os.makedirs(output_dir, exist_ok=True)
                    model_to_save = (
                        model.module if hasattr(model, "module") else model
                    )  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    tokenizer.save_pretrained(output_dir)

                    torch.save(args, os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    _rotate_checkpoints(args, checkpoint_prefix)

                    torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break
            
        # Evaluation
        result = evaluate(args, model, tokenizer, df_trn, df_val)
        result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
        print('epochwise results:', result)

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step

# Evaluation of some model

def evaluate(args, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, df_trn, df_val, prefix="", write_to_file=False) -> Dict:
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = args.output_dir

    eval_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=True)
    os.makedirs(eval_output_dir, exist_ok=True)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    # Note that DistributedSampler samples randomly

    def collate(examples: List[torch.Tensor]):
        if tokenizer._pad_token is None:
            return pad_sequence(examples, batch_first=True)
        return pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id)

    eval_sampler = SequentialSampler(eval_dataset)
    eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate, drop_last = True
    )

    # multi-gpu evaluate
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))
    logger.info("  Num examples = %d", len(eval_dataset))
    logger.info("  Batch size = %d", args.eval_batch_size)
    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs, labels = (batch, batch)
        inputs = inputs.to(args.device)
        labels = labels.to(args.device)

        with torch.no_grad():
            outputs = model(inputs, labels=labels)
            lm_loss = outputs[0]
            print(lm_loss)
            eval_loss += lm_loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity}
    
    if write_to_file:
        output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results {} *****".format(prefix))
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    return result

In [11]:
# Main runner

def main(df_trn, df_val):
    args = Args()
    
    if args.should_continue:
        sorted_checkpoints = _sorted_checkpoints(args)
        if len(sorted_checkpoints) == 0:
            raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
        else:
            args.model_name_or_path = sorted_checkpoints[-1]

    if (
        os.path.exists(args.output_dir)
        and os.listdir(args.output_dir)
        and args.do_train
        and not args.overwrite_output_dir
        and not args.should_continue
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
                args.output_dir
            )
        )

    # Setup CUDA, GPU & distributed training
    device = torch.device("cuda")
    args.n_gpu = torch.cuda.device_count()
    args.device = device

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank,
        device,
        args.n_gpu,
        bool(args.local_rank != -1),
        args.fp16,
    )

    # Set seed
    set_seed(args)

    config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)
    tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
    model = AutoModelWithLMHead.from_pretrained(
        args.model_name_or_path,
        from_tf=False,
        config=config,
        cache_dir=args.cache_dir,
    )
    print(model)
    model.to(args.device)
    
    logger.info("Training/evaluation parameters %s", args)

    # Training
    if args.do_train:
        train_dataset = load_and_cache_examples(args, tokenizer, df_trn, df_val, evaluate=False)
        global_step, tr_loss = train(args, train_dataset, df_trn, df_val, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
    if args.do_train:
        # Create output directory if needed
        os.makedirs(args.output_dir, exist_ok=True)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

        # Load a trained model and vocabulary that you have fine-tuned
        model = AutoModelWithLMHead.from_pretrained(args.output_dir)
        tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
        model.to(args.device)

    # Evaluation
    results = {}
    if args.do_eval and args.local_rank in [-1, 0]:
        checkpoints = [args.output_dir]
        if args.eval_all_checkpoints:
            checkpoints = list(
                os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
            )
            logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            model = AutoModelWithLMHead.from_pretrained(checkpoint)
            model.to(args.device)
            result = evaluate(args, model, tokenizer, df_trn, df_val, prefix=prefix, write_to_file=True)
            result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
            results.update(result)

    return results

It is time to train our model!

![alt text](https://media.giphy.com/media/Tia3dkakIp2m4uGoDI/giphy.gif)

Image from [Giphy](https://giphy.com/)

In [12]:
main(trn_df, val_df)



GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dro

06/19/2022 15:53:56 - INFO - __main__ -   Training/evaluation parameters <__main__.Args object at 0x000001E73B477460>
06/19/2022 15:53:56 - INFO - __main__ -   Creating features from dataset file at cached
06/19/2022 15:53:58 - INFO - __main__ -   Saving features into cached file cached\gpt2_cached_lm_512
06/19/2022 15:53:58 - INFO - __main__ -   ***** Running training *****
06/19/2022 15:53:58 - INFO - __main__ -     Num examples = 1517
06/19/2022 15:53:58 - INFO - __main__ -     Num Epochs = 3
06/19/2022 15:53:58 - INFO - __main__ -     Instantaneous batch size per GPU = 2
06/19/2022 15:53:58 - INFO - __main__ -     Total train batch size (w. parallel, distributed & accumulation) = 2
06/19/2022 15:53:58 - INFO - __main__ -     Gradient Accumulation steps = 1
06/19/2022 15:53:58 - INFO - __main__ -     Total optimization steps = 2274
06/19/2022 15:53:58 - INFO - __main__ -     Fine-tuning from scratch...


Epoch:   0%|          | 0/3 [00:00<?, ?it/s]

06/19/2022 15:53:58 - INFO - __main__ -   Creating features from dataset file at cached
06/19/2022 15:53:58 - INFO - __main__ -   Saving features into cached file cached\gpt2_cached_lm_512
06/19/2022 15:53:58 - INFO - __main__ -   ***** Running evaluation  *****
06/19/2022 15:53:58 - INFO - __main__ -     Num examples = 374
06/19/2022 15:53:58 - INFO - __main__ -     Batch size = 2


Evaluating:   0%|          | 0/187 [00:00<?, ?it/s]

tensor(8.6445, device='cuda:0')
tensor(7.7567, device='cuda:0')
tensor(8.0826, device='cuda:0')
tensor(6.6683, device='cuda:0')
tensor(6.3875, device='cuda:0')
tensor(6.8282, device='cuda:0')
tensor(7.6643, device='cuda:0')
tensor(8.4046, device='cuda:0')
tensor(8.7092, device='cuda:0')
tensor(8.7659, device='cuda:0')
tensor(8.9562, device='cuda:0')
tensor(7.7752, device='cuda:0')
tensor(7.2743, device='cuda:0')
tensor(6.8402, device='cuda:0')
tensor(7.5114, device='cuda:0')
tensor(6.6543, device='cuda:0')
tensor(6.0172, device='cuda:0')
tensor(6.3606, device='cuda:0')
tensor(6.5810, device='cuda:0')
tensor(6.7023, device='cuda:0')
tensor(7.1457, device='cuda:0')
tensor(6.9617, device='cuda:0')
tensor(6.8703, device='cuda:0')
tensor(7.2301, device='cuda:0')
tensor(7.0437, device='cuda:0')
tensor(6.1717, device='cuda:0')
tensor(6.6474, device='cuda:0')
tensor(6.2146, device='cuda:0')
tensor(6.4246, device='cuda:0')
tensor(6.9656, device='cuda:0')
tensor(7.7037, device='cuda:0')
tensor(8

Iteration:   0%|          | 0/758 [00:00<?, ?it/s]

06/19/2022 15:55:18 - INFO - __main__ -   Creating features from dataset file at cached
06/19/2022 15:55:18 - INFO - __main__ -   Saving features into cached file cached\gpt2_cached_lm_512
06/19/2022 15:55:18 - INFO - __main__ -   ***** Running evaluation  *****
06/19/2022 15:55:18 - INFO - __main__ -     Num examples = 374
06/19/2022 15:55:18 - INFO - __main__ -     Batch size = 2


Evaluating:   0%|          | 0/187 [00:00<?, ?it/s]

tensor(4.5818, device='cuda:0')
tensor(4.3031, device='cuda:0')
tensor(4.1152, device='cuda:0')
tensor(3.9673, device='cuda:0')
tensor(3.3992, device='cuda:0')
tensor(3.3441, device='cuda:0')
tensor(3.1737, device='cuda:0')
tensor(3.2542, device='cuda:0')
tensor(3.4716, device='cuda:0')
tensor(3.8334, device='cuda:0')
tensor(4.0817, device='cuda:0')
tensor(4.0977, device='cuda:0')
tensor(4.1238, device='cuda:0')
tensor(4.0276, device='cuda:0')
tensor(3.2116, device='cuda:0')
tensor(3.2303, device='cuda:0')
tensor(3.3711, device='cuda:0')
tensor(3.5202, device='cuda:0')
tensor(3.7234, device='cuda:0')
tensor(3.8291, device='cuda:0')
tensor(4.0545, device='cuda:0')
tensor(4.3198, device='cuda:0')
tensor(4.2116, device='cuda:0')
tensor(4.5760, device='cuda:0')
tensor(4.0660, device='cuda:0')
tensor(4.1639, device='cuda:0')
tensor(4.0467, device='cuda:0')
tensor(3.5088, device='cuda:0')
tensor(3.7168, device='cuda:0')
tensor(3.7611, device='cuda:0')
tensor(3.7073, device='cuda:0')
tensor(3

Iteration:   0%|          | 0/758 [00:00<?, ?it/s]

06/19/2022 15:56:35 - INFO - __main__ -   Creating features from dataset file at cached
06/19/2022 15:56:36 - INFO - __main__ -   Saving features into cached file cached\gpt2_cached_lm_512
06/19/2022 15:56:36 - INFO - __main__ -   ***** Running evaluation  *****
06/19/2022 15:56:36 - INFO - __main__ -     Num examples = 374
06/19/2022 15:56:36 - INFO - __main__ -     Batch size = 2


Evaluating:   0%|          | 0/187 [00:00<?, ?it/s]

tensor(5.4231, device='cuda:0')
tensor(5.1068, device='cuda:0')
tensor(4.8349, device='cuda:0')
tensor(4.5921, device='cuda:0')
tensor(3.8613, device='cuda:0')
tensor(3.7518, device='cuda:0')
tensor(3.6029, device='cuda:0')
tensor(3.7538, device='cuda:0')
tensor(4.0245, device='cuda:0')
tensor(4.5014, device='cuda:0')
tensor(4.7813, device='cuda:0')
tensor(4.8256, device='cuda:0')
tensor(4.7407, device='cuda:0')
tensor(4.5944, device='cuda:0')
tensor(3.5514, device='cuda:0')
tensor(3.5789, device='cuda:0')
tensor(3.9615, device='cuda:0')
tensor(4.1413, device='cuda:0')
tensor(4.3321, device='cuda:0')
tensor(4.3620, device='cuda:0')
tensor(4.6022, device='cuda:0')
tensor(4.8316, device='cuda:0')
tensor(4.8075, device='cuda:0')
tensor(5.3185, device='cuda:0')
tensor(4.6643, device='cuda:0')
tensor(4.8611, device='cuda:0')
tensor(4.6127, device='cuda:0')
tensor(4.0033, device='cuda:0')
tensor(4.2407, device='cuda:0')
tensor(4.3164, device='cuda:0')
tensor(4.2579, device='cuda:0')
tensor(4

Iteration:   0%|          | 0/758 [00:00<?, ?it/s]

06/19/2022 15:57:53 - INFO - __main__ -   Creating features from dataset file at cached
06/19/2022 15:57:53 - INFO - __main__ -   Saving features into cached file cached\gpt2_cached_lm_512
06/19/2022 15:57:53 - INFO - __main__ -   ***** Running evaluation  *****
06/19/2022 15:57:53 - INFO - __main__ -     Num examples = 374
06/19/2022 15:57:53 - INFO - __main__ -     Batch size = 2


Evaluating:   0%|          | 0/187 [00:00<?, ?it/s]

tensor(5.8473, device='cuda:0')
tensor(5.5320, device='cuda:0')
tensor(5.2341, device='cuda:0')
tensor(4.9742, device='cuda:0')
tensor(4.1755, device='cuda:0')
tensor(4.0430, device='cuda:0')
tensor(3.9148, device='cuda:0')
tensor(4.0934, device='cuda:0')
tensor(4.3878, device='cuda:0')
tensor(4.8891, device='cuda:0')
tensor(5.1592, device='cuda:0')
tensor(5.1957, device='cuda:0')
tensor(5.0078, device='cuda:0')
tensor(4.8295, device='cuda:0')
tensor(3.7081, device='cuda:0')
tensor(3.8266, device='cuda:0')
tensor(4.2386, device='cuda:0')
tensor(4.4225, device='cuda:0')
tensor(4.6351, device='cuda:0')
tensor(4.6232, device='cuda:0')
tensor(4.8839, device='cuda:0')
tensor(5.1013, device='cuda:0')
tensor(5.1066, device='cuda:0')
tensor(5.7224, device='cuda:0')
tensor(5.0470, device='cuda:0')
tensor(5.3060, device='cuda:0')
tensor(4.9715, device='cuda:0')
tensor(4.2979, device='cuda:0')
tensor(4.5398, device='cuda:0')
tensor(4.6439, device='cuda:0')
tensor(4.5719, device='cuda:0')
tensor(4

06/19/2022 15:57:56 - INFO - __main__ -    global_step = 2274, average loss = 1.8839206161631128
06/19/2022 15:57:56 - INFO - __main__ -   Saving model checkpoint to output-small


tensor(5.9819, device='cuda:0')
tensor(4.9808, device='cuda:0')
tensor(4.7165, device='cuda:0')
tensor(5.3673, device='cuda:0')
tensor(5.6573, device='cuda:0')
tensor(5.5446, device='cuda:0')
tensor(4.4543, device='cuda:0')
tensor(4.9177, device='cuda:0')
tensor(4.3173, device='cuda:0')
tensor(5.0997, device='cuda:0')
epochwise results: {'perplexity_2274': tensor(134.3033)}


06/19/2022 15:57:59 - INFO - __main__ -   Evaluate the following checkpoints: ['output-small']
06/19/2022 15:58:01 - INFO - __main__ -   Creating features from dataset file at cached
06/19/2022 15:58:01 - INFO - __main__ -   Saving features into cached file cached\gpt2_cached_lm_512
06/19/2022 15:58:01 - INFO - __main__ -   ***** Running evaluation  *****
06/19/2022 15:58:01 - INFO - __main__ -     Num examples = 374
06/19/2022 15:58:01 - INFO - __main__ -     Batch size = 2


Evaluating:   0%|          | 0/187 [00:00<?, ?it/s]

tensor(5.8473, device='cuda:0')
tensor(5.5320, device='cuda:0')
tensor(5.2341, device='cuda:0')
tensor(4.9742, device='cuda:0')
tensor(4.1755, device='cuda:0')
tensor(4.0430, device='cuda:0')
tensor(3.9148, device='cuda:0')
tensor(4.0934, device='cuda:0')
tensor(4.3878, device='cuda:0')
tensor(4.8891, device='cuda:0')
tensor(5.1592, device='cuda:0')
tensor(5.1957, device='cuda:0')
tensor(5.0078, device='cuda:0')
tensor(4.8295, device='cuda:0')
tensor(3.7081, device='cuda:0')
tensor(3.8266, device='cuda:0')
tensor(4.2386, device='cuda:0')
tensor(4.4225, device='cuda:0')
tensor(4.6351, device='cuda:0')
tensor(4.6232, device='cuda:0')
tensor(4.8839, device='cuda:0')
tensor(5.1013, device='cuda:0')
tensor(5.1066, device='cuda:0')
tensor(5.7224, device='cuda:0')
tensor(5.0470, device='cuda:0')
tensor(5.3060, device='cuda:0')
tensor(4.9715, device='cuda:0')
tensor(4.2979, device='cuda:0')
tensor(4.5398, device='cuda:0')
tensor(4.6439, device='cuda:0')
tensor(4.5719, device='cuda:0')
tensor(4

06/19/2022 15:58:05 - INFO - __main__ -   ***** Eval results  *****
06/19/2022 15:58:05 - INFO - __main__ -     perplexity = tensor(134.3033)


tensor(4.9808, device='cuda:0')
tensor(4.7165, device='cuda:0')
tensor(5.3673, device='cuda:0')
tensor(5.6573, device='cuda:0')
tensor(5.5446, device='cuda:0')
tensor(4.4543, device='cuda:0')
tensor(4.9177, device='cuda:0')
tensor(4.3173, device='cuda:0')
tensor(5.0997, device='cuda:0')


{'perplexity_': tensor(134.3033)}

In [14]:
val_df

Unnamed: 0,response,context,context/0,context/1,context/2,context/3,context/4,context/5
0,"Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...,"Don't worry, Morty, they love you. Superheroes...",This article says the reason we weren't involv...
1,"Yeah, but not heroes.","Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...,"Don't worry, Morty, they love you. Superheroes..."
2,"Oh, please. They just call themselves heroes s...","Yeah, but not heroes.","Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...
3,"I'm calling them that, Rick! They're my heroes...","Oh, please. They just call themselves heroes s...","Yeah, but not heroes.","Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan..."
4,Huh... no accounting for taste. I'm gonna go g...,"I'm calling them that, Rick! They're my heroes...","Oh, please. They just call themselves heroes s...","Yeah, but not heroes.","Why? Morty, I defeat gagoos more powerful tha...","Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y..."
...,...,...,...,...,...,...,...,...
369,That was amazing!,"Whoa!! Hahaha, yeah! Atlantis, baby!",Holy crap... Slick's wish came true.,"I don't know, and I don't have to know. I've b...",To what?,No graduation. No new Ricks. The school's curr...,Did we miss graduation? Where are the new Ricks?,But... I violated at least a dozen departmenta...
370,Got some of that mermaid puss!,That was amazing!,"Whoa!! Hahaha, yeah! Atlantis, baby!",Holy crap... Slick's wish came true.,"I don't know, and I don't have to know. I've b...",To what?,No graduation. No new Ricks. The school's curr...,Did we miss graduation? Where are the new Ricks?
371,I'm really hoping it wasn't a one-off thing an...,Got some of that mermaid puss!,That was amazing!,"Whoa!! Hahaha, yeah! Atlantis, baby!",Holy crap... Slick's wish came true.,"I don't know, and I don't have to know. I've b...",To what?,No graduation. No new Ricks. The school's curr...
372,"Pssh! Not at all, Morty. That place will never...",I'm really hoping it wasn't a one-off thing an...,Got some of that mermaid puss!,That was amazing!,"Whoa!! Hahaha, yeah! Atlantis, baby!",Holy crap... Slick's wish came true.,"I don't know, and I don't have to know. I've b...",To what?


## Chatting with  Rick

The model is ready, so it's time to chat with Rick. But don't forget that Rick can be rude, I warned you.

A variety of methods can be used in responces generation. You can find more details about these methods by this [link](https://huggingface.co/blog/how-to-generate). 

![alt text](https://media.giphy.com/media/ftl8NiWz2C3tuXOn3w/giphy.gif)

Image from [Giphy](https://giphy.com/)

In [13]:
tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')
model = AutoModelWithLMHead.from_pretrained('output-small')

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
    # print(new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature = 0.8
    )
    
    # pretty print last ouput tokens from bot
    print("RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

>> User:How are you, Rick?
RickBot: I'm fine.
>> User:Where is Morty?
RickBot: He's doing great things. He's doing just fine. Where are we going?


KeyboardInterrupt: Interrupted by user

**User**:How are you, Rick? <br>
**RickBot**: I'm fine. <br>
**User**:Where is Morty? <br>
**RickBot**: He's in a hospital. <br>
**User**:Why? <br>
**RickBot**: Well, that depends on who breaks first -- me or the hospital. I mean, you know, the one that keeps giving me pills. <br>
**User**:I don't understand <br>
**RickBot**: Whoa! Whoa! <br>
**User**:I see, you took your pills <br>
**RickBot**: Hey, I-I'm not proud of this. <br>

I will give an example of a few more dialogues to show how Ricky is our Rick now.

![alt text](https://media.giphy.com/media/eNpXWzGIMRjIo4lXT8/giphy.gif)

Image from [Giphy](https://giphy.com/)

In [16]:
tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')
model = AutoModelWithLMHead.from_pretrained('output-small')

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
    # print(new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature = 0.8
    )
    
    # pretty print last ouput tokens from bot
    print("RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

>> User:Who are you?
RickBot: The head of the Federation of Jan-Michael Vincent's.
>> User:That's so cool
RickBot: Jan-Michael Vindicators.


KeyboardInterrupt: Interrupted by user

**User**:Who are you? <br>
**RickBot**: A musician, I'm not a musician, but I'm pretty close to you. <br>
**User**:What kind of musician are you? <br>
**RickBot**: Who cares? I'm a musician. <br>
**User**:Ok, play a song for me <br>
**RickBot**: Shut up and listen to me, Morty. Listen to me. Listen! I'm the leader of a bunch of different death worlds, and I'm just trying to assemble a band. I'm gonna make this one-piece, and you're gonna be part of it. <br>
**User**:ok <br>
**RickBot**: Whoa, whoa, hey, who's this? <br>
**User**:This is Morty <br>
**RickBot**: Geez!!! <br>

In [22]:
tokenizer = AutoTokenizer.from_pretrained('microsoft/DialoGPT-small')
model = AutoModelWithLMHead.from_pretrained('output-small')

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt')
    # print(new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature = 0.8
    )
    
    # pretty print last ouput tokens from bot
    print("RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

>> User:Rick what's going on in your lab?
RickBot: A whole world populated by intelligent dogs. A whole world of intelligent dogs, one house at a time.
>> User:Whoa, that's cool
RickBot: I think I can rent a dog now.
>> User:More than that, you have an entire world of intelligent dogs!
RickBot: That's enough, Rick! We don't want you injuring your ball-throwing arm, you're scaring us.


KeyboardInterrupt: Interrupted by user

**User**:What do you think about Elon Musk? <br>
**RickBot**: I think about him a lot. <br>
**User**:But what exactly do you think? <br>
**RickBot**: He's not a bad guy. He's just... not a very charismatic one. <br>
**User**:I think he is charismatic <br>
**RickBot**: He is. <br>
**User**:What else you can say about him? <br>
**RickBot**: You don't have to say anything. He is the leader of your class. <br>
**User**:ok, thank you <br>
**RickBot**: !!!

![alt text](https://i.kym-cdn.com/news/posts/desktop/000/000/247/Screenshot_21.jpg)

Image from [news.knowyourmeme.com](https://news.knowyourmeme.com/news/elon-musks-rick-and-morty-cameo-is-a-meme-reference)

![alt text](https://media.giphy.com/media/YSTmV6MWEkgW0HblZQ/giphy.gif)

Image from [Giphy](https://giphy.com/)

**User**:hey, go get me a burger <br>
RickBot: Whoa, whoa, hold on a second, let's not get crazy. This is gonna be good. <br>
**User**:I'm hungry! c'mon, buy me a burger or better two <br>
**RickBot**: Holy crap. <br>
**User**:no crappy burger, a good one! <br>
**RickBot**: What the hell? <br>
**User**:don't be so bad, feed me <br>
**RickBot**: I'll be back as fast as I can, Morty. <br>
**User**:good boy, and don't forget ketchup and fries <br>
**RickBot**: !!!!!! <br>
**User**:what a hell? can you reply with words? <br>
**RickBot**: Yes!! I will!! We will go out on a date. <br>

## Conclusion

![alt text](https://media.giphy.com/media/YQBWwGLJtFDQIATck5/giphy.gif)

Image from [Giphy](https://giphy.com/)

Congratulations! Our virtual Rick is alive (almost)! With the help of fine-tuning our model on a small dataset, we were able to create a virtual character with whom we can conduct interesting dialogs. 

Using the proposed approach you can create many interesting virtual characters based on an arbitrary dialogs dataset (just a csv file with replicas, one replica per line).