In [1]:
import torch
from torch.utils.data import DataLoader, Dataset

from transformers import GPT2Config, GPT2Tokenizer, BertModel, BertTokenizer, DistilBertModel, DistilBertTokenizer
from transformers import AdamW, get_linear_schedule_with_warmup

from utils.InductiveAttentionModels import GPT2InductiveAttentionHeadModel
from utils.SequenceCrossEntropyLoss import SequenceCrossEntropyLoss

import numpy as np
import math
import random
import time
import tqdm

from annoy import AnnoyIndex

In [2]:
bert_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
bert_model_recall = DistilBertModel.from_pretrained('distilbert-base-uncased')
bert_model_rerank = DistilBertModel.from_pretrained('distilbert-base-uncased')
# bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
# bert_model = BertModel.from_pretrained('bert-base-uncased')
gpt_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_model = GPT2InductiveAttentionHeadModel.from_pretrained('gpt2')

# REC_TOKEN = "R"
# REC_END_TOKEN = "E"
REC_TOKEN = "[REC]"
REC_END_TOKEN = "[REC_END]"
SEP_TOKEN = "[SEP]"
PLACEHOLDER_TOKEN = "[MOVIE_ID]"
gpt_tokenizer.add_tokens([REC_TOKEN, REC_END_TOKEN, SEP_TOKEN, PLACEHOLDER_TOKEN])
gpt2_model.resize_token_embeddings(len(gpt_tokenizer)) 
original_token_emb_size = gpt2_model.get_input_embeddings().weight.shape[0]
print(original_token_emb_size)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_transform.bias']
- T

50261


In [3]:
class MovieRecDataset(Dataset):
    def __init__(self, data, bert_tok, gpt2_tok):
        self.data = data
        self.bert_tok = bert_tok
        self.gpt2_tok = gpt2_tok
        self.turn_ending = torch.tensor([[628, 198]]) # end of turn, '\n\n\n'
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        dialogue = self.data[index]
        
        dialogue_tokens = []
        
        for utterance, gt_ind in dialogue:
            utt_tokens = self.gpt2_tok(utterance, return_tensors="pt")['input_ids']
            dialogue_tokens.append( ( torch.cat( (utt_tokens, self.turn_ending), dim=1), gt_ind) )
            
        role_ids = None
        previous_role_ids = None
        if role_ids == None:
            role_ids = [ 0 if item[0] == 'B' else 1 for item, _ in dialogue]
            previous_role_ids = role_ids
        else:
            role_ids = [ 0 if item[0] == 'B' else 1 for item, _ in dialogue]
            if not np.array_equal(role_ids, previous_role_ids):
                raise Exception("Role ids dont match between languages")
            previous_role_ids = role_ids
        
        return role_ids, dialogue_tokens
    
    def collate(self, unpacked_data):
        return unpacked_data
    

In [4]:
items_db = torch.load('/local-scratch1/data/by2299/redial_full_movie_db_placeholder')

def sample_ids_from_db(item_db,
                       gt_id, # ground truth id
                       num_samples, # num samples to return
                       include_gt # if we want gt_id to be included
                      ):
    ids_2_sample_from = list(item_db.keys())
    ids_2_sample_from.remove(gt_id)
    if include_gt:
        results = random.sample(ids_2_sample_from, num_samples-1)
        results.append(gt_id)
    else:
        results = random.sample(ids_2_sample_from, num_samples)
    return results

In [5]:
class UniversalCRSModel(torch.nn.Module):
    def __init__(self, 
                 language_model, # backbone of Pretrained LM such as GPT2
                 encoder, # backbone of item encoder such as bert
                 recall_encoder,
                 lm_tokenizer, # language model tokenizer
                 encoder_tokenizer, # item encoder tokenizer
                 device, # Cuda device
                 items_db, # {id:info}, information of all items to be recommended
                 annoy_base_recall=None, # annoy index base of encoded recall embeddings of items
                 annoy_base_rerank=None, # annoy index base of encoded rerank embeddings of items, for validation and inference only
                 recall_item_dim=768, # dimension of each item to be stored in annoy base
                 lm_trim_offset=100, # offset to trim language model wte inputs length = (1024-lm_trim_offset)
                 rec_token_str="[REC]", # special token indicating recommendation and used for recall
                 rec_end_token_str="[REC_END]", # special token indicating recommendation ended, conditional generation starts
                 sep_token_str="[SEP]",
                 placeholder_token_str="[MOVIE_ID]"
                ):
        super(UniversalCRSModel, self).__init__()
        
        #models and tokenizers
        self.language_model = language_model
        self.item_encoder = encoder
        self.recall_encoder = recall_encoder
        self.lm_tokenizer = lm_tokenizer
        self.encoder_tokenizer = encoder_tokenizer
        self.device = device
        
        # item db and annoy index base
        self.items_db = items_db
        self.annoy_base_recall = annoy_base_recall
        self.annoy_base_rerank = annoy_base_rerank
        
        # hyperparameters
        self.recall_item_dim = recall_item_dim
        self.lm_trim_offset = lm_trim_offset
        
        #constants
        self.REC_TOKEN_STR = rec_token_str
        self.REC_END_TOKEN_STR = rec_end_token_str
        self.SEP_TOKEN_STR = sep_token_str
        self.PLACEHOLDER_TOKEN_STR = placeholder_token_str
        
        # map language model hidden states to a vector to query annoy-item-base for recall
        self.recall_lm_query_mapper = torch.nn.Linear(self.language_model.config.n_embd, self.recall_item_dim) # default [768,768]
        # map output of self.item_encoder to vectors to be stored in annoy-item-base 
        self.recall_item_wte_mapper = torch.nn.Linear(self.recall_encoder.config.hidden_size, self.recall_item_dim) # default [768,768]
        # map output of self.item_encoder to a wte of self.language_model
        self.rerank_item_wte_mapper = torch.nn.Linear(self.item_encoder.config.hidden_size, self.language_model.config.n_embd) # default [768,768]
        # map language model hidden states of item wte to a one digit logit for softmax computation
        self.rerank_logits_mapper = torch.nn.Linear(self.language_model.config.n_embd, 1) # default [768,1]
    
    def get_sep_token_wtes(self):
        sep_token_input_ids = self.lm_tokenizer(self.SEP_TOKEN_STR, return_tensors="pt")["input_ids"].to(self.device)
        return self.language_model.transformer.wte(sep_token_input_ids) # [1, 1, self.language_model.config.n_embd]

    def get_rec_token_wtes(self):
        rec_token_input_ids = self.lm_tokenizer(self.REC_TOKEN_STR, return_tensors="pt")["input_ids"].to(self.device)
        return self.language_model.transformer.wte(rec_token_input_ids) # [1, 1, self.language_model.config.n_embd]
    
    def get_rec_end_token_wtes(self):
        rec_end_token_input_ids = self.lm_tokenizer(self.REC_END_TOKEN_STR, return_tensors="pt")["input_ids"].to(self.device)
        return self.language_model.transformer.wte(rec_end_token_input_ids) # [1, 1, self.language_model.config.n_embd]
    
    def get_movie_title(self, m_id):
        title = self.items_db[m_id]
        title = title.split('[SEP]')[0].strip()
        return title
    
    # compute BERT encoded item hidden representation
    # output can be passed to self.recall_item_wte_mapper or self.rerank_item_wte_mapper
    def compute_encoded_embeddings_for_items(self, 
                                             encoder_to_use,
                                             item_ids, # an array of ids, single id should be passed as [id]
                                             items_db_to_use # item databse to use
                                            ):
        chunk_ids = item_ids
        chunk_infos = [items_db_to_use[key] for key in chunk_ids ]
        chunk_tokens = self.encoder_tokenizer(chunk_infos, padding=True, truncation=True, return_tensors="pt")
        chunk_input_ids = chunk_tokens['input_ids'].to(self.device)
        chunk_attention_mask = chunk_tokens['attention_mask'].to(self.device)
        chunk_hiddens = encoder_to_use(input_ids=chunk_input_ids, attention_mask=chunk_attention_mask).last_hidden_state
        # TODO: analyze whether averaging or taking [CLS] gives better result.
#             chunk_pooled = torch.mean(chunk_hiddens, dim=1).cpu().detach().numpy()
#             chunk_pooled = chunk_hiddens[:,0,:].cpu().detach().numpy()

        # average of non-padding tokens
        expanded_mask_size = list(chunk_attention_mask.size())
        expanded_mask_size.append(encoder_to_use.config.hidden_size)
        expanded_mask = chunk_attention_mask.unsqueeze(-1).expand(expanded_mask_size)
        chunk_masked = torch.mul(chunk_hiddens, expanded_mask) # [num_example, len, 768]
        chunk_pooled = torch.sum(chunk_masked, dim=1) / torch.sum(chunk_attention_mask, dim=1).unsqueeze(-1)
        
        # [len(item_ids), encoder_to_use.config.hidden_size], del chunk_hiddens to free up GPU memory
        return chunk_pooled, chunk_hiddens
    
    def annoy_base_constructor_emb_analysis(self, items_db=None, distance_type='angular', chunk_size=50, n_trees=10):
        items_db_to_use = self.items_db if items_db == None else items_db
        all_item_ids = list(items_db_to_use.keys())
        
        total_pooled = []
        # break into chunks/batches for model concurrency
        num_chunks = math.ceil(len(all_item_ids) / chunk_size)
        for i in range(num_chunks):
            chunk_ids = all_item_ids[i*chunk_size: (i+1)*chunk_size]
            chunk_pooled, chunk_hiddens = self.compute_encoded_embeddings_for_items(self.item_encoder, chunk_ids, items_db_to_use)
#             chunk_pooled = chunk_hiddens[:,0,:].cpu().detach().numpy()
            chunk_pooled = chunk_pooled.cpu().detach().numpy()
            del chunk_hiddens
            total_pooled.append(chunk_pooled)
        total_pooled = np.concatenate(total_pooled, axis=0)
        
        pooled_tensor = torch.tensor(total_pooled).to(self.device)
        
        #build recall annoy index
#         annoy_base_recall = AnnoyIndex(self.recall_item_wte_mapper.out_features, distance_type)
#         pooled_recall = self.recall_item_wte_mapper(pooled_tensor) # [len(items_db_to_use), self.recall_item_wte_mapper.out_features]
#         pooled_recall = pooled_recall.cpu().detach().numpy()
#         for i, vector in zip(all_item_ids, pooled_recall):
#             annoy_base_recall.add_item(i, vector)
#         annoy_base_recall.build(n_trees)
        
        #build rerank annoy index, for validation and inference only
        annoy_base_rerank = AnnoyIndex(self.rerank_item_wte_mapper.out_features, distance_type)
#         pooled_rerank = self.rerank_item_wte_mapper(pooled_tensor) # [len(items_db_to_use), self.recall_item_wte_mapper.out_features]
        pooled_rerank = pooled_tensor.cpu().detach().numpy()
        for i, vector in zip(all_item_ids, pooled_rerank):
            annoy_base_rerank.add_item(i, vector)
        annoy_base_rerank.build(n_trees)
        
        del pooled_tensor
        
#         self.annoy_base_recall = annoy_base_recall
        self.annoy_base_rerank = annoy_base_rerank
    
    # annoy_base_constructor, constructs
    def annoy_base_constructor(self, items_db=None, distance_type='angular', chunk_size=50, n_trees=10):
        items_db_to_use = self.items_db if items_db == None else items_db
        all_item_ids = list(items_db_to_use.keys())
        
        total_pooled = []
        # break into chunks/batches for model concurrency
        num_chunks = math.ceil(len(all_item_ids) / chunk_size)
        for i in range(num_chunks):
            chunk_ids = all_item_ids[i*chunk_size: (i+1)*chunk_size]
            chunk_pooled, chunk_hiddens = self.compute_encoded_embeddings_for_items(self.recall_encoder, chunk_ids, items_db_to_use)
            chunk_pooled = chunk_pooled.cpu().detach().numpy()
            del chunk_hiddens
            total_pooled.append(chunk_pooled)
        total_pooled = np.concatenate(total_pooled, axis=0)
        
        pooled_tensor = torch.tensor(total_pooled).to(self.device)
        
        #build recall annoy index
        annoy_base_recall = AnnoyIndex(self.recall_item_wte_mapper.out_features, distance_type)
        pooled_recall = self.recall_item_wte_mapper(pooled_tensor) # [len(items_db_to_use), self.recall_item_wte_mapper.out_features]
        pooled_recall = pooled_recall.cpu().detach().numpy()
        for i, vector in zip(all_item_ids, pooled_recall):
            annoy_base_recall.add_item(i, vector)
        annoy_base_recall.build(n_trees)
        
        total_pooled = []
        # break into chunks/batches for model concurrency
        num_chunks = math.ceil(len(all_item_ids) / chunk_size)
        for i in range(num_chunks):
            chunk_ids = all_item_ids[i*chunk_size: (i+1)*chunk_size]
            chunk_pooled, chunk_hiddens = self.compute_encoded_embeddings_for_items(self.item_encoder, chunk_ids, items_db_to_use)
            chunk_pooled = chunk_pooled.cpu().detach().numpy()
            del chunk_hiddens
            total_pooled.append(chunk_pooled)
        total_pooled = np.concatenate(total_pooled, axis=0)
        
        pooled_tensor = torch.tensor(total_pooled).to(self.device)
        
        #build rerank annoy index, for validation and inference only
        annoy_base_rerank = AnnoyIndex(self.rerank_item_wte_mapper.out_features, distance_type)
        pooled_rerank = self.rerank_item_wte_mapper(pooled_tensor) # [len(items_db_to_use), self.recall_item_wte_mapper.out_features]
        pooled_rerank = pooled_rerank.cpu().detach().numpy()
        for i, vector in zip(all_item_ids, pooled_rerank):
            annoy_base_rerank.add_item(i, vector)
        annoy_base_rerank.build(n_trees)
        
        del pooled_tensor
        
        self.annoy_base_recall = annoy_base_recall
        self.annoy_base_rerank = annoy_base_rerank
    
    def annoy_loader(self, path, annoy_type, distance_type="angular"):
        if annoy_type == "recall":
            annoy_base = AnnoyIndex(self.recall_item_wte_mapper.out_features, distance_type)
            annoy_base.load(path)
            return annoy_base
        elif annoy_type == "rerank":
            annoy_base = AnnoyIndex(self.rerank_item_wte_mapper.out_features, distance_type)
            annoy_base.load(path)
            return annoy_base
        else:
            return None
    
    def lm_expand_wtes_with_items_annoy_base(self):
        all_item_ids = list(self.items_db.keys())
        total_pooled = []
        for i in all_item_ids:
            total_pooled.append(self.annoy_base_rerank.get_item_vector(i))
        total_pooled = np.asarray(total_pooled) # [len(all_item_ids), 768]
        pooled_tensor = torch.tensor(total_pooled, dtype=torch.float).to(self.device)
        
        old_embeddings = self.language_model.get_input_embeddings()
        item_id_2_lm_token_id = {}
        for k in all_item_ids:
            item_id_2_lm_token_id[k] = len(item_id_2_lm_token_id) + old_embeddings.weight.shape[0]
        new_embeddings = torch.cat([old_embeddings.weight, pooled_tensor], 0)
        new_embeddings = torch.nn.Embedding.from_pretrained(new_embeddings)
        self.language_model.set_input_embeddings(new_embeddings)
        self.language_model.to(device)
        return item_id_2_lm_token_id
    
    def lm_restore_wtes(self, original_token_emb_size):
        old_embeddings = self.language_model.get_input_embeddings()
        new_embeddings = torch.nn.Embedding(original_token_emb_size, old_embeddings.weight.size()[1])
        new_embeddings.to(self.device, dtype=old_embeddings.weight.dtype)
        new_embeddings.weight.data[:original_token_emb_size, :] = old_embeddings.weight.data[:original_token_emb_size, :]
        self.language_model.set_input_embeddings(new_embeddings)
        self.language_model.to(self.device)
        assert(self.language_model.get_input_embeddings().weight.shape[0] == original_token_emb_size)
    
    def trim_lm_wtes(self, wtes):
        trimmed_wtes = wtes
        if trimmed_wtes.shape[1] > self.language_model.config.n_positions:
            trimmed_wtes = trimmed_wtes[:,-self.language_model.config.n_positions + self.lm_trim_offset:,:]
        return trimmed_wtes # [batch, self.language_model.config.n_positions - self.lm_trim_offset, self.language_model.config.n_embd]
    
    def trim_positional_ids(self, p_ids, num_items_wtes):
        trimmed_ids = p_ids
        if trimmed_ids.shape[1] > self.language_model.config.n_positions:
            past_ids = trimmed_ids[:,:self.language_model.config.n_positions - self.lm_trim_offset - num_items_wtes]
#             past_ids = trimmed_ids[:, self.lm_trim_offset + num_items_wtes:self.language_model.config.n_positions]
            item_ids = trimmed_ids[:,-num_items_wtes:]
            trimmed_ids = torch.cat((past_ids, item_ids), dim=1)
        return trimmed_ids # [batch, self.language_model.config.n_positions - self.lm_trim_offset]
    
    def compute_inductive_attention_mask(self, length_language, length_rerank_items_wtes):
        total_length = length_language + length_rerank_items_wtes
        language_mask_to_add = torch.zeros((length_language, total_length), dtype=torch.float, device=self.device)
        items_mask_to_add = torch.ones((length_rerank_items_wtes, total_length), dtype=torch.float, device=self.device)
        combined_mask_to_add = torch.cat((language_mask_to_add, items_mask_to_add), dim=0)
        return combined_mask_to_add #[total_length, total_length]
    
    def forward_pure_language_turn(self, 
                                   past_wtes, # past word token embeddings, [1, len, 768]
                                   current_tokens # tokens of current turn conversation, [1, len]
                                  ):
        train_logits, train_targets = None, None
        current_wtes = self.language_model.transformer.wte(current_tokens)
        
        if past_wtes == None:
            lm_outputs = self.language_model(inputs_embeds=current_wtes)
            train_logits = lm_outputs.logits[:, :-1, :]
            train_targets = current_tokens[:,1:]
        else:
            all_wtes = torch.cat((past_wtes, current_wtes), dim=1)
            all_wtes = self.trim_lm_wtes(all_wtes)
            lm_outputs = self.language_model(inputs_embeds=all_wtes)
            train_logits = lm_outputs.logits[:, -current_wtes.shape[1]:-1, :] # skip the last one
            train_targets = current_tokens[:,1:]
        
        # torch.Size([batch, len_cur, lm_vocab]), torch.Size([batch, len_cur]), torch.Size([batch, len_past+len_cur, lm_emb(768)])
        # if past_wtes == None, len_cur - 1
        return train_logits, train_targets
        
    def forward_recall(self, 
                       past_wtes, # past word token embeddings, [1, len, 768]
                       current_tokens, # tokens of current turn conversation, [1, len]
                       gt_item_id, # id, ex. 0
                       num_samples # num examples to sample for training, including groud truth id
                      ):
        # recall step 1. construct LM sequence output
        # LM input composition: [past_wtes, REC_wtes, gt_item_wte, (gt_item_info_wtes), REC_END_wtes, current_wtes ]
        
        REC_wtes = self.get_rec_token_wtes() # [1, 1, self.language_model.config.n_embd]
        gt_item_wte, _ = self.compute_encoded_embeddings_for_items(self.recall_encoder, [gt_item_id], self.items_db)
        gt_item_wte = self.rerank_item_wte_mapper(gt_item_wte) # [1, self.rerank_item_wte_mapper.out_features]
        
        REC_END_wtes = self.get_rec_end_token_wtes() # [1, 1, self.language_model.config.n_embd]
        current_wtes = self.language_model.transformer.wte(current_tokens) #[1, current_tokens.shape[1], self.language_model.config.n_embd]
        
        REC_wtes_len = REC_wtes.shape[1] # 1 by default
        gt_item_wte_len = gt_item_wte.shape[0] # 1 by default
        REC_END_wtes_len = REC_END_wtes.shape[1] # 1 by default
        current_wtes_len = current_wtes.shape[1]
        
        lm_wte_inputs = torch.cat(
            (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
             REC_wtes,
             gt_item_wte.unsqueeze(0), # reshape to [1,1,self.rerank_item_wte_mapper.out_features]
             REC_END_wtes,
             current_wtes # [batch (1), len, self.language_model.config.n_embd]
            ),
            dim=1
        )
        lm_wte_inputs = self.trim_lm_wtes(lm_wte_inputs) # trim for len > self.language_model.config.n_positions
        
        # recall step 2. get gpt output logits and hidden states
        lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs, output_hidden_states=True)
        
        # recall step 3. pull logits (recall, rec_token and language logits of current turn) and compute
        
        # recall logit(s)
        rec_token_start_index = -current_wtes_len-REC_END_wtes_len-gt_item_wte_len-REC_wtes_len
        rec_token_end_index = -current_wtes_len-REC_END_wtes_len-gt_item_wte_len
        # [batch (1), REC_wtes_len, self.language_model.config.n_embd]
        rec_token_hidden = lm_outputs.hidden_states[-1][:, rec_token_start_index:rec_token_end_index, :]
        # [batch (1), self.recall_lm_query_mapper.out_features]
        rec_query_vector = self.recall_lm_query_mapper(rec_token_hidden).squeeze(1)
        
        # sample num_samples item ids to train recall with "recommendation as classification"
        sampled_item_ids = sample_ids_from_db(self.items_db, gt_item_id, num_samples, include_gt=True)
        gt_item_id_index = sampled_item_ids.index(gt_item_id)
        
        # [num_samples, self.item_encoder.config.hidden_size]
        encoded_items_embeddings, _ = self.compute_encoded_embeddings_for_items(self.recall_encoder, sampled_item_ids, self.items_db)
        # to compute dot product with rec_query_vector
        items_key_vectors = self.recall_item_wte_mapper(encoded_items_embeddings) # [num_samples, self.recall_item_wte_mapper.out_features]
        expanded_rec_query_vector = rec_query_vector.expand(items_key_vectors.shape[0], rec_query_vector.shape[1]) # [num_samples, self.recall_item_wte_mapper.out_features]
        recall_logits = torch.sum(expanded_rec_query_vector * items_key_vectors, dim=1) # torch.size([num_samples])
        
        # REC_TOKEN prediction and future sentence prediction
        # hidden rep of the token that's right before REC_TOKEN
        token_before_REC_logits = lm_outputs.logits[:, rec_token_start_index-1:rec_token_end_index-1, :]
        REC_targets = self.lm_tokenizer(self.REC_TOKEN_STR, return_tensors="pt")['input_ids'].to(self.device) # [1, 1]
        
        #language logits and targets
        current_language_logits = lm_outputs.logits[:, -current_wtes_len:-1, :]
        current_language_targets = current_tokens[:,1:]
        
        # REC token and language, their logits and targets
        # [batch, current_wtes_len+REC_wtes_len, lm_vocab]
        all_wte_logits = torch.cat((token_before_REC_logits, current_language_logits), dim=1)
        # [current_wtes_len+REC_wtes_len, lm_vocab]
        all_wte_targets = torch.cat((REC_targets, current_language_targets), dim=1)
        
        # torch.size([num_samples]), id, [batch, current_wtes_len+REC_wtes_len, lm_vocab], [current_wtes_len+REC_wtes_len, lm_vocab]
        return recall_logits, gt_item_id_index, all_wte_logits, all_wte_targets
        
    def forward_rerank(self,
                       past_wtes, # past word token embeddings, [1, len, 768]
                       gt_item_id, # tokens of current turn conversation, [1, len]
                       num_samples, # num examples to sample for training, including groud truth id
                       rerank_items_chunk_size=10, # batch size for encoder GPU computation
                      ):
        # REC wte
        REC_wtes = self.get_rec_token_wtes() # [batch (1), 1, self.language_model.config.n_embd]
        
        #  items wtes to compute rerank loss
        # sample rerank examples
        sampled_item_ids = sample_ids_from_db(self.items_db, gt_item_id, num_samples, include_gt=True)
        gt_item_id_index = sampled_item_ids.index(gt_item_id)
        # compute item wtes by batch
        num_chunks = math.ceil(len(sampled_item_ids) / rerank_items_chunk_size)
        total_wtes = []
        for i in range(num_chunks):
            chunk_ids = sampled_item_ids[i*rerank_items_chunk_size: (i+1)*rerank_items_chunk_size]
            chunk_pooled, _ = self.compute_encoded_embeddings_for_items(self.item_encoder, chunk_ids, self.items_db) # [rerank_items_chunk_size, self.item_encoder.config.hidden_size]
            chunk_wtes = self.rerank_item_wte_mapper(chunk_pooled)
            total_wtes.append(chunk_wtes)
        total_wtes = torch.cat(total_wtes, dim=0) # [num_samples, self.language_model.config.n_embd]
        
        past_wtes_len = past_wtes.shape[1]
        REC_wtes_len = REC_wtes.shape[1] # 1 by default
        total_wtes_len = total_wtes.shape[0]
        
        # compute positional ids, all rerank item wte should have the same positional encoding id 0
        position_ids = torch.arange(0, past_wtes_len + REC_wtes_len, dtype=torch.long, device=self.device)
        items_position_ids = torch.zeros(total_wtes.shape[0], dtype=torch.long, device=device)
#         items_position_ids = torch.tensor([1023] * total_wtes.shape[0], dtype=torch.long, device=device)
        combined_position_ids = torch.cat((position_ids, items_position_ids), dim=0)
        combined_position_ids = combined_position_ids.unsqueeze(0) # [1, past_wtes_len+REC_wtes_len+total_wtes_len]
        
        # compute concatenated lm wtes
        lm_wte_inputs = torch.cat(
            (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
             REC_wtes, # [batch (1), 1, self.language_model.config.n_embd]
             total_wtes.unsqueeze(0), # [1, num_samples, self.language_model.config.n_embd]
            ),
            dim=1
        ) # [1, past_len + REC_wtes_len + num_samples, self.language_model.config.n_embd]

        # trim sequence to smaller length (len < self.language_model.config.n_positions-self.lm_trim_offset)
        combined_position_ids_trimmed = self.trim_positional_ids(combined_position_ids, total_wtes_len) # [1, len]
        lm_wte_inputs_trimmed = self.trim_lm_wtes(lm_wte_inputs) # [1, len, self.language_model.config.n_embd]
        assert(combined_position_ids.shape[1] == lm_wte_inputs.shape[1])
        
        # compute inductive attention mask
        #     Order of recommended items shouldn't affect their score, thus every item 
        # should have full attention over the entire sequence: they should know each other and the entire
        # conversation history
        inductive_attention_mask = self.compute_inductive_attention_mask(
            lm_wte_inputs_trimmed.shape[1]-total_wtes.shape[0], 
            total_wtes.shape[0]
        )
        rerank_lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs_trimmed,
                  inductive_attention_mask=inductive_attention_mask,
                  position_ids=combined_position_ids_trimmed,
                  output_hidden_states=True)
        
        rerank_lm_hidden = rerank_lm_outputs.hidden_states[-1][:, -total_wtes.shape[0]:, :]
        rerank_logits = self.rerank_logits_mapper(rerank_lm_hidden).squeeze() # torch.Size([num_samples])
        
        return rerank_logits, gt_item_id_index
    
    def validation_perform_recall(self, past_wtes, topk):
        REC_wtes = self.get_rec_token_wtes()
        lm_wte_inputs = torch.cat(
            (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
             REC_wtes # [1, 1, self.language_model.config.n_embd]
            ),
            dim=1
        )
        lm_wte_inputs = self.trim_lm_wtes(lm_wte_inputs) # trim for len > self.language_model.config.n_positions
        lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs, output_hidden_states=True)
        
        rec_token_hidden = lm_outputs.hidden_states[-1][:, -1, :]
        # [batch (1), self.recall_lm_query_mapper.out_features]
        rec_query_vector = self.recall_lm_query_mapper(rec_token_hidden).squeeze(0) # [768]
        rec_query_vector = rec_query_vector.cpu().detach().numpy()
        recall_results = self.annoy_base_recall.get_nns_by_vector(rec_query_vector, topk)
        return recall_results
    
    def validation_perform_rerank(self, past_wtes, recalled_ids):
        REC_wtes = self.get_rec_token_wtes()
        
        total_wtes = [ self.annoy_base_rerank.get_item_vector(r_id) for r_id in recalled_ids]
        total_wtes = [ torch.tensor(wte).reshape(-1, self.language_model.config.n_embd).to(self.device) for wte in total_wtes]
        total_wtes = torch.cat(total_wtes, dim=0) # [len(recalled_ids), 768]
        
        past_wtes_len = past_wtes.shape[1]
        REC_wtes_len = REC_wtes.shape[1] # 1 by default
        total_wtes_len = total_wtes.shape[0]
        
        # compute positional ids, all rerank item wte should have the same positional encoding id 0
        position_ids = torch.arange(0, past_wtes_len + REC_wtes_len, dtype=torch.long, device=self.device)
        items_position_ids = torch.zeros(total_wtes.shape[0], dtype=torch.long, device=device)
#         items_position_ids = torch.tensor([1023] * total_wtes.shape[0], dtype=torch.long, device=device)
        combined_position_ids = torch.cat((position_ids, items_position_ids), dim=0)
        combined_position_ids = combined_position_ids.unsqueeze(0) # [1, past_wtes_len+REC_wtes_len+total_wtes_len]
        
        # compute concatenated lm wtes
        lm_wte_inputs = torch.cat(
            (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
             REC_wtes, # [batch (1), 1, self.language_model.config.n_embd]
             total_wtes.unsqueeze(0), # [1, num_samples, self.language_model.config.n_embd]
            ),
            dim=1
        ) # [1, past_len + REC_wtes_len + num_samples, self.language_model.config.n_embd]

        # trim sequence to smaller length (len < self.language_model.config.n_positions-self.lm_trim_offset)
        combined_position_ids_trimmed = self.trim_positional_ids(combined_position_ids, total_wtes_len) # [1, len]
        lm_wte_inputs_trimmed = self.trim_lm_wtes(lm_wte_inputs) # [1, len, self.language_model.config.n_embd]
        assert(combined_position_ids.shape[1] == lm_wte_inputs.shape[1])
        
        inductive_attention_mask = self.compute_inductive_attention_mask(
            lm_wte_inputs_trimmed.shape[1]-total_wtes.shape[0], 
            total_wtes.shape[0]
        )
        rerank_lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs_trimmed,
                  inductive_attention_mask=inductive_attention_mask,
                  position_ids=combined_position_ids_trimmed,
                  output_hidden_states=True)
        
        rerank_lm_hidden = rerank_lm_outputs.hidden_states[-1][:, -total_wtes.shape[0]:, :]
        rerank_logits = self.rerank_logits_mapper(rerank_lm_hidden).squeeze()
        
        return rerank_logits
    
    # TODO: Modify
    def validation_generate_sentence(self, past_wtes, recommended_id):
        if recommended_id == None: #  pure language
            lm_wte_inputs = self.trim_lm_wtes(past_wtes)
            lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs, output_hidden_states=True)
            generated = model.language_model.generate(
                past_key_values=lm_outputs.past_key_values, 
                max_length=50, 
                num_return_sequences=1, 
                do_sample=True, 
                eos_token_id=198,
                pad_token_id=198
            )
            generated_sen = gpt_tokenizer.decode(generated[0], skip_special_tokens=True)
            return generated_sen, None
        else:
            REC_wtes = self.get_rec_token_wtes() # [1, 1, self.language_model.config.n_embd]
            gt_item_wte, _ = self.compute_encoded_embeddings_for_items([recommended_id], self.items_db)
            gt_item_wte = self.rerank_item_wte_mapper(gt_item_wte) # [1, self.rerank_item_wte_mapper.out_features]
            REC_END_wtes = self.get_rec_end_token_wtes() # [1, 1, self.language_model.config.n_embd]

            REC_wtes_len = REC_wtes.shape[1] # 1 by default
            gt_item_wte_len = gt_item_wte.shape[0] # 1 by default
            REC_END_wtes_len = REC_END_wtes.shape[1] # 1 by default

            lm_wte_inputs = torch.cat(
                (past_wtes, # [batch (1), len, self.language_model.config.n_embd]
                 REC_wtes,
                 gt_item_wte.unsqueeze(0), # reshape to [1,1,self.rerank_item_wte_mapper.out_features]
                 REC_END_wtes
                ),
                dim=1
            )
            lm_wte_inputs = self.trim_lm_wtes(lm_wte_inputs) # trim for len > self.language_model.config.n_positions
            lm_outputs = self.language_model(inputs_embeds=lm_wte_inputs, output_hidden_states=True)
            generated = model.language_model.generate(
                past_key_values=lm_outputs.past_key_values, 
                max_length=50, 
                num_return_sequences=1, 
                do_sample=True, 
                eos_token_id=198,
                pad_token_id=198
            )
            generated_sen = gpt_tokenizer.decode(generated[0], skip_special_tokens=True)
            return generated_sen, self.get_movie_title(recommended_id)

In [6]:
device = torch.device(0)
model = UniversalCRSModel(gpt2_model, bert_model_rerank, bert_model_recall, gpt_tokenizer, bert_tokenizer, device, items_db, rec_token_str=REC_TOKEN, rec_end_token_str=REC_END_TOKEN)
# model.load_state_dict(torch.load("/local-scratch1/data/by2299/CRS_Redial_Train_1st_normal_lang_sep_annoy_7.pt"))

model.to(device)
pass
# model.recall_item_wte_mapper.out_features

In [7]:
# model.annoy_base_constructor()
model.annoy_base_constructor_emb_analysis()

In [7]:
movie_2_genres = {}
for k, v in items_db.items():
    genres = v.split("[SEP]")[3].strip().split(",")
#     if len(genres) != 1 or genres[0] == '': continue
    if '' in genres:
        continue
    for i in range(len(genres)):
        genres[i] = genres[i].strip()
        if 'Sci-Fi' in genres[i]:
            genres[i] = 'Science Fiction'
        if 'Musical' in genres[i]:
            genres[i] = 'Music'
    movie_2_genres[k] = genres

In [8]:
distinct_genres = {}
for k, genres in movie_2_genres.items():
    for g in genres:
        if g.strip() not in distinct_genres:
            distinct_genres[g.strip()] = len(distinct_genres)
distinct_genres

{'Animation': 0,
 'Action': 1,
 'Adventure': 2,
 'Comedy': 3,
 'Drama': 4,
 'Crime': 5,
 'Fantasy': 6,
 'Horror': 7,
 'Science Fiction': 8,
 'Thriller': 9,
 'Mystery': 10,
 'Romance': 11,
 'Music': 12,
 'Family': 13,
 'Biography': 14,
 'Sport': 15,
 'History': 16,
 'War': 17,
 'Film-Noir': 18,
 'Documentary': 19,
 'TV Movie': 20,
 'Western': 21,
 'Foreign': 22,
 'Short': 23,
 'Adult': 24}

In [9]:
wtes = []
genre_labels = []
for k, genres in movie_2_genres.items():
    wte = model.annoy_base_rerank.get_item_vector(k)
    wtes.append(wte)
    genre_vector = [0] * len(distinct_genres)
    for g in genres:
        genre_vector[distinct_genres[g]] = 1
    genre_labels.append(genre_vector)

In [10]:
portion = 4800
train_wtes = wtes[:portion]; train_labels = genre_labels[:portion]
test_wtes = wtes[portion:]; test_labels = genre_labels[portion:]

In [11]:
class MultiClassModel(torch.nn.Module):
    def __init__(self):
        super(MultiClassModel, self).__init__()
        self.core = torch.nn.Linear(768, len(distinct_genres))
    def forward(self, vectors):
        return self.core(vectors)
model = MultiClassModel()

In [12]:
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.03)

In [13]:
#eval
def eval_func(features_test, labels_test):
    total = 0; correct = 0
    features_test = torch.tensor( features_test, dtype=torch.float )
    labels_test = torch.tensor( labels_test, dtype=torch.float )
    preds = model.forward(features_test)
    sig = torch.nn.Sigmoid()
    preds = sig(preds)
    for pred, gt in zip(preds, labels_test):
        c_vec = [0] * len(distinct_genres)
        for i in range(len(gt)):
#             if pred[i] >= 0.5:
#                 c_vec[i] = 1
#             else:
#                 c_vec[i] = 0
#         print(gt)
#         print(c_vec)
#         print()
            if gt[i] == 1:
                total += 1
                if pred[i] >= 0.5:
                    correct += 1
#             elif gt[i] == 0 and pred[i] < 0.5:
#                 correct += 1
    return correct, total
                

In [14]:
#train
batch_size = 100
epochs = 50
for e in range(epochs):
    model.train()
    for i in range(int(portion/batch_size)):
        features_train = torch.tensor( train_wtes[i*batch_size : (i+1)*batch_size], dtype=torch.float )
        labels_train = torch.tensor( train_labels[i*batch_size : (i+1)*batch_size], dtype=torch.float )
        outputs = model.forward(features_train)
        loss = criterion(outputs, labels_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    model.eval()
    correct, total = eval_func(test_wtes, test_labels)
    print(f"Epoch {e}, ratio {correct/total}")

Epoch 0, ratio 0.5061728395061729
Epoch 1, ratio 0.5653021442495126
Epoch 2, ratio 0.5912930474333983
Epoch 3, ratio 0.6140350877192983
Epoch 4, ratio 0.6231319038336582
Epoch 5, ratio 0.6367771280051981
Epoch 6, ratio 0.6432748538011696
Epoch 7, ratio 0.6510721247563352
Epoch 8, ratio 0.6588693957115009
Epoch 9, ratio 0.6653671215074723
Epoch 10, ratio 0.6686159844054581
Epoch 11, ratio 0.6757634827810266
Epoch 12, ratio 0.6848602988953866
Epoch 13, ratio 0.6900584795321637
Epoch 14, ratio 0.6939571150097466
Epoch 15, ratio 0.6952566601689408
Epoch 16, ratio 0.6985055230669266
Epoch 17, ratio 0.7030539311241065
Epoch 18, ratio 0.7069525666016894
Epoch 19, ratio 0.7108512020792722
Epoch 20, ratio 0.7166991552956465
Epoch 21, ratio 0.7179987004548408
Epoch 22, ratio 0.7179987004548408
Epoch 23, ratio 0.7205977907732294
Epoch 24, ratio 0.723196881091618
Epoch 25, ratio 0.7257959714100065
Epoch 26, ratio 0.7264457439896036
Epoch 27, ratio 0.7277452891487979
Epoch 28, ratio 0.7283950617283

In [None]:
#base 0.743
# ep0 0.871
# ep1 0.860
# ep2 0.793
# ep3 0.782
# ep4 0.719

In [29]:
#train
batch_size = 100
epochs = 50
for e in range(epochs):
    model.train()
    for i in range(int(4800/batch_size)):
        features_train = torch.tensor( train_wtes[i*100 : (i+1)*100], dtype=torch.float )
        labels_train = torch.tensor( train_labels[i*100 : (i+1)*100], dtype=torch.float )
        outputs = model.forward(features_train)
        loss = criterion(outputs, labels_train)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    model.eval()
    correct, total = eval_func(test_wtes, test_labels)
    print(f"Epoch {e}, ratio {correct/total}")

Epoch 0, ratio 0.2014294996751137
Epoch 1, ratio 0.2378167641325536
Epoch 2, ratio 0.28914879792072773
Epoch 3, ratio 0.3398310591293047
Epoch 4, ratio 0.3742690058479532
Epoch 5, ratio 0.39246263807667314
Epoch 6, ratio 0.41325536062378165
Epoch 7, ratio 0.4301494476933073
Epoch 8, ratio 0.44834307992202727
Epoch 9, ratio 0.45938921377517866
Epoch 10, ratio 0.47043534762833006
Epoch 11, ratio 0.4821312540610786
Epoch 12, ratio 0.48927875243664715
Epoch 13, ratio 0.49577647823261856
Epoch 14, ratio 0.5029239766081871
Epoch 15, ratio 0.50682261208577
Epoch 16, ratio 0.5139701104613386
Epoch 17, ratio 0.5224171539961013
Epoch 18, ratio 0.5282651072124757
Epoch 19, ratio 0.5308641975308642
Epoch 20, ratio 0.5386614684860299
Epoch 21, ratio 0.5432098765432098
Epoch 22, ratio 0.5477582846003899
Epoch 23, ratio 0.5516569200779727
Epoch 24, ratio 0.557504873294347
Epoch 25, ratio 0.5594541910331384
Epoch 26, ratio 0.5607537361923327
Epoch 27, ratio 0.5627030539311241
Epoch 28, ratio 0.5659519

In [7]:
single_genre = {}
for k, v in items_db.items():
    if len(v.split("[SEP]")) < 5: continue
    genres = v.split("[SEP]")[3].strip().split(",")
    if genres[0] == "Documentary":
        continue
    if len(genres) == 1 and genres[0] != '':
        cleaned_genre = genres[0]
        if genres[0] == 'Music':
            cleaned_genre = 'Musical'
        if genres[0] == 'Sci-Fi':
            cleaned_genre = 'Science Fiction'
        single_genre[k] = cleaned_genre

In [8]:
from collections import Counter
Counter(single_genre.values())

Counter({'Drama': 192,
         'Comedy': 218,
         'Romance': 5,
         'Horror': 69,
         'Fantasy': 3,
         'Thriller': 17,
         'Animation': 5,
         'Science Fiction': 6,
         'Mystery': 4,
         'Action': 7,
         'Western': 3,
         'Adventure': 3,
         'Adult': 1,
         'Short': 3,
         'Musical': 3,
         'Family': 5,
         'Crime': 1})

In [10]:
wtes = []
genre_labels = []
for k,v in single_genre.items():
    wte = model.annoy_base_rerank.get_item_vector(k)
    wtes.append(wte)
    genre_labels.append(v)

In [17]:
from sklearn.cluster import KMeans
from collections import defaultdict
sample_times = 20
avg_acc = []
for rs in range(sample_times):
    kmeans = KMeans(n_clusters=3, random_state=rs).fit(wtes)
    id_2_main_genre = {}
    for i in range(len(wtes)):
        cluster_id = kmeans.labels_[i]
        genre_label = genre_labels[i]
        if cluster_id not in id_2_main_genre:
            id_2_main_genre[cluster_id] = defaultdict(int)
        id_2_main_genre[cluster_id][genre_label] += 1
    class_name = []
    acc = []
    for key in id_2_main_genre.keys():
        cluster = id_2_main_genre[key]
        cnter = Counter(cluster)
        majority = cnter.most_common(1)[0]
        total_cnt = sum(cluster.values())
        main_acc = majority[1] / total_cnt
        class_name.append(majority[0])
        acc.append(main_acc)
    avg_acc.append(np.mean(acc))

In [18]:
np.mean(avg_acc)

0.49274422012473573

In [None]:
# no training
3: 0.49274422012473573
4: 0.5136649549291342
5: 0.5738269301657036
    
#no genre, no plot
# 3: 0.6044874826639764
# 4: 0.6664281020406956
# 5: 0.689556041130807
3: 0.5554836134916321
4: 0.589166330158624
5: 0.6060811410293794

# sep annoy
3: 0.694558391935933
4: 0.7253916217797459
5: 0.7375563891288288

In [29]:
# without Documentary
old:
    3, 0.4936322713287639
    4, 0.5144524780708778
    5, 0.5736495385033692

old CLS:
    3, 0.43372707112345416
    4, 0.45186175234274845
    5, 0.490816081925875
    
trained:
    3, 0.6068449328399261
    4, 0.7047642984671186
    5, 0.7459912269026595
    
# with documentary
old:
    3, 0.5419993871238592
    4, 0.5654777377625189
    5, 0.5780730722362244

old CLS:
    3, 0.5253262398313966
    4, 0.5285883323843146
    5, 0.512210301337264

trained:
    3, 0.5557598070972909
    4, 0.6259810961551399
    5, 0.7178085603955927

SyntaxError: invalid syntax (3318023016.py, line 1)

In [11]:

id_2_main_genre = {}
for i in range(len(wtes)):
    cluster_id = kmeans.labels_[i]
    genre_label = genre_labels[i]
    if cluster_id not in id_2_main_genre:
        id_2_main_genre[cluster_id] = defaultdict(int)
    id_2_main_genre[cluster_id][genre_label] += 1
        

In [12]:
class_name = []
acc = []
for rs in range(20):
for key in id_2_main_genre.keys():
    cluster = id_2_main_genre[key]
    cnter = Counter(cluster)
    majority = cnter.most_common(1)[0]
    total_cnt = sum(cluster.values())
    main_acc = majority[1] / total_cnt
    class_name.append(majority[0])
    acc.append(main_acc)