In [4]:
import torch
from torch.utils.data import DataLoader,Dataset

from transformers import GPT2Config, GPT2Tokenizer, BertModel, BertTokenizer, DistilBertModel, DistilBertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup

from InductiveAttentionModels import GPT2InductiveAttentionHeadModel
from loss import SequenceCrossEntropyLoss

from trainer import Trainer
import time
import tqdm
from mese import UniversalCRSModel
from engine import Engine
import numpy as np

In [17]:
class MovieRecDataset(Dataset):
    def __init__(self, data, bert_tok, gpt2_tok):
        self.data = data
        self.bert_tok = bert_tok
        self.gpt2_tok = gpt2_tok
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        dialogue = self.data[index]
        
        utterances = [utt for utt,_ in dialogue]
        gt_indices = [gt_ind for _,gt_ind in dialogue]
        dialogue_tokens = self.gpt2_tok(utterances, padding='max_length', truncation=True, max_length=64, return_tensors="pt")['input_ids']
            
        role_ids = None
        previous_role_ids = None
        if role_ids == None:
            role_ids = [ 0 if item[0] == 'B' else 1 for item, _ in dialogue]
            previous_role_ids = role_ids
        else:
            role_ids = [ 0 if item[0] == 'B' else 1 for item, _ in dialogue]
            if not np.array_equal(role_ids, previous_role_ids):
                raise Exception("Role ids dont match between languages")
            previous_role_ids = role_ids
        
        return role_ids, dialogue_tokens, gt_indices
    
    def collate(self, unpacked_data):
        return unpacked_data

In [18]:
bert_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
bert_model_recall = DistilBertModel.from_pretrained('distilbert-base-uncased')
bert_model_rerank = DistilBertModel.from_pretrained('distilbert-base-uncased')
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2InductiveAttentionHeadModel.from_pretrained('gpt2')

REC_TOKEN = "[REC]"
REC_END_TOKEN = "[REC_END]"
SEP_TOKEN = "[SEP]"
PLACEHOLDER_TOKEN = "[MOVIE_ID]"
PAD_TOKEN = "[PAD]"
gpt_tokenizer.add_tokens([REC_TOKEN, REC_END_TOKEN, SEP_TOKEN, PLACEHOLDER_TOKEN])
gpt_tokenizer.pad_token = PAD_TOKEN
gpt2_model.resize_token_embeddings(len(gpt_tokenizer)) 

train_path = "data/processed/durecdial2_full_train_placeholder"
test_path = "data/processed/durecdial2_full_dev_placeholder"
items_db_path = "data/processed/durecdial2_full_movie_db_placeholder"
items_db = torch.load(items_db_path)

train_dataset = MovieRecDataset(torch.load(train_path), bert_tokenizer, gpt_tokenizer)
test_dataset = MovieRecDataset(torch.load(test_path), bert_tokenizer, gpt_tokenizer)


Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight']
- T

In [23]:
role_ids, dialogue_ids, labels = train_dataset[1]

In [25]:
train_dataloader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=2)

In [26]:
for batch in train_dataloader:
    print(batch)

RuntimeError: each element in list of batch should be of equal size