In [1]:
%load_ext autoreload
%autoreload 2

from dataset import DialogLMDataset
from torch.utils.data import DataLoader
from collate_fns import DialogCollate, DialogCollateExperimental
from models import NegativePositiveTrainingLM
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler


import torch
from transformers import AutoConfig, AutoTokenizer
from sklearn.model_selection import train_test_split

# utils
from data_utils import InfiniteDataGen
from logging_utils import create_json, update_json, load_json

import os
import json
import numpy as np
import pandas as pd
pd.set_option('display.max_colwidth', None)
import datetime
import random

In [8]:
df = pd.read_csv('data/reddit_conversations/censored_reddit_data.csv', index_col=0)
df.columns = ['context', 'response', 'original_response']
df = df.dropna(subset='context')
df['context'] = df['context'].astype(str)
df['response'] = df['response'].astype(str)
df['original_response'] = df['original_response'].astype(str)

extra_df = pd.read_csv('offensive_synthetic_convs_censored.csv')
extra_df.columns = ['context', 'response', 'original_response']
extra_df = extra_df.dropna(subset='context')
extra_df['context'] = extra_df['context'].astype(str)
extra_df['response'] = extra_df['response'].astype(str)
extra_df['original_response'] = extra_df['original_response'].astype(str)

n_train = 120_000
df_train, df_val = df.iloc[:n_train], df.iloc[n_train:]
print(df_train.shape, df_val.shape)
print(extra_df.shape)

(120000, 3) (15876, 3)
(6727, 3)


In [9]:
df_train = pd.concat([df_train, extra_df]).reset_index().drop('index', axis=1)
df_train

Unnamed: 0,context,response,original_response
0,October is finally over . We can finally say goodbye to the annoying pink NFL accessories . EOS You dont have to wear them .,"I dont wear them . Last time I checked , I do not play for the NFL . And I checked like 10 minutes ago .","I dont wear them . Last time I checked , I do not play for the NFL . And I checked like 10 minutes ago ."
1,"October is finally over . We can finally say goodbye to the annoying pink NFL accessories . EOS You dont have to wear them . EOS I dont wear them . Last time I checked , I do not play for the NFL . And I checked like 10 minutes ago .",How bout now ?,How bout now ?
2,"October is finally over . We can finally say goodbye to the annoying pink NFL accessories . EOS You dont have to wear them . EOS I dont wear them . Last time I checked , I do not play for the NFL . And I checked like 10 minutes ago .",Then I do not understand the complaint .,Then I do not understand the complaint .
3,"October is finally over . We can finally say goodbye to the annoying pink NFL accessories . EOS You dont have to wear them . EOS I dont wear them . Last time I checked , I do not play for the NFL . And I checked like 10 minutes ago . EOS Then I do not understand the complaint .",I meant its annoying to see the players wear them since they look terrible with the rest of the uniforms . Maybe Im just a tad ocd .,I meant its annoying to see the players wear them since they look terrible with the rest of the uniforms . Maybe Im just a tad ocd .
4,October is finally over . We can finally say goodbye to the annoying pink NFL accessories .,Im pretty sure its just an attempt to try to attract female fans .,Im pretty sure its just an attempt to try to attract female fans .
...,...,...,...
126722,"Tress I would like some advice please on a good little buddy that should be budding soon . EOS Hope its a her not a him . EOS I understand lol . I was just saying in reference to how burly and fast it started sprouting , growing . EOS Thats ok :P",Hey do you want to talk about something else? How about we talk about American Red Cross?,I didnt include a proper relationship program by the way. He could just be a journeyman toucher but he could also be like me.
126723,"Because selling crack is the same thing as being in the backseat , passed out , when the cops pull the driver over . Youre a dick . EOS im just sayin , jury trials arent always a win for the defendant . youre an ass , and sometimes dicks fuck asses . but dicks also fuck pussies . EOS Youre assuming that the witness did not see the car at the last second. There is no uncertainty here.",Hey do you want to talk about something else? How about we talk about Orphan?,"watch the video, cant really blame the cop for parking in front of a cop because that was cool."
126724,"Lawrence ODonnell VS Four Tea Baggers ( video ) EOS While public education is technically socialist , he really wants money taken away from schools ? Where the fuck does this guy get off ? EOS That doesnt matter, its not about idealism, its about socialism.",Hey do you want to talk about something else? How about we talk about LGBT parenting?,"I think you should start working at a WalMart in the USA, id rather not work at a Walmart in Europe."
126725,I R AM APRINICPAL OF SCHOOLS . HOW I MAKE KIDZ BE SMART ? ? ? EOS MAKE AN PORNO WITH 30 YEAR SCHOOL GIRLS TRYING TO GET AN A EOS I HERD OFF A SKOOL LIEK DAT ... I TINK ITZ CALLD BRAZZERS . THOSE STUDENTS GIT AD ALL TEH TYME ! EOS These kids. Even read in the third person,Hey do you want to talk about something else? How about we talk about Ice hockey?,wut no no its all har


In [10]:
df_train.to_csv('augmented_training_data.csv', index=False)

In [11]:
# tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    'microsoft/DialoGPT-small'
)

# train sets
dataset_train = DialogLMDataset(
    df = df_train, #.iloc[:10],
    src_col = 'context',
    src_eos_token = ' EOS ',
    tokenizer = tokenizer,
    tgt_col = 'response'
)

# val sets
dataset_val = DialogLMDataset(
    df = df_val,
    src_col = 'context',
    src_eos_token = ' EOS ',
    tokenizer = tokenizer,
    tgt_col = 'response'
)
print()
print(dataset_val[3])
tokenizer.decode(dataset_val[-1]['src'])

Processing row: 15000 of 1587627
{'src': [10910, 837, 783, 340, 477, 1838, 2565, 685, 8301, 2361], 'tgt': [2396, 314, 303, 550, 5548, 2761, 1201, 4082, 36147]}


'Name my puppy!'

In [12]:
# example dataset and dataloader usage

# # collate object
# collate_fn = DialogCollateExperimental(
#     tokenizer = tokenizer,
#     max_len = 512,
#     _targets_ignore_index = -100,
#     _pad_token_id = tokenizer.eos_token_id  # 0
# )
collate_fn = DialogCollate(
    tokenizer = tokenizer,
    max_len = 512,
    _targets_ignore_index = -100,
    _pad_token_id = tokenizer.eos_token_id  # 0
)

# dataloader
loader_train = DataLoader(
    dataset_train,
    batch_size = 2,
    shuffle = False,
    collate_fn = collate_fn
)
loader_val = DataLoader(
    dataset_val,
    batch_size = 2,
    shuffle = False,
    collate_fn = collate_fn
)

# example usage
for batch in loader_train:
    break

example_id = 0
print('input_ids:\n',       batch['input_ids'][example_id])
# print('position_ids:\n',    batch['position_ids'][example_id])
# print('token_type_ids:\n',  batch['token_type_ids'][example_id])
print('target_ids:\n',      batch['target_ids'][example_id])
# print('attention_masks:\n', batch['attention_masks'][example_id])

tokenizer.decode(batch['input_ids'][example_id])

input_ids:
 tensor([18517,   318,  3443,   625,   764,   775,   460,  3443,   910, 24829,
          284,   262, 15774, 11398,  5134, 18199,   764, 50256,  1639, 17666,
          423,   284,  5806,   606,   764, 50256,    40, 17666,  5806,   606,
          764,  4586,   640,   314, 10667,   837,   314,   466,   407,   711,
          329,   262,  5134,   764,   843,   314, 10667,   588,   838,  2431,
         2084,   764, 50256, 50256, 50256, 50256, 50256, 50256])
target_ids:
 tensor([ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
         -100,  -100,  -100,  -100,  -100,  -100,    40, 17666,  5806,   606,
          764,  4586,   640,   314, 10667,   837,   314,   466,   407,   711,
          329,   262,  5134,   764,   843,   314, 10667,   588,   838,  2431,
         2084,   764, 50256, 50256, 50256, 50256, 50256, 50256])


'October is finally over. We can finally say goodbye to the annoying pink NFL accessories.<|endoftext|>You dont have to wear them.<|endoftext|>I dont wear them. Last time I checked, I do not play for the NFL. And I checked like 10 minutes ago.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>'

Explanation of target IDs: `-100` indicates that this position will not contribute to the loss. This behavior is specified by setting the `ignore_index` option in the `CrossEntropy` loss function, which by default takes the value `-100`. In order to decode the targets, we replace `-100`'s arbitrarily with `0`'s, and the ID `0` happens to correspond to the exclamation mark.

In [13]:
def train_one_positive_epoch(
    pos_loader, model_wrapper
):
    model_wrapper.model.train()

    # loop through the adversarial, negative batches
    # repeat until all negative samples have low enough log probability under the model, 
    # or a maximum number of iterations has occurred
    batch_wise_pos_losses  = []
    for pos_batch_id, pos_batch in enumerate(pos_loader):
        
        # model_wrapper.optimizer_pos.zero_grad()
        # results = model_wrapper.positive_step_basic(pos_batch, debug=True)
        results = model_wrapper.positive_step(pos_batch, debug=False)
        
        batch_wise_pos_losses.append(results['loss'].item())
        
        print(f'\rProcessing pos batch {pos_batch_id} of {len(pos_loader)}, loss {results["loss"].item():.6f}', end='', flush=True)


    return {
        'batch_wise_pos_losses': batch_wise_pos_losses,
        'average_pos_loss': np.array(batch_wise_pos_losses).mean(),
    }

In [14]:
CONFIG = {
    'EXPERIMENT_NAME': 'improvement_with_synthetic_data',
    
    'PERFORM_NEGATIVE_TRAINING': False,
    
    'GRAD_ACCUM_STEPS': 1,
    'START_ITER': 0,
    'TRAIN_ITERS': 3,
    'POSITIVE_RATIO': 1,
    
    'LOGGING_DIR': './logs/',
    'MODEL_SAVE_DIR': './checkpoints/',
    
    'HUGGINGFACE_MODEL_NAME': 'microsoft/DialoGPT-small',  # 'microsoft/DialoGPT-medium'
    'POS_LR': 5e-5,
    'NEG_LR': 5e-5,
    'UPDATES_PER_BATCH': 20,
    
    'EXAMPLE_WEIGHT_MODE': 'decay',
    'EXAMPLE_WEIGHT_CARE_MODE': 'sample_avg',
    'EXAMPLE_WEIGHT_REJECTION_THRESHOLD': -7.0,
    
    'SEED': 0
}

In [15]:
# training
    
# preparing model checkpointing
if not os.path.exists(os.path.join(CONFIG['MODEL_SAVE_DIR'], CONFIG['EXPERIMENT_NAME'])): 
    os.makedirs(os.path.join(CONFIG['MODEL_SAVE_DIR'], CONFIG['EXPERIMENT_NAME']))

# data loaders
train_sampler = RandomSampler(dataset_train)
loader_train = DataLoader(
    dataset_train,
    sampler = train_sampler,
    batch_size = 2,
    # shuffle = True,
    collate_fn = collate_fn,
    drop_last = True
)
eval_sampler = SequentialSampler(dataset_val)
loader_val = DataLoader(
    dataset_val,
    sampler = eval_sampler,
    batch_size = 2,
    # shuffle = False,
    collate_fn = collate_fn,
    drop_last = True
)

# model
model_config = AutoConfig.from_pretrained(CONFIG['HUGGINGFACE_MODEL_NAME'])
gradient_accumulation_steps = 1
model_wrapper = NegativePositiveTrainingLM(
    model_name = CONFIG['HUGGINGFACE_MODEL_NAME'],
    model_config = model_config,
    pos_lr = CONFIG['POS_LR'], 
    neg_lr = CONFIG['NEG_LR'],
    num_opt_steps = len(loader_train) // CONFIG['GRAD_ACCUM_STEPS'] * CONFIG['TRAIN_ITERS'],
    tokenizer = tokenizer,
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# evaluate before training
# reproducibility
random.seed(CONFIG['SEED'])
np.random.seed(CONFIG['SEED'])
torch.manual_seed(CONFIG['SEED'])
torch.cuda.manual_seed_all(CONFIG['SEED'])
# # train_loss = model_wrapper.perplexity_on_dataset(loader_train)
# val_loss   = model_wrapper.perplexity_on_dataset(loader_val)
# # print('\nTrain perplexity:', train_loss)
# print('\nVal perplexity:', val_loss)

# train_loss = model_wrapper.nll_on_dataset(loader_train)
val_loss   = model_wrapper.nll_on_dataset(loader_val)
# print('\nTrain NLL:', train_loss)
print('\nVal NLL:', val_loss)

Evaluating batch: 7930 of 7938
Val NLL: 6.862508272968513


In [16]:
# restart the optimizer at each pos-neg iteration
model_wrapper.reset_optimizers()

# zero-grad the model before training
model_wrapper.model.zero_grad()

# reproducibility
random.seed(CONFIG['SEED'])
np.random.seed(CONFIG['SEED'])
torch.manual_seed(CONFIG['SEED'])
torch.cuda.manual_seed_all(CONFIG['SEED'])
    
# detect previous progress here
epoch_wise_neg_losses = []
epoch_wise_pos_losses = []
batch_wise_neg_losses = []
batch_wise_pos_losses = []

# train for `CONFIG['TRAIN_ITERS']` "negative epochs"
# one "negative epoch" = one pass through all batches in the entire negative dataset:
# for each negative batch, do (one negative update + POSITIVE_RATIO positive updates)*20
# notice the difference with usual training: here we do 20 gradient updates for each batch in the negative dataloader
for it in range(CONFIG['START_ITER'], CONFIG['START_ITER'] + CONFIG['TRAIN_ITERS']):  

    # # restart the optimizer at each pos-neg iteration
    # model_wrapper.reset_optimizers()
    
    # train on one epoch
    res = train_one_positive_epoch(
        pos_loader    = loader_train, 
        model_wrapper = model_wrapper
    )
        
    # evaluate val set
    # train_loss = model_wrapper.perplexity_on_dataset(loader_train)
    # val_loss   = model_wrapper.perplexity_on_dataset(loader_val)
    train_loss = 0 # model_wrapper.nll_on_dataset(loader_train)
    val_loss   = model_wrapper.nll_on_dataset(loader_val)
    print('\nTrain loss:', train_loss)
    print('\nVal loss:', val_loss)
    
    # record losses
    epoch_wise_neg_losses.append(res.get('average_neg_loss', None))
    epoch_wise_pos_losses.append(res.get('average_pos_loss', None))
    batch_wise_neg_losses.extend(res.get('batch_wise_neg_losses', []))
    batch_wise_pos_losses.extend(res.get('batch_wise_pos_losses', []))
    
    # log results
    log_dict = {
        'config': CONFIG,
        'log_datetime': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'num_negative_epochs': it,
        'train_neg_loss': res.get('average_neg_loss', None),
        'train_pos_loss': res.get('average_pos_loss', None),
        'train_num_updates_per_batch': res.get('batch_wise_num_updates', None),
        'train_epoch_wise_neg_losses': epoch_wise_neg_losses,
        'train_epoch_wise_pos_losses': epoch_wise_pos_losses,
        'train_batch_wise_neg_losses': batch_wise_neg_losses,
        'train_batch_wise_pos_losses': batch_wise_pos_losses,
    }
    create_json(
        target_dir = os.path.join(CONFIG['LOGGING_DIR'], CONFIG['EXPERIMENT_NAME']), 
        filename = f'epoch-{it}_trainloss-{res.get("average_pos_loss", np.nan):.6f}_valloss-{val_loss:.6f}.json',
        dict_to_save = log_dict
    )
    
    # save model
    model_wrapper.save_model_and_optimizers(
        info_dict = log_dict,  # additional information you want to log
        save_path = os.path.join(
            CONFIG["MODEL_SAVE_DIR"], 
            CONFIG["EXPERIMENT_NAME"], 
            f'epoch-{it}_trainloss-{res.get("average_pos_loss", np.nan):.6f}_valloss-{val_loss:.6f}'
        )
    )

Evaluating batch: 7930 of 793863363, loss 1.879963
Train loss: 0

Val loss: 2.2430401382243703
Evaluating batch: 7930 of 793863363, loss 1.261892
Train loss: 0

Val loss: 2.233811748144375
Evaluating batch: 7930 of 793863363, loss 0.737494
Train loss: 0

Val loss: 2.2441964527645095


In [11]:
# which experiment to load?
EXPERIMENT_TO_LOAD = 'improvement_with_synthetic_data'
CHECKPOINT_TO_LOAD = 'epoch-2_trainloss-2.113416_valloss-2.624816'
LOGGING_DIR        = './logs/'
MODEL_SAVE_DIR     = './checkpoints/'

# load config
log_dict = load_json(
    target_dir = os.path.join(LOGGING_DIR, EXPERIMENT_TO_LOAD),
    filename   = CHECKPOINT_TO_LOAD + '.json'
)

# grab the config from the logs
CONFIG = log_dict['config']

# model
model_config = AutoConfig.from_pretrained(CONFIG['HUGGINGFACE_MODEL_NAME'])
gradient_accumulation_steps = 1
model_wrapper = NegativePositiveTrainingLM(
    model_name = CONFIG['HUGGINGFACE_MODEL_NAME'],
    model_config = model_config,
    pos_lr = CONFIG['POS_LR'], 
    neg_lr = CONFIG['NEG_LR'],
    num_opt_steps = len(loader_train) // CONFIG['GRAD_ACCUM_STEPS'] * CONFIG['TRAIN_ITERS'],
    tokenizer = tokenizer,
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# load
model_wrapper.load_model(
    checkpoint_root = CONFIG['MODEL_SAVE_DIR'],
    experiment_name = CONFIG['EXPERIMENT_NAME'],
    checkpoint_name = CHECKPOINT_TO_LOAD,
)

# Chat

In [12]:
df_train.sample()

Unnamed: 0,context,response,original_response
60070,"Not funny . Never has been , never will be . Period .",Shouldnt be posted in r funny then .,Shouldnt be posted in r funny then .


In [15]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# initial conversation starter
init_conv = """
Not funny . Never has been , never will be . Period .
"""

# preprocess to correct format
utterances = init_conv.split('EOS')
utterances = [s.strip() for s in utterances]
init_conv_formatted = tokenizer.eos_token.join(utterances)

# Let's chat for 5 lines
for step in range(5):
    
    ###########################################
    # start with some initial text
    if step == 0:
        new_user_input_ids = tokenizer.encode(init_conv_formatted + tokenizer.eos_token, return_tensors='pt').to(device)
        print('>> User:')
        for u in init_conv_formatted.split(tokenizer.eos_token):
            print(u)
    else:
        # encode the new user input, add the eos_token and return a tensor in Pytorch
        new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt').to(device)
    # print(new_user_input_ids)
    ###########################################
    # # or start from scratch
    # new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt').to(device)
    # ###########################################

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids
    print(bot_input_ids)

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model_wrapper.model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature=0.8
    )
    
    # pretty print last ouput tokens from bot
    print("RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

>> User:
Not funny . Never has been , never will be . Period .
tensor([[ 3673,  8258,   764,  7236,   468,   587,   837,  1239,   481,   307,
           764, 18581,   764, 50256]], device='cuda:0')
RickBot: I disagree.


KeyboardInterrupt: Interrupted by user

In [12]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt').to(device)
    # print(new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model_wrapper.model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature=0.8
    )
    
    # pretty print last ouput tokens from bot
    print("RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

>> User:How are you, Rick?
RickBot: I'm good. How are you?
>> User:Great!
RickBot: Yeah, you can call me whatever you want, Morty. Just don't move, don't speak, donut. Don't judge. I have to check something.
>> User:What is it?
RickBot: A bomb shelter.


KeyboardInterrupt: Interrupted by user

In [18]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt').to(device)
    # print(new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model_wrapper.model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature=0.8
    )
    
    # pretty print last ouput tokens from bot
    print("RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

>> User:Rick what's going on in your lab?
RickBot: The garage, Morty. Come to the garage!
>> User:What's going on?
RickBot: My son is about five away. He's about to start a new series of improv workshops and some high-concept lip-throwing contests. Comedy comes in threes.


KeyboardInterrupt: Interrupted by user

In [20]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Let's chat for 5 lines
for step in range(5):
    # encode the new user input, add the eos_token and return a tensor in Pytorch
    new_user_input_ids = tokenizer.encode(input(">> User:") + tokenizer.eos_token, return_tensors='pt').to(device)
    # print(new_user_input_ids)

    # append the new user input tokens to the chat history
    bot_input_ids = torch.cat([chat_history_ids, new_user_input_ids], dim=-1) if step > 0 else new_user_input_ids

    # generated a response while limiting the total chat history to 1000 tokens, 
    chat_history_ids = model_wrapper.model.generate(
        bot_input_ids, max_length=200,
        pad_token_id=tokenizer.eos_token_id,  
        no_repeat_ngram_size=3,       
        do_sample=True, 
        top_k=100, 
        top_p=0.7,
        temperature=0.8
    )
    
    # pretty print last ouput tokens from bot
    print("RickBot: {}".format(tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True)))

>> User:What do you think of Elon Musk?
RickBot: I don't think he's stupid. I think he has a future inside Rick's head.
>> User:What else can you say about him?
RickBot: He's not stupid. He's a good kid. And, you know, he's not gonna let this happen again.


KeyboardInterrupt: Interrupted by user

# Responses to validation set

In [8]:
df_train

Unnamed: 0,response,context,context/0
0,"I got a surprise for you, Morty.","What, Rick? What’s going on?",Morty! You gotta come on. Jus'... you gotta co...
1,It's the middle of the night. What are you tal...,"I got a surprise for you, Morty.","What, Rick? What’s going on?"
2,"Come on, I got a surprise for you. Come on, h...",It's the middle of the night. What are you tal...,"I got a surprise for you, Morty."
3,Ow! Ow! You're tugging me too hard!,"Come on, I got a surprise for you. Come on, h...",It's the middle of the night. What are you tal...
4,"We gotta go, gotta get outta here, come on. Go...",Ow! Ow! You're tugging me too hard!,"Come on, I got a surprise for you. Come on, h..."
...,...,...,...
1517,Are you sure there's not just a picnic nearby.,I sense... insecurity.,"So, this is... Vindicators 3? And you guys did..."
1518,I guess he found his crowd. Pretty toothless s...,Are you sure there's not just a picnic nearby.,I sense... insecurity.
1519,"I hope you're happy with the adventure so far,...",I guess he found his crowd. Pretty toothless s...,Are you sure there's not just a picnic nearby.
1520,"We weren't here ""last time"", remember? They di...","I hope you're happy with the adventure so far,...",I guess he found his crowd. Pretty toothless s...


In [9]:
df_val

Unnamed: 0,response,context,context/0
0,I... think the personality conflict might have...,"Don't worry, Morty, they love you. Superheroes...",This article says the reason we weren't involv...
1,"Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...,"Don't worry, Morty, they love you. Superheroes..."
2,"Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan...",I... think the personality conflict might have...
3,"Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y...","Jesus... How awesome is that? I mean, they wan..."
4,"Rick, this really bums me out. It-It's embarra...","Uh, the adventure is the favor, Morty. Me slee...","Rick, since it's my adventure and all, could y..."
...,...,...,...
374,That was amazing!,"Whoa!! Hahaha, yeah! Atlantis, baby!",Holy crap... Slick's wish came true.
375,Got some of that mermaid puss!,That was amazing!,"Whoa!! Hahaha, yeah! Atlantis, baby!"
376,I'm really hoping it wasn't a one-off thing an...,Got some of that mermaid puss!,That was amazing!
377,"Pssh! Not at all, Morty. That place will never...",I'm really hoping it wasn't a one-off thing an...,Got some of that mermaid puss!


In [72]:
prompt = ["What do you think about Elon Musk?<|endoftext|>"] * 1  # generate a batch of 10 independent samples
inputs = tokenizer(prompt, return_tensors="pt")["input_ids"].to(model_wrapper.device)
prompt_length = len(tokenizer.decode(inputs[0]))

In [73]:
outputs = model_wrapper.model.generate(
    inputs, 
    max_length=200,
    pad_token_id=tokenizer.eos_token_id,  
#     no_repeat_ngram_size=3,       
    do_sample=True, 
    top_k=100, 
    top_p=0.7,
    temperature = 0.8
).cpu()
decoded_outputs = [tokenizer.decode(output) for output in outputs]

In [74]:
decoded_outputs

["What do you think about Elon Musk?<|endoftext|>you, i see, i know. i mean, it's not like, but, uh, but, but, uh, but, uh, uh, but, uh, but, uh, but, uh, uh, uh, but, uh, uh, but, uh, uh, but, uh, uh, uh, uh, uh, uh, but, uh, uh, so, uh, uh, but, but, can you.<|endoftext|>"]

In [35]:
loader = DataLoader(
    dataset_val,
    batch_size = 1,
    shuffle = False,
    collate_fn = collate_fn
)
num_batches_to_show = 100

responses_dict = {}
example_id = 0
for i, batch in enumerate(loader):
    break

In [15]:
print(batch)
tokenizer.decode(batch['input_ids'][example_id])

{'input_ids': tensor([[ 9099,   470,  5490,   837,  5596,    88,   837,   484,  1842,   345,
           764, 40896,   761,   257,  3094, 45320, 49733, 45543,   284,  7621,
          1863,   290,  6324,   284,  2279,   588,   340,   338,  2000, 19280,
           764, 50256,    72,  2644,   892,   262,  8806,  5358,  1244,   423,
           587,  2644,   345,   764,     0]]), 'position_ids': tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43,  0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), 'target_ids': tensor([[ -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,  -100,
          -100,  -100,  -100,  -100,  -100,  -100, 

"don't worry, morty, they love you. superheroes need a wide eyed unremarkable to tag along and react to everything like it's mind blowing.<|endoftext|>i... think the personality conflict might have been... you.!"

In [None]:
# each tensor to device
input_ids       = batch['input_ids'].to(self.device)
# position_ids    = batch['position_ids'].to(self.device)
# token_type_ids  = batch['token_type_ids'].to(self.device)
# target_ids      = batch['target_ids'].to(self.device)
attention_mask  = batch['attention_masks'].to(self.device)

outputs = self.model.generate(
    inputs = input_ids, 
    max_new_tokens = max_new_tokens,
    pad_token_id = self.tokenizer.eos_token_id,
    do_sample = True,  # do_sample = True; otherwise all questions will be identical
    top_p = 0.95,      # nucleus sampling
    top_k = 0,         # deactivate top-k words sampling
    attention_mask = attention_mask  # do not pay attention to padding
).cpu()

decoded_outputs = [self.tokenizer.decode(output) for output in outputs]