# TODO:

Input Baseline ----> <br>
* input: context_1 <sep> context_2 <sep> response <eos> <eos> <eos> <eos> <eos>  <br>
* labels: input <br>

Input Masked ------><br> 
* input: <speaker> context_1 <bot> context_2 <speaker> context_2 <bot> response <eos> <pad> <pad> <pad> <pad><br>
* labels: [-100,-100,...,-100,<bot> response, -100,-100,-100]<br>

Input Context -----><br>
* input: <context> situeation_context <speaker> context_1 <bot> context_2 <speaker> context_2 <bot> response <eos> <pad> <pad> <pad> <pad><br>
* labels: [-100,-100,...,-100,<bot> response, -100,-100,-100]<br>


In [1]:
DRIVE=True


if DRIVE:
    # conect google drive
    from google.colab import drive
    drive.mount('/content/drive')

    currently_path = '/content/drive/MyDrive/03. chatbot_with_personality'
    %cd $currently_path  

    %pip install -r requirements.txt  

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/03. chatbot_with_personality


In [2]:
# Essential libraries
import torch

# load and split data in train & test
from sklearn.model_selection import train_test_split
import pandas as pd

# Create Dataset
import os
import pickle
import pandas as pd
from torch.utils.data import Dataset

# pad sequences
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

# Libraries for find checkpoint
import re
import glob
import shutil

# animation for iteration
from tqdm.notebook import tqdm, trange

# Parameters 
import config as cfg

from transformers import (
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
)

import logging
# Instance logger
logger = logging.getLogger(__name__)
# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO
)

In [3]:
def setup_device_for_train():
    try:
        device = torch.device('cuda') 
    except:
        device = torch.device('cpu') #cpu
    cfg.n_gpu = torch.cuda.device_count()
    cfg.device = device

def _sorted_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False):
    ordering_and_checkpoint_path = []

    glob_checkpoints = glob.glob(os.path.join(args.output_dir, "{}-*".format(checkpoint_prefix)))

    for path in glob_checkpoints:
        if use_mtime:
            ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
        else:
            regex_match = re.match(".*{}-([0-9]+)".format(checkpoint_prefix), path)
            if regex_match and regex_match.groups():
                ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

    checkpoints_sorted = sorted(ordering_and_checkpoint_path)
    checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
    return checkpoints_sorted

def _rotate_checkpoints(args, checkpoint_prefix="checkpoint", use_mtime=False) -> None:
    if not args.save_total_limit:
        return
    if args.save_total_limit <= 0:
        return

    # Check if we should delete older checkpoint(s)
    checkpoints_sorted = _sorted_checkpoints(args, checkpoint_prefix, use_mtime)
    if len(checkpoints_sorted) <= args.save_total_limit:
        return

    number_of_checkpoints_to_delete = max(0, len(checkpoints_sorted) - args.save_total_limit)
    checkpoints_to_be_deleted = checkpoints_sorted[:number_of_checkpoints_to_delete]
    for checkpoint in checkpoints_to_be_deleted:
        logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
        shutil.rmtree(checkpoint)

# Added here for reproducibility
def set_seed(seed):
    torch.manual_seed(seed)
    if cfg.n_gpu > 0:
        torch.cuda.manual_seed_all(seed)


def set_tokens(tokenizer,model,Model_version=None):
    
    # Set tokens
    #tokenizer.bos_token = '<|beginningoftext|>'
    #tokenizer.eos_token = '<|endoftext|>'
    #tokenizer.sep_token = '<|sep|>'
    #tokenizer.pad_token = '<|pad|>'
    #tokenizer.unk_token = '<|unk|>'

    SPECIAL_TOKENS = {
        'bos_token': '<|beginningoftext|>',
        'eos_token': '<|endoftext|>',
        'sep_token': '<|sep|>',
        'pad_token': '<|pad|>',
        'unk_token': '<|unk|>'
        }
    tokenizer.add_special_tokens(SPECIAL_TOKENS)
    
    if Model_version is not None:
        # Only for model v2.
        ## Set special tokens
        speaker_token = '<|speaker|>'
        bot_token = '<|bot|>'
        tokenizer.add_tokens([speaker_token, bot_token], special_tokens=True)

    # set tokens in model
    model.resize_token_embeddings(len(tokenizer))

    return tokenizer, model


def transform_dataset_to_features(row, tokenizer, BOT_TOKEN=False):
    """ Add EOS token to each sentecenes from row --> tokenized ---> flatten tokenized row
    Input Baseline ---->
        input: context_1 <sep> context_2 <sep> response <eos> <pad> <pad> <pad> ... <pad> 
        labels: [-100,-100,...,-100,<bot> response, -100,-100,-100]

    Args:
        row (list of sentences): the first N elements is the context and the last element is the response or target
        tokenizer (object): tokenizer from respective model

    Returns:
        (list of ints): [len(flatten(row))] (shape is variable because not have padding) flatten row tokenized 
            >>>>  [651,5513,86,24905,287,...,640,284,651,5513,86,50256] --> 50256=eos_token
    """

    flatten = lambda l: [item for sublist in l for item in sublist]
    token_features = []
    labels = []

    for i,x in enumerate(row):
        
        if BOT_TOKEN == False:
            if cfg.bot_token in x:
                x = x.replace(cfg.bot_token+' ', '')
                BOT=True
            elif cfg.user_token in x:
                x = x.replace(cfg.user_token+' ', '')
                BOT=False

        tokens = tokenizer.encode(x) + [tokenizer.eos_token_id]
        token_features.append(tokens)

        # add labels
        if BOT:
            labels.append(tokens)
        else:
            labels.append([-100]*len(tokens))

    token_features = flatten(token_features)
    labels = flatten(labels)
    return token_features, labels

In [4]:
# Create Dataset
class My_convertional_dataset(Dataset):
    def __init__(self, df_prepared, tokenizer, cfg, save_features=False):
        """
        Args:
            df_prepared (pd.DataFrame): dataframe created 0.0.prepare_dataset
                [
                    [context_0, context_1, ..., context_N, response], # sample_0
                    [context_0, context_1, ..., context_N, response], # sample_1
                    ...,
                    [context_0, context_1, ..., context_N, response] # sample_L
                ]
            tokenizer (object)
            cfg (object): parameters
        """
        # Initialize Dataset
        super(Dataset, self).__init__()

        # Create the name for file that to store the features tokenized
        directory = cfg.cache_dir
        cached_features_file = os.path.join(directory, cfg.model_type + "_cached_lm_" + str(tokenizer.model_max_length))

        # print logging info
        logger.info("Creating features from dataset file at %s", cached_features_file)

        # Transform dataset to features and store in {cached_features_file} file
        ## transform data
        self.examples = [transform_dataset_to_features(row, tokenizer) for _,row in df_prepared.iterrows()] # list of tuples (x,y)

        if save_features:
            ## stored features (tokenized dataset)
            with open(cached_features_file, 'wb') as handle:
                pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)

            # print logging info
            logger.info("Features created and saved at %s", cached_features_file)
        
    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        token_featues, labels = self.examples[idx]
        return torch.tensor(token_featues, dtype=torch.long), torch.tensor(labels)

# Train

In [5]:
def train(cfg, train_dataset, model, tokenizer):
    def padding_fn(examples):
        # Only uses for batch's in inference it's not necessary
        x,y = list(zip(*examples))
        x_pad = pad_sequence(x, batch_first=True, padding_value=tokenizer.eos_token_id)
        y_pad = pad_sequence(y, batch_first=True, padding_value=-100)
        return x_pad, y_pad

    # Build DataLoadet
    train_dataloader = DataLoader(
                            dataset=train_dataset,
                            batch_size=cfg.train_batch_size,
                            shuffle=True,
                            collate_fn=padding_fn
                            )


    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": cfg.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.learning_rate, eps=cfg.adam_epsilon)
    t_total = len(train_dataloader) // cfg.num_train_epochs # paramter for schedule lr
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=cfg.warmup_steps, num_training_steps=t_total
    )

    # Check if saved optimizer or scheduler states exist
    if (
        cfg.model_name_or_path
        and os.path.isfile(os.path.join(cfg.model_name_or_path, "optimizer.pt"))
        and os.path.isfile(os.path.join(cfg.model_name_or_path, "scheduler.pt"))
    ):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(torch.load(os.path.join(cfg.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(torch.load(os.path.join(cfg.model_name_or_path, "scheduler.pt")))

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", cfg.num_train_epochs)
    logger.info("  Num steps by Epoch = %d", len(train_dataset)//cfg.train_batch_size)


    # Initizalize variables
    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    tr_loss = 0.0

    # Check if continuing training from a checkpoint
    if cfg.model_name_or_path and os.path.exists(cfg.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = cfg.model_name_or_path.split("-")[-1].split("/")[0] # TODO: check this
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // len(train_dataloader)
            steps_trained_in_current_epoch = global_step % len(train_dataloader)

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d", global_step)
            logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    # Reset accumulate gradient
    model.zero_grad()
    # switch model to train mode
    model.train()

    # iteration by epoch animation
    epoch_generator = trange(epochs_trained, int(cfg.num_train_epochs), desc="Epoch")
    for _ in epoch_generator: # iterate for epoch
        train_generator = tqdm(train_dataloader, desc="Iteration")
        for step,batch in enumerate(train_generator): # iterate for batch

            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue
            
            inputs, labels = batch
            # pass the next batch if the seuqence length is grether than 1024
            if inputs.shape[1] > 1024: continue

            # Load tensor to device
            inputs = inputs.to(cfg.device)
            labels = labels.to(cfg.device)

            # Compute prediction error
            ## Forward --> when labels is provided returns a tuple and the first postion is the loss (see documentation)
            outputs = model(inputs, labels=labels)
            loss = outputs[0]
            
            # Backpropagation
            ## limit parameters to avoid gradient exploding
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)
            ## Calculate gradient from current loss
            loss.backward()
            ## Update Neural Network parameters with current gradient
            optimizer.step()
            ## Update learning rate schedule
            scheduler.step()  
            ## Set cumulative gradient to zero
            model.zero_grad()
            
            global_step += 1
            tr_loss += loss.item()

            logging.info(f"{global_step}. loss ---------> {tr_loss/ global_step:.4f}")

            # Save model, tokenizer, optimizer, scheduler each cfg.save_steps
            if cfg.save_steps > 0 and global_step % cfg.save_steps == 0:
                checkpoint_prefix = "checkpoint"
                # Save model checkpoint
                output_dir = os.path.join(cfg.output_dir, "{}-{}".format(checkpoint_prefix, global_step))
                os.makedirs(output_dir, exist_ok=True)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                #torch.save(cfg, os.path.join(output_dir, "training_cfg.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                _rotate_checkpoints(cfg, checkpoint_prefix)

                torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s", output_dir)
            
            if cfg.max_steps > 0 and global_step > cfg.max_steps:
                train_generator.close()
                break
        if cfg.max_steps > 0 and global_step > cfg.max_steps:
            epoch_generator.close()
            break
    return global_step, tr_loss / global_step

In [6]:
def evaluate(cfg, model, tokenizer, val_dataset, prefix=""):
    # Loop to handle MNLI double evaluation (matched, mis-matched)
    eval_output_dir = cfg.output_dir
    os.makedirs(eval_output_dir, exist_ok=True)

    def padding_fn(examples):
        # Only uses for batch's in inference it's not necessary
        x,y = list(zip(*examples))
        x_pad = pad_sequence(x, batch_first=True, padding_value=tokenizer.eos_token_id)
        y_pad = pad_sequence(y, batch_first=True, padding_value=-100)
        return x_pad, y_pad

    # Build DataLoadet
    eval_dataloader = DataLoader(
                            dataset=val_dataset,
                            batch_size=cfg.train_batch_size,
                            shuffle=True,
                            collate_fn=padding_fn
                            )

    # Eval!
    logger.info("***** Running evaluation {} *****".format(prefix))

    eval_loss = 0.0
    nb_eval_steps = 0
    model.eval()

    for batch in tqdm(eval_dataloader, desc="Evaluating"):
        inputs, labels = batch
        inputs = inputs.to(cfg.device)
        labels = labels.to(cfg.device)

        with torch.no_grad():
            outputs = model(inputs, labels=labels)
            lm_loss = outputs[0]
            eval_loss += lm_loss.mean().item()
        nb_eval_steps += 1

    eval_loss = eval_loss / nb_eval_steps
    perplexity = torch.exp(torch.tensor(eval_loss))

    result = {"perplexity": perplexity}

    output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
    with open(output_eval_file, "w") as writer:
        logger.info("***** Eval results {} *****".format(prefix))
        for key in sorted(result.keys()):
            logger.info("  %s = %s", key, str(result[key]))
            writer.write("%s = %s\n" % (key, str(result[key])))
    return result


# Main

In [7]:
def main(cfg, df_train, df_val):
    # train from checkpoint
    if cfg.should_continue:
        sorted_checkpoints = _sorted_checkpoints(cfg)
        if len(sorted_checkpoints) == 0:
            raise ValueError("Used --should_continue but no checkpoint was found in --output_dir.")
        else:
            cfg.model_name_or_path = sorted_checkpoints[-1]

    # Setup CUDA, GPU
    setup_device_for_train()

    # print info about device
    logger.warning(f"Info device: {cfg.device}, n_gpu: {cfg.n_gpu}")

    # Set seed for reproducibility
    set_seed(cfg.seed)

    # Instance config, model and tokenizer
    ## Save model & tokenizer in {cfg.cached_dir} folder 
    config = AutoConfig.from_pretrained(cfg.config_name, cache_dir=cfg.cache_dir)
    tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_name, cache_dir=cfg.cache_dir)

    #model = AutoModelWithLMHead.from_pretrained( # AutoModelWithLMHead --> Language Model #AutoModelForCausalLM
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name_or_path,
        from_tf=False,
        config=config,
        cache_dir=cfg.cache_dir,
    )

    # setting special tokens (only if not checkpoints)
    #tokenizer, model = set_tokens(tokenizer,model,Model_version=None)

    # Load model to device
    model.to(cfg.device)

    if cfg.do_train:
        train_dataset = My_convertional_dataset(df_train, tokenizer, cfg)

        global_step, tr_loss = train(cfg, train_dataset, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

    # Saving best-practices: if you use save_pretrained for the model and tokenizer, you can reload them using from_pretrained()
    if cfg.do_train:
        # Create output directory if needed
        os.makedirs(cfg.output_dir, exist_ok=True)

        logger.info("Saving model checkpoint to %s", cfg.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(cfg.output_dir)
        tokenizer.save_pretrained(cfg.output_dir)

        # Load a trained model and vocabulary that you have fine-tuned
        model = AutoModelForCausalLM.from_pretrained(cfg.output_dir)
        tokenizer = AutoTokenizer.from_pretrained(cfg.output_dir)
        model.to(cfg.device)

        # Evaluation
        results = {}
        if cfg.do_eval:
            eval_dataset = My_convertional_dataset(df_val, tokenizer, cfg)
            checkpoints = [cfg.output_dir]
            if cfg.eval_all_checkpoints:
                checkpoints = list(
                    os.path.dirname(c) for c in sorted(glob.glob(cfg.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
                )
                logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN)  # Reduce logging
            logger.info("Evaluate the following checkpoints: %s", checkpoints)
            for checkpoint in checkpoints:
                global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
                prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""

                model = AutoModelForCausalLM.from_pretrained(checkpoint)
                model.to(cfg.device)
                result = evaluate(cfg, model, tokenizer, eval_dataset, prefix=prefix)
                result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
                results.update(result)
    return results

In [8]:
# load and split dataset
PATH_DATASET = "Rick_dataset_preproced.csv"
df = pd.read_csv(PATH_DATASET)
df_train, df_val = train_test_split(df,test_size=0.1, random_state=cfg.seed)

cfg.max_steps = -1
cfg.should_continue = False
cfg.num_train_epochs = 20
cfg.output_dir = "Bot_Rick"
main(cfg, df_train, df_val)

10/05/2021 11:39:04 - INFO - __main__ -   Creating features from dataset file at cached/gpt2_cached_lm_1024
10/05/2021 11:39:05 - INFO - __main__ -   ***** Running training *****
10/05/2021 11:39:05 - INFO - __main__ -     Num examples = 374
10/05/2021 11:39:05 - INFO - __main__ -     Num Epochs = 20
10/05/2021 11:39:05 - INFO - __main__ -     Num steps by Epoch = 93


Epoch:   0%|          | 0/20 [00:00<?, ?it/s]

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:39:06 - INFO - root -   1. loss ---------> 8.3301
10/05/2021 11:39:06 - INFO - root -   2. loss ---------> 8.1997
10/05/2021 11:39:06 - INFO - root -   3. loss ---------> 8.0326
10/05/2021 11:39:06 - INFO - root -   4. loss ---------> 7.9416
10/05/2021 11:39:06 - INFO - root -   5. loss ---------> 7.8512
10/05/2021 11:39:07 - INFO - root -   6. loss ---------> 7.6588
10/05/2021 11:39:07 - INFO - root -   7. loss ---------> 7.5469
10/05/2021 11:39:07 - INFO - root -   8. loss ---------> 7.4836
10/05/2021 11:39:07 - INFO - root -   9. loss ---------> 7.3712
10/05/2021 11:39:07 - INFO - root -   10. loss ---------> 7.3064
10/05/2021 11:39:07 - INFO - root -   11. loss ---------> 7.2201
10/05/2021 11:39:08 - INFO - root -   12. loss ---------> 7.2453
10/05/2021 11:39:08 - INFO - root -   13. loss ---------> 7.3512
10/05/2021 11:39:08 - INFO - root -   14. loss ---------> 7.3999
10/05/2021 11:39:08 - INFO - root -   15. loss ---------> 7.3861
10/05/2021 11:39:08 - INFO - root 

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:39:22 - INFO - root -   95. loss ---------> 7.5391
10/05/2021 11:39:22 - INFO - root -   96. loss ---------> 7.5364
10/05/2021 11:39:22 - INFO - root -   97. loss ---------> 7.5361
10/05/2021 11:39:23 - INFO - root -   98. loss ---------> 7.5321
10/05/2021 11:39:23 - INFO - root -   99. loss ---------> 7.5225
10/05/2021 11:39:23 - INFO - root -   100. loss ---------> 7.5220
10/05/2021 11:39:23 - INFO - root -   101. loss ---------> 7.5176
10/05/2021 11:39:23 - INFO - root -   102. loss ---------> 7.5413
10/05/2021 11:39:23 - INFO - root -   103. loss ---------> 7.5380
10/05/2021 11:39:23 - INFO - root -   104. loss ---------> 7.5445
10/05/2021 11:39:24 - INFO - root -   105. loss ---------> 7.5431
10/05/2021 11:39:24 - INFO - root -   106. loss ---------> 7.5502
10/05/2021 11:39:24 - INFO - root -   107. loss ---------> 7.5522
10/05/2021 11:39:24 - INFO - root -   108. loss ---------> 7.5478
10/05/2021 11:39:24 - INFO - root -   109. loss ---------> 7.5416
10/05/2021 11:3

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:39:39 - INFO - root -   189. loss ---------> 7.5150
10/05/2021 11:39:39 - INFO - root -   190. loss ---------> 7.5476
10/05/2021 11:39:39 - INFO - root -   191. loss ---------> 7.5488
10/05/2021 11:39:39 - INFO - root -   192. loss ---------> 7.5487
10/05/2021 11:39:39 - INFO - root -   193. loss ---------> 7.5419
10/05/2021 11:39:39 - INFO - root -   194. loss ---------> 7.5401
10/05/2021 11:39:40 - INFO - root -   195. loss ---------> 7.5447
10/05/2021 11:39:40 - INFO - root -   196. loss ---------> 7.5447
10/05/2021 11:39:40 - INFO - root -   197. loss ---------> 7.5410
10/05/2021 11:39:40 - INFO - root -   198. loss ---------> 7.5376
10/05/2021 11:39:40 - INFO - root -   199. loss ---------> 7.5340
10/05/2021 11:39:40 - INFO - root -   200. loss ---------> 7.5270
10/05/2021 11:39:41 - INFO - root -   201. loss ---------> 7.5250
10/05/2021 11:39:41 - INFO - root -   202. loss ---------> 7.5248
10/05/2021 11:39:41 - INFO - root -   203. loss ---------> 7.5232
10/05/2021

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:39:55 - INFO - root -   283. loss ---------> 7.5153
10/05/2021 11:39:55 - INFO - root -   284. loss ---------> 7.5149
10/05/2021 11:39:55 - INFO - root -   285. loss ---------> 7.5139
10/05/2021 11:39:55 - INFO - root -   286. loss ---------> 7.5154
10/05/2021 11:39:56 - INFO - root -   287. loss ---------> 7.5159
10/05/2021 11:39:56 - INFO - root -   288. loss ---------> 7.5159
10/05/2021 11:39:56 - INFO - root -   289. loss ---------> 7.5121
10/05/2021 11:39:56 - INFO - root -   290. loss ---------> 7.5095
10/05/2021 11:39:56 - INFO - root -   291. loss ---------> 7.5073
10/05/2021 11:39:56 - INFO - root -   292. loss ---------> 7.5068
10/05/2021 11:39:57 - INFO - root -   293. loss ---------> 7.5060
10/05/2021 11:39:57 - INFO - root -   294. loss ---------> 7.5052
10/05/2021 11:39:57 - INFO - root -   295. loss ---------> 7.5073
10/05/2021 11:39:57 - INFO - root -   296. loss ---------> 7.5147
10/05/2021 11:39:57 - INFO - root -   297. loss ---------> 7.5130
10/05/2021

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:40:11 - INFO - root -   377. loss ---------> 7.5033
10/05/2021 11:40:12 - INFO - root -   378. loss ---------> 7.5029
10/05/2021 11:40:12 - INFO - root -   379. loss ---------> 7.5046
10/05/2021 11:40:12 - INFO - root -   380. loss ---------> 7.5038
10/05/2021 11:40:12 - INFO - root -   381. loss ---------> 7.5042
10/05/2021 11:40:12 - INFO - root -   382. loss ---------> 7.5037
10/05/2021 11:40:13 - INFO - root -   383. loss ---------> 7.5039
10/05/2021 11:40:13 - INFO - root -   384. loss ---------> 7.5019
10/05/2021 11:40:13 - INFO - root -   385. loss ---------> 7.5065
10/05/2021 11:40:13 - INFO - root -   386. loss ---------> 7.5060
10/05/2021 11:40:13 - INFO - root -   387. loss ---------> 7.5065
10/05/2021 11:40:13 - INFO - root -   388. loss ---------> 7.5056
10/05/2021 11:40:14 - INFO - root -   389. loss ---------> 7.5067
10/05/2021 11:40:14 - INFO - root -   390. loss ---------> 7.5059
10/05/2021 11:40:14 - INFO - root -   391. loss ---------> 7.5062
10/05/2021

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:40:28 - INFO - root -   471. loss ---------> 7.5037
10/05/2021 11:40:28 - INFO - root -   472. loss ---------> 7.5016
10/05/2021 11:40:28 - INFO - root -   473. loss ---------> 7.5018
10/05/2021 11:40:29 - INFO - root -   474. loss ---------> 7.5018
10/05/2021 11:40:29 - INFO - root -   475. loss ---------> 7.5014
10/05/2021 11:40:29 - INFO - root -   476. loss ---------> 7.4991
10/05/2021 11:40:29 - INFO - root -   477. loss ---------> 7.5003
10/05/2021 11:40:29 - INFO - root -   478. loss ---------> 7.5009
10/05/2021 11:40:30 - INFO - root -   479. loss ---------> 7.5000
10/05/2021 11:40:30 - INFO - root -   480. loss ---------> 7.5027
10/05/2021 11:40:30 - INFO - root -   481. loss ---------> 7.5055
10/05/2021 11:40:30 - INFO - root -   482. loss ---------> 7.5040
10/05/2021 11:40:30 - INFO - root -   483. loss ---------> 7.5046
10/05/2021 11:40:30 - INFO - root -   484. loss ---------> 7.5037
10/05/2021 11:40:31 - INFO - root -   485. loss ---------> 7.5021
10/05/2021

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:40:45 - INFO - root -   565. loss ---------> 7.4923
10/05/2021 11:40:45 - INFO - root -   566. loss ---------> 7.4915
10/05/2021 11:40:45 - INFO - root -   567. loss ---------> 7.4913
10/05/2021 11:40:45 - INFO - root -   568. loss ---------> 7.4909
10/05/2021 11:40:45 - INFO - root -   569. loss ---------> 7.4912
10/05/2021 11:40:46 - INFO - root -   570. loss ---------> 7.4929
10/05/2021 11:40:46 - INFO - root -   571. loss ---------> 7.4939
10/05/2021 11:40:46 - INFO - root -   572. loss ---------> 7.4947
10/05/2021 11:40:46 - INFO - root -   573. loss ---------> 7.4964
10/05/2021 11:40:47 - INFO - root -   574. loss ---------> 7.4954
10/05/2021 11:40:47 - INFO - root -   575. loss ---------> 7.4964
10/05/2021 11:40:47 - INFO - root -   576. loss ---------> 7.4967
10/05/2021 11:40:47 - INFO - root -   577. loss ---------> 7.4972
10/05/2021 11:40:47 - INFO - root -   578. loss ---------> 7.4981
10/05/2021 11:40:47 - INFO - root -   579. loss ---------> 7.4969
10/05/2021

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:41:01 - INFO - root -   659. loss ---------> 7.4900
10/05/2021 11:41:01 - INFO - root -   660. loss ---------> 7.4910
10/05/2021 11:41:02 - INFO - root -   661. loss ---------> 7.4900
10/05/2021 11:41:02 - INFO - root -   662. loss ---------> 7.4918
10/05/2021 11:41:02 - INFO - root -   663. loss ---------> 7.4931
10/05/2021 11:41:02 - INFO - root -   664. loss ---------> 7.4933
10/05/2021 11:41:02 - INFO - root -   665. loss ---------> 7.4919
10/05/2021 11:41:02 - INFO - root -   666. loss ---------> 7.4914
10/05/2021 11:41:03 - INFO - root -   667. loss ---------> 7.4917
10/05/2021 11:41:03 - INFO - root -   668. loss ---------> 7.4911
10/05/2021 11:41:03 - INFO - root -   669. loss ---------> 7.4913
10/05/2021 11:41:03 - INFO - root -   670. loss ---------> 7.4921
10/05/2021 11:41:03 - INFO - root -   671. loss ---------> 7.4915
10/05/2021 11:41:04 - INFO - root -   672. loss ---------> 7.4896
10/05/2021 11:41:04 - INFO - root -   673. loss ---------> 7.4892
10/05/2021

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:41:18 - INFO - root -   753. loss ---------> 7.4967
10/05/2021 11:41:18 - INFO - root -   754. loss ---------> 7.4965
10/05/2021 11:41:18 - INFO - root -   755. loss ---------> 7.4956
10/05/2021 11:41:18 - INFO - root -   756. loss ---------> 7.4951
10/05/2021 11:41:18 - INFO - root -   757. loss ---------> 7.4955
10/05/2021 11:41:18 - INFO - root -   758. loss ---------> 7.4961
10/05/2021 11:41:19 - INFO - root -   759. loss ---------> 7.4958
10/05/2021 11:41:19 - INFO - root -   760. loss ---------> 7.4956
10/05/2021 11:41:19 - INFO - root -   761. loss ---------> 7.4989
10/05/2021 11:41:19 - INFO - root -   762. loss ---------> 7.4981
10/05/2021 11:41:19 - INFO - root -   763. loss ---------> 7.4980
10/05/2021 11:41:20 - INFO - root -   764. loss ---------> 7.4983
10/05/2021 11:41:20 - INFO - root -   765. loss ---------> 7.4994
10/05/2021 11:41:20 - INFO - root -   766. loss ---------> 7.4997
10/05/2021 11:41:20 - INFO - root -   767. loss ---------> 7.4982
10/05/2021

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:41:34 - INFO - root -   847. loss ---------> 7.4903
10/05/2021 11:41:34 - INFO - root -   848. loss ---------> 7.4899
10/05/2021 11:41:34 - INFO - root -   849. loss ---------> 7.4912
10/05/2021 11:41:35 - INFO - root -   850. loss ---------> 7.4908
10/05/2021 11:41:35 - INFO - root -   851. loss ---------> 7.4958
10/05/2021 11:41:35 - INFO - root -   852. loss ---------> 7.4961
10/05/2021 11:41:35 - INFO - root -   853. loss ---------> 7.4956
10/05/2021 11:41:35 - INFO - root -   854. loss ---------> 7.4958
10/05/2021 11:41:36 - INFO - root -   855. loss ---------> 7.4958
10/05/2021 11:41:36 - INFO - root -   856. loss ---------> 7.4961
10/05/2021 11:41:36 - INFO - root -   857. loss ---------> 7.4957
10/05/2021 11:41:36 - INFO - root -   858. loss ---------> 7.4955
10/05/2021 11:41:36 - INFO - root -   859. loss ---------> 7.4965
10/05/2021 11:41:37 - INFO - root -   860. loss ---------> 7.4965
10/05/2021 11:41:37 - INFO - root -   861. loss ---------> 7.4964
10/05/2021

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:41:51 - INFO - root -   941. loss ---------> 7.4915
10/05/2021 11:41:51 - INFO - root -   942. loss ---------> 7.4917
10/05/2021 11:41:51 - INFO - root -   943. loss ---------> 7.4917
10/05/2021 11:41:51 - INFO - root -   944. loss ---------> 7.4921
10/05/2021 11:41:52 - INFO - root -   945. loss ---------> 7.4933
10/05/2021 11:41:52 - INFO - root -   946. loss ---------> 7.4924
10/05/2021 11:41:52 - INFO - root -   947. loss ---------> 7.4924
10/05/2021 11:41:52 - INFO - root -   948. loss ---------> 7.4941
10/05/2021 11:41:52 - INFO - root -   949. loss ---------> 7.4939
10/05/2021 11:41:53 - INFO - root -   950. loss ---------> 7.4935
10/05/2021 11:41:53 - INFO - root -   951. loss ---------> 7.4935
10/05/2021 11:41:53 - INFO - root -   952. loss ---------> 7.4927
10/05/2021 11:41:53 - INFO - root -   953. loss ---------> 7.4920
10/05/2021 11:41:53 - INFO - root -   954. loss ---------> 7.4916
10/05/2021 11:41:53 - INFO - root -   955. loss ---------> 7.4921
10/05/2021

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:42:07 - INFO - root -   1035. loss ---------> 7.4911
10/05/2021 11:42:08 - INFO - root -   1036. loss ---------> 7.4910
10/05/2021 11:42:08 - INFO - root -   1037. loss ---------> 7.4906
10/05/2021 11:42:08 - INFO - root -   1038. loss ---------> 7.4915
10/05/2021 11:42:08 - INFO - root -   1039. loss ---------> 7.4913
10/05/2021 11:42:08 - INFO - root -   1040. loss ---------> 7.4917
10/05/2021 11:42:08 - INFO - root -   1041. loss ---------> 7.4925
10/05/2021 11:42:09 - INFO - root -   1042. loss ---------> 7.4919
10/05/2021 11:42:09 - INFO - root -   1043. loss ---------> 7.4914
10/05/2021 11:42:09 - INFO - root -   1044. loss ---------> 7.4922
10/05/2021 11:42:09 - INFO - root -   1045. loss ---------> 7.4925
10/05/2021 11:42:09 - INFO - root -   1046. loss ---------> 7.4926
10/05/2021 11:42:10 - INFO - root -   1047. loss ---------> 7.4930
10/05/2021 11:42:10 - INFO - root -   1048. loss ---------> 7.4928
10/05/2021 11:42:10 - INFO - root -   1049. loss ---------> 7.

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:42:24 - INFO - root -   1129. loss ---------> 7.4933
10/05/2021 11:42:24 - INFO - root -   1130. loss ---------> 7.4925
10/05/2021 11:42:24 - INFO - root -   1131. loss ---------> 7.4940
10/05/2021 11:42:25 - INFO - root -   1132. loss ---------> 7.4940
10/05/2021 11:42:25 - INFO - root -   1133. loss ---------> 7.4952
10/05/2021 11:42:25 - INFO - root -   1134. loss ---------> 7.4947
10/05/2021 11:42:25 - INFO - root -   1135. loss ---------> 7.4948
10/05/2021 11:42:25 - INFO - root -   1136. loss ---------> 7.4954
10/05/2021 11:42:25 - INFO - root -   1137. loss ---------> 7.4950
10/05/2021 11:42:26 - INFO - root -   1138. loss ---------> 7.4944
10/05/2021 11:42:26 - INFO - root -   1139. loss ---------> 7.4938
10/05/2021 11:42:26 - INFO - root -   1140. loss ---------> 7.4933
10/05/2021 11:42:26 - INFO - root -   1141. loss ---------> 7.4931
10/05/2021 11:42:26 - INFO - root -   1142. loss ---------> 7.4927
10/05/2021 11:42:27 - INFO - root -   1143. loss ---------> 7.

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:42:40 - INFO - root -   1223. loss ---------> 7.4921
10/05/2021 11:42:41 - INFO - root -   1224. loss ---------> 7.4959
10/05/2021 11:42:41 - INFO - root -   1225. loss ---------> 7.4957
10/05/2021 11:42:41 - INFO - root -   1226. loss ---------> 7.4956
10/05/2021 11:42:41 - INFO - root -   1227. loss ---------> 7.4958
10/05/2021 11:42:41 - INFO - root -   1228. loss ---------> 7.4956
10/05/2021 11:42:42 - INFO - root -   1229. loss ---------> 7.4953
10/05/2021 11:42:42 - INFO - root -   1230. loss ---------> 7.4958
10/05/2021 11:42:42 - INFO - root -   1231. loss ---------> 7.4958
10/05/2021 11:42:42 - INFO - root -   1232. loss ---------> 7.4948
10/05/2021 11:42:42 - INFO - root -   1233. loss ---------> 7.4952
10/05/2021 11:42:42 - INFO - root -   1234. loss ---------> 7.4953
10/05/2021 11:42:43 - INFO - root -   1235. loss ---------> 7.4951
10/05/2021 11:42:43 - INFO - root -   1236. loss ---------> 7.4938
10/05/2021 11:42:43 - INFO - root -   1237. loss ---------> 7.

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:42:57 - INFO - root -   1317. loss ---------> 7.4902
10/05/2021 11:42:57 - INFO - root -   1318. loss ---------> 7.4900
10/05/2021 11:42:58 - INFO - root -   1319. loss ---------> 7.4898
10/05/2021 11:42:58 - INFO - root -   1320. loss ---------> 7.4906
10/05/2021 11:42:58 - INFO - root -   1321. loss ---------> 7.4904
10/05/2021 11:42:58 - INFO - root -   1322. loss ---------> 7.4909
10/05/2021 11:42:58 - INFO - root -   1323. loss ---------> 7.4910
10/05/2021 11:42:58 - INFO - root -   1324. loss ---------> 7.4915
10/05/2021 11:42:59 - INFO - root -   1325. loss ---------> 7.4922
10/05/2021 11:42:59 - INFO - root -   1326. loss ---------> 7.4917
10/05/2021 11:42:59 - INFO - root -   1327. loss ---------> 7.4912
10/05/2021 11:42:59 - INFO - root -   1328. loss ---------> 7.4910
10/05/2021 11:42:59 - INFO - root -   1329. loss ---------> 7.4910
10/05/2021 11:42:59 - INFO - root -   1330. loss ---------> 7.4903
10/05/2021 11:43:00 - INFO - root -   1331. loss ---------> 7.

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:43:14 - INFO - root -   1411. loss ---------> 7.4923
10/05/2021 11:43:14 - INFO - root -   1412. loss ---------> 7.4929
10/05/2021 11:43:14 - INFO - root -   1413. loss ---------> 7.4936
10/05/2021 11:43:14 - INFO - root -   1414. loss ---------> 7.4933
10/05/2021 11:43:14 - INFO - root -   1415. loss ---------> 7.4928
10/05/2021 11:43:15 - INFO - root -   1416. loss ---------> 7.4943
10/05/2021 11:43:15 - INFO - root -   1417. loss ---------> 7.4947
10/05/2021 11:43:15 - INFO - root -   1418. loss ---------> 7.4947
10/05/2021 11:43:15 - INFO - root -   1419. loss ---------> 7.4945
10/05/2021 11:43:15 - INFO - root -   1420. loss ---------> 7.4949
10/05/2021 11:43:16 - INFO - root -   1421. loss ---------> 7.4947
10/05/2021 11:43:16 - INFO - root -   1422. loss ---------> 7.4942
10/05/2021 11:43:16 - INFO - root -   1423. loss ---------> 7.4938
10/05/2021 11:43:16 - INFO - root -   1424. loss ---------> 7.4940
10/05/2021 11:43:16 - INFO - root -   1425. loss ---------> 7.

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:43:30 - INFO - root -   1505. loss ---------> 7.4907
10/05/2021 11:43:30 - INFO - root -   1506. loss ---------> 7.4904
10/05/2021 11:43:31 - INFO - root -   1507. loss ---------> 7.4907
10/05/2021 11:43:31 - INFO - root -   1508. loss ---------> 7.4905
10/05/2021 11:43:31 - INFO - root -   1509. loss ---------> 7.4904
10/05/2021 11:43:31 - INFO - root -   1510. loss ---------> 7.4899
10/05/2021 11:43:31 - INFO - root -   1511. loss ---------> 7.4908
10/05/2021 11:43:32 - INFO - root -   1512. loss ---------> 7.4913
10/05/2021 11:43:32 - INFO - root -   1513. loss ---------> 7.4908
10/05/2021 11:43:32 - INFO - root -   1514. loss ---------> 7.4900
10/05/2021 11:43:32 - INFO - root -   1515. loss ---------> 7.4899
10/05/2021 11:43:32 - INFO - root -   1516. loss ---------> 7.4899
10/05/2021 11:43:32 - INFO - root -   1517. loss ---------> 7.4890
10/05/2021 11:43:33 - INFO - root -   1518. loss ---------> 7.4889
10/05/2021 11:43:33 - INFO - root -   1519. loss ---------> 7.

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:43:47 - INFO - root -   1599. loss ---------> 7.4884
10/05/2021 11:43:47 - INFO - root -   1600. loss ---------> 7.4878
10/05/2021 11:43:47 - INFO - root -   1601. loss ---------> 7.4879
10/05/2021 11:43:47 - INFO - root -   1602. loss ---------> 7.4870
10/05/2021 11:43:47 - INFO - root -   1603. loss ---------> 7.4879
10/05/2021 11:43:47 - INFO - root -   1604. loss ---------> 7.4876
10/05/2021 11:43:48 - INFO - root -   1605. loss ---------> 7.4875
10/05/2021 11:43:48 - INFO - root -   1606. loss ---------> 7.4879
10/05/2021 11:43:48 - INFO - root -   1607. loss ---------> 7.4879
10/05/2021 11:43:48 - INFO - root -   1608. loss ---------> 7.4881
10/05/2021 11:43:48 - INFO - root -   1609. loss ---------> 7.4885
10/05/2021 11:43:49 - INFO - root -   1610. loss ---------> 7.4891
10/05/2021 11:43:49 - INFO - root -   1611. loss ---------> 7.4886
10/05/2021 11:43:49 - INFO - root -   1612. loss ---------> 7.4883
10/05/2021 11:43:49 - INFO - root -   1613. loss ---------> 7.

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:44:03 - INFO - root -   1693. loss ---------> 7.4853
10/05/2021 11:44:03 - INFO - root -   1694. loss ---------> 7.4848
10/05/2021 11:44:04 - INFO - root -   1695. loss ---------> 7.4851
10/05/2021 11:44:04 - INFO - root -   1696. loss ---------> 7.4849
10/05/2021 11:44:04 - INFO - root -   1697. loss ---------> 7.4850
10/05/2021 11:44:04 - INFO - root -   1698. loss ---------> 7.4847
10/05/2021 11:44:04 - INFO - root -   1699. loss ---------> 7.4841
10/05/2021 11:44:04 - INFO - root -   1700. loss ---------> 7.4853
10/05/2021 11:44:05 - INFO - root -   1701. loss ---------> 7.4851
10/05/2021 11:44:05 - INFO - root -   1702. loss ---------> 7.4850
10/05/2021 11:44:05 - INFO - root -   1703. loss ---------> 7.4850
10/05/2021 11:44:05 - INFO - root -   1704. loss ---------> 7.4855
10/05/2021 11:44:05 - INFO - root -   1705. loss ---------> 7.4853
10/05/2021 11:44:05 - INFO - root -   1706. loss ---------> 7.4855
10/05/2021 11:44:06 - INFO - root -   1707. loss ---------> 7.

Iteration:   0%|          | 0/94 [00:00<?, ?it/s]

10/05/2021 11:44:20 - INFO - root -   1787. loss ---------> 7.4842
10/05/2021 11:44:20 - INFO - root -   1788. loss ---------> 7.4840
10/05/2021 11:44:20 - INFO - root -   1789. loss ---------> 7.4837
10/05/2021 11:44:20 - INFO - root -   1790. loss ---------> 7.4842
10/05/2021 11:44:20 - INFO - root -   1791. loss ---------> 7.4837
10/05/2021 11:44:21 - INFO - root -   1792. loss ---------> 7.4837
10/05/2021 11:44:21 - INFO - root -   1793. loss ---------> 7.4832
10/05/2021 11:44:21 - INFO - root -   1794. loss ---------> 7.4826
10/05/2021 11:44:21 - INFO - root -   1795. loss ---------> 7.4830
10/05/2021 11:44:21 - INFO - root -   1796. loss ---------> 7.4833
10/05/2021 11:44:21 - INFO - root -   1797. loss ---------> 7.4835
10/05/2021 11:44:22 - INFO - root -   1798. loss ---------> 7.4833
10/05/2021 11:44:22 - INFO - root -   1799. loss ---------> 7.4830
10/05/2021 11:44:22 - INFO - root -   1800. loss ---------> 7.4827
10/05/2021 11:44:22 - INFO - root -   1801. loss ---------> 7.

Evaluating:   0%|          | 0/11 [00:00<?, ?it/s]

10/05/2021 11:44:45 - INFO - __main__ -   ***** Eval results  *****
10/05/2021 11:44:45 - INFO - __main__ -     perplexity = tensor(572.6240)


{'perplexity_': tensor(572.6240)}

loss = 6.359944536423516 <br>
perplexity = 255.2563

# Inference

In [9]:
tokenizer = AutoTokenizer.from_pretrained(cfg.output_dir)
model = AutoModelForCausalLM.from_pretrained(cfg.output_dir)

# Let's chat for 5 lines
# encode the new user input, add the eos_token and return a tensor in Pytorch
# "Hello Rick, how are you?" "Where is Morty"
new_user_input_ids = tokenizer.encode("Hello Rick, how are you?" + tokenizer.eos_token, return_tensors='pt')
# print(new_user_input_ids)

# append the new user input tokens to the chat history
bot_input_ids = new_user_input_ids

# generated a response while limiting the total chat history to 1000 tokens, 
chat_history_ids = model.generate(
    bot_input_ids, max_length=200,
    pad_token_id=tokenizer.eos_token_id,
    no_repeat_ngram_size=3,       
)
# pretty print last ouput tokens from bot
print("1.RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))


# generated a response while limiting the total chat history to 1000 tokens, 
chat_history_ids = model.generate(
    bot_input_ids, max_length=200,
    pad_token_id=tokenizer.eos_token_id,
    no_repeat_ngram_size=3,       
    num_beams=5,
    #early_stopping=True
)
# pretty print last ouput tokens from bot
print("2.RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

# generated a response while limiting the total chat history to 1000 tokens, 
chat_history_ids = model.generate(
    bot_input_ids, max_length=200,
    pad_token_id=tokenizer.eos_token_id,
    no_repeat_ngram_size=3,       
    do_sample=True, 
    top_k=100, 
    top_p=0.7,
    temperature = 0.4,
    #num_beams=5,
    #early_stopping=True
)
# pretty print last ouput tokens from bot
print("3.RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))) 


# generated a response while limiting the total chat history to 1000 tokens, 
chat_history_ids = model.generate(
    bot_input_ids, max_length=200,
    pad_token_id=tokenizer.eos_token_id,
    no_repeat_ngram_size=3,       
    do_sample=True, 
    temperature = 0.1,
)
# pretty print last ouput tokens from bot
print("4.RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True))) 


1.RickBot: I'm a little late to the party, but I'm here to say hi.
2.RickBot: I'm Rick. Nice to meet you.
3.RickBot: I'm a little late to the party, but I'm here for the party.
4.RickBot: I'm a guy and I'm a girl.


In [11]:
tokenizer = AutoTokenizer.from_pretrained(cfg.output_dir)
model = AutoModelForCausalLM.from_pretrained(cfg.output_dir)

# Let's chat for 5 lines
for step in range(7):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User: ") + tokenizer.eos_token, return_tensors='pt')
    # print(new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature = 0.7,
        #num_beams=5,
        #early_stopping=True
    )
    
    # pretty print last ouput tokens from bot
    print("RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

>> User: pickle Rick!!
RickBot: I've got a jar of pickle rick in the trunk of my car.
>> User: Are you pckle Rick?
RickBot: I'm a pickle Rick
>> User: hey pickle Rick where is pickle Morty
RickBot: I like to think of it as a Rick and Morty reference.
>> User: yes it is
RickBot: I was just thinking of that.
>> User: nice thinking
RickBot: I think it's a Rick And Morty reference


KeyboardInterrupt: ignored

In [None]:
df_train