In [1]:
%load_ext autoreload
%autoreload 2

from dataset import DialogLMDataset
from torch.utils.data import DataLoader
from collate_fns import DialogCollate, DialogCollateExperimental
from models import NegativePositiveTrainingLM
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler


import torch
from transformers import AutoConfig, AutoTokenizer
from sklearn.model_selection import train_test_split

# utils
from data_utils import InfiniteDataGen
from logging_utils import create_json, update_json, load_json

import os
import json
import numpy as np
import pandas as pd
pd.set_option('display.max_colwidth', None)
import datetime

In [2]:
# # data leakage version

# # read data
# all_rick = pd.read_csv('data/rick_and_morty_conversations/RickAndMortyScripts.csv')
# all_rick.head(10)

# # contextualize
# contexted = []
# n = 7
# for i in range(n, len(all_rick['line'])):
#     row = []
#     prev = i - 1 - n # we additionally subtract 1, so row will contain current response and 7 previous responses  
#     for j in range(i, prev, -1):
#         row.append(all_rick['line'][j])
#     contexted.append(row)
# contexted_context_joined = {
#     'context': [],
#     'response': []
# }
# for row in contexted:
#     response = row[0]
#     context_joined = ' EOS '.join(row[1:][::-1])
#     contexted_context_joined['context'].append(context_joined)
#     contexted_context_joined['response'].append(response)
# df = pd.DataFrame(contexted_context_joined)
# df.head(5)

# # split
# df_train, df_val = train_test_split(df, test_size = 0.1, random_state=42)

In [3]:
# no data leakage

# read data
all_rick = pd.read_csv('data/rick_and_morty_conversations/RickAndMortyScripts.csv')
all_rick.head(10)

all_rick_train, all_rick_val = all_rick.iloc[:int(0.8*len(all_rick))], all_rick.iloc[int(0.8*len(all_rick)):]
# all_rick_train, all_rick_val = train_test_split(all_rick, test_size=0.2, random_state=42)
all_rick_train = all_rick_train.reset_index()
all_rick_val   = all_rick_val.reset_index()


# construct df with prompts and responses as columns
contexted = []
n = 7
for i in range(n, len(all_rick_train['line'])):
    row = []
    prev = i - 1 - n # we additionally subtract 1, so row will contain current response and 7 previous responses  
    for j in range(i, prev, -1):
        row.append(all_rick_train['line'][j])
    contexted.append(row)
contexted_context_joined = {
    'context': [],
    'response': []
}
for row in contexted:
    response = row[0]
    context_joined = ' EOS '.join(row[1:][::-1])
    contexted_context_joined['context'].append(context_joined)
    contexted_context_joined['response'].append(response)
df_train = pd.DataFrame(contexted_context_joined)
df_train.head(5)

contexted = []
n = 7
for i in range(n, len(all_rick_val['line'])):
    row = []
    prev = i - 1 - n # we additionally subtract 1, so row will contain current response and 7 previous responses  
    for j in range(i, prev, -1):
        row.append(all_rick_val['line'][j])
    contexted.append(row)
contexted_context_joined = {
    'context': [],
    'response': []
}
for row in contexted:
    response = row[0]
    context_joined = ' EOS '.join(row[1:][::-1])
    contexted_context_joined['context'].append(context_joined)
    contexted_context_joined['response'].append(response)
df_val = pd.DataFrame(contexted_context_joined)
df_val.head(5)


Unnamed: 0,context,response
0,"This article says the reason we weren't involved was... ""personality conflicts"". EOS Don't worry, Morty, they love you. Superheroes need a wide-eyed unremarkable to tag along and react to everything like it's mind-blowing. EOS I... think the personality conflict might have been... you. EOS Jesus... How awesome is that? I mean, they wanted to not need me so bad, they murdered three innocent heroes of color, and they still had to bring me back? EOS Rick, since it's my adventure and all, could you do me a favor? EOS Uh, the adventure is the favor, Morty. Me sleeping on these linens is the favor. I mean, w-w-w-what--what are we vindicating? Comfort? EOS Rick, this really bums me out. It-It's embarrassing to find out these guys don't like us.","Why? Morty, I defeat gagoos more powerful than these guys every week."
1,"Don't worry, Morty, they love you. Superheroes need a wide-eyed unremarkable to tag along and react to everything like it's mind-blowing. EOS I... think the personality conflict might have been... you. EOS Jesus... How awesome is that? I mean, they wanted to not need me so bad, they murdered three innocent heroes of color, and they still had to bring me back? EOS Rick, since it's my adventure and all, could you do me a favor? EOS Uh, the adventure is the favor, Morty. Me sleeping on these linens is the favor. I mean, w-w-w-what--what are we vindicating? Comfort? EOS Rick, this really bums me out. It-It's embarrassing to find out these guys don't like us. EOS Why? Morty, I defeat gagoos more powerful than these guys every week.","Yeah, but not heroes."
2,"I... think the personality conflict might have been... you. EOS Jesus... How awesome is that? I mean, they wanted to not need me so bad, they murdered three innocent heroes of color, and they still had to bring me back? EOS Rick, since it's my adventure and all, could you do me a favor? EOS Uh, the adventure is the favor, Morty. Me sleeping on these linens is the favor. I mean, w-w-w-what--what are we vindicating? Comfort? EOS Rick, this really bums me out. It-It's embarrassing to find out these guys don't like us. EOS Why? Morty, I defeat gagoos more powerful than these guys every week. EOS Yeah, but not heroes.","Oh, please. They just call themselves heroes so they can..."
3,"Jesus... How awesome is that? I mean, they wanted to not need me so bad, they murdered three innocent heroes of color, and they still had to bring me back? EOS Rick, since it's my adventure and all, could you do me a favor? EOS Uh, the adventure is the favor, Morty. Me sleeping on these linens is the favor. I mean, w-w-w-what--what are we vindicating? Comfort? EOS Rick, this really bums me out. It-It's embarrassing to find out these guys don't like us. EOS Why? Morty, I defeat gagoos more powerful than these guys every week. EOS Yeah, but not heroes. EOS Oh, please. They just call themselves heroes so they can...","I'm calling them that, Rick! They're my heroes! Mine!"
4,"Rick, since it's my adventure and all, could you do me a favor? EOS Uh, the adventure is the favor, Morty. Me sleeping on these linens is the favor. I mean, w-w-w-what--what are we vindicating? Comfort? EOS Rick, this really bums me out. It-It's embarrassing to find out these guys don't like us. EOS Why? Morty, I defeat gagoos more powerful than these guys every week. EOS Yeah, but not heroes. EOS Oh, please. They just call themselves heroes so they can... EOS I'm calling them that, Rick! They're my heroes! Mine!",Huh... no accounting for taste. I'm gonna go get a drink.


In [4]:
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    'microsoft/DialoGPT-small'
)

# train sets
dataset_train = DialogLMDataset(
    df = df_train,
    src_col = 'context',
    src_eos_token = ' EOS ',
    tokenizer = tokenizer,
    tgt_col = 'response',
)

# val sets
dataset_val = DialogLMDataset(
    df = df_val,
    src_col = 'context',
    src_eos_token = ' EOS ',
    tokenizer = tokenizer,
    tgt_col = 'response',
)
print()
print(dataset_val[3])
tokenizer.decode(dataset_val[-1]['src'])

Processing row: 0 of 3741517
{'src': [28219, 986, 1374, 7427, 318, 326, 30, 314, 1612, 11, 484, 2227, 284, 407, 761, 502, 523, 2089, 11, 484, 12864, 1115, 10218, 10281, 286, 3124, 11, 290, 484, 991, 550, 284, 2222, 502, 736, 30, 50256, 33048, 11, 1201, 340, 338, 616, 8855, 290, 477, 11, 714, 345, 466, 502, 257, 2661, 30, 50256, 34653, 11, 262, 8855, 318, 262, 2661, 11, 30395, 13, 2185, 11029, 319, 777, 9493, 641, 318, 262, 2661, 13, 314, 1612, 11, 266, 12, 86, 12, 86, 12, 10919, 438, 10919, 389, 356, 29178, 12364, 30, 45769, 30, 50256, 33048, 11, 428, 1107, 275, 5700, 502, 503, 13, 632, 12, 1026, 338, 18997, 284, 1064, 503, 777, 3730, 836, 470, 588, 514, 13, 50256, 5195, 30, 30395, 11, 220, 314, 7433, 308, 3839, 418, 517, 3665, 621, 777, 3730, 790, 1285, 13, 50256, 10995, 11, 475, 407, 10281, 13, 50256, 5812, 11, 3387, 13, 1119, 655, 869, 2405, 10281, 523, 484, 460, 986], 'tgt': [40, 1101, 4585, 606, 326, 11, 8759, 0, 1119, 821, 616, 10281, 0, 11517, 0]}


"I don't know, and I don't have to know. I've been fired.  Good luck, turds.<|endoftext|>Holy crap... Slick's wish came true.<|endoftext|>Whoa!! Hahaha, yeah! Atlantis, baby!<|endoftext|>That was amazing!<|endoftext|>Got some of that mermaid puss!<|endoftext|>I'm really hoping it wasn't a one-off thing and I can see her again. By the way, hey, um... still not curious about what might've happened at that crazy Citadel place?<|endoftext|>Pssh! Not at all, Morty. That place will never have any bearing over our lives ever again. Unlike that mermaid puss! Yeah!! We're going back for seconds! We're gonna do that shit every week, man! That was Atlantis!"

In [5]:
# example dataset and dataloader usage

# collate object
collate_fn = DialogCollateExperimental(
    tokenizer = tokenizer,
    max_len = 512,
    _targets_ignore_index = -100,
    _pad_token_id = tokenizer.eos_token_id  # 0
)

# dataloader
loader_train = DataLoader(
    dataset_train,
    batch_size = 2,
    shuffle = False,
    collate_fn = collate_fn
)
loader_val = DataLoader(
    dataset_val,
    batch_size = 2,
    shuffle = False,
    collate_fn = collate_fn
)

# example usage
for batch in loader_train:
    break

example_id = 1
print('input_ids:\n',       batch['input_ids'][example_id])
# print('position_ids:\n',    batch['position_ids'][example_id])
# print('token_type_ids:\n',  batch['token_type_ids'][example_id])
# print('target_ids:\n',      batch['target_ids'][example_id])
# print('attention_masks:\n', batch['attention_masks'][example_id])

tokenizer.decode(batch['input_ids'][example_id])

input_ids:
 tensor([ 2061,    11,  8759,    30,  1867,   447,   247,    82,  1016,   319,
           30, 50256,    40,  1392,   257,  5975,   329,   345,    11, 30395,
           13, 50256,  1026,   338,   262,  3504,   286,   262,  1755,    13,
         1867,   389,   345,  3375,   546,    30, 50256, 16773,   319,    11,
          314,  1392,   257,  5975,   329,   345,    13,   220,  7911,   319,
           11, 23290,   510,    13, 50256,    46,    86,     0, 11960,     0,
          921,   821, 27762,  2667,   502,  1165,  1327,     0, 50256,  1135,
        17753,   467,    11, 17753,   651,   503,  8326,   994,    11,  1282,
          319,    13, 11853,   257,  5975,   329,   345, 30395,    13, 50256,
         2061,   466,   345,   892,   286,   428,   986,  7348,  4038,    11,
        30395,    30,   314,  3170,   340,   503,  8326,  3404,   314,  1043,
          287,   262, 15591,    13, 50256, 10995,    11,  8759,   986,   314,
           12,   270,   338,  1049,    13,  1148,   

"What, Rick? What’s going on?<|endoftext|>I got a surprise for you, Morty.<|endoftext|>It's the middle of the night. What are you talking about?<|endoftext|>Come on, I got a surprise for you.  Come on, hurry up.<|endoftext|>Ow! Ow! You're tugging me too hard!<|endoftext|>We gotta go, gotta get outta here, come on. Got a surprise for you Morty.<|endoftext|>What do you think of this... flying vehicle, Morty? I built it outta stuff I found in the garage.<|endoftext|>Yeah, Rick... I-it's great. Is this the surprise?<|endoftext|>!!!!"

Explanation of target IDs: `-100` indicates that this position will not contribute to the loss. This behavior is specified by setting the `ignore_index` option in the `CrossEntropy` loss function, which by default takes the value `-100`. In order to decode the targets, we replace `-100`'s arbitrarily with `0`'s, and the ID `0` happens to correspond to the exclamation mark.

In [6]:
def train_one_positive_epoch(
    pos_loader, model_wrapper
):
    # loop through the adversarial, negative batches
    # repeat until all negative samples have low enough log probability under the model, 
    # or a maximum number of iterations has occurred
    batch_wise_pos_losses  = []
    for pos_batch_id, pos_batch in enumerate(pos_loader):
        
        model_wrapper.optimizer_pos.zero_grad()
        results = model_wrapper.positive_step_basic(pos_batch)
        batch_wise_pos_losses.append(results['loss'].item())
        
        print(f'\rProcessing pos batch {pos_batch_id} of {len(pos_loader)}, loss {results["loss"].item():.6f}', end='', flush=True)


    return {
        'batch_wise_pos_losses': batch_wise_pos_losses,
        'average_pos_loss': np.array(batch_wise_pos_losses).mean(),
    }

In [7]:
CONFIG = {
    'EXPERIMENT_NAME': 'test_dialogpt_rick_and_morty_basic',
    
    'PERFORM_NEGATIVE_TRAINING': False,
    
    'GRAD_ACCUM_STEPS': 1,
    'START_ITER': 0,
    'TRAIN_ITERS': 3,
    'POSITIVE_RATIO': 1,
    
    'LOGGING_DIR': './logs/',
    'MODEL_SAVE_DIR': './checkpoints/',
    
    'HUGGINGFACE_MODEL_NAME': 'microsoft/DialoGPT-small',  # 'microsoft/DialoGPT-medium'
    'POS_LR': 5e-5,
    'NEG_LR': 5e-5,
    'UPDATES_PER_BATCH': 20,
    
    'EXAMPLE_WEIGHT_MODE': 'decay',
    'EXAMPLE_WEIGHT_CARE_MODE': 'sample_avg',
    'EXAMPLE_WEIGHT_REJECTION_THRESHOLD': -7.0,
    
    'SEED': 0
}

In [20]:
# training
    
# preparing model checkpointing
if not os.path.exists(os.path.join(CONFIG['MODEL_SAVE_DIR'], CONFIG['EXPERIMENT_NAME'])): 
    os.makedirs(os.path.join(CONFIG['MODEL_SAVE_DIR'], CONFIG['EXPERIMENT_NAME']))

# data loaders
train_sampler = RandomSampler(dataset_train)
loader_train = DataLoader(
    dataset_train,
    sampler = train_sampler,
    batch_size = 2,
    # shuffle = True,
    collate_fn = collate_fn,
    drop_last = True
)
eval_sampler = SequentialSampler(dataset_val)
loader_val = DataLoader(
    dataset_val,
    sampler = eval_sampler,
    batch_size = 2,
    # shuffle = False,
    collate_fn = collate_fn,
    drop_last = True
)

# model
model_config = AutoConfig.from_pretrained(CONFIG['HUGGINGFACE_MODEL_NAME'])
gradient_accumulation_steps = 1
model_wrapper = NegativePositiveTrainingLM(
    model_name = CONFIG['HUGGINGFACE_MODEL_NAME'],
    model_config = model_config,
    pos_lr = CONFIG['POS_LR'], 
    neg_lr = CONFIG['NEG_LR'],
    num_opt_steps = len(loader_train) // CONFIG['GRAD_ACCUM_STEPS'] * CONFIG['TRAIN_ITERS'],
    tokenizer = tokenizer,
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# evaluate before training
train_loss = model_wrapper.perplexity_on_dataset(loader_train)
val_loss   = model_wrapper.perplexity_on_dataset(loader_val)
print('\nTrain perplexity:', train_loss)
print('\nVal perplexity:', val_loss)

# reproducibility
np.random.seed(CONFIG['SEED'])
torch.manual_seed(CONFIG['SEED'])
torch.cuda.manual_seed_all(CONFIG['SEED'])



Evaluating batch: 0 of 758tensor(7.5369, device='cuda:0')
tensor(6.7540, device='cuda:0')
tensor(7.3920, device='cuda:0')
tensor(7.5851, device='cuda:0')
tensor(8.8062, device='cuda:0')
tensor(8.7754, device='cuda:0')
tensor(8.0506, device='cuda:0')
tensor(8.9373, device='cuda:0')
tensor(6.6687, device='cuda:0')
tensor(8.7962, device='cuda:0')
Evaluating batch: 10 of 758tensor(8.8260, device='cuda:0')
tensor(9.2716, device='cuda:0')
tensor(7.4334, device='cuda:0')
tensor(7.4631, device='cuda:0')
tensor(9.5139, device='cuda:0')
tensor(6.8227, device='cuda:0')
tensor(7.8557, device='cuda:0')
tensor(7.1340, device='cuda:0')
tensor(8.1026, device='cuda:0')
tensor(7.4078, device='cuda:0')
Evaluating batch: 20 of 758tensor(7.5020, device='cuda:0')
tensor(8.0904, device='cuda:0')
tensor(8.0355, device='cuda:0')
tensor(7.5116, device='cuda:0')
tensor(8.8476, device='cuda:0')
tensor(7.2681, device='cuda:0')
tensor(7.4217, device='cuda:0')
tensor(9.5003, device='cuda:0')
tensor(5.8660, device='c

tensor(7.3050, device='cuda:0')
tensor(7.0959, device='cuda:0')
Evaluating batch: 240 of 758tensor(8.8075, device='cuda:0')
tensor(7.3851, device='cuda:0')
tensor(6.3826, device='cuda:0')
tensor(6.2934, device='cuda:0')
tensor(7.4519, device='cuda:0')
tensor(7.2733, device='cuda:0')
tensor(8.4424, device='cuda:0')
tensor(8.0625, device='cuda:0')
tensor(8.4501, device='cuda:0')
tensor(5.8408, device='cuda:0')
Evaluating batch: 250 of 758tensor(8.4389, device='cuda:0')
tensor(8.8006, device='cuda:0')
tensor(7.4285, device='cuda:0')
tensor(6.7186, device='cuda:0')
tensor(7.5426, device='cuda:0')
tensor(7.9738, device='cuda:0')
tensor(6.8052, device='cuda:0')
tensor(7.4035, device='cuda:0')
tensor(7.7601, device='cuda:0')
tensor(7.7673, device='cuda:0')
Evaluating batch: 260 of 758tensor(6.5047, device='cuda:0')
tensor(8.3865, device='cuda:0')
tensor(7.7395, device='cuda:0')
tensor(9.0396, device='cuda:0')
tensor(7.6757, device='cuda:0')
tensor(7.9160, device='cuda:0')
tensor(7.6842, devic

tensor(7.4543, device='cuda:0')
tensor(9.0619, device='cuda:0')
tensor(8.4585, device='cuda:0')
tensor(6.9439, device='cuda:0')
tensor(8.0454, device='cuda:0')
tensor(8.1068, device='cuda:0')
tensor(6.9577, device='cuda:0')
Evaluating batch: 480 of 758tensor(7.8301, device='cuda:0')
tensor(7.7178, device='cuda:0')
tensor(7.6830, device='cuda:0')
tensor(8.5914, device='cuda:0')
tensor(7.3416, device='cuda:0')
tensor(6.7740, device='cuda:0')
tensor(8.6134, device='cuda:0')
tensor(8.7789, device='cuda:0')
tensor(7.2894, device='cuda:0')
tensor(8.8068, device='cuda:0')
Evaluating batch: 490 of 758tensor(8.3288, device='cuda:0')
tensor(7.2044, device='cuda:0')
tensor(8.2811, device='cuda:0')
tensor(9.2070, device='cuda:0')
tensor(8.0408, device='cuda:0')
tensor(7.9858, device='cuda:0')
tensor(7.7563, device='cuda:0')
tensor(5.8553, device='cuda:0')
tensor(7.9115, device='cuda:0')
tensor(6.9002, device='cuda:0')
Evaluating batch: 500 of 758tensor(7.2458, device='cuda:0')
tensor(8.3055, devic

Evaluating batch: 710 of 758tensor(7.9047, device='cuda:0')
tensor(7.9731, device='cuda:0')
tensor(8.3382, device='cuda:0')
tensor(7.7031, device='cuda:0')
tensor(8.7249, device='cuda:0')
tensor(6.2310, device='cuda:0')
tensor(8.2943, device='cuda:0')
tensor(7.3412, device='cuda:0')
tensor(7.9137, device='cuda:0')
tensor(7.0908, device='cuda:0')
Evaluating batch: 720 of 758tensor(7.1810, device='cuda:0')
tensor(7.1652, device='cuda:0')
tensor(6.5795, device='cuda:0')
tensor(5.8108, device='cuda:0')
tensor(6.2867, device='cuda:0')
tensor(8.4254, device='cuda:0')
tensor(6.5213, device='cuda:0')
tensor(7.0902, device='cuda:0')
tensor(7.4057, device='cuda:0')
tensor(6.5064, device='cuda:0')
Evaluating batch: 730 of 758tensor(6.4947, device='cuda:0')
tensor(7.7285, device='cuda:0')
tensor(7.9038, device='cuda:0')
tensor(7.3166, device='cuda:0')
tensor(6.6491, device='cuda:0')
tensor(6.4405, device='cuda:0')
tensor(7.8652, device='cuda:0')
tensor(7.5166, device='cuda:0')
tensor(7.5524, devic

Val perplexity: tensor(2766.9617)


KeyError: 'SEED'

In [21]:
# detect previous progress here
epoch_wise_neg_losses = []
epoch_wise_pos_losses = []
batch_wise_neg_losses = []
batch_wise_pos_losses = []

# train for `CONFIG['TRAIN_ITERS']` "negative epochs"
# one "negative epoch" = one pass through all batches in the entire negative dataset:
# for each negative batch, do (one negative update + POSITIVE_RATIO positive updates)*20
# notice the difference with usual training: here we do 20 gradient updates for each batch in the negative dataloader
for it in range(CONFIG['START_ITER'], CONFIG['START_ITER'] + CONFIG['TRAIN_ITERS']):  

    # restart the optimizer at each pos-neg iteration
    model_wrapper.reset_optimizers()
    
    # train on one epoch
    res = train_one_positive_epoch(
        pos_loader    = loader_train, 
        model_wrapper = model_wrapper
    )
        
    # evaluate val set
    # val_loss = model_wrapper.nll_on_dataset(loader_val)
    train_loss = model_wrapper.perplexity_on_dataset(loader_train)
    val_loss   = model_wrapper.perplexity_on_dataset(loader_val)
    print('\nTrain loss:', train_loss)
    print('\nVal loss:', val_loss)
    
    # record losses
    epoch_wise_neg_losses.append(res.get('average_neg_loss', None))
    epoch_wise_pos_losses.append(res.get('average_pos_loss', None))
    batch_wise_neg_losses.extend(res.get('batch_wise_neg_losses', []))
    batch_wise_pos_losses.extend(res.get('batch_wise_pos_losses', []))
    
    # log results
    log_dict = {
        'config': CONFIG,
        'log_datetime': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'num_negative_epochs': it,
        'train_neg_loss': res.get('average_neg_loss', None),
        'train_pos_loss': res.get('average_pos_loss', None),
        'train_num_updates_per_batch': res.get('batch_wise_num_updates', None),
        'train_epoch_wise_neg_losses': epoch_wise_neg_losses,
        'train_epoch_wise_pos_losses': epoch_wise_pos_losses,
        'train_batch_wise_neg_losses': batch_wise_neg_losses,
        'train_batch_wise_pos_losses': batch_wise_pos_losses,
    }
    create_json(
        target_dir = os.path.join(CONFIG['LOGGING_DIR'], CONFIG['EXPERIMENT_NAME']), 
        filename = f'epoch-{it}_trainloss-{res.get("average_pos_loss", np.nan):.6f}_valloss-{val_loss:.6f}.json',
        dict_to_save = log_dict
    )
    
    # save model
    model_wrapper.save_model_and_optimizers(
        info_dict = log_dict,  # additional information you want to log
        save_path = os.path.join(
            CONFIG["MODEL_SAVE_DIR"], 
            CONFIG["EXPERIMENT_NAME"], 
            f'epoch-{it}_trainloss-{res.get("average_pos_loss", np.nan):.6f}_valloss-{val_loss:.6f}'
        )
    )

Evaluating batch: 0 of 758f 758, loss 0.501138tensor(0.5939, device='cuda:0')
tensor(0.4689, device='cuda:0')
tensor(0.5585, device='cuda:0')
tensor(0.3989, device='cuda:0')
tensor(0.2947, device='cuda:0')
tensor(0.3964, device='cuda:0')
tensor(0.4438, device='cuda:0')
tensor(0.6116, device='cuda:0')
tensor(0.3415, device='cuda:0')
tensor(0.5320, device='cuda:0')
Evaluating batch: 10 of 758tensor(0.3728, device='cuda:0')
tensor(0.4147, device='cuda:0')
tensor(0.5168, device='cuda:0')
tensor(0.3375, device='cuda:0')
tensor(0.3632, device='cuda:0')
tensor(0.6684, device='cuda:0')
tensor(0.3769, device='cuda:0')
tensor(0.4829, device='cuda:0')
tensor(0.4777, device='cuda:0')
tensor(0.4214, device='cuda:0')
Evaluating batch: 20 of 758tensor(0.5728, device='cuda:0')
tensor(0.6461, device='cuda:0')
tensor(0.3957, device='cuda:0')
tensor(0.3410, device='cuda:0')
tensor(0.9389, device='cuda:0')
tensor(0.3848, device='cuda:0')
tensor(0.5809, device='cuda:0')
tensor(0.4288, device='cuda:0')
tens

tensor(0.6613, device='cuda:0')
tensor(0.4144, device='cuda:0')
tensor(0.3518, device='cuda:0')
tensor(0.4137, device='cuda:0')
tensor(0.5120, device='cuda:0')
Evaluating batch: 240 of 758tensor(0.4388, device='cuda:0')
tensor(0.4752, device='cuda:0')
tensor(0.6541, device='cuda:0')
tensor(0.4600, device='cuda:0')
tensor(0.6378, device='cuda:0')
tensor(0.6823, device='cuda:0')
tensor(0.7009, device='cuda:0')
tensor(0.4058, device='cuda:0')
tensor(0.4602, device='cuda:0')
tensor(0.5148, device='cuda:0')
Evaluating batch: 250 of 758tensor(0.6429, device='cuda:0')
tensor(0.4277, device='cuda:0')
tensor(0.6816, device='cuda:0')
tensor(0.8814, device='cuda:0')
tensor(0.3149, device='cuda:0')
tensor(0.8470, device='cuda:0')
tensor(0.5829, device='cuda:0')
tensor(0.6176, device='cuda:0')
tensor(0.2366, device='cuda:0')
tensor(0.4483, device='cuda:0')
Evaluating batch: 260 of 758tensor(0.4653, device='cuda:0')
tensor(0.3835, device='cuda:0')
tensor(0.5497, device='cuda:0')
tensor(0.5763, devic

tensor(0.5186, device='cuda:0')
tensor(0.4509, device='cuda:0')
tensor(0.3551, device='cuda:0')
tensor(0.5757, device='cuda:0')
tensor(0.4068, device='cuda:0')
tensor(0.4456, device='cuda:0')
tensor(0.5667, device='cuda:0')
tensor(0.5107, device='cuda:0')
Evaluating batch: 480 of 758tensor(0.4920, device='cuda:0')
tensor(0.8303, device='cuda:0')
tensor(0.7057, device='cuda:0')
tensor(0.4264, device='cuda:0')
tensor(0.4153, device='cuda:0')
tensor(0.3112, device='cuda:0')
tensor(0.3414, device='cuda:0')
tensor(0.5140, device='cuda:0')
tensor(0.4927, device='cuda:0')
tensor(0.3968, device='cuda:0')
Evaluating batch: 490 of 758tensor(0.4665, device='cuda:0')
tensor(0.5027, device='cuda:0')
tensor(0.5614, device='cuda:0')
tensor(0.4368, device='cuda:0')
tensor(0.3996, device='cuda:0')
tensor(0.7238, device='cuda:0')
tensor(0.5476, device='cuda:0')
tensor(0.4453, device='cuda:0')
tensor(0.6094, device='cuda:0')
tensor(0.3111, device='cuda:0')
Evaluating batch: 500 of 758tensor(0.4496, devic

tensor(0.3523, device='cuda:0')
tensor(0.4937, device='cuda:0')
Evaluating batch: 710 of 758tensor(0.4831, device='cuda:0')
tensor(0.5082, device='cuda:0')
tensor(0.2882, device='cuda:0')
tensor(0.3828, device='cuda:0')
tensor(0.5819, device='cuda:0')
tensor(1.1620, device='cuda:0')
tensor(0.3966, device='cuda:0')
tensor(0.5404, device='cuda:0')
tensor(0.7517, device='cuda:0')
tensor(0.3745, device='cuda:0')
Evaluating batch: 720 of 758tensor(0.4949, device='cuda:0')
tensor(0.5240, device='cuda:0')
tensor(0.2975, device='cuda:0')
tensor(0.5873, device='cuda:0')
tensor(0.4068, device='cuda:0')
tensor(0.5767, device='cuda:0')
tensor(0.3636, device='cuda:0')
tensor(0.5634, device='cuda:0')
tensor(0.5846, device='cuda:0')
tensor(0.4947, device='cuda:0')
Evaluating batch: 730 of 758tensor(0.5444, device='cuda:0')
tensor(0.7032, device='cuda:0')
tensor(0.3811, device='cuda:0')
tensor(0.7635, device='cuda:0')
tensor(0.3941, device='cuda:0')
tensor(0.6040, device='cuda:0')
tensor(0.4277, devic


Train loss: tensor(1.6232)

Val loss: tensor(393.1899)
Evaluating batch: 0 of 758f 758, loss 0.124292tensor(0.2149, device='cuda:0')
tensor(0.0964, device='cuda:0')
tensor(0.0564, device='cuda:0')
tensor(0.1253, device='cuda:0')
tensor(0.0973, device='cuda:0')
tensor(0.0592, device='cuda:0')
tensor(0.1141, device='cuda:0')
tensor(0.0668, device='cuda:0')
tensor(0.0714, device='cuda:0')
tensor(0.0991, device='cuda:0')
Evaluating batch: 10 of 758tensor(0.1155, device='cuda:0')
tensor(0.1254, device='cuda:0')
tensor(0.0851, device='cuda:0')
tensor(0.0916, device='cuda:0')
tensor(0.0991, device='cuda:0')
tensor(0.0951, device='cuda:0')
tensor(0.0571, device='cuda:0')
tensor(0.0681, device='cuda:0')
tensor(0.0590, device='cuda:0')
tensor(0.0687, device='cuda:0')
Evaluating batch: 20 of 758tensor(0.0739, device='cuda:0')
tensor(0.1007, device='cuda:0')
tensor(0.0710, device='cuda:0')
tensor(0.0800, device='cuda:0')
tensor(0.1689, device='cuda:0')
tensor(0.0913, device='cuda:0')
tensor(0.092

tensor(0.0632, device='cuda:0')
tensor(0.0972, device='cuda:0')
tensor(0.0606, device='cuda:0')
tensor(0.0870, device='cuda:0')
tensor(0.0784, device='cuda:0')
tensor(0.0769, device='cuda:0')
Evaluating batch: 240 of 758tensor(0.1124, device='cuda:0')
tensor(0.0863, device='cuda:0')
tensor(0.0656, device='cuda:0')
tensor(0.0860, device='cuda:0')
tensor(0.0962, device='cuda:0')
tensor(0.0500, device='cuda:0')
tensor(0.1181, device='cuda:0')
tensor(0.1912, device='cuda:0')
tensor(0.1229, device='cuda:0')
tensor(0.1436, device='cuda:0')
Evaluating batch: 250 of 758tensor(0.1220, device='cuda:0')
tensor(0.0867, device='cuda:0')
tensor(0.2489, device='cuda:0')
tensor(0.1164, device='cuda:0')
tensor(0.0951, device='cuda:0')
tensor(0.1398, device='cuda:0')
tensor(0.0625, device='cuda:0')
tensor(0.0946, device='cuda:0')
tensor(0.2542, device='cuda:0')
tensor(0.0834, device='cuda:0')
Evaluating batch: 260 of 758tensor(0.0873, device='cuda:0')
tensor(0.1173, device='cuda:0')
tensor(0.0768, devic

Evaluating batch: 470 of 758tensor(0.0738, device='cuda:0')
tensor(0.2770, device='cuda:0')
tensor(0.0728, device='cuda:0')
tensor(0.0722, device='cuda:0')
tensor(0.0833, device='cuda:0')
tensor(0.0782, device='cuda:0')
tensor(0.0482, device='cuda:0')
tensor(0.1189, device='cuda:0')
tensor(0.0900, device='cuda:0')
tensor(0.1224, device='cuda:0')
Evaluating batch: 480 of 758tensor(0.0650, device='cuda:0')
tensor(0.0879, device='cuda:0')
tensor(0.1234, device='cuda:0')
tensor(0.0614, device='cuda:0')
tensor(0.1035, device='cuda:0')
tensor(0.0807, device='cuda:0')
tensor(0.0785, device='cuda:0')
tensor(0.0999, device='cuda:0')
tensor(0.0821, device='cuda:0')
tensor(0.0937, device='cuda:0')
Evaluating batch: 490 of 758tensor(0.1354, device='cuda:0')
tensor(0.0888, device='cuda:0')
tensor(0.0745, device='cuda:0')
tensor(0.0933, device='cuda:0')
tensor(0.0923, device='cuda:0')
tensor(0.0984, device='cuda:0')
tensor(0.1377, device='cuda:0')
tensor(0.0865, device='cuda:0')
tensor(0.1272, devic

tensor(0.1572, device='cuda:0')
tensor(0.0771, device='cuda:0')
tensor(0.1674, device='cuda:0')
tensor(0.0658, device='cuda:0')
Evaluating batch: 710 of 758tensor(0.0788, device='cuda:0')
tensor(0.1880, device='cuda:0')
tensor(0.0875, device='cuda:0')
tensor(0.0616, device='cuda:0')
tensor(0.0987, device='cuda:0')
tensor(0.0712, device='cuda:0')
tensor(0.0937, device='cuda:0')
tensor(0.1171, device='cuda:0')
tensor(0.1657, device='cuda:0')
tensor(0.2120, device='cuda:0')
Evaluating batch: 720 of 758tensor(0.1041, device='cuda:0')
tensor(0.1456, device='cuda:0')
tensor(0.0729, device='cuda:0')
tensor(0.7689, device='cuda:0')
tensor(0.1047, device='cuda:0')
tensor(0.0794, device='cuda:0')
tensor(0.0952, device='cuda:0')
tensor(0.0784, device='cuda:0')
tensor(0.1117, device='cuda:0')
tensor(0.1187, device='cuda:0')
Evaluating batch: 730 of 758tensor(0.1231, device='cuda:0')
tensor(0.0912, device='cuda:0')
tensor(0.0778, device='cuda:0')
tensor(0.0629, device='cuda:0')
tensor(0.1103, devic

tensor(7.1732, device='cuda:0')
tensor(6.8222, device='cuda:0')
tensor(7.7584, device='cuda:0')

Train loss: tensor(1.1037)

Val loss: tensor(1695.8442)
Evaluating batch: 0 of 758f 758, loss 0.158988tensor(0.0427, device='cuda:0')
tensor(0.0339, device='cuda:0')
tensor(0.1263, device='cuda:0')
tensor(0.1293, device='cuda:0')
tensor(0.0455, device='cuda:0')
tensor(0.0510, device='cuda:0')
tensor(0.0831, device='cuda:0')
tensor(0.0694, device='cuda:0')
tensor(0.0347, device='cuda:0')
tensor(0.0714, device='cuda:0')
Evaluating batch: 10 of 758tensor(0.0290, device='cuda:0')
tensor(0.0753, device='cuda:0')
tensor(0.0582, device='cuda:0')
tensor(0.0755, device='cuda:0')
tensor(0.0528, device='cuda:0')
tensor(0.0557, device='cuda:0')
tensor(0.1011, device='cuda:0')
tensor(0.0865, device='cuda:0')
tensor(0.1948, device='cuda:0')
tensor(0.0871, device='cuda:0')
Evaluating batch: 20 of 758tensor(0.0751, device='cuda:0')
tensor(0.0554, device='cuda:0')
tensor(0.0587, device='cuda:0')
tensor(0.05

Evaluating batch: 230 of 758tensor(0.0425, device='cuda:0')
tensor(0.1085, device='cuda:0')
tensor(0.0613, device='cuda:0')
tensor(0.0590, device='cuda:0')
tensor(0.0798, device='cuda:0')
tensor(0.0995, device='cuda:0')
tensor(0.0755, device='cuda:0')
tensor(0.0657, device='cuda:0')
tensor(0.1080, device='cuda:0')
tensor(0.0638, device='cuda:0')
Evaluating batch: 240 of 758tensor(0.0845, device='cuda:0')
tensor(0.0448, device='cuda:0')
tensor(0.0611, device='cuda:0')
tensor(0.0718, device='cuda:0')
tensor(0.0534, device='cuda:0')
tensor(0.0588, device='cuda:0')
tensor(0.0661, device='cuda:0')
tensor(0.0408, device='cuda:0')
tensor(0.0728, device='cuda:0')
tensor(0.0678, device='cuda:0')
Evaluating batch: 250 of 758tensor(0.1130, device='cuda:0')
tensor(0.1141, device='cuda:0')
tensor(0.0510, device='cuda:0')
tensor(0.0492, device='cuda:0')
tensor(0.0811, device='cuda:0')
tensor(0.0514, device='cuda:0')
tensor(0.0788, device='cuda:0')
tensor(0.0632, device='cuda:0')
tensor(0.1513, devic

tensor(0.0718, device='cuda:0')
tensor(0.1277, device='cuda:0')
tensor(0.0703, device='cuda:0')
Evaluating batch: 470 of 758tensor(0.0688, device='cuda:0')
tensor(0.0485, device='cuda:0')
tensor(0.0903, device='cuda:0')
tensor(0.0622, device='cuda:0')
tensor(0.0474, device='cuda:0')
tensor(0.0497, device='cuda:0')
tensor(0.0647, device='cuda:0')
tensor(0.0460, device='cuda:0')
tensor(0.0445, device='cuda:0')
tensor(0.0567, device='cuda:0')
Evaluating batch: 480 of 758tensor(0.0480, device='cuda:0')
tensor(0.0668, device='cuda:0')
tensor(0.0395, device='cuda:0')
tensor(0.0626, device='cuda:0')
tensor(0.0596, device='cuda:0')
tensor(0.0683, device='cuda:0')
tensor(0.0496, device='cuda:0')
tensor(0.0581, device='cuda:0')
tensor(0.0651, device='cuda:0')
tensor(0.0699, device='cuda:0')
Evaluating batch: 490 of 758tensor(0.0694, device='cuda:0')
tensor(0.0364, device='cuda:0')
tensor(0.0932, device='cuda:0')
tensor(0.0835, device='cuda:0')
tensor(0.0692, device='cuda:0')
tensor(0.0827, devic

tensor(0.0568, device='cuda:0')
tensor(0.0811, device='cuda:0')
tensor(0.0420, device='cuda:0')
tensor(0.0686, device='cuda:0')
tensor(0.0384, device='cuda:0')
tensor(0.0578, device='cuda:0')
tensor(0.0538, device='cuda:0')
Evaluating batch: 710 of 758tensor(0.0744, device='cuda:0')
tensor(0.0709, device='cuda:0')
tensor(0.0422, device='cuda:0')
tensor(0.0871, device='cuda:0')
tensor(0.0586, device='cuda:0')
tensor(0.0782, device='cuda:0')
tensor(0.0286, device='cuda:0')
tensor(0.0428, device='cuda:0')
tensor(0.0803, device='cuda:0')
tensor(0.0657, device='cuda:0')
Evaluating batch: 720 of 758tensor(0.0656, device='cuda:0')
tensor(0.0465, device='cuda:0')
tensor(0.0312, device='cuda:0')
tensor(0.0512, device='cuda:0')
tensor(0.0617, device='cuda:0')
tensor(0.0911, device='cuda:0')
tensor(0.0626, device='cuda:0')
tensor(0.0690, device='cuda:0')
tensor(0.0761, device='cuda:0')
tensor(0.0408, device='cuda:0')
Evaluating batch: 730 of 758tensor(0.0588, device='cuda:0')
tensor(0.1028, devic

tensor(7.8436, device='cuda:0')
tensor(8.1890, device='cuda:0')
tensor(6.8010, device='cuda:0')
tensor(7.5606, device='cuda:0')
tensor(7.0326, device='cuda:0')
tensor(8.1775, device='cuda:0')

Train loss: tensor(1.0726)

Val loss: tensor(2582.9404)


In [11]:
# which experiment to load?
EXPERIMENT_TO_LOAD = 'test_dialogpt_rick_and_morty_basic'
CHECKPOINT_TO_LOAD = 'epoch-2_trainloss-0.108243_valloss-2873.574951'
LOGGING_DIR        = './logs/'
MODEL_SAVE_DIR     = './checkpoints/'

# load config
log_dict = load_json(
    target_dir = os.path.join(LOGGING_DIR, EXPERIMENT_TO_LOAD),
    filename   = CHECKPOINT_TO_LOAD + '.json'
)

# grab the config from the logs
CONFIG = log_dict['config']

# model
model_config = AutoConfig.from_pretrained(CONFIG['HUGGINGFACE_MODEL_NAME'])
gradient_accumulation_steps = 1
model_wrapper = NegativePositiveTrainingLM(
    model_name = CONFIG['HUGGINGFACE_MODEL_NAME'],
    model_config = model_config,
    pos_lr = CONFIG['POS_LR'], 
    neg_lr = CONFIG['NEG_LR'],
    num_opt_steps = len(loader_train) // CONFIG['GRAD_ACCUM_STEPS'] * CONFIG['TRAIN_ITERS'],
    tokenizer = tokenizer,
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# load
model_wrapper.load_model(
    checkpoint_root = CONFIG['MODEL_SAVE_DIR'],
    experiment_name = CONFIG['EXPERIMENT_NAME'],
    checkpoint_name = CHECKPOINT_TO_LOAD,
)

# Chat

In [16]:
df_train

Unnamed: 0,context,response
0,"Morty! You gotta come on. Jus'... you gotta come with me. EOS What, Rick? What’s going on? EOS I got a surprise for you, Morty. EOS It's the middle of the night. What are you talking about? EOS Come on, I got a surprise for you. Come on, hurry up. EOS Ow! Ow! You're tugging me too hard! EOS We gotta go, gotta get outta here, come on. Got a surprise for you Morty.","What do you think of this... flying vehicle, Morty? I built it outta stuff I found in the garage."
1,"What, Rick? What’s going on? EOS I got a surprise for you, Morty. EOS It's the middle of the night. What are you talking about? EOS Come on, I got a surprise for you. Come on, hurry up. EOS Ow! Ow! You're tugging me too hard! EOS We gotta go, gotta get outta here, come on. Got a surprise for you Morty. EOS What do you think of this... flying vehicle, Morty? I built it outta stuff I found in the garage.","Yeah, Rick... I-it's great. Is this the surprise?"
2,"I got a surprise for you, Morty. EOS It's the middle of the night. What are you talking about? EOS Come on, I got a surprise for you. Come on, hurry up. EOS Ow! Ow! You're tugging me too hard! EOS We gotta go, gotta get outta here, come on. Got a surprise for you Morty. EOS What do you think of this... flying vehicle, Morty? I built it outta stuff I found in the garage. EOS Yeah, Rick... I-it's great. Is this the surprise?","Morty. I had to... I had to do it. I had— I had to— I had to make a bomb, Morty. I had to create a bomb."
3,"It's the middle of the night. What are you talking about? EOS Come on, I got a surprise for you. Come on, hurry up. EOS Ow! Ow! You're tugging me too hard! EOS We gotta go, gotta get outta here, come on. Got a surprise for you Morty. EOS What do you think of this... flying vehicle, Morty? I built it outta stuff I found in the garage. EOS Yeah, Rick... I-it's great. Is this the surprise? EOS Morty. I had to... I had to do it. I had— I had to— I had to make a bomb, Morty. I had to create a bomb.",What?! A bomb?!
4,"Come on, I got a surprise for you. Come on, hurry up. EOS Ow! Ow! You're tugging me too hard! EOS We gotta go, gotta get outta here, come on. Got a surprise for you Morty. EOS What do you think of this... flying vehicle, Morty? I built it outta stuff I found in the garage. EOS Yeah, Rick... I-it's great. Is this the surprise? EOS Morty. I had to... I had to do it. I had— I had to— I had to make a bomb, Morty. I had to create a bomb. EOS What?! A bomb?!","We're gonna drop it down there just get a whole fresh start, Morty. Create a whole fresh start."
...,...,...
1512,"I never forget a kid. What do you say, Vindicators? Let's make this three for three? EOS Did he say ""three for three""? EOS Did he say he never forgets a kid? EOS You mean ""two for two"", right, Vance? EOS Actually, we assembled a second time last summer to fight Doomnomitron. EOS So, this is... Vindicators 3? And you guys did Vindicators 2... w-without us? EOS I sense... insecurity.",Are you sure there's not just a picnic nearby.
1513,"Did he say ""three for three""? EOS Did he say he never forgets a kid? EOS You mean ""two for two"", right, Vance? EOS Actually, we assembled a second time last summer to fight Doomnomitron. EOS So, this is... Vindicators 3? And you guys did Vindicators 2... w-without us? EOS I sense... insecurity. EOS Are you sure there's not just a picnic nearby.","I guess he found his crowd. Pretty toothless stuff, guys."
1514,"Did he say he never forgets a kid? EOS You mean ""two for two"", right, Vance? EOS Actually, we assembled a second time last summer to fight Doomnomitron. EOS So, this is... Vindicators 3? And you guys did Vindicators 2... w-without us? EOS I sense... insecurity. EOS Are you sure there's not just a picnic nearby. EOS I guess he found his crowd. Pretty toothless stuff, guys.","I hope you're happy with the adventure so far, Morty. These guys are even lamer than last time."
1515,"You mean ""two for two"", right, Vance? EOS Actually, we assembled a second time last summer to fight Doomnomitron. EOS So, this is... Vindicators 3? And you guys did Vindicators 2... w-without us? EOS I sense... insecurity. EOS Are you sure there's not just a picnic nearby. EOS I guess he found his crowd. Pretty toothless stuff, guys. EOS I hope you're happy with the adventure so far, Morty. These guys are even lamer than last time.","We weren't here ""last time"", remember? They did a whole Vindicators without us. A bunch of them got killed, too. They lost Lady Katana, Calypso, Diablo Verde..."


In [19]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt').to(device)
    # print(new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model_wrapper.model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature=0.8
    )
    
    # pretty print last ouput tokens from bot
    print("RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

>> User:Ow! Ow! You're tugging me too hard! 
RickBot: We gotta go, gotta get outta here, come on. Got a surprise for you Morty.


KeyboardInterrupt: Interrupted by user

# Responses to validation set

In [8]:
df_train

Unnamed: 0,response,context,context/0
0,"I got a surprise for you, Morty.","What, Rick? What’s going on?",Morty! You gotta come on. Jus'... you gotta co...
1,It's the middle of the night. What are you tal...,"I got a surprise for you, Morty.","What, Rick? What’s going on?"
2,"Come on, I got a surprise for you. Come on, h...",It's the middle of the night. What are you tal...,"I got a surprise for you, Morty."
3,Ow! Ow! You're tugging me too hard!,"Come on, I got a surprise for you. Come on, h...",It's the middle of the night. What are you tal...
4,"We gotta go, gotta get outta here, come on. Go...",Ow! Ow! You're tugging me too hard!,"Come on, I got a surprise for you. Come on, h..."
...,...,...,...
1517,Are you sure there's not just a picnic nearby.,I sense... insecurity.,"So, this is... Vindicators 3? And you guys did..."
1518,I guess he found his crowd. Pretty toothless s...,Are you sure there's not just a picnic nearby.,I sense... insecurity.
1519,"I hope you're happy with the adventure so far,...",I guess he found his crowd. Pretty toothless s...,Are you sure there's not just a picnic nearby.
1520,"We weren't here ""last time"", remember? They di...","I hope you're happy with the adventure so far,...",I guess he found his crowd. Pretty toothless s...


In [9]:
df_val

Unnamed: 0,response,context,context/0
0,I... think the personality conflict might have...,"Don't worry, Morty, they love you. Superheroes...",This article says the reason we weren't involv...
1,"Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...,"Don't worry, Morty, they love you. Superheroes..."
2,"Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...
3,"Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan..."
4,"Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y..."
...,...,...,...
374,That was amazing!,"Whoa!! Hahaha, yeah! Atlantis, baby!",Holy crap... Slick's wish came true.
375,Got some of that mermaid puss!,That was amazing!,"Whoa!! Hahaha, yeah! Atlantis, baby!"
376,I'm really hoping it wasn't a one-off thing an...,Got some of that mermaid puss!,That was amazing!
377,"Pssh! Not at all, Morty. That place will never...",I'm really hoping it wasn't a one-off thing an...,Got some of that mermaid puss!


In [72]:
prompt = ["What do you think about Elon Musk?<|endoftext|>"] * 1  # generate a batch of 10 independent samples
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"].to(model_wrapper.device)
prompt_length = len(tokenizer.decode(inputs[0]))

In [73]:
outputs = model_wrapper.model.generate(
    inputs, 
    max_length=200,
    pad_token_id=tokenizer.eos_token_id,  
#     no_repeat_ngram_size=3,       
    do_sample=True, 
    top_k=100, 
    top_p=0.7,
    temperature = 0.8
).cpu()
decoded_outputs = [tokenizer.decode(output) for output in outputs]

In [74]:
decoded_outputs

["What do you think about Elon Musk?<|endoftext|>you, i see, i know. i mean, it's not like, but, uh, but, but, uh, but, uh, uh, but, uh, but, uh, but, uh, uh, uh, but, uh, uh, but, uh, uh, but, uh, uh, uh, uh, uh, uh, but, uh, uh, so, uh, uh, but, but, can you.<|endoftext|>"]

In [35]:
loader = DataLoader(
    dataset_val,
    batch_size = 1,
    shuffle = False,
    collate_fn = collate_fn
)
num_batches_to_show = 100

responses_dict = {}
example_id = 0
for i, batch in enumerate(loader):
    break

In [15]:
print(batch)
tokenizer.decode(batch['input_ids'][example_id])

{'input_ids': tensor([[ 9099,   470,  5490,   837,  5596,    88,   837,   484,  1842,   345,
           764, 40896,   761,   257,  3094, 45320, 49733, 45543,   284,  7621,
          1863,   290,  6324,   284,  2279,   588,   340,   338,  2000, 19280,
           764, 50256,    72,  2644,   892,   262,  8806,  5358,  1244,   423,
           587,  2644,   345,   764,     0]]), 'position_ids': tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43,  0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'target_ids': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100, 

"don't worry, morty, they love you. superheroes need a wide eyed unremarkable to tag along and react to everything like it's mind blowing.<|endoftext|>i... think the personality conflict might have been... you.!"

In [None]:
# each tensor to device
input_ids       = batch['input_ids'].to(self.device)
# position_ids    = batch['position_ids'].to(self.device)
# token_type_ids  = batch['token_type_ids'].to(self.device)
# target_ids      = batch['target_ids'].to(self.device)
attention_mask  = batch['attention_masks'].to(self.device)

outputs = self.model.generate(
    inputs = input_ids, 
    max_new_tokens = max_new_tokens,
    pad_token_id = self.tokenizer.eos_token_id,
    do_sample = True,  # do_sample = True; otherwise all questions will be identical
    top_p = 0.95,      # nucleus sampling
    top_k = 0,         # deactivate top-k words sampling
    attention_mask = attention_mask  # do not pay attention to padding
).cpu()

decoded_outputs = [self.tokenizer.decode(output) for output in outputs]