In [1]:
import torch
from torch.utils.data import DataLoader, Dataset

from transformers import GPT2Config, GPT2Tokenizer, BertModel, BertTokenizer, DistilBertModel, DistilBertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup

from utils.InductiveAttentionModels import GPT2InductiveAttentionHeadModel
from utils.SequenceCrossEntropyLoss import SequenceCrossEntropyLoss

import numpy as np
import math
import random
import time
import tqdm
import os
import string

# from nltk.translate import bleu_score
# from nltk.translate.bleu_score import sentence_bleu

from annoy import AnnoyIndex

In [2]:
bert_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
bert_model_recall = DistilBertModel.from_pretrained('distilbert-base-uncased')
bert_model_rerank = DistilBertModel.from_pretrained('distilbert-base-uncased')
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2InductiveAttentionHeadModel.from_pretrained('gpt2')

# REC_TOKEN = "R"
# REC_END_TOKEN = "E"
REC_TOKEN = "[REC]"
REC_END_TOKEN = "[REC_END]"
SEP_TOKEN = "[SEP]"
PLACEHOLDER_TOKEN = "[MOVIE_ID]"
gpt_tokenizer.add_tokens([REC_TOKEN, REC_END_TOKEN, SEP_TOKEN, PLACEHOLDER_TOKEN])
gpt2_model.resize_token_embeddings(len(gpt_tokenizer)) 
original_token_emb_size = gpt2_model.get_input_embeddings().weight.shape[0]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.weight', 'vocab_layer_norm.weight']
- T

In [3]:
class MovieRecDataset(Dataset):
    def __init__(self, data, bert_tok, gpt2_tok):
        self.data = data
        self.bert_tok = bert_tok
        self.gpt2_tok = gpt2_tok
        self.turn_ending = torch.tensor([[628, 198]]) # end of turn, '\n\n\n'
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        dialogue = self.data[index]
        
        dialogue_tokens = []
        
        for utterance, gt_ind in dialogue:
            utt_tokens = self.gpt2_tok(utterance, return_tensors="pt")['input_ids']
            dialogue_tokens.append( ( torch.cat( (utt_tokens, self.turn_ending), dim=1), gt_ind) )
            
        role_ids = None
        previous_role_ids = None
        if role_ids == None:
            role_ids = [ 0 if item[0] == 'B' else 1 for item, _ in dialogue]
            previous_role_ids = role_ids
        else:
            role_ids = [ 0 if item[0] == 'B' else 1 for item, _ in dialogue]
            if not np.array_equal(role_ids, previous_role_ids):
                raise Exception("Role ids dont match between languages")
            previous_role_ids = role_ids
        
        return role_ids, dialogue_tokens
    
    def collate(self, unpacked_data):
        return unpacked_data
    

In [4]:
train_path = "/local-scratch1/data/by2299/redial_full_train_placeholder"
test_path = "/local-scratch1/data/by2299/redial_full_test_placeholder"
items_db_path = "/local-scratch1/data/by2299/redial_full_movie_db_placeholder"

In [5]:
train_dataset = MovieRecDataset(torch.load(train_path), bert_tokenizer, gpt_tokenizer)
test_dataset = MovieRecDataset(torch.load(test_path), bert_tokenizer, gpt_tokenizer)
train_dataloader = DataLoader(dataset=train_dataset, shuffle=False, batch_size=1, collate_fn=train_dataset.collate)
test_dataloader = DataLoader(dataset=test_dataset, shuffle=False, batch_size=1, collate_fn=test_dataset.collate)


In [6]:
items_db = torch.load(items_db_path)

def sample_ids_from_db(item_db,
                       gt_id, # ground truth id
                       num_samples, # num samples to return
                       include_gt # if we want gt_id to be included
                      ):
    ids_2_sample_from = list(item_db.keys())
    ids_2_sample_from.remove(gt_id)
    if include_gt:
        results = random.sample(ids_2_sample_from, num_samples-1)
        results.append(gt_id)
    else:
        results = random.sample(ids_2_sample_from, num_samples)
    return results

In [7]:
class UniversalCRSModel(torch.nn.Module):
    def __init__(self, 
                 language_model, # backbone of Pretrained LM such as GPT2
                 encoder, # backbone of item encoder such as bert
                 recall_encoder,
                 lm_tokenizer, # language model tokenizer
                 encoder_tokenizer, # item encoder tokenizer
                 device, # Cuda device
                 items_db, # {id:info}, information of all items to be recommended
                 annoy_base_recall=None, # annoy index base of encoded recall embeddings of items
                 annoy_base_rerank=None, # annoy index base of encoded rerank embeddings of items, for validation and inference only
                 recall_item_dim=768, # dimension of each item to be stored in annoy base
                 lm_trim_offset=100, # offset to trim language model wte inputs length = (1024-lm_trim_offset)
                 rec_token_str="[REC]", # special token indicating recommendation and used for recall
                 rec_end_token_str="[REC_END]", # special token indicating recommendation ended, conditional generation starts
                 sep_token_str="[SEP]",
                 placeholder_token_str="[MOVIE_ID]"
                ):
        super(UniversalCRSModel, self).__init__()
        
        #models and tokenizers
        self.language_model = language_model
        self.item_encoder = encoder
        self.recall_encoder = recall_encoder
        self.lm_tokenizer = lm_tokenizer
        self.encoder_tokenizer = encoder_tokenizer
        self.device = device
        
        # item db and annoy index base
        self.items_db = items_db
        self.annoy_base_recall = annoy_base_recall
        self.annoy_base_rerank = annoy_base_rerank
        
        # hyperparameters
        self.recall_item_dim = recall_item_dim
        self.lm_trim_offset = lm_trim_offset
        
        #constants
        self.REC_TOKEN_STR = rec_token_str
        self.REC_END_TOKEN_STR = rec_end_token_str
        self.SEP_TOKEN_STR = sep_token_str
        self.PLACEHOLDER_TOKEN_STR = placeholder_token_str
        
        # map language model hidden states to a vector to query annoy-item-base for recall
        self.recall_lm_query_mapper = torch.nn.Linear(self.language_model.config.n_embd, self.recall_item_dim) # default [768,768]
        # map output of self.item_encoder to vectors to be stored in annoy-item-base 
        self.recall_item_wte_mapper = torch.nn.Linear(self.recall_encoder.config.hidden_size, self.recall_item_dim) # default [768,768]
        # map output of self.item_encoder to a wte of self.language_model
        self.rerank_item_wte_mapper = torch.nn.Linear(self.item_encoder.config.hidden_size, self.language_model.config.n_embd) # default [768,768]
        # map language model hidden states of item wte to a one digit logit for softmax computation
        self.rerank_logits_mapper = torch.nn.Linear(self.language_model.config.n_embd, 1) # default [768,1]
    
    def get_sep_token_wtes(self):
        sep_token_input_ids = self.lm_tokenizer(self.SEP_TOKEN_STR, return_tensors="pt")["input_ids"].to(self.device)
        return self.language_model.transformer.wte(sep_token_input_ids) # [1, 1, self.language_model.config.n_embd]

    def get_rec_token_wtes(self):
        rec_token_input_ids = self.lm_tokenizer(self.REC_TOKEN_STR, return_tensors="pt")["input_ids"].to(self.device)
        return self.language_model.transformer.wte(rec_token_input_ids) # [1, 1, self.language_model.config.n_embd]
    
    def get_rec_end_token_wtes(self):
        rec_end_token_input_ids = self.lm_tokenizer(self.REC_END_TOKEN_STR, return_tensors="pt")["input_ids"].to(self.device)
        return self.language_model.transformer.wte(rec_end_token_input_ids) # [1, 1, self.language_model.config.n_embd]
    
    def get_movie_title(self, m_id):
        title = self.items_db[m_id]
        title = title.split('[SEP]')[0].strip()
        return title
    
    # compute BERT encoded item hidden representation
    # output can be passed to self.recall_item_wte_mapper or self.rerank_item_wte_mapper
    def compute_encoded_embeddings_for_items(self, 
                                             encoder_to_use,
                                             item_ids, # an array of ids, single id should be passed as [id]
                                             items_db_to_use # item databse to use
                                            ):
        chunk_ids = item_ids
        chunk_infos = [items_db_to_use[key] for key in chunk_ids ]
        chunk_tokens = self.encoder_tokenizer(chunk_infos, padding=True, truncation=True, return_tensors="pt")
        chunk_input_ids = chunk_tokens['input_ids'].to(self.device)
        chunk_attention_mask = chunk_tokens['attention_mask'].to(self.device)
        chunk_hiddens = encoder_to_use(input_ids=chunk_input_ids, attention_mask=chunk_attention_mask).last_hidden_state

        # average of non-padding tokens
        expanded_mask_size = list(chunk_attention_mask.size())
        expanded_mask_size.append(encoder_to_use.config.hidden_size)
        expanded_mask = chunk_attention_mask.unsqueeze(-1).expand(expanded_mask_size)
        chunk_masked = torch.mul(chunk_hiddens, expanded_mask) # [num_example, len, 768]
        chunk_pooled = torch.sum(chunk_masked, dim=1) / torch.sum(chunk_attention_mask, dim=1).unsqueeze(-1)
        
        # [len(item_ids), encoder_to_use.config.hidden_size], del chunk_hiddens to free up GPU memory
        return chunk_pooled, chunk_hiddens
    
    # annoy_base_constructor, constructs
    def annoy_base_constructor(self, items_db=None, distance_type='angular', chunk_size=50, n_trees=10):
        items_db_to_use = self.items_db if items_db == None else items_db
        all_item_ids = list(items_db_to_use.keys())
        
        total_pooled = []
        # break into chunks/batches for model concurrency
        num_chunks = math.ceil(len(all_item_ids) / chunk_size)
        for i in range(num_chunks):
            chunk_ids = all_item_ids[i*chunk_size: (i+1)*chunk_size]
            chunk_pooled, chunk_hiddens = self.compute_encoded_embeddings_for_items(self.recall_encoder, chunk_ids, items_db_to_use)
            chunk_pooled = chunk_pooled.cpu().detach().numpy()
            del chunk_hiddens
            total_pooled.append(chunk_pooled)
        total_pooled = np.concatenate(total_pooled, axis=0)
        
        pooled_tensor = torch.tensor(total_pooled).to(self.device)
        
        #build recall annoy index
        annoy_base_recall = AnnoyIndex(self.recall_item_wte_mapper.out_features, distance_type)
        pooled_recall = self.recall_item_wte_mapper(pooled_tensor) # [len(items_db_to_use), self.recall_item_wte_mapper.out_features]
        pooled_recall = pooled_recall.cpu().detach().numpy()
        for i, vector in zip(all_item_ids, pooled_recall):
            annoy_base_recall.add_item(i, vector)
        annoy_base_recall.build(n_trees)
        
        total_pooled = []
        # break into chunks/batches for model concurrency
        num_chunks = math.ceil(len(all_item_ids) / chunk_size)
        for i in range(num_chunks):
            chunk_ids = all_item_ids[i*chunk_size: (i+1)*chunk_size]
            chunk_pooled, chunk_hiddens = self.compute_encoded_embeddings_for_items(self.item_encoder, chunk_ids, items_db_to_use)
            chunk_pooled = chunk_pooled.cpu().detach().numpy()
            del chunk_hiddens
            total_pooled.append(chunk_pooled)
        total_pooled = np.concatenate(total_pooled, axis=0)
        
        pooled_tensor = torch.tensor(total_pooled).to(self.device)
        
        #build rerank annoy index, for validation and inference only
        annoy_base_rerank = AnnoyIndex(self.rerank_item_wte_mapper.out_features, distance_type)
        pooled_rerank = self.rerank_item_wte_mapper(pooled_tensor) # [len(items_db_to_use), self.recall_item_wte_mapper.out_features]
        pooled_rerank = pooled_rerank.cpu().detach().numpy()
        for i, vector in zip(all_item_ids, pooled_rerank):
            annoy_base_rerank.add_item(i, vector)
        annoy_base_rerank.build(n_trees)
        
        del pooled_tensor
        
        self.annoy_base_recall = annoy_base_recall
        self.annoy_base_rerank = annoy_base_rerank
    
    def annoy_loader(self, path, annoy_type, distance_type="angular"):
        if annoy_type == "recall":
            annoy_base = AnnoyIndex(self.recall_item_wte_mapper.out_features, distance_type)
            annoy_base.load(path)
            return annoy_base
        elif annoy_type == "rerank":
            annoy_base = AnnoyIndex(self.rerank_item_wte_mapper.out_features, distance_type)
            annoy_base.load(path)
            return annoy_base
        else:
            return None
    
    def lm_expand_wtes_with_items_annoy_base(self):
        all_item_ids = list(self.items_db.keys())
        total_pooled = []
        for i in all_item_ids:
            total_pooled.append(self.annoy_base_rerank.get_item_vector(i))
        total_pooled = np.asarray(total_pooled) # [len(all_item_ids), 768]
        pooled_tensor = torch.tensor(total_pooled, dtype=torch.float).to(self.device)
        
        old_embeddings = self.language_model.get_input_embeddings()
        item_id_2_lm_token_id = {}
        for k in all_item_ids:
            item_id_2_lm_token_id[k] = len(item_id_2_lm_token_id) + old_embeddings.weight.shape[0]
        new_embeddings = torch.cat([old_embeddings.weight, pooled_tensor], 0)
        new_embeddings = torch.nn.Embedding.from_pretrained(new_embeddings)
        self.language_model.set_input_embeddings(new_embeddings)
        self.language_model.to(device)
        return item_id_2_lm_token_id
    
    def lm_restore_wtes(self, original_token_emb_size):
        old_embeddings = self.language_model.get_input_embeddings()
        new_embeddings = torch.nn.Embedding(original_token_emb_size, old_embeddings.weight.size()[1])
        new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype)
        new_embeddings.weight.data[:original_token_emb_size, :] = old_embeddings.weight.data[:original_token_emb_size, :]
        self.language_model.set_input_embeddings(new_embeddings)
        self.language_model.to(self.device)
        assert(self.language_model.get_input_embeddings().weight.shape[0] == original_token_emb_size)
    
    def trim_lm_wtes(self, wtes):
        trimmed_wtes = wtes
        if trimmed_wtes.shape[1] > self.language_model.config.n_positions:
            trimmed_wtes = trimmed_wtes[:,-self.language_model.config.n_positions + self.lm_trim_offset:,:]
        return trimmed_wtes # [batch, self.language_model.config.n_positions - self.lm_trim_offset, self.language_model.config.n_embd]
    
    def trim_positional_ids(self, p_ids, num_items_wtes):
        trimmed_ids = p_ids
        if trimmed_ids.shape[1] > self.language_model.config.n_positions:
            past_ids = trimmed_ids[:,:self.language_model.config.n_positions - self.lm_trim_offset - num_items_wtes]
#             past_ids = trimmed_ids[:, self.lm_trim_offset + num_items_wtes:self.language_model.config.n_positions]
            item_ids = trimmed_ids[:,-num_items_wtes:]
            trimmed_ids = torch.cat((past_ids, item_ids), dim=1)
        return trimmed_ids # [batch, self.language_model.config.n_positions - self.lm_trim_offset]
    
    def compute_inductive_attention_mask(self, length_language, length_rerank_items_wtes):
        total_length = length_language + length_rerank_items_wtes
        language_mask_to_add = torch.zeros((length_language, total_length), dtype=torch.float, device=self.device)
        items_mask_to_add = torch.ones((length_rerank_items_wtes, total_length), dtype=torch.float, device=self.device)
        combined_mask_to_add = torch.cat((language_mask_to_add, items_mask_to_add), dim=0)
        return combined_mask_to_add #[total_length, total_length]
    
    def forward_pure_language_turn(self, 
                                   past_wtes, # past word token embeddings, [1, len, 768]
                                   current_tokens # tokens of current turn conversation, [1, len]
                                  ):
        train_logits, train_targets = None, None
        current_wtes = self.language_model.transformer.wte(current_tokens)
        
        if past_wtes == None:
            lm_outputs = self.language_model(inputs_embeds=current_wtes)
            train_logits = lm_outputs.logits[:, :-1, :]
            train_targets = current_tokens[:,1:]
        else:
            all_wtes = torch.cat((past_wtes, current_wtes), dim=1)
            all_wtes = self.trim_lm_wtes(all_wtes)
            lm_outputs = self.language_model(inputs_embeds=all_wtes)
            train_logits = lm_outputs.logits[:, -current_wtes.shape[1]:-1, :] # skip the last one
            train_targets = current_tokens[:,1:]
        
        # torch.Size([batch, len_cur, lm_vocab]), torch.Size([batch, len_cur]), torch.Size([batch, len_past+len_cur, lm_emb(768)])
        return train_logits, train_targets
        
    def forward_recall(self, 
                       past_wtes, # past word token embeddings, [1, len, 768]
                       current_tokens, # tokens of current turn conversation, [1, len]
                       gt_item_id, # id, ex. 0
                       num_samples # num examples to sample for training, including groud truth id
                      ):
        # recall step 1. construct LM sequence output
        # LM input composition: [past_wtes, REC_wtes, gt_item_wte, (gt_item_info_wtes), REC_END_wtes, current_wtes ]
        
        REC_wtes = self.get_rec_token_wtes() # [1, 1, self.language_model.config.n_embd]
        gt_item_wte, _ = self.compute_encoded_embeddings_for_items(self.recall_encoder, [gt_item_id], self.items_db)
        gt_item_wte = self.rerank_item_wte_mapper(gt_item_wte) # [1, self.rerank_item_wte_mapper.out_features]
        
        REC_END_wtes = self.get_rec_end_token_wtes() # [1, 1, self.language_model.config.n_embd]
        current_wtes = self.language_model.transformer.wte(current_tokens) #[1, current_tokens.shape[1], self.language_model.config.n_embd]
        
        REC_wtes_len = REC_wtes.shape[1] # 1 by default
        gt_item_wte_len = gt_item_wte.shape[0] # 1 by default
        REC_END_wtes_len = REC_END_wtes.shape[1] # 1 by default
        current_wtes_len = current_wtes.shape[1]
        
        lm_wte_inputs = torch.cat(
            (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
             REC_wtes,
             gt_item_wte.unsqueeze(0), # reshape to [1,1,self.rerank_item_wte_mapper.out_features]
             REC_END_wtes,
             current_wtes # [batch (1), len, self.language_model.config.n_embd]
            ),
            dim=1
        )
        lm_wte_inputs = self.trim_lm_wtes(lm_wte_inputs) # trim for len > self.language_model.config.n_positions
        
        # recall step 2. get gpt output logits and hidden states
        lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs, output_hidden_states=True)
        
        # recall step 3. pull logits (recall, rec_token and language logits of current turn) and compute
        
        # recall logit(s)
        rec_token_start_index = -current_wtes_len-REC_END_wtes_len-gt_item_wte_len-REC_wtes_len
        rec_token_end_index = -current_wtes_len-REC_END_wtes_len-gt_item_wte_len
        # [batch (1), REC_wtes_len, self.language_model.config.n_embd]
        rec_token_hidden = lm_outputs.hidden_states[-1][:, rec_token_start_index:rec_token_end_index, :]
        # [batch (1), self.recall_lm_query_mapper.out_features]
        rec_query_vector = self.recall_lm_query_mapper(rec_token_hidden).squeeze(1)
        
        # sample num_samples item ids to train recall with "recommendation as classification"
        sampled_item_ids = sample_ids_from_db(self.items_db, gt_item_id, num_samples, include_gt=True)
        gt_item_id_index = sampled_item_ids.index(gt_item_id)
        
        # [num_samples, self.item_encoder.config.hidden_size]
        encoded_items_embeddings, _ = self.compute_encoded_embeddings_for_items(self.recall_encoder, sampled_item_ids, self.items_db)
        # to compute dot product with rec_query_vector
        items_key_vectors = self.recall_item_wte_mapper(encoded_items_embeddings) # [num_samples, self.recall_item_wte_mapper.out_features]
        expanded_rec_query_vector = rec_query_vector.expand(items_key_vectors.shape[0], rec_query_vector.shape[1]) # [num_samples, self.recall_item_wte_mapper.out_features]
        recall_logits = torch.sum(expanded_rec_query_vector * items_key_vectors, dim=1) # torch.size([num_samples])
        
        # REC_TOKEN prediction and future sentence prediction
        # hidden rep of the token that's right before REC_TOKEN
        token_before_REC_logits = lm_outputs.logits[:, rec_token_start_index-1:rec_token_end_index-1, :]
        REC_targets = self.lm_tokenizer(self.REC_TOKEN_STR, return_tensors="pt")['input_ids'].to(self.device) # [1, 1]
        
        #language logits and targets
        current_language_logits = lm_outputs.logits[:, -current_wtes_len:-1, :]
        current_language_targets = current_tokens[:,1:]
        
        # REC token and language, their logits and targets
        # [batch, current_wtes_len+REC_wtes_len, lm_vocab]
        all_wte_logits = torch.cat((token_before_REC_logits, current_language_logits), dim=1)
        # [current_wtes_len+REC_wtes_len, lm_vocab]
        all_wte_targets = torch.cat((REC_targets, current_language_targets), dim=1)
        
        # torch.size([num_samples]), id, [batch, current_wtes_len+REC_wtes_len, lm_vocab], [current_wtes_len+REC_wtes_len, lm_vocab]
        return recall_logits, gt_item_id_index, all_wte_logits, all_wte_targets
        
    def forward_rerank(self,
                       past_wtes, # past word token embeddings, [1, len, 768]
                       gt_item_id, # tokens of current turn conversation, [1, len]
                       num_samples, # num examples to sample for training, including groud truth id
                       rerank_items_chunk_size=10, # batch size for encoder GPU computation
                      ):
        # REC wte
        REC_wtes = self.get_rec_token_wtes() # [batch (1), 1, self.language_model.config.n_embd]
        
        #  items wtes to compute rerank loss
        # sample rerank examples
        sampled_item_ids = sample_ids_from_db(self.items_db, gt_item_id, num_samples, include_gt=True)
        gt_item_id_index = sampled_item_ids.index(gt_item_id)
        # compute item wtes by batch
        num_chunks = math.ceil(len(sampled_item_ids) / rerank_items_chunk_size)
        total_wtes = []
        for i in range(num_chunks):
            chunk_ids = sampled_item_ids[i*rerank_items_chunk_size: (i+1)*rerank_items_chunk_size]
            chunk_pooled, _ = self.compute_encoded_embeddings_for_items(self.item_encoder, chunk_ids, self.items_db) # [rerank_items_chunk_size, self.item_encoder.config.hidden_size]
            chunk_wtes = self.rerank_item_wte_mapper(chunk_pooled)
            total_wtes.append(chunk_wtes)
        total_wtes = torch.cat(total_wtes, dim=0) # [num_samples, self.language_model.config.n_embd]
        
        past_wtes_len = past_wtes.shape[1]
        REC_wtes_len = REC_wtes.shape[1] # 1 by default
        total_wtes_len = total_wtes.shape[0]
        
        # compute positional ids, all rerank item wte should have the same positional encoding id 0
        position_ids = torch.arange(0, past_wtes_len + REC_wtes_len, dtype=torch.long, device=self.device)
        items_position_ids = torch.zeros(total_wtes.shape[0], dtype=torch.long, device=device)
#         items_position_ids = torch.tensor([1023] * total_wtes.shape[0], dtype=torch.long, device=device)
        combined_position_ids = torch.cat((position_ids, items_position_ids), dim=0)
        combined_position_ids = combined_position_ids.unsqueeze(0) # [1, past_wtes_len+REC_wtes_len+total_wtes_len]
        
        # compute concatenated lm wtes
        lm_wte_inputs = torch.cat(
            (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
             REC_wtes, # [batch (1), 1, self.language_model.config.n_embd]
             total_wtes.unsqueeze(0), # [1, num_samples, self.language_model.config.n_embd]
            ),
            dim=1
        ) # [1, past_len + REC_wtes_len + num_samples, self.language_model.config.n_embd]

        # trim sequence to smaller length (len < self.language_model.config.n_positions-self.lm_trim_offset)
        combined_position_ids_trimmed = self.trim_positional_ids(combined_position_ids, total_wtes_len) # [1, len]
        lm_wte_inputs_trimmed = self.trim_lm_wtes(lm_wte_inputs) # [1, len, self.language_model.config.n_embd]
        assert(combined_position_ids.shape[1] == lm_wte_inputs.shape[1])
        
        # compute inductive attention mask
        #     Order of recommended items shouldn't affect their score, thus every item 
        # should have full attention over the entire sequence: they should know each other and the entire
        # conversation history
        inductive_attention_mask = self.compute_inductive_attention_mask(
            lm_wte_inputs_trimmed.shape[1]-total_wtes.shape[0], 
            total_wtes.shape[0]
        )
        rerank_lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs_trimmed,
                  inductive_attention_mask=inductive_attention_mask,
                  position_ids=combined_position_ids_trimmed,
                  output_hidden_states=True)
        
        rerank_lm_hidden = rerank_lm_outputs.hidden_states[-1][:, -total_wtes.shape[0]:, :]
        rerank_logits = self.rerank_logits_mapper(rerank_lm_hidden).squeeze() # torch.Size([num_samples])
        
        return rerank_logits, gt_item_id_index
    
    def validation_perform_recall(self, past_wtes, topk):
        REC_wtes = self.get_rec_token_wtes()
        lm_wte_inputs = torch.cat(
            (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
             REC_wtes # [1, 1, self.language_model.config.n_embd]
            ),
            dim=1
        )
        lm_wte_inputs = self.trim_lm_wtes(lm_wte_inputs) # trim for len > self.language_model.config.n_positions
        lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs, output_hidden_states=True)
        
        rec_token_hidden = lm_outputs.hidden_states[-1][:, -1, :]
        # [batch (1), self.recall_lm_query_mapper.out_features]
        rec_query_vector = self.recall_lm_query_mapper(rec_token_hidden).squeeze(0) # [768]
        rec_query_vector = rec_query_vector.cpu().detach().numpy()
        recall_results = self.annoy_base_recall.get_nns_by_vector(rec_query_vector, topk)
        return recall_results
    
    def validation_perform_rerank(self, past_wtes, recalled_ids):
        REC_wtes = self.get_rec_token_wtes()
        
        total_wtes = [ self.annoy_base_rerank.get_item_vector(r_id) for r_id in recalled_ids]
        total_wtes = [ torch.tensor(wte).reshape(-1, self.language_model.config.n_embd).to(self.device) for wte in total_wtes]
        total_wtes = torch.cat(total_wtes, dim=0) # [len(recalled_ids), 768]
        
        past_wtes_len = past_wtes.shape[1]
        REC_wtes_len = REC_wtes.shape[1] # 1 by default
        total_wtes_len = total_wtes.shape[0]
        
        # compute positional ids, all rerank item wte should have the same positional encoding id 0
        position_ids = torch.arange(0, past_wtes_len + REC_wtes_len, dtype=torch.long, device=self.device)
        items_position_ids = torch.zeros(total_wtes.shape[0], dtype=torch.long, device=device)
        combined_position_ids = torch.cat((position_ids, items_position_ids), dim=0)
        combined_position_ids = combined_position_ids.unsqueeze(0) # [1, past_wtes_len+REC_wtes_len+total_wtes_len]
        
        # compute concatenated lm wtes
        lm_wte_inputs = torch.cat(
            (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
             REC_wtes, # [batch (1), 1, self.language_model.config.n_embd]
             total_wtes.unsqueeze(0), # [1, num_samples, self.language_model.config.n_embd]
            ),
            dim=1
        ) # [1, past_len + REC_wtes_len + num_samples, self.language_model.config.n_embd]

        # trim sequence to smaller length (len < self.language_model.config.n_positions-self.lm_trim_offset)
        combined_position_ids_trimmed = self.trim_positional_ids(combined_position_ids, total_wtes_len) # [1, len]
        lm_wte_inputs_trimmed = self.trim_lm_wtes(lm_wte_inputs) # [1, len, self.language_model.config.n_embd]
        assert(combined_position_ids.shape[1] == lm_wte_inputs.shape[1])
        
        inductive_attention_mask = self.compute_inductive_attention_mask(
            lm_wte_inputs_trimmed.shape[1]-total_wtes.shape[0], 
            total_wtes.shape[0]
        )
        rerank_lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs_trimmed,
                  inductive_attention_mask=inductive_attention_mask,
                  position_ids=combined_position_ids_trimmed,
                  output_hidden_states=True)
        
        rerank_lm_hidden = rerank_lm_outputs.hidden_states[-1][:, -total_wtes.shape[0]:, :]
        rerank_logits = self.rerank_logits_mapper(rerank_lm_hidden).squeeze()
        
        return rerank_logits
    
    # TODO: Modify
    def validation_generate_sentence(self, past_wtes, recommended_id):
        if recommended_id == None: #  pure language
            lm_wte_inputs = self.trim_lm_wtes(past_wtes)
            lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs, output_hidden_states=True)
            generated = model.language_model.generate(
                past_key_values=lm_outputs.past_key_values, 
                max_length=50, 
                num_return_sequences=1, 
                do_sample=True, 
                eos_token_id=198,
                pad_token_id=198
            )
            generated_sen = gpt_tokenizer.decode(generated[0], skip_special_tokens=True)
            return generated_sen, None
        else:
            REC_wtes = self.get_rec_token_wtes() # [1, 1, self.language_model.config.n_embd]
            gt_item_wte, _ = self.compute_encoded_embeddings_for_items([recommended_id], self.items_db)
            gt_item_wte = self.rerank_item_wte_mapper(gt_item_wte) # [1, self.rerank_item_wte_mapper.out_features]
            REC_END_wtes = self.get_rec_end_token_wtes() # [1, 1, self.language_model.config.n_embd]

            REC_wtes_len = REC_wtes.shape[1] # 1 by default
            gt_item_wte_len = gt_item_wte.shape[0] # 1 by default
            REC_END_wtes_len = REC_END_wtes.shape[1] # 1 by default

            lm_wte_inputs = torch.cat(
                (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
                 REC_wtes,
                 gt_item_wte.unsqueeze(0), # reshape to [1,1,self.rerank_item_wte_mapper.out_features]
                 REC_END_wtes
                ),
                dim=1
            )
            lm_wte_inputs = self.trim_lm_wtes(lm_wte_inputs) # trim for len > self.language_model.config.n_positions
            lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs, output_hidden_states=True)
            generated = model.language_model.generate(
                past_key_values=lm_outputs.past_key_values, 
                max_length=50, 
                num_return_sequences=1, 
                do_sample=True, 
                eos_token_id=198,
                pad_token_id=198
            )
            generated_sen = gpt_tokenizer.decode(generated[0], skip_special_tokens=True)
            return generated_sen, self.get_movie_title(recommended_id)

In [8]:
device = torch.device(3)
model = UniversalCRSModel(
    gpt2_model, 
    bert_model_recall, 
    bert_model_rerank, 
    gpt_tokenizer, 
    bert_tokenizer, 
    device, 
    items_db, 
    rec_token_str=REC_TOKEN, 
    rec_end_token_str=REC_END_TOKEN
)
# model.load_state_dict(torch.load())

model.to(device)
pass

In [9]:
start = time.time()
model.annoy_base_constructor()
end = time.time()
print(end-start)
# model.annoy_base_recall.save('/local-scratch1/data/by2299/INITIAL_RECALL_ANNOY_BASE_REDIAL_TRAIN_BERT_DISTIL_MULTIPLE.ann')
# model.annoy_base_rerank.save('/local-scratch1/data/by2299/INITIAL_RERANK_ANNOY_BASE_REDIAL_TRAIN_BERT_DISTIL_MULTIPLE.ann')

25.420650720596313


In [10]:
# parameters
batch_size = 1
num_epochs = 10
num_gradients_accumulation = 1
num_train_optimization_steps = len(train_dataset) * num_epochs // batch_size // num_gradients_accumulation

num_samples_recall_train = 100
num_samples_rerank_train = 150
rerank_encoder_chunk_size = int(num_samples_rerank_train / 15)
validation_recall_size = 500

temperature = 1.2

language_loss_train_coeff = 0.15
language_loss_train_coeff_beginnging_turn = 1.0
recall_loss_train_coeff = 0.8
rerank_loss_train_coeff = 1.0

# loss
criterion_language = SequenceCrossEntropyLoss()
criterion_recall = torch.nn.CrossEntropyLoss()
rerank_class_weights = torch.FloatTensor([1] * (num_samples_rerank_train-1) + [30]).to(model.device)
criterion_rerank_train = torch.nn.CrossEntropyLoss(weight=rerank_class_weights)

# optimizer and scheduler
param_optimizer = list(model.language_model.named_parameters()) + \
    list(model.recall_encoder.named_parameters()) + \
    list(model.item_encoder.named_parameters()) + \
    list(model.recall_lm_query_mapper.named_parameters()) + \
    list(model.recall_item_wte_mapper.named_parameters()) + \
    list(model.rerank_item_wte_mapper.named_parameters()) + \
    list(model.rerank_logits_mapper.named_parameters())

no_decay = ['bias', 'ln', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

optimizer = AdamW(optimizer_grouped_parameters, 
                  lr=3e-5,
                  eps=1e-06)

scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=len(train_dataset) // num_gradients_accumulation , num_training_steps = num_train_optimization_steps)

update_count = 0
progress_bar = tqdm.notebook.tqdm
start = time.time()

In [23]:
def past_wtes_constructor(past_list, model):
    past_wtes = []
    for language_tokens, recommended_ids in past_list:
        if language_tokens == None and recommended_ids != None: # rec turn
            # append REC, gt_item_wte, REC_END
            REC_wte = model.get_rec_token_wtes() # [1, 1, 768]
            gt_item_wte, _ = model.compute_encoded_embeddings_for_items(
                model.item_encoder,
                recommended_ids, 
                model.items_db
            ) # [1, 768]
            gt_item_wte = model.rerank_item_wte_mapper(gt_item_wte)
            
            REC_END_wte = model.get_rec_end_token_wtes() # [1, 1, 768]
            combined_wtes = torch.cat(
                (REC_wte,
                 gt_item_wte.unsqueeze(0), # [1, 1, 768]
                 REC_END_wte
                ), 
                dim=1
            ) # [1, 3, 768]
            past_wtes.append(combined_wtes)
        elif recommended_ids == None and language_tokens != None: # language turn simply append wtes
            wtes = model.language_model.transformer.wte(language_tokens) # [1, len, 768]
            past_wtes.append(wtes)
        elif recommended_ids != None and language_tokens != None: # user mentioned turn
            l_wtes = model.language_model.transformer.wte(language_tokens)
            
            SEP_wte = model.get_sep_token_wtes()
            gt_item_wte, _ = model.compute_encoded_embeddings_for_items(
                model.item_encoder,
                recommended_ids, 
                model.items_db
            ) # [1, 768]
            gt_item_wte = model.rerank_item_wte_mapper(gt_item_wte)
            SEP_wte = model.get_sep_token_wtes()
            combined_wtes = torch.cat(
                (l_wtes,
                 SEP_wte,
                 gt_item_wte.unsqueeze(0), # [1, 1, 768]
                 SEP_wte
                ), 
                dim=1
            )
            past_wtes.append(combined_wtes)
            
    
    past_wtes = torch.cat(past_wtes, dim=1)
    # don't trim since we already dealt with length in model functions
    return past_wtes

def train_one_iteration(batch, model):
    role_ids, dialogues = batch
    dialog_tensors = [torch.LongTensor(utterance).to(model.device) for utterance, _ in dialogues]
    
    past_list = []
    ppl_history = []
#     language_logits, language_targets = [], []
    for turn_num in range(len(role_ids)):
        current_tokens = dialog_tensors[turn_num]
        _, recommended_ids = dialogues[turn_num]
        
        if past_list == []:
            past_list.append((current_tokens, recommended_ids))
            continue
        
        if recommended_ids == None: # no rec
            if role_ids[turn_num] == 0: # user
                past_list.append((current_tokens, None))
            else: #system
                past_wtes = past_wtes_constructor(past_list, model)
                language_logits, language_targets = model.forward_pure_language_turn(past_wtes, current_tokens)
                
                # loss backward
                language_targets_mask = torch.ones_like(language_targets).float()
                loss_ppl = criterion_language(language_logits, language_targets, language_targets_mask, label_smoothing=0.02, reduce="batch")
                loss_ppl = language_loss_train_coeff * loss_ppl
                loss_ppl.backward()
                perplexity = np.exp(loss_ppl.item())
                ppl_history.append(perplexity)
                
                # append to past list
                past_list.append((current_tokens, None))
        else: # rec!
            
            if role_ids[turn_num] == 0: #user mentioned
                past_list.append((current_tokens, recommended_ids))
                continue
            for recommended_id in recommended_ids:
                #system recommend turn
                past_wtes = past_wtes_constructor(past_list, model)

                # recall
                recall_logits, recall_true_index, all_wte_logits, all_wte_targets = model.forward_recall(
                    past_wtes, 
                    current_tokens, 
                    recommended_id, 
                    num_samples_recall_train
                )
                
                # recall items loss
                recall_targets = torch.LongTensor([recall_true_index]).to(model.device)
                loss_recall = criterion_recall(recall_logits.unsqueeze(0), recall_targets)

                # language loss in recall turn, REC_TOKEN, Language on conditional generation
                all_wte_targets_mask = torch.ones_like(all_wte_targets).float()
                loss_ppl = criterion_language(all_wte_logits, all_wte_targets, all_wte_targets_mask, label_smoothing=0.02, reduce="batch")
                perplexity = np.exp(loss_ppl.item())
                ppl_history.append(perplexity)

                # combined loss
                recall_total_loss = loss_recall * recall_loss_train_coeff + loss_ppl * language_loss_train_coeff
                recall_total_loss.backward()

                # rerank
                past_wtes = past_wtes_constructor(past_list, model)
                rerank_logits, rerank_true_index = model.forward_rerank(
                    past_wtes, 
                    recommended_id, 
                    num_samples_rerank_train, 
                    rerank_encoder_chunk_size
                )
                
                rerank_logits /= temperature

                # rerank loss 
                rerank_targets = torch.LongTensor([rerank_true_index]).to(model.device)
                loss_rerank = criterion_rerank_train(rerank_logits.unsqueeze(0), rerank_targets)
                loss_rerank *= rerank_loss_train_coeff
                loss_rerank.backward()

            past_list.append((None, recommended_ids))
            past_list.append((current_tokens, None))
    return np.mean(ppl_history)


def validate_one_iteration(batch, model):
    role_ids, dialogues = batch
    dialog_tensors = [torch.LongTensor(utterance).to(model.device) for utterance, _ in dialogues]
    
    past_list = []
    ppl_history = []
    recall_loss_history = []
    rerank_loss_history = []
    total = 0
    recall_top100, recall_top300, recall_top500 = 0, 0, 0,
    rerank_top1, rerank_top10, rerank_top50 = 0, 0, 0
    
    for turn_num in range(len(role_ids)):
        current_tokens = dialog_tensors[turn_num]
        _, recommended_ids = dialogues[turn_num]
        
        if past_list == []:
            past_list.append((current_tokens, None))
            continue
        
        if recommended_ids == None: # no rec
            if role_ids[turn_num] == 0: # user
                past_list.append((current_tokens, None))
            else: #system
                past_wtes = past_wtes_constructor(past_list, model)
                language_logits, language_targets = model.forward_pure_language_turn(past_wtes, current_tokens)
                
                # loss backward
                language_targets_mask = torch.ones_like(language_targets).float()
                loss_ppl = criterion_language(language_logits, language_targets, language_targets_mask, label_smoothing=-1, reduce="sentence")
                perplexity = np.exp(loss_ppl.item())
                ppl_history.append(perplexity)
                del loss_ppl
                
                # append to past list
                past_list.append((current_tokens, None))
        else: # rec!
            
            if role_ids[turn_num] == 0: #user mentioned
                past_list.append((current_tokens, recommended_ids))
                continue
            for recommended_id in recommended_ids:
                past_wtes = past_wtes_constructor(past_list, model)

                total += 1

                # recall
                recall_logits, recall_true_index, all_wte_logits, all_wte_targets = model.forward_recall(
                    past_wtes, 
                    current_tokens, 
                    recommended_id, 
                    num_samples_recall_train
                )

                # recall items loss
                recall_targets = torch.LongTensor([recall_true_index]).to(model.device)
                loss_recall = criterion_recall(recall_logits.unsqueeze(0), recall_targets)
                recall_loss_history.append(loss_recall.item())
                del loss_recall; del recall_logits; del recall_targets

                # language loss in recall turn, REC_TOKEN, Language on conditional generation
                all_wte_targets_mask = torch.ones_like(all_wte_targets).float()
                loss_ppl = criterion_language(all_wte_logits, all_wte_targets, all_wte_targets_mask, label_smoothing=-1, reduce="sentence")
                perplexity = np.exp(loss_ppl.item())
                ppl_history.append(perplexity)
                del loss_ppl; del all_wte_logits; del all_wte_targets

                recalled_ids = model.validation_perform_recall(past_wtes, validation_recall_size)

                if recommended_id in recalled_ids[:500]:
                    recall_top500 += 1
                if recommended_id in recalled_ids[:400]:
                    recall_top300 += 1
                if recommended_id in recalled_ids[:300]:
                    recall_top100 += 1

                if recommended_id not in recalled_ids:
                    continue # no need to compute rerank since recall is unsuccessful

                # rerank
                past_wtes = past_wtes_constructor(past_list, model)
                rerank_true_index = recalled_ids.index(recommended_id)
                rerank_logits = model.validation_perform_rerank(past_wtes, recalled_ids)
    #             print(rerank_logits)
                reranks = np.argsort(rerank_logits.cpu().detach().numpy())
                if rerank_true_index in reranks[-50:]:
                    rerank_top50 += 1
                if rerank_true_index in reranks[-10:]:
                    rerank_top10 += 1
                if rerank_true_index in reranks[-1:]:
                    rerank_top1 += 1
                    
    #             print(recalled_ids[reranks[-1]],recalled_ids[reranks[-2]],recalled_ids[reranks[-3]],
    #                  recalled_ids[reranks[-4]],recalled_ids[reranks[-5]],recalled_ids[reranks[-6]],
    #                  recalled_ids[reranks[-7]],recalled_ids[reranks[-8]],recalled_ids[reranks[-9]],)
    #             print(recommended_id)
                rerank_targets = torch.LongTensor([rerank_true_index]).to(model.device)
    #             loss_rerank = criterion_rerank(rerank_logits.unsqueeze(0), rerank_targets)
                rerank_loss_val = torch.nn.CrossEntropyLoss()
                loss_rerank = rerank_loss_val(rerank_logits.unsqueeze(0), rerank_targets)
                rerank_loss_history.append(loss_rerank.item())
                del loss_rerank; del rerank_logits; del rerank_targets
            
            past_list.append((None, recommended_ids))
            past_list.append((current_tokens, None))
    return ppl_history, recall_loss_history, rerank_loss_history, \
            total, recall_top100, recall_top300, recall_top500, \
            rerank_top1, rerank_top10, rerank_top50

In [24]:
def distinct_metrics(outs):
    # outputs is a list which contains several sentences, each sentence contains several words
    unigram_count = 0
    bigram_count = 0
    trigram_count=0
    quagram_count=0
    unigram_set = set()
    bigram_set = set()
    trigram_set=set()
    quagram_set=set()
    for sen in outs:
        for word in sen:
            unigram_count += 1
            unigram_set.add(word)
        for start in range(len(sen) - 1):
            bg = str(sen[start]) + ' ' + str(sen[start + 1])
            bigram_count += 1
            bigram_set.add(bg)
        for start in range(len(sen)-2):
            trg=str(sen[start]) + ' ' + str(sen[start + 1]) + ' ' + str(sen[start + 2])
            trigram_count+=1
            trigram_set.add(trg)
        for start in range(len(sen)-3):
            quag=str(sen[start]) + ' ' + str(sen[start + 1]) + ' ' + str(sen[start + 2]) + ' ' + str(sen[start + 3])
            quagram_count+=1
            quagram_set.add(quag)
    dis1 = len(unigram_set) / len(outs)#unigram_count
    dis2 = len(bigram_set) / len(outs)#bigram_count
    dis3 = len(trigram_set)/len(outs)#trigram_count
    dis4 = len(quagram_set)/len(outs)#quagram_count
    return dis1, dis2, dis3, dis4

In [25]:
def validate_language_metrics_batch(batch, model, item_id_2_lm_token_id):
    role_ids, dialogues = batch
    dialog_tensors = [torch.LongTensor(utterance).to(model.device) for utterance, _ in dialogues]
    
#     past_list = []
    past_tokens = None
    tokenized_sentences = []
    integration_total, integration_cnt = 0, 0
    
    for turn_num in range(len(role_ids)):
        dial_turn_inputs = dialog_tensors[turn_num]
        _, recommended_ids = dialogues[turn_num]
        
        item_ids = []; 
        if recommended_ids != None:
            for r_id in recommended_ids:
                item_ids.append(item_id_2_lm_token_id[r_id])
            item_ids = torch.tensor([item_ids]).to(device)
        
        if turn_num == 0:
            past_tokens = dial_turn_inputs
        if role_ids[turn_num] == 0:
            if turn_num != 0:
                past_tokens = torch.cat((past_tokens, dial_turn_inputs), dim=1)
        else:
            if turn_num != 0:
                if item_ids != []:
                    rec_start_token = model.lm_tokenizer(model.REC_TOKEN_STR, return_tensors="pt")["input_ids"].to(model.device)
                    rec_end_token = model.lm_tokenizer(model.REC_END_TOKEN_STR, return_tensors="pt")["input_ids"].to(model.device)
                    past_tokens = torch.cat((past_tokens, rec_start_token, item_ids, rec_end_token), dim=1)
                else:
                    past_tokens = past_tokens
                
            total_len = past_tokens.shape[1]
            if total_len >= 1024: break
#                 print("Original Rec: " + gpt_tokenizer.decode(dial_turn_inputs[0], skip_special_tokens=True))
            generated = model.language_model.generate(
                input_ids= torch.cat((past_tokens, torch.tensor([[32, 25]]).to(device)), dim=1),
                max_length=1024,
                num_return_sequences=1,
                do_sample=True,
                num_beams=2,
                top_k=50,
                temperature=1.05,
                eos_token_id=628,
                pad_token_id=628,
#                 no_repeat_ngram_size=3,
#                         length_penalty=3.0

            )
            generated_sen =  gpt_tokenizer.decode(generated[0][past_tokens.shape[1]:], skip_special_tokens=True)
#                 print("Generated Rec: " + generated_sen)
            tokenized_sen = generated_sen.strip().split(' ')
            tokenized_sentences.append(tokenized_sen)
            if recommended_ids != None:
                integration_total += 1                        
                if "[MOVIE_ID]" in generated_sen:
                    integration_cnt += 1
            
            if turn_num != 0:
                past_tokens = torch.cat((past_tokens, dial_turn_inputs), dim=1)
            
    return tokenized_sentences, integration_cnt, integration_total

In [14]:
output_file_path = "Outputs/CRS_Train.txt"
model_saved_path = "/local-scratch1/data/by2299/CRS_Redial_Train_Same_BERT_"

In [26]:
for ep in range(num_epochs):

    #"Training"
    pbar = progress_bar(train_dataloader)
    model.train()
    for batch in pbar:
        # batch size of train_dataloader is 1
        avg_ppl = train_one_iteration(batch[0], model)
        update_count +=1
        if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:
            
            # update for gradient accumulation
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            
            # speed measure
            end = time.time()
            speed = batch_size * num_gradients_accumulation / (end - start)
            start = end
            
            # show progress
            pbar.set_postfix(ppl=avg_ppl, speed=speed)
            
    model.eval()
    model.annoy_base_constructor()
    
    pbar = progress_bar(test_dataloader)
    ppls, recall_losses, rerank_losses = [],[],[]
    total_val, recall_top100_val, recall_top300_val, recall_top500_val, \
        rerank_top1_val, rerank_top10_val, rerank_top50_val = 0,0,0,0,0,0,0
    for batch in pbar:
        ppl_history, recall_loss_history, rerank_loss_history, \
        total, recall_top100, recall_top300, recall_top500, \
        rerank_top1, rerank_top10, rerank_top50 = validate_one_iteration(batch[0], model)
        ppls += ppl_history; recall_losses += recall_loss_history; rerank_losses += rerank_loss_history
        total_val += total; 
        recall_top100_val += recall_top100; recall_top300_val += recall_top300; recall_top500_val += recall_top500
        rerank_top1_val += rerank_top1; rerank_top10_val += rerank_top10; rerank_top50_val += rerank_top50
    
    item_id_2_lm_token_id = model.lm_expand_wtes_with_items_annoy_base()
    pbar = progress_bar(test_dataloader)
    total_sentences = []
    integration_cnt, total_int_cnt = 0, 0
    for batch in pbar:
        sentences, ic, tc = validate_language_metrics_batch(batch[0], model, item_id_2_lm_token_id)
        for s in sentences:
            total_sentences.append(s)

        integration_cnt += ic; total_int_cnt += tc
    integration_ratio = integration_cnt / total_int_cnt
    dist1, dist2, dist3, dist4 = distinct_metrics(total_sentences)
    model.lm_restore_wtes(original_token_emb_size)
    
    output_file = open(output_file_path, 'a')
    output_file.writelines([f"Epcoh {ep} ppl: {np.mean(ppls)}, recall_loss: {np.mean(recall_losses)}, rerank_loss: {np.mean(rerank_losses)}"])
    output_file.write('\n')
    output_file.writelines([f"recall top100: {recall_top100_val/total_val}, top300: {recall_top300_val/total_val}, top500: {recall_top500_val/total_val}"])
    output_file.write('\n')
    output_file.writelines([f"rerank top1: {rerank_top1_val/total_val}, top10: {rerank_top10_val/total_val}, top50: {rerank_top50_val/total_val}"])
    output_file.write('\n')
    output_file.writelines([f"Integration Ratio: {integration_ratio}"])
    output_file.write('\n')
    output_file.writelines([f"Dist1: {dist1}, Dist2: {dist2}, Dist3: {dist3}, Dist4: {dist4}"])
    output_file.write('\n\n')
    output_file.close()
    
    torch.save(model.state_dict(), model_saved_path + str(ep) +".pt")

  0%|          | 0/1342 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [12]:
def distinct_metrics(outs):
    # outputs is a list which contains several sentences, each sentence contains several words
    unigram_count = 0
    bigram_count = 0
    trigram_count=0
    quagram_count=0
    unigram_set = set()
    bigram_set = set()
    trigram_set=set()
    quagram_set=set()
    for sen in outs:
        sen = [ i.translate(str.maketrans('', '', string.punctuation)) for i in sen][1:]
        for i in range(len(sen)):
            sen[i] = sen[i].lower()
        if len(sen) >= 100: continue
        for word in sen:
            unigram_count += 1
            unigram_set.add(word)
        for start in range(len(sen) - 1):
            bg = str(sen[start]) + ' ' + str(sen[start + 1])
            bigram_count += 1
            bigram_set.add(bg)
        for start in range(len(sen)-2):
            trg=str(sen[start]) + ' ' + str(sen[start + 1]) + ' ' + str(sen[start + 2])
            trigram_count+=1
            trigram_set.add(trg)
        for start in range(len(sen)-3):
            quag=str(sen[start]) + ' ' + str(sen[start + 1]) + ' ' + str(sen[start + 2]) + ' ' + str(sen[start + 3])
            quagram_count+=1
            quagram_set.add(quag)
    dis1 = len(unigram_set) / len(outs)#unigram_count
    dis2 = len(bigram_set) / len(outs)#bigram_count
    dis3 = len(trigram_set)/len(outs)#trigram_count
    dis4 = len(quagram_set)/len(outs)#quagram_count
    return dis1, dis2, dis3, dis4

def bleu_calc_one(ref, hyp):
    for i in range(len(ref)):
        ref[i] = ref[i].lower()
    for i in range(len(hyp)):
        hyp[i] = hyp[i].lower()
    bleu1 = sentence_bleu([ref], hyp, weights=(1, 0, 0, 0), smoothing_function=bleu_score.SmoothingFunction(epsilon=1e-12).method7)
    bleu2 = sentence_bleu([ref], hyp, weights=(1/2, 1/2, 0, 0), smoothing_function=bleu_score.SmoothingFunction(epsilon=1e-12).method7)
    bleu3 = sentence_bleu([ref], hyp, weights=(1/3, 1/3, 1/3, 0), smoothing_function=bleu_score.SmoothingFunction(epsilon=1e-12).method7)
    bleu4 = sentence_bleu([ref], hyp, weights=(1/4, 1/4, 1/4, 1/4), smoothing_function=bleu_score.SmoothingFunction(epsilon=1e-12).method7)
    return bleu1, bleu2, bleu3, bleu4

def bleu_calc_all(originals, generated):
    bleu1_total, bleu2_total, bleu3_total, bleu4_total = 0, 0, 0, 0
    total = 0
    for o, g in zip(originals, generated):
        r = [ i.translate(str.maketrans('', '', string.punctuation)) for i in o][1:]
        h = [ i.translate(str.maketrans('', '', string.punctuation)) for i in g][1:]
        if '[MOVIE_ID]' in r: continue
#         if len(g) >= 500: continue
        if len(g) >= 100: continue
        bleu1, bleu2, bleu3, bleu4 = bleu_calc_one(r, h)
        bleu1_total += bleu1; bleu2_total += bleu2; bleu3_total += bleu3; bleu4_total += bleu4;
        total += 1
    return bleu1_total / total, bleu2_total / total, bleu3_total / total, bleu4_total / total

In [13]:
def replace_placeholder(sentence, movie_titles):
    sen = sentence
    for title in movie_titles:
        sen = sen.replace("[MOVIE_ID]", title, 1)
    return sen
        

def validate_language_metrics_batch(batch, model, item_id_2_lm_token_id):
    role_ids, dialogues = batch
    dialog_tensors = [torch.LongTensor(utterance).to(model.device) for utterance, _ in dialogues]
    
#     past_list = []
    past_tokens = None
    original_sentences = []
    tokenized_sentences = []
    integration_total, integration_cnt = 0, 0
    valid_gen_selected_cnt = 0; total_gen_cnt = 0; response_with_items = 0; original_response_with_items = 0
    
    for turn_num in range(len(role_ids)):
        dial_turn_inputs = dialog_tensors[turn_num]
        _, recommended_ids = dialogues[turn_num]
        
        item_ids = []; item_titles = []
        if recommended_ids != None:
            for r_id in recommended_ids:
                item_ids.append(item_id_2_lm_token_id[r_id])
                title = model.items_db[r_id]
                title = title.split('[SEP]')[0].strip()
                item_titles.append(title)
            item_ids = torch.tensor([item_ids]).to(device)
        
#         if turn_num == 0:
#             past_tokens = dial_turn_inputs
        if role_ids[turn_num] == 0:
            if turn_num == 0:
                past_tokens = dial_turn_inputs
            elif turn_num != 0:
                past_tokens = torch.cat((past_tokens, dial_turn_inputs), dim=1)
        else:
            if turn_num != 0:
                if item_ids != []:
                    rec_start_token = model.lm_tokenizer(model.REC_TOKEN_STR, return_tensors="pt")["input_ids"].to(model.device)
                    rec_end_token = model.lm_tokenizer(model.REC_END_TOKEN_STR, return_tensors="pt")["input_ids"].to(model.device)
                    past_tokens = torch.cat((past_tokens, rec_start_token, item_ids, rec_end_token), dim=1)
                else:
                    past_tokens = past_tokens
                
                total_len = past_tokens.shape[1]
                if total_len >= 1024: break

                original_sen = gpt_tokenizer.decode(dial_turn_inputs[0], skip_special_tokens=True)


                generated = model.language_model.generate(
                    input_ids= torch.cat((past_tokens, torch.tensor([[32, 25]]).to(device)), dim=1),
                    max_length=1024,
                    num_return_sequences=5,
                    do_sample=True,
                    num_beams=5,
                    top_k=50,
                    temperature=1.25,
                    eos_token_id=628,
                    pad_token_id=628,
    #                 no_repeat_ngram_size=3,
                    output_scores=True,
                    return_dict_in_generate=True
                )
                # check valid generations, equal num [MOVIE_ID] placeholders
                total_gen_cnt += 1
                valid_gens = []; valid_gens_scores = []
                final_gen = None
                if len(item_ids) == 0: # no rec items
                    for i in range(len(generated.sequences)):
                        gen_sen = gpt_tokenizer.decode(generated.sequences[i][past_tokens.shape[1]:], skip_special_tokens=True)
                        if gen_sen.count("[MOVIE_ID]") == 0:
                            valid_gens.append(gen_sen); valid_gens_scores.append(generated.sequences_scores[i].item())
                    if valid_gens == [] and valid_gens_scores == []: # no valid, pick with highest score
                        i = torch.argmax(generated.sequences_scores).item()
                        final_gen = gpt_tokenizer.decode(generated.sequences[i][past_tokens.shape[1]:], skip_special_tokens=True)
                    else: # yes valid
                        i = np.argmax(valid_gens_scores)
                        final_gen = valid_gens[i]
                        valid_gen_selected_cnt += 1
                else:
                    original_response_with_items += 1
                    for i in range(len(generated.sequences)):
                        gen_sen = gpt_tokenizer.decode(generated.sequences[i][past_tokens.shape[1]:], skip_special_tokens=True)
                        if gen_sen.count("[MOVIE_ID]") == original_sen.count("[MOVIE_ID]"):
                            valid_gens.append(gen_sen); valid_gens_scores.append(generated.sequences_scores[i].item())
                    if valid_gens == [] and valid_gens_scores == []: # no valid, pick with highest score
                        i = torch.argmax(generated.sequences_scores).item()
                        final_gen = gpt_tokenizer.decode(generated.sequences[i][past_tokens.shape[1]:], skip_special_tokens=True)
                        if "[MOVIE_ID]" in final_gen:
                            response_with_items += 1
                        final_gen = replace_placeholder(final_gen, item_titles)
                    else:
                        i = np.argmax(valid_gens_scores)
                        final_gen = valid_gens[i]
                        if "[MOVIE_ID]" in final_gen:
                            response_with_items += 1
                        final_gen = replace_placeholder(final_gen, item_titles)
                        valid_gen_selected_cnt += 1

    #             generated_sen =  gpt_tokenizer.decode(generated[0][past_tokens.shape[1]:], skip_special_tokens=True)
    #             print("Generated Rec: " + final_gen)
                tokenized_sen = final_gen.strip().split(' ')
                tokenized_sentences.append(tokenized_sen)
                original_sen = replace_placeholder(original_sen, item_titles).replace("\n\n\n", "")
    #             print("Original Rec: " + original_sen)
                original_sentences.append( original_sen.strip().split(' ') )
                if recommended_ids != None:
                    integration_total += 1                        
                    if "[MOVIE_ID]" in final_gen:
                        integration_cnt += 1

                if turn_num != 0:
                    past_tokens = torch.cat((past_tokens, dial_turn_inputs), dim=1)
            elif turn_num == 0:
                original_sen = gpt_tokenizer.decode(dial_turn_inputs[0], skip_special_tokens=True)


                generated = model.language_model.generate(
                    input_ids= torch.tensor([[32, 25]]).to(device),
                    max_length=1024,
                    num_return_sequences=5,
                    do_sample=True,
                    num_beams=5,
                    top_k=50,
                    temperature=1.25,
                    eos_token_id=628,
                    pad_token_id=628,
    #                 no_repeat_ngram_size=3,
                    output_scores=True,
                    return_dict_in_generate=True
                )
                # check valid generations, equal num [MOVIE_ID] placeholders
                total_gen_cnt += 1
                valid_gens = []; valid_gens_scores = []
                final_gen = None
                if len(item_ids) == 0: # no rec items
                    for i in range(len(generated.sequences)):
                        gen_sen = gpt_tokenizer.decode(generated.sequences[i], skip_special_tokens=True)
                        if gen_sen.count("[MOVIE_ID]") == 0:
                            valid_gens.append(gen_sen); valid_gens_scores.append(generated.sequences_scores[i].item())
                    if valid_gens == [] and valid_gens_scores == []: # no valid, pick with highest score
                        i = torch.argmax(generated.sequences_scores).item()
                        final_gen = gpt_tokenizer.decode(generated.sequences[i], skip_special_tokens=True)
                    else: # yes valid
                        i = np.argmax(valid_gens_scores)
                        final_gen = valid_gens[i]
                        valid_gen_selected_cnt += 1
                else:
                    original_response_with_items += 1
                    for i in range(len(generated.sequences)):
                        gen_sen = gpt_tokenizer.decode(generated.sequences[i], skip_special_tokens=True)
                        if gen_sen.count("[MOVIE_ID]") == original_sen.count("[MOVIE_ID]"):
                            valid_gens.append(gen_sen); valid_gens_scores.append(generated.sequences_scores[i].item())
                    if valid_gens == [] and valid_gens_scores == []: # no valid, pick with highest score
                        i = torch.argmax(generated.sequences_scores).item()
                        final_gen = gpt_tokenizer.decode(generated.sequences[i], skip_special_tokens=True)
                        if "[MOVIE_ID]" in final_gen:
                            response_with_items += 1
                        final_gen = replace_placeholder(final_gen, item_titles)
                    else:
                        i = np.argmax(valid_gens_scores)
                        final_gen = valid_gens[i]
                        if "[MOVIE_ID]" in final_gen:
                            response_with_items += 1
                        final_gen = replace_placeholder(final_gen, item_titles)
                        valid_gen_selected_cnt += 1

    #             generated_sen =  gpt_tokenizer.decode(generated[0][past_tokens.shape[1]:], skip_special_tokens=True)
    #             print("Generated Rec: " + final_gen)
                tokenized_sen = final_gen.strip().split(' ')
                tokenized_sentences.append(tokenized_sen)
                original_sen = replace_placeholder(original_sen, item_titles).replace("\n\n\n", "")
    #             print("Original Rec: " + original_sen)
                original_sentences.append( original_sen.strip().split(' ') )
                if recommended_ids != None:
                    integration_total += 1                        
                    if "[MOVIE_ID]" in final_gen:
                        integration_cnt += 1
                
                if turn_num == 0:
                    past_tokens = dial_turn_inputs
            
    return original_sentences, tokenized_sentences, integration_cnt, integration_total, valid_gen_selected_cnt, total_gen_cnt, response_with_items, original_response_with_items

In [14]:
model.eval()
model.annoy_base_constructor()

# pbar = progress_bar(test_dataloader)

item_id_2_lm_token_id = model.lm_expand_wtes_with_items_annoy_base()
pbar = progress_bar(test_dataloader)
total_sentences_original = []; total_sentences_generated = []
integration_cnt, total_int_cnt = 0, 0
valid_cnt, total_gen_cnt, response_with_items = 0, 0, 0
for batch in pbar:
    original_sens, sentences, ic, tc, vc, tgc, rwi, group = validate_language_metrics_batch(batch[0], model, item_id_2_lm_token_id)
#     for s in original_sens:
#         total_sentences_original.append(s)
#     for s in sentences:
#         total_sentences_generated.append(s)
    total_sentences_original.append(original_sens)
    total_sentences_generated.append(sentences)
    
    integration_cnt += ic; total_int_cnt += tc
    valid_cnt += vc; total_gen_cnt += tgc; response_with_items += rwi
integration_ratio = integration_cnt / total_int_cnt
valid_gen_ratio = valid_cnt / total_gen_cnt
# dist1, dist2, dist3, dist4 = distinct_metrics(total_sentences_generated)
# bleu1, bleu2, bleu3, bleu4 = bleu_calc_all(total_sentences_original, total_sentences_generated)
model.lm_restore_wtes(original_token_emb_size)

  0%|          | 0/1342 [00:00<?, ?it/s]

  next_indices = next_tokens // vocab_size


In [15]:
valid_cnt / total_gen_cnt, response_with_items / total_gen_cnt

(0.9145149160190524, 0.4413386813737779)

In [17]:
torch.save(total_sentences_generated, '../human_eval/mese2.pt')

In [22]:
valid_cnt / total_gen_cnt, response_with_items / total_gen_cnt

(0.9173978440711957, 0.4402105790925044)

In [23]:
dist1, dist2, dist3, dist4 = distinct_metrics(total_sentences_generated)
bleu1, bleu2, bleu3, bleu4 = bleu_calc_all(total_sentences_original, total_sentences_generated)
print(dist1, dist2, dist3, dist4)
print(bleu1, bleu2, bleu3, bleu4)

0.2602155928804212 0.7049385810980195 1.0142892955627978 1.156054148909501
0.342754064465557 0.2512299152344354 0.1892696463143165 0.14407982061678162
