In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers

In [None]:
import os
os.chdir("drive/My Drive/target-guided-sat-chatbot")

In [None]:
!ls

In [None]:
import time
import datetime
import random
import numpy as np
import json
import pickle
from itertools import chain

import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler

from transformers import GPT2Tokenizer, GPT2LMHeadModel
from transformers import AdamW, get_linear_schedule_with_warmup, get_polynomial_decay_schedule_with_warmup

In [None]:
mname = "gpt2"
model = GPT2LMHeadModel.from_pretrained(mname)
tokenizer = GPT2Tokenizer.from_pretrained(mname)

In [None]:
special_tokens = {
    'bos_token': "<|startoftext|>",
    'additional_special_tokens': ["<|keyword|>", "<|speaker1|>", "<|speaker2|>"]
}
tokenizer.add_special_tokens(special_tokens)
print(tokenizer.all_special_tokens) # --> ['<|startoftext|>', '<|endoftext|>', '<|keyword|>', '<|speaker1|>', '<|speaker2|>']
print(tokenizer.all_special_ids)    # --> [50261, 50256, 50262, 50263, 50264]

model.resize_token_embeddings(len(tokenizer))

Load raw data

In [None]:
with open('data/train/concepts_nv.json') as f:
  train_data_json = [json.loads(row) for row in f]
print(f"train length: {len(train_data_json)}")

In [None]:
with open('data/dev/concepts_nv.json') as f:
  validation_data_json = [json.loads(row) for row in f]
print(f"validation length: {len(validation_data_json)}")

Dataset Class

In [None]:
class KeywordGenerationDataset(Dataset):

  def __init__(self, data_json, tokenizer):

    self.input_ids = []
    self.token_type_ids = []
    self.labels = []

    bos_id = tokenizer.bos_token_id
    eos_id = tokenizer.eos_token_id
    speaker1_id = tokenizer.additional_special_tokens_ids[1]
    speaker2_id = tokenizer.additional_special_tokens_ids[2]
    kw_token = tokenizer.additional_special_tokens[0]

    max_input_length = 1024
    max_context_length = 5

    for data in tqdm(data_json):
      dialog = data['dialog']
      concepts = data['concepts']

      for idx in range(1, len(dialog)):
        keywords = concepts[idx]

        start_idx = max(0, idx-max_context_length+1)
        contexts = dialog[start_idx:idx+1]
        if len(contexts) % 2 == 0:
          start_speaker_id = speaker1_id
          next_speaker_id = speaker2_id
        else:
          start_speaker_id = speaker2_id
          next_speaker_id = speaker1_id
        encoded_contexts = [[start_speaker_id] + tokenizer.encode(c) if i % 2 == 0 else [next_speaker_id] + tokenizer.encode(c) for i, c in enumerate(contexts)]
        assert encoded_contexts[-1][0] == speaker2_id

        random.shuffle(keywords)
        keywords_with_special_tokens = kw_token + kw_token.join(keywords)
        encoded_keywords = tokenizer.encode(keywords_with_special_tokens)

        input_ids = encoded_keywords + [bos_id] + list(chain.from_iterable(encoded_contexts)) + [eos_id]
        if len(input_ids) > max_input_length:
          continue

        token_type_ids_keywords = [speaker2_id] * len(encoded_keywords)
        token_type_ids_contexts = [[start_speaker_id] * len(c) if i % 2 == 0 else [next_speaker_id] * len(c) for i, c in enumerate(encoded_contexts)]
        assert token_type_ids_contexts[-1][0] == speaker2_id
        token_type_ids = token_type_ids_keywords + [speaker1_id] + list(chain.from_iterable(token_type_ids_contexts)) + [speaker2_id]
        assert len(input_ids) == len(token_type_ids)

        labels = [-100] * (len(encoded_keywords) + 1 + sum([len(c) for c in encoded_contexts[:-1]]) + 1) + encoded_contexts[-1][1:] + [eos_id]
        assert len(input_ids) == len(labels)

        self.input_ids.append(input_ids)
        self.token_type_ids.append(token_type_ids)
        self.labels.append(labels)

  def __len__(self):
    return len(self.input_ids)

  def __getitem__(self, idx):
    return {'input_ids' : self.input_ids[idx], 'token_type_ids' : self.token_type_ids[idx], 'labels' : self.labels[idx]}



In [None]:
def collate_fn(batch):
  input_ids, token_type_ids, labels = [], [], []
  eos_id = tokenizer.eos_token_id
  for b in batch:
    input_ids.append(torch.LongTensor(b['input_ids']))
    token_type_ids.append(torch.LongTensor(b['token_type_ids']))
    labels.append(torch.LongTensor(b['labels']))

  input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=eos_id)
  token_type_ids = torch.nn.utils.rnn.pad_sequence(token_type_ids, batch_first=True, padding_value=eos_id)
  labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)

  return {'input_ids' : input_ids, 'token_type_ids' : token_type_ids, 'labels' : labels}

In [None]:
train_dataset = KeywordGenerationDataset(train_data_json, tokenizer)
validation_dataset = KeywordGenerationDataset(validation_data_json, tokenizer)
print(len(train_dataset))

Use pickled dataset

In [None]:
# open pickled dataset
with open("keyword_generation_dataset/generation_train_dataset_gpt2.pickle", "rb") as f:
    train_dataset = pickle.load(f)
with open("keyword_generation_dataset/generation_dev_dataset_gpt2.pickle", "rb") as f:
    validation_dataset = pickle.load(f)

In [None]:
batch_size = 8
train_dataloader = DataLoader(train_dataset, sampler = RandomSampler(train_dataset), batch_size=batch_size, collate_fn=collate_fn)
validation_dataloader = DataLoader(validation_dataset, sampler = SequentialSampler(validation_dataset), batch_size=batch_size, collate_fn=collate_fn)

In [None]:
# some parameters I cooked up that work reasonably well

num_epochs = 5
num_training_steps = len(train_dataloader) * num_epochs # Total number of training steps is [number of batches] x [number of epochs].
learning_rate = 2e-5 # 5e-4
warmup_ratio = 0.1
warmup_steps = int(warmup_ratio * num_training_steps ) # 1e2

# this produces sample output every 100 steps
sample_step = 100

# Note: AdamW is a class from the huggingface library (as opposed to pytorch)
optimizer = AdamW(model.parameters(),
                  lr = learning_rate,
                )

# lr_scheduler = get_linear_schedule_with_warmup(optimizer,
#                                             num_warmup_steps = warmup_steps,
#                                             num_training_steps = num_training_steps)

lr_scheduler = get_polynomial_decay_schedule_with_warmup(
                optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=num_training_steps,
                power=2
                )

In [None]:
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [None]:
from tqdm.auto import tqdm

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using Cuda?: {torch.cuda.is_available()}")
model.to(device)

total_t0 = time.time()
progress_bar = tqdm(range(num_training_steps))
training_stats = []
best_loss = 100

for epoch in range(num_epochs):
  print("")
  print('======== Epoch {:} / {:} ========'.format(epoch + 1, num_epochs))

  # ========================================
  #               Training
  # ========================================
  print('Training...')
  t0 = time.time()
  total_train_loss = 0

  model.train()

  for step, batch in enumerate(train_dataloader):
    b_input_ids = batch['input_ids'].to(device)
    b_token_type_ids = batch['token_type_ids'].to(device)
    b_labels = batch['labels'].to(device)
    outputs = model(b_input_ids,
                    token_type_ids=b_token_type_ids,
                    labels=b_labels
                    )
    loss = outputs.loss
    total_train_loss += loss.item()

    # Get sample every x batches.
    if step % sample_step == 0 and not step == 0:
      elapsed = format_time(time.time() - t0)
      print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), loss, elapsed))

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
    lr_scheduler.step()

    progress_bar.update(1)
  # Calculate the average loss over all of the batches.
  avg_train_loss = total_train_loss / len(train_dataloader)

  # Measure how long this epoch took.
  training_time = format_time(time.time() - t0)
  print("")
  print("  Average training loss: {0:.2f}".format(avg_train_loss))
  print("  Training epoch took: {:}".format(training_time))

  # ========================================
  #               Validation
  # ========================================

  print("")
  print("Running Validation...")

  t0 = time.time()

  model.eval()

  total_eval_loss = 0

  # Evaluate data for one epoch
  for batch in validation_dataloader:

    b_input_ids = batch['input_ids'].to(device)
    b_token_type_ids = batch['token_type_ids'].to(device)
    b_labels = batch['labels'].to(device)
    outputs = model(b_input_ids,
                    token_type_ids=b_token_type_ids,
                    labels=b_labels
                    )
    loss = outputs.loss
    total_eval_loss += loss.item()

  avg_val_loss = total_eval_loss / len(validation_dataloader)

  validation_time = format_time(time.time() - t0)

  print("  Validation Loss: {0:.2f}".format(avg_val_loss))
  print("  Validation took: {:}".format(validation_time))

  # Record all statistics from this epoch.
  training_stats.append(
      {
          'epoch': epoch + 1,
          'Training Loss': avg_train_loss,
          'Valid. Loss': avg_val_loss,
          'Training Time': training_time,
          'Validation Time': validation_time
      }
  )

  if avg_val_loss < best_loss:
    best_loss = avg_val_loss
    state_dict = {
                      'model_state_dict': model.state_dict(),
                      'optim_state_dict': optimizer.state_dict(),
                      'sched_state_dict': lr_scheduler.state_dict(),
                      'loss': best_loss,
                      'epoch': epoch + 1
                  }

    torch.save(state_dict, f"best_ckpt_epoch={epoch+1}_valid_loss={round(best_loss, 4)}.ckpt")

print("")
print("Training complete!")
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))

In [None]:
best_loss = avg_val_loss
state_dict = {
                  'model_state_dict': model.state_dict(),
                  'optim_state_dict': optimizer.state_dict(),
                  'sched_state_dict': lr_scheduler.state_dict(),
                  'loss': best_loss,
                  'epoch': 5
              }

torch.save(state_dict, f"best_ckpt_epoch=5_valid_loss={round(best_loss, 4)}.ckpt")


In [None]:

total_t0 = time.time()
progress_bar = tqdm(range(num_training_steps))


for epoch in range(5, 10):
  print("")
  print('======== Epoch {:} / {:} ========'.format(epoch + 1, 10))

  # ========================================
  #               Training
  # ========================================
  print('Training...')
  t0 = time.time()
  total_train_loss = 0

  model.train()

  for step, batch in enumerate(train_dataloader):
    b_input_ids = batch['input_ids'].to(device)
    b_token_type_ids = batch['token_type_ids'].to(device)
    b_labels = batch['labels'].to(device)
    outputs = model(b_input_ids,
                    token_type_ids=b_token_type_ids,
                    labels=b_labels
                    )
    loss = outputs.loss
    total_train_loss += loss.item()

    # Get sample every x batches.
    if step % sample_step == 0 and not step == 0:
      elapsed = format_time(time.time() - t0)
      print('  Batch {:>5,}  of  {:>5,}. Loss: {:>5,}.   Elapsed: {:}.'.format(step, len(train_dataloader), loss, elapsed))

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
    lr_scheduler.step()

    progress_bar.update(1)
  # Calculate the average loss over all of the batches.
  avg_train_loss = total_train_loss / len(train_dataloader)

  # Measure how long this epoch took.
  training_time = format_time(time.time() - t0)
  print("")
  print("  Average training loss: {0:.2f}".format(avg_train_loss))
  print("  Training epoch took: {:}".format(training_time))

  # ========================================
  #               Validation
  # ========================================

  print("")
  print("Running Validation...")

  t0 = time.time()

  model.eval()

  total_eval_loss = 0

  # Evaluate data for one epoch
  for batch in validation_dataloader:

    b_input_ids = batch['input_ids'].to(device)
    b_token_type_ids = batch['token_type_ids'].to(device)
    b_labels = batch['labels'].to(device)
    outputs = model(b_input_ids,
                    token_type_ids=b_token_type_ids,
                    labels=b_labels
                    )
    loss = outputs.loss
    total_eval_loss += loss.item()

  avg_val_loss = total_eval_loss / len(validation_dataloader)

  validation_time = format_time(time.time() - t0)

  print("  Validation Loss: {0:.2f}".format(avg_val_loss))
  print("  Validation took: {:}".format(validation_time))

  # Record all statistics from this epoch.
  training_stats.append(
      {
          'epoch': epoch + 1,
          'Training Loss': avg_train_loss,
          'Valid. Loss': avg_val_loss,
          'Training Time': training_time,
          'Validation Time': validation_time
      }
  )

print("")
print("Training complete!")
print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))

In [None]:
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()

output_dir = './model_gpt2_5epochs_more_data/'

# Create output directory if needed
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

print("Saving model to %s" % output_dir)

# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
model_to_save.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

# Good practice: save your training arguments together with the trained model
# torch.save(args, os.path.join(output_dir, 'training_args.bin'))


In [None]:
training_stats = []
avg_train_loss = [3.15, 2.31, 2.12, 2.04, 2.01]
avg_val_loss = [2.32, 2.10, 2.02, 1.99, 1.99]
avg_train_loss = [1.65, 1.09, 0.78, 0.63, 0.58]
avg_val_loss = [1.27, 0.98, 0.82, 0.75, 0.74]
for epoch in range(0, 5):
  training_stats.append(
        {
            'epoch': epoch + 1,
            'Training Loss': avg_train_loss[epoch],
            'Valid. Loss': avg_val_loss[epoch]
        }
    )

In [None]:
import pandas as pd
# Display floats with two decimal places.
# pd.set_option('precision', 2)

# Create a DataFrame from our training statistics.
df_stats = pd.DataFrame(data=training_stats)

# Use the 'epoch' as the row index.
df_stats = df_stats.set_index('epoch')

# A hack to force the column headers to wrap.
#df = df.style.set_table_styles([dict(selector="th",props=[('max-width', '70px')])])

# Display the table.
df_stats

In [None]:
with open("model_gpt2_10epochs_stats.pickle", "wb") as f:
  pickle.dump(training_stats, f)

In [None]:
# Use plot styling from seaborn.
# sns.set(style='darkgrid')

# Increase the plot size and font size.
# sns.set(font_scale=1.5)
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (6,3)

# Plot the learning curve.
plt.plot(df_stats['Training Loss'], 'b-o', label="Training")
plt.plot(df_stats['Valid. Loss'], 'g-o', label="Validation")

# Label the plot.
plt.title("Training & Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.xticks([1, 2, 3, 4, 5])

plt.show()

In [None]:
def nucleus_sampling(input_ids, token_type_ids, input_len):
  output_ids = []
  max_len = input_len + 20
  for pos in range(input_len, max_len):
    output_logits = model(input_ids=input_ids, token_type_ids=token_type_ids)[0][:, pos-1]
    output = F.softmax(output_logits, dim=-1)


In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
mname = "model_gpt2_5epochs"
model = GPT2LMHeadModel.from_pretrained(mname)
tokenizer = GPT2Tokenizer.from_pretrained(mname)

In [None]:
def nucleus_sampling(input_ids, token_type_ids, input_len):
  output_ids = []
  top_p = 0.9
  for pos in range(input_len, input_len + 20):
    output = model(input_ids=input_ids, token_type_ids=token_type_ids)[0][:, pos-1]  # (1, V)
    output = F.softmax(output, dim=-1)  # (1, V)

    sorted_probs, sorted_idxs = torch.sort(output, descending=True)
    cumsum_probs = torch.cumsum(sorted_probs, dim=-1)  # (1, V)
    idx_remove = cumsum_probs > top_p
    idx_remove[:, 1:] = idx_remove[:, :-1].clone()
    idx_remove[:, 0] = False
    sorted_probs[idx_remove] = 0.0
    sorted_probs /= torch.sum(sorted_probs, dim=-1, keepdim=True)  # (1, V)

    probs = torch.zeros(output.shape).scatter_(-1, sorted_idxs, sorted_probs)  # (1, V)
    idx = torch.multinomial(probs, 1)  # (1, 1)

    idx_item = idx.squeeze(-1).squeeze(-1).item()
    output_ids.append(idx_item)

    if idx_item == tokenizer.eos_token_id:
        break

    input_ids = torch.cat((input_ids, idx), dim=-1)
    next_type_id = torch.LongTensor([[tokenizer.additional_special_tokens_ids[2]]])
    token_type_ids = torch.cat((token_type_ids, next_type_id), dim=-1)
    assert input_ids.shape == token_type_ids.shape

  return output_ids

In [None]:
# Chat
bos = tokenizer.bos_token
eos = tokenizer.eos_token
kw_token = tokenizer.additional_special_tokens[0]
s1 = tokenizer.additional_special_tokens[1]
s2 = tokenizer.additional_special_tokens[2]

keywords = [["cat", "park", "lunch", "piano", "sad"],["dog"],["love"],["play","basketball"],["game"]]
conversation_history = ["What did you do today?"]
print("Bot: What did you do today?")
i = 0
input_ids_history = []
model.to('cpu')
while True:
  user_input = input(">> User: ")
  input_ids = tokenizer.encode(s1 + user_input, add_special_tokens=False)
  input_ids_history.append(input_ids)

  keywords = kw_token + kw_token.join(keywords[i])
  encoded_kw = tokenizer.encode(keywords, add_special_tokens=False)

  input_ids = encoded_kw + [tokenizer.bos_token_id] + list(chain.from_iterable(input_ids_history)) + [tokenizer.additional_special_tokens_ids[2]]

  token_type_ids = [[0] * len(hist) if h % 2 == 0 else [1] * len(hist) for h, hist in enumerate(input_ids_history)]
  token_type_ids = [1] * len(encoded_kw) + [0] + list(chain.from_iterable(token_type_ids)) + [1]
  input_len = len(input_ids)

  print(len(input_ids), len(token_type_ids))
  assert len(input_ids) == len(token_type_ids)

  input_ids = torch.LongTensor(input_ids).unsqueeze(0)
  token_type_ids = torch.LongTensor(token_type_ids).unsqueeze(0)

  output_ids = nucleus_sampling(input_ids, token_type_ids, input_len)

  res = tokenizer.decode(output_ids, skip_special_tokens=True)

  print(f"Bot: {res}")
  input_ids_history.append(tokenizer.encode(s2 + res))

  i+=1