In [1]:
import torch
from torch.utils.data import DataLoader

from transformers import GPT2Config, GPT2Tokenizer, BertModel, BertTokenizer, DistilBertModel, DistilBertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup

from InductiveAttentionModels import GPT2InductiveAttentionHeadModel
from loss import SequenceCrossEntropyLoss

import time
import tqdm
from dataset import MovieRecDataset
from mese import UniversalCRSModel

from utilities import get_memory_free_MiB


In [2]:
bert_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
bert_model_recall = DistilBertModel.from_pretrained('distilbert-base-uncased')
bert_model_rerank = DistilBertModel.from_pretrained('distilbert-base-uncased')
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2InductiveAttentionHeadModel.from_pretrained('gpt2')

REC_TOKEN = "[REC]"
REC_END_TOKEN = "[REC_END]"
SEP_TOKEN = "[SEP]"
PLACEHOLDER_TOKEN = "[MOVIE_ID]"
gpt_tokenizer.add_tokens([REC_TOKEN, REC_END_TOKEN, SEP_TOKEN, PLACEHOLDER_TOKEN])
gpt2_model.resize_token_embeddings(len(gpt_tokenizer)) 

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias']
- T

Embedding(50261, 768)

In [3]:
train_path = "data/processed/redial_full_train_placeholder"
test_path = "data/processed/redial_full_test_placeholder"
items_db_path = "data/processed/redial_full_movie_db_placeholder"

In [4]:
items_db = torch.load(items_db_path)

In [5]:
device = torch.device(0)
model = UniversalCRSModel(
    gpt2_model, 
    bert_model_recall, 
    bert_model_rerank, 
    gpt_tokenizer, 
    bert_tokenizer, 
    device, 
    items_db, 
    rec_token_str=REC_TOKEN, 
    rec_end_token_str=REC_END_TOKEN
)

model.to(device)

UniversalCRSModel(
  (language_model): GPT2InductiveAttentionHeadModel(
    (transformer): GPT2InductiveAttention(
      (wte): Embedding(50261, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0): BlockIA(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): AttentionIA(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): BlockIA(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): AttentionIA(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn

In [6]:
train_dataset = MovieRecDataset(torch.load(train_path)[:1], bert_tokenizer, gpt_tokenizer)
test_dataset = MovieRecDataset(torch.load(test_path)[:1], bert_tokenizer, gpt_tokenizer)

In [7]:
# parameters
batch_size = 1
num_epochs = 1
num_gradients_accumulation = 1
num_train_optimization_steps = len(train_dataset) * num_epochs // batch_size // num_gradients_accumulation

num_samples_recall_train = 100
num_samples_rerank_train = 150
rerank_encoder_chunk_size = int(num_samples_rerank_train / 15)
validation_recall_size = 500

temperature = 1.2

language_loss_train_coeff = 0.15
language_loss_train_coeff_beginnging_turn = 1.0
recall_loss_train_coeff = 0.8
rerank_loss_train_coeff = 1.0

# loss
criterion_language = SequenceCrossEntropyLoss()
criterion_recall = torch.nn.CrossEntropyLoss()
rerank_class_weights = torch.FloatTensor([1] * (num_samples_rerank_train-1) + [30]).to(model.device)
criterion_rerank_train = torch.nn.CrossEntropyLoss(weight=rerank_class_weights)

# optimizer and scheduler
param_optimizer = list(model.language_model.named_parameters()) + \
    list(model.recall_encoder.named_parameters()) + \
    list(model.item_encoder.named_parameters()) + \
    list(model.recall_lm_query_mapper.named_parameters()) + \
    list(model.recall_item_wte_mapper.named_parameters()) + \
    list(model.rerank_item_wte_mapper.named_parameters()) + \
    list(model.rerank_logits_mapper.named_parameters())

no_decay = ['bias', 'ln', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

optimizer = AdamW(optimizer_grouped_parameters, 
                  lr=3e-5,
                  eps=1e-06)

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=len(train_dataset) // num_gradients_accumulation , num_training_steps = num_train_optimization_steps)

progress_bar = tqdm.notebook.tqdm
start = time.time()



In [8]:
train_dataloader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=batch_size, collate_fn=train_dataset.collate)
test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=batch_size, collate_fn=test_dataset.collate)

In [9]:
from engine import Engine

In [10]:
engine = Engine(device,
                criterion_language,
                criterion_recall,
                criterion_rerank_train,
                language_loss_train_coeff,
                recall_loss_train_coeff,
                rerank_loss_train_coeff,
                num_samples_recall_train,
                num_samples_rerank_train,
                rerank_encoder_chunk_size,
                validation_recall_size,
                temperature)

In [11]:
output_file_path = "out/CRS_Train.txt"
model_saved_path = "runs/CRS_Redial_Train_Same_BERT_"

In [12]:
from trainer import Trainer



In [13]:
## Define Trainer
trainer = Trainer(
    model,
    engine,
    train_dataloader,
    test_dataloader,
    optimizer,
    scheduler,
    progress_bar
)

In [25]:
get_memory_free_MiB(0)

16708

In [26]:
trainer.train(
    num_epochs,
    num_gradients_accumulation,
    batch_size,
    output_file_path,
    model_saved_path
)

  0%|          | 0/1 [00:00<?, ?it/s]

UnboundLocalError: local variable 'update_count' referenced before assignment

In [15]:
valid_cnt / total_gen_cnt, response_with_items / total_gen_cnt

ZeroDivisionError: division by zero

In [17]:
torch.save(total_sentences_generated, '../human_eval/mese2.pt')

In [22]:
valid_cnt / total_gen_cnt, response_with_items / total_gen_cnt

(0.9173978440711957, 0.4402105790925044)

In [23]:
dist1, dist2, dist3, dist4 = distinct_metrics(total_sentences_generated)
bleu1, bleu2, bleu3, bleu4 = bleu_calc_all(total_sentences_original, total_sentences_generated)
print(dist1, dist2, dist3, dist4)
print(bleu1, bleu2, bleu3, bleu4)

0.2602155928804212 0.7049385810980195 1.0142892955627978 1.156054148909501
0.342754064465557 0.2512299152344354 0.1892696463143165 0.14407982061678162
