In [1]:
import torch
import os
from tqdm import tqdm
import logging
import random
from collections import Counter
from nltk.tokenize import word_tokenize
import numpy as np

In [2]:
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import IterableDataset, Dataset
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score

In [3]:
def get_sample(all_elements, num_sample):
    if num_sample > len(all_elements):
        return random.sample(all_elements * (num_sample // len(all_elements) + 1), num_sample)
    else:
        return random.sample(all_elements, num_sample)

In [4]:
train_data_dir = './data/MINDsmall_train'

In [5]:
def prepare_training_data(seed = 1009, npratio = 4):
    random.seed(seed)
    behaviors = []
    
    behavior_file_path = os.path.join(train_data_dir, 'behaviors.tsv')
    with open(behavior_file_path, 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            iid, uid, time, history, imp = line.strip().split('\t')
            his = history.split(' ')
            # print(his, len(his))
            if his[0] == '':
                continue
            impressions = [x.split('-') for x in imp.split(' ')]
            pos, neg = [], []
            for news_ID, label in impressions:
                if label == '0':
                    neg.append(news_ID)
                elif label == '1':
                    pos.append(news_ID)
            if len(pos) == 0 or len(neg) == 0:
                continue
            for pos_id in pos:
                neg_candidate = get_sample(neg, npratio)
                neg_str = ' '.join(neg_candidate)
                new_line = '\t'.join([iid, uid, time, history, pos_id, neg_str]) + '\n'
                behaviors.append(new_line)

    random.shuffle(behaviors)
    processed_file_path = os.path.join(train_data_dir, f'behaviors_np{npratio}.tsv')
    with open(processed_file_path, 'w') as f:
        f.writelines(behaviors)
    return len(behaviors)

In [6]:
prepare_training_data()

156965it [00:03, 50302.59it/s]


231530

In [7]:
def overlap_info(set1, set2):
    overlap_keys = set(set1.keys()) & set(set2.keys())
    print(f"Number of overlapping keys between news2int and news3int: {len(overlap_keys)}, lenset: {len(set1), len(set2)}")
    print("****set1****")
    for key in set1.keys():
        if key not in overlap_keys:
            print(key)
    print("****set2****")
    for key in set2.keys():
        if key not in overlap_keys:
            print(key)


def get_indexCate(dictCat, key):
    if key in dictCat:
        return dictCat[key]
    print(len(dictCat), key)
    return 0

def update_dict(dict, key, value=None):
    if key not in dict:
        if value is None:
            dict[key] = len(dict) + 1
        else:
            dict[key] = value
            
def read_news(news_path, tt_mat = None, mode = 'train'):
    news = {}
    newsVT = {}
    news_index = {}
    tt_emb = {}
    counter = 0
    with open(news_path, 'r', encoding="utf8") as ifile:
        news_collection = ifile.readlines()
        for line in tqdm(news_collection):
            newsid, category, subcategory, title, abstract, _, _, _ = line.strip().split("\t")
            if newsid in news:
                continue
            emtt = tt_mat[counter]
            counter += 1
            update_dict(news, newsid, [emtt, category, subcategory])
            update_dict(news_index, newsid)
            if mode == 'train':
                update_dict(category_dict, category)
                update_dict(subcategory_dict, subcategory)
            ft = np.concatenate((emtt, [get_indexCate(category_dict, category), get_indexCate(subcategory_dict, subcategory)]))
            update_dict(newsVT, newsid, ft)

    assert  tt_mat.shape[0] == len(news)
    return news, news_index, newsVT


In [93]:
title_encoded = np.load('./data/tt_mat_train.npy')
abs_encoded = np.load('./data/at_mat_train.npy')
tt = np.hstack([title_encoded, abs_encoded])
category_dict = {}
subcategory_dict = {}
news, news_index, newsVT = read_news(os.path.join(train_data_dir, 'news.tsv'), tt)

100%|█████████████████████████████████████████████████████████████████████████| 51282/51282 [00:01<00:00, 50312.30it/s]


### Prepare train Dataset
Trim history of users which longer than 100

In [9]:
def listN2feat(listN, features):
    listF = []
    for newsID in listN:
        listF.append(features[newsID])
    return np.asarray(listF)
def pad_to_fix_len(x, fix_length, padding_front=True, padding_value=0):
    if padding_front:
        pad_x = [padding_value] * (fix_length - len(x)) + x[-fix_length:]
        mask = [0] * (fix_length - len(x)) + [1] * min(fix_length, len(x))
    else:
        pad_x = x[-fix_length:] + [padding_value] * (fix_length - len(x))
        mask = [1] * min(fix_length, len(x)) + [0] * (fix_length - len(x))
    return pad_x, np.array(mask, dtype='float32')

def pad_matrix(matrix, pad_length):
  """Pads a numpy matrix with zeros to a specific shape.

  Args:
    matrix: The numpy matrix to pad.

  Returns:
    A new numpy matrix with zeros padded to shape (x,y).
  """
  padded_matrix = np.zeros((pad_length, matrix.shape[1]))
  padded_matrix[-matrix.shape[0]:, -matrix.shape[1]:] = matrix
  return padded_matrix

In [10]:
class DatasetTrain(Dataset):
    def __init__(self, path, news, fullVT):
        self.data = path
        self.news = news
        self.user_log_length = 100
        self.npratio = 4
        self.behaviors = []
        self.newsVT = fullVT
        with open(path, 'r', encoding='utf-8') as f:
            for line in tqdm(f):
                iid, uid, time, history, pos, neg = line.strip().split('\t')
                negs = [x.split('-')[0] for x in neg.split(' ')]
                histories = [x.split('-')[0] for x in history.split(' ')]
                if len(histories) > 100:
                    histories = histories[-100:]
                self.behaviors.append([uid, histories, pos, negs])

    def __getitem__(self, idx):
        uid, histories, pos, negs = self.behaviors[idx]
        user_feature = pad_matrix(listN2feat(histories, self.newsVT), self.user_log_length)
        _, log_mask = pad_to_fix_len(histories, self.user_log_length)
        label = random.randint(0, self.npratio)
        sample_news = negs[:label]
        sample_news.append(pos)
        sample_news.extend(negs[label:])
        new_ft = listN2feat(sample_news, self.newsVT)
        
        return user_feature, log_mask, new_ft, label

    def __len__(self):
        return len(self.behaviors)

class DatasetTest(Dataset):
    def __init__(self, path, scoring):
        self.data = path
        self.user_log_length = 100
        self.behaviors = []
        self.scoring = scoring
        with open(path, 'r', encoding='utf-8') as f:
            for line in tqdm(f):
                iid, uid, time, history, imp = line.strip().split('\t')
                histories = [x.split('-')[0] for x in history.split(' ')]
                if histories[0] == '':
                    continue
                candidates = [x.split('-')[0] for x in imp.split(' ')]
                labels = np.array([int(i.split('-')[1]) for i in imp.split()])
                if len(histories) > 100:
                    histories = histories[-100:]
                self.behaviors.append([uid, histories, candidates, labels])

    def __getitem__(self, idx):
        uid, histories, candidates, labels = self.behaviors[idx]
        
        user_feature = pad_matrix(listN2feat(histories, self.scoring), self.user_log_length)
        _, log_mask = pad_to_fix_len(histories, self.user_log_length)
        
        new_ft = listN2feat(candidates, self.scoring)
        
        return user_feature, log_mask, new_ft, labels

    def __len__(self):
        return len(self.behaviors)

class NewsDataset(Dataset):
    def __init__(self, data):
        self.data = data
        self.listK = list(data.keys())

    def __getitem__(self, idx):
        return self.listK[idx], self.data[self.listK[idx]]

    def __len__(self):
        return len(self.listK)



def collate_fn(tuple_list):
    log_vecs = torch.FloatTensor([x[0] for x in tuple_list])
    log_mask = torch.FloatTensor([x[1] for x in tuple_list])
    news_vecs = [x[2] for x in tuple_list]
    labels = [x[3] for x in tuple_list]
    return (log_vecs, log_mask, news_vecs, labels)

In [94]:
dataset = DatasetTrain(train_data_dir + '/behaviors_np4.tsv', news, newsVT)
dataloader = DataLoader(dataset, batch_size=128)


231530it [00:03, 71091.92it/s]


In [95]:
class AttentionPooling(nn.Module):
    def __init__(self, emb_size, hidden_size):
        super(AttentionPooling, self).__init__()
        self.att_fc1 = nn.Linear(emb_size, hidden_size)
        self.att_fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x, attn_mask=None):
        """
        Args:
            x: batch_size, candidate_size, emb_dim
            attn_mask: batch_size, candidate_size
        Returns:
            (shape) batch_size, emb_dim
        """
        e = self.att_fc1(x)
        e = nn.Tanh()(e)
        alpha = self.att_fc2(e)
        alpha = torch.exp(alpha)

        if attn_mask is not None:
            alpha = alpha * attn_mask.unsqueeze(2)

        alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)
        x = torch.bmm(x.permute(0, 2, 1), alpha).squeeze(dim=-1)
        return x


class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k

    def forward(self, Q, K, V, attn_mask=None):
        '''
            Q: batch_size, n_head, candidate_num, d_k
            K: batch_size, n_head, candidate_num, d_k
            V: batch_size, n_head, candidate_num, d_v
            attn_mask: batch_size, n_head, candidate_num
            Return: batch_size, n_head, candidate_num, d_v
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
        scores = torch.exp(scores)

        if attn_mask is not None:
            scores = scores * attn_mask.unsqueeze(dim=-2)

        attn = scores / (torch.sum(scores, dim=-1, keepdim=True) + 1e-8)
        context = torch.matmul(attn, V)
        return context


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, d_k, d_v):
        super(MultiHeadSelfAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v

        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)

        self.scaled_dot_product_attn = ScaledDotProductAttention(self.d_k)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=1)

    def forward(self, Q, K, V, mask=None):
        '''
            Q: batch_size, candidate_num, d_model
            K: batch_size, candidate_num, d_model
            V: batch_size, candidate_num, d_model
            mask: batch_size, candidate_num
        '''
        batch_size = Q.shape[0]
        if mask is not None:
            mask = mask.unsqueeze(dim=1).expand(-1, self.n_heads, -1)

        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)

        context = self.scaled_dot_product_attn(q_s, k_s, v_s, mask)
        output = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
        return output




def print_metrics(cnt, x):
    print(cnt, x)

def get_mean(arr):
    return [np.array(i).mean() for i in arr]

def get_sum(arr):
    return [np.array(i).sum() for i in arr]


def dcg_score(y_true, y_score, k=10):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order[:k])
    gains = 2**y_true - 1
    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gains / discounts)


def ndcg_score(y_true, y_score, k=10):
    best = dcg_score(y_true, y_true, k)
    actual = dcg_score(y_true, y_score, k)
    return actual / best


def mrr_score(y_true, y_score):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order)
    rr_score = y_true / (np.arange(len(y_true)) + 1)
    return np.sum(rr_score) / np.sum(y_true)


def ctr_score(y_true, y_score, k=1):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order[:k])
    return np.mean(y_true)

In [96]:
class NewsEncoder(nn.Module):
    def __init__(self, num_category, num_subcategory):
        super(NewsEncoder, self).__init__()
        self.drop_rate = 0.2
        category_emb_dim = 100
        news_dim = 768
        news_query_vector_dim = 200
        self.ttEmb = nn.Linear(news_dim, news_dim)
        self.category_emb = nn.Embedding(num_category + 1, category_emb_dim, padding_idx=0)
        self.category_dense = nn.Linear(category_emb_dim, news_dim)
        self.subcategory_emb = nn.Embedding(num_subcategory + 1, category_emb_dim, padding_idx=0)
        self.subcategory_dense = nn.Linear(category_emb_dim, news_dim)
        self.final_attn = AttentionPooling(news_dim, news_query_vector_dim)

    def forward(self, x, mask=None):
        '''
            x: batch_size, word_num
            mask: batch_size, word_num
        '''
        title = x.reshape(-1, 770)[:,:-2].float()
        category = x[..., -2].reshape(-1).long()
        subcat = x[..., -1].reshape(-1).long()
    
        title = self.ttEmb(title)
        all_vecs = [title]

        category_vecs = self.category_dense(self.category_emb(category))
        all_vecs.append(category_vecs)
        subcategory_vecs = self.subcategory_dense(self.subcategory_emb(subcat))
        all_vecs.append(subcategory_vecs)
        
        # all_vecs = torch.cat(all_vecs, dim=1)
        all_vecs = torch.stack(all_vecs, dim=1)
        news_vecs = self.final_attn(all_vecs)
        return news_vecs

In [97]:

class UserEncoder(nn.Module):
    def __init__(self):
        super(UserEncoder, self).__init__()
        news_dim = 768
        user_query_vector_dim = 200
        self.user_log_length = 100
        self.user_log_mask = False
        self.attn = AttentionPooling(news_dim, user_query_vector_dim)
        self.pad_doc = nn.Parameter(torch.empty(1, news_dim).uniform_(-1, 1)).type(torch.FloatTensor)

    def forward(self, news_vecs, log_mask=None):
        '''
            news_vecs: batch_size, history_num, news_dim
            log_mask: batch_size, history_num
        '''
        bz = news_vecs.shape[0]
        if self.user_log_mask:
            user_vec = self.attn(news_vecs, log_mask)
        else:
            padding_doc = self.pad_doc.unsqueeze(dim=0).expand(bz, self.user_log_length, -1)
            news_vecs = news_vecs * log_mask.unsqueeze(dim=-1) + padding_doc * (1 - log_mask.unsqueeze(dim=-1))
            user_vec = self.attn(news_vecs)
        return user_vec

In [98]:
class NAML(torch.nn.Module):
    def __init__(self, num_category, num_subcategory, **kwargs):
        super(NAML, self).__init__()
        self.news_encoder = NewsEncoder(num_category, num_subcategory)
        self.user_encoder = UserEncoder()
        self.loss_fn = nn.CrossEntropyLoss()
        self.npratio = 4
        self.news_dim = 768
        self.user_log_length = 100

    def forward(self, history, history_mask, candidate, label):
        '''
            history: batch_size, history_length, num_word_title
            history_mask: batch_size, history_length
            candidate: batch_size, 1+K, num_word_title
            label: batch_size, 1+K
        '''
        candidate_news_vecs = self.news_encoder(candidate).reshape(-1, 1 + self.npratio, self.news_dim)
        # print("candidate: ")
        history_news_vecs = self.news_encoder(history).reshape(-1, self.user_log_length, self.news_dim)
        # print("history_news_vecs: ")
        user_vec = self.user_encoder(history_news_vecs, history_mask)
        score = torch.bmm(candidate_news_vecs, user_vec.unsqueeze(dim=-1)).squeeze(dim=-1)
        loss = self.loss_fn(score, label)
        # stop
        return loss, score

In [45]:

def acc(y_true, y_hat):
    y_hat = torch.argmax(y_hat, dim=-1)
    tot = y_true.shape[0]
    hit = torch.sum(y_true == y_hat)
    return hit.data.float() * 1.0 / tot

In [99]:
# train():
model = NAML(len(category_dict), len(subcategory_dict))
optimizer = optim.Adam(model.parameters(), lr=0.0003)
model = model.cuda()
torch.set_grad_enabled(True)
model.train()

NAML(
  (news_encoder): NewsEncoder(
    (ttEmb): Linear(in_features=768, out_features=768, bias=True)
    (category_emb): Embedding(18, 100, padding_idx=0)
    (category_dense): Linear(in_features=100, out_features=768, bias=True)
    (subcategory_emb): Embedding(265, 100, padding_idx=0)
    (subcategory_dense): Linear(in_features=100, out_features=768, bias=True)
    (final_attn): AttentionPooling(
      (att_fc1): Linear(in_features=768, out_features=200, bias=True)
      (att_fc2): Linear(in_features=200, out_features=1, bias=True)
    )
  )
  (user_encoder): UserEncoder(
    (attn): AttentionPooling(
      (att_fc1): Linear(in_features=768, out_features=200, bias=True)
      (att_fc2): Linear(in_features=200, out_features=1, bias=True)
    )
  )
  (loss_fn): CrossEntropyLoss()
)

In [100]:
for ep in range(5):
    loss = 0.0
    accuary = 0.0
    print("EPOCH: " + str(ep))
    for cnt, (log_ids, log_mask, input_ids, targets) in tqdm(enumerate(dataloader)):
        log_ids = log_ids.cuda()
        log_mask = log_mask.cuda()
        input_ids = input_ids.cuda()
        targets = targets.cuda()

        bz_loss, y_hat = model(log_ids, log_mask, input_ids, targets)
        loss += bz_loss.data.float()
        accuary += acc(targets, y_hat)
        optimizer.zero_grad()
        bz_loss.backward()
        optimizer.step()
        # stop
    print(loss, accuary)

    #     if rank == 0 and cnt != 0 and cnt % args.save_steps == 0:
    #         ckpt_path = os.path.join(args.model_dir, f'epoch-{ep+1}-{cnt}.pt')
    #         torch.save(
    #             {
    #                 'model_state_dict':
    #                     {'.'.join(k.split('.')[1:]): v for k, v in model.state_dict().items()}
    #                     if is_distributed else model.state_dict(),
    #                 'category_dict': category_dict,
    #                 'word_dict': word_dict,
    #                 'subcategory_dict': subcategory_dict
    #             }, ckpt_path)
    #         logging.info(f"Model saved to {ckpt_path}.")

    # logging.info('Training finish.')

    # if rank == 0:
    #     ckpt_path = os.path.join(args.model_dir, f'epoch-{ep+1}.pt')
    #     torch.save(
    #         {
    #             'model_state_dict':
    #                 {'.'.join(k.split('.')[1:]): v for k, v in model.state_dict().items()}
    #                 if is_distributed else model.state_dict(),
    #             'category_dict': category_dict,
    #             'subcategory_dict': subcategory_dict,
    #             'word_dict': word_dict,
    #         }, ckpt_path)
    #     logging.info(f"Model saved to {ckpt_path}.")

EPOCH: 0


1809it [01:51, 16.25it/s]


tensor(2530.5469, device='cuda:0') tensor(781.0856, device='cuda:0')
EPOCH: 1


1809it [01:50, 16.40it/s]


tensor(2424.9500, device='cuda:0') tensor(832.4359, device='cuda:0')
EPOCH: 2


1809it [01:50, 16.33it/s]


tensor(2392.2183, device='cuda:0') tensor(851.2969, device='cuda:0')
EPOCH: 3


1809it [01:50, 16.37it/s]


tensor(2368.9851, device='cuda:0') tensor(864.1279, device='cuda:0')
EPOCH: 4


1809it [01:50, 16.32it/s]

tensor(2349.6792, device='cuda:0') tensor(874.0905, device='cuda:0')





### Testing

In [32]:
test_data_dir = './data/MINDsmall_dev'

In [101]:
model.cuda()
model.eval()
torch.set_grad_enabled(False)
title_encoded = np.load('./data/tt_mat_test.npy')
abs_encoded = np.load('./data/at_mat_test.npy')
tt = np.hstack([title_encoded, abs_encoded])
news, news_index, newsVT_test = read_news(os.path.join(test_data_dir, 'news.tsv'), tt, 'test')

 68%|████████████████████████████████████████████████▉                       | 28820/42416 [00:00<00:00, 133625.40it/s]

264 lifestyleanimals
264 shop-computers-electronics
264 lifestyletravel
17 games
264 games-news
264 newsvideos


100%|████████████████████████████████████████████████████████████████████████| 42416/42416 [00:00<00:00, 130458.02it/s]

264 newstechnology





In [102]:
news_dataset = NewsDataset(newsVT_test)
news_dataloader = DataLoader(news_dataset, batch_size=128)

In [103]:
news_scoring = {}
with torch.no_grad():
    for k, input_ids in tqdm(news_dataloader):
        input_ids = input_ids.cuda()
        news_vec = model.news_encoder(input_ids)
        news_vec = news_vec.to(torch.device("cpu")).detach().numpy()
        for idx, eachK in enumerate(k):
            news_scoring[eachK] = news_vec[idx]


100%|███████████████████████████████████████████████████████████████████████████████| 332/332 [00:00<00:00, 361.05it/s]


In [104]:
data_file_path = os.path.join(test_data_dir, f'behaviors.tsv')

datasetTest = DatasetTest(data_file_path, news_scoring)
dataloaderTest = DataLoader(datasetTest, batch_size=128, collate_fn=collate_fn)

73152it [00:02, 26983.24it/s]


In [105]:
AUC = []
MRR = []
nDCG5 = []
nDCG10 = []
for cnt, (log_vecs, log_mask, news_vecs, labels) in enumerate(dataloaderTest):
    log_vecs = log_vecs.cuda()
    log_mask = log_mask.cuda()

    user_vecs = model.user_encoder(log_vecs, log_mask).to(torch.device("cpu")).detach().numpy()

    for user_vec, news_vec, label in zip(user_vecs, news_vecs, labels):
        if label.mean() == 0 or label.mean() == 1:
            continue

        score = np.dot(news_vec, user_vec)

        auc = roc_auc_score(label, score)
        mrr = mrr_score(label, score)
        ndcg5 = ndcg_score(label, score, k=5)
        ndcg10 = ndcg_score(label, score, k=10)

        AUC.append(auc)
        MRR.append(mrr)
        nDCG5.append(ndcg5)
        nDCG10.append(ndcg10)

    if cnt % 100 == 0:
        print_metrics(cnt, get_mean([AUC, MRR, nDCG5, nDCG10]))

print_metrics(cnt, get_mean([AUC, MRR, nDCG5, nDCG10]))

0 [0.6587261480798572, 0.3255909975479099, 0.3560354229500778, 0.4207368253952161]
100 [0.6515517596013197, 0.3162897624862829, 0.3464775701666316, 0.4085870939247561]


KeyboardInterrupt: 