In [1]:
import os
import sys
os.chdir('../')
sys.path.append('../')

import torch
from utils.utils import train,prepare,evaluate,tune
from models.Interactors import FIM_Interactor, KNRM_Interactor
from models.Encoders.FIM import FIM_Encoder
from models.SFI import SFI_gating, SFI_gating_MultiView
from configs.ManualConfig import hparams

In [2]:
name='sfi'
hparams['k'] = 30
hparams['his_size'] = 50
hparams['select'] = 'gating'
hparams['onehot'] = True
hparams['device'] = 'cuda:0'
# hparams['threshold'] = 0.5

In [3]:
# hparams['validate'] = True
vocab, loaders = prepare(hparams, pin_memory=False)

[2021-04-21 16:26:19,548] INFO (root) Hyper Parameters are
{'scale': 'demo', 'mode': 'train', 'batch_size': 10, 'title_size': 20, 'abs_size': 40, 'his_size': 50, 'vert_num': 18, 'subvert_num': 293, 'npratio': 4, 'dropout_p': 0.2, 'query_dim': 200, 'embedding_dim': 300, 'filter_num': 150, 'value_dim': 16, 'head_num': 16, 'epochs': 8, 'metrics': 'auc,mean_mrr,ndcg@5,ndcg@10', 'device': 'cuda:0', 'attrs': ['title'], 'k': 30, 'select': 'gating', 'save_step': [0], 'news_id': False, 'validate': False, 'interval': 10, 'spadam': True, 'onehot': True, 'val_freq': 1, 'schedule': None}
[2021-04-21 16:26:19,550] INFO (root) preparing dataset...
[2021-04-21 16:26:23,267] INFO (torchtext.vocab) Loading vectors from .vector_cache/glove.840B.300d.txt.pt


In [4]:
record = next(iter(loaders[0]))

In [5]:
import logging
import math,random
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.Interactors import FIM_Interactor
from models.Attention import Attention


class SFI_gating(nn.Module):
    def __init__(self, hparams, encoder, interactor=None):
        super().__init__()

        self.cdd_size = (hparams['npratio'] +
                         1) if hparams['npratio'] > 0 else 1
        self.batch_size = hparams['batch_size']
        self.his_size = hparams['his_size']
        self.signal_length = hparams['title_size']

        self.k = hparams['k']

        # contrasive learning deprecated
        self.contra_num = 0

        self.encoder = encoder
        self.level = encoder.level
        self.hidden_dim = encoder.hidden_dim

        # concatenate category embedding and subcategory embedding

        self.device = hparams['device']
        # elements in the slice along dim will sum up to 1
        self.softmax = nn.Softmax(dim=-1)

        if not interactor:
            self.interactor = FIM_Interactor(self.level)
        else:
            self.interactor = interactor

        final_dim = int(int(self.k / 3) /3) * int(int(self.signal_length / 3) / 3)**2 * 16
        self.learningToRank = nn.Linear(final_dim, 1)

        self.name = '-'.join(['sfi-gating', self.encoder.name, self.interactor.name])

        # self.src = torch.ones(self.batch_size, self.cdd_size, self.k, self.his_size,device=self.device)
        # self.dest = torch.zeros(self.batch_size, self.cdd_size, self.k, self.his_size,device=self.device)
        # self.selectionProject = nn.Linear(self.hidden_dim, self.hidden_dim)
        # self.selectionProject = nn.Sequential(
        #     nn.Linear(self.hidden_dim, 128),
        #     nn.Tanh(),
        #     nn.Linear(128,100)
        # )

        # self.lstm = nn.LSTM(self.hidden_dim, self.hidden_dim//2, batch_first=True, bidirectional=True)

        if hasattr(self,'selectionProject'):
            if isinstance(self.selectionProject, nn.Linear):
                    nn.init.xavier_normal_(self.selectionProject.weight)
            else:
                for param in self.selectionProject:
                    if isinstance(param, nn.Linear):
                        nn.init.xavier_normal_(param.weight)

        nn.init.xavier_normal_(self.learningToRank.weight)

        if 'threshold' in hparams and hparams['threshold']:
            threshold = torch.tensor([hparams['threshold']])
            self.register_buffer('threshold', threshold)
            if self.k != self.his_size:
                raise ValueError("K value not matched!")

    def _news_attention(self, cdd_repr, his_repr, his_embedding, his_mask):
        """ apply news-level attention

        Args:
            cdd_repr: tensor of [batch_size, cdd_size, hidden_dim]
            his_repr: tensor of [batch_size, his_size, hidden_dim]
            his_embedding: tensor of [batch_size, his_size, signal_length, level, *]
            his_mask: tensor of [batch_size, his_size, 1]

        Returns:
            his_activated: tensor of [batch_size, cdd_size, k, signal_length, *]
            his_focus: tensor of [batch_size, cdd_size, k, his_size]
            pos_repr: tensor of [batch_size, cdd_size, contra_num, hidden_dim]
            neg_repr: tensor of [batch_size, cdd_size, contra_num, hidden_dim]
        """
        # [bs, cs, hs]
        if hasattr(self, 'threshold'):
            attn_weights = F.normalize(self.selectionProject(cdd_repr), dim=-1).matmul(F.normalize(self.selectionProject(his_repr),dim=-1).transpose(-2,-1))
            # attn_weights = F.normalize(cdd_repr, dim=-1).matmul(F.normalize(his_repr, dim=-1).transpose(-1, -2))

            his_activated = his_embedding.unsqueeze(dim=1) * (attn_weights.masked_fill(attn_weights<self.threshold, 0).view(self.batch_size, self.cdd_size, self.k, 1, 1, 1))

            output = (his_activated, None)

        else:
            # cdd_repr = self.selectionProject(cdd_repr)
            # his_repr = self.selectionProject(his_repr)

            attn_weights = cdd_repr.matmul(his_repr.transpose(-1, -2))

            # Masking off these 0s will force the gumbel_softmax to attend to only non-zero histories.
            # Masking in candidate also cause such a problem, however we donot need to fix it

            attn_weights = self.softmax(attn_weights.masked_fill(his_mask.transpose(-1, -2), -float("inf")))
            

            # attn_weights = self.softmax(attn_weights)

            _, attn_weights_sorted = attn_weights.detach().sort(dim=-1, descending=True)

            # use scatter to map the index tensor to one-hot encoding, this is faster than F.one_hot
            attn_focus = torch.zeros(self.batch_size, self.cdd_size, self.k, self.his_size,device=self.device)
            src = torch.ones(self.batch_size, self.cdd_size, self.k, self.his_size,device=self.device)
            his_focus = attn_focus.scatter(-1, attn_weights_sorted[:, :, :self.k].unsqueeze(dim=-1), src)

            # [bs, cs, k, sl, level, fn]
            his_activated = torch.matmul(his_focus, his_embedding.reshape(
                self.batch_size, 1, self.his_size, -1)).view(self.batch_size, self.cdd_size, self.k, -1, self.level, self.hidden_dim)

            output = (his_activated, his_focus)
        return output

    def _click_predictor(self, fusion_tensors):
        """ calculate batch of click probabolity

        Args:
            fusion_tensors: tensor of [batch_size, cdd_size, *]

        Returns:
            score: tensor of [batch_size, cdd_size], which is normalized click probabilty
        """
        score = self.learningToRank(fusion_tensors).squeeze(dim=-1)
        return score

    def forward_(self, x):
        if x['candidate_title'].shape[0] != self.batch_size:
            self.batch_size = x['candidate_title'].shape[0]

        cdd_news = x['candidate_title'].long().to(self.device)
        cdd_news_embedding, cdd_news_repr, cdd_news_repr_selection = self.encoder(
            cdd_news,
            user_index=x['user_index'].long().to(self.device),
            news_id=x['cdd_id'].long().to(self.device),
            attn_mask=x['candidate_title_pad'].to(self.device))

        his_news = x['clicked_title'].long().to(self.device)
        his_news_embedding, his_news_repr, his_news_repr_selection = self.encoder(
            his_news,
            user_index=x['user_index'].long().to(self.device),
            news_id=x['his_id'].long().to(self.device),
            attn_mask=x['clicked_title_pad'].to(self.device))

        output = self._news_attention(
            cdd_news_repr_selection, his_news_repr_selection, his_news_embedding, x['his_mask'].to(self.device))

        if self.interactor.name == 'knrm':
            cdd_pad = x['candidate_title_pad'].float().to(self.device).view(self.batch_size, self.cdd_size, 1, 1, -1, 1)
            if output[1] is not None:
                his_pad = torch.matmul(output[1], x['clicked_title_pad'].float().to(self.device).reshape(
                    self.batch_size, 1, self.his_size, -1)).view(self.batch_size, self.cdd_size, self.k, 1, 1, -1, 1)
            else:
                his_pad = x['clicked_title_pad'].float().to(self.device).view(self.batch_size, 1, self.k, 1, 1, self.signal_length, 1).expand(self.batch_size, self.cdd_size, self.k, 1, 1, self.signal_length, 1)

            fusion_tensors = self.interactor(cdd_news_embedding, output[0], cdd_pad=cdd_pad, his_pad=his_pad)

        elif self.interactor.name == 'fim':
            fusion_tensors = self.interactor(cdd_news_embedding, output[0])

        return self._click_predictor(fusion_tensors)

    def forward(self, x):
        score = self.forward_(x)
        if self.cdd_size > 1:
            score = nn.functional.log_softmax(score, dim=1)
        else:
            score = torch.sigmoid(score)
        return score

In [6]:
class FIM_Encoder(nn.Module):
    def __init__(self, hparams, vocab):
        super().__init__()
        self.name = 'fim'

        self.kernel_size = 3
        # self.level = 3
        self.level = 4

        # concatenate category embedding and subcategory embedding
        self.hidden_dim = hparams['filter_num']
        self.embedding_dim = hparams['embedding_dim']

        # pretrained embedding
        self.embedding = nn.Embedding.from_pretrained(vocab.vectors,sparse=True,freeze=False)

        self.ReLU = nn.ReLU()
        self.LayerNorm = nn.LayerNorm(self.hidden_dim)
        self.DropOut = nn.Dropout(p=hparams['dropout_p'])

        self.query_words = nn.Parameter(torch.randn(
            (1, self.hidden_dim), requires_grad=True))
        self.query_levels = nn.Parameter(torch.randn(
            (1, self.hidden_dim), requires_grad=True))

        self.CNN_d1 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.hidden_dim,
                                kernel_size=self.kernel_size, dilation=1, padding=1)
        self.CNN_d2 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.hidden_dim,
                                kernel_size=self.kernel_size, dilation=2, padding=2)
        self.CNN_d3 = nn.Conv1d(in_channels=self.embedding_dim, out_channels=self.hidden_dim,
                                kernel_size=self.kernel_size, dilation=3, padding=3)

        self.lstm = nn.LSTM(self.embedding_dim, self.hidden_dim, batch_first=True)
        self.selectionProject = nn.Linear(self.hidden_dim, self.hidden_dim)


        self.device = hparams['device']
        self.attrs = hparams['attrs']

        nn.init.xavier_normal_(self.CNN_d1.weight)
        nn.init.xavier_normal_(self.CNN_d2.weight)
        nn.init.xavier_normal_(self.CNN_d3.weight)
        nn.init.xavier_normal_(self.selectionProject.weight)

        for i in self.lstm.all_weights:
            for j in i:
                if len(j.size()) > 1:
                    nn.init.orthogonal_(j)

    def _HDC(self, news_embedding_set):
        """ stack 1d CNN with dilation rate expanding from 1 to 3

        Args:
            news_embedding_set: tensor of [set_size, signal_length, embedding_dim]

        Returns:
            news_embedding_dilations: tensor of [set_size, signal_length, levels(3), filter_num]
        """

        # don't know what d_0 meant in the original paper
        news_embedding_dilations = torch.zeros(
            (news_embedding_set.shape[0], news_embedding_set.shape[1], 3, self.hidden_dim), device=self.device)

        # news_embedding_seq,_ = self.lstm(news_embedding_set)
        # news_embedding_dilations[:,:,0,:] = news_embedding_seq

        news_embedding_set = news_embedding_set.transpose(-2,-1)

        news_embedding_d1 = self.CNN_d1(news_embedding_set)
        news_embedding_d1 = self.LayerNorm(news_embedding_d1.transpose(-2,-1))
        news_embedding_dilations[:,:,0,:] = self.ReLU(news_embedding_d1)

        news_embedding_d2 = self.CNN_d2(news_embedding_set)
        news_embedding_d2 = self.LayerNorm(news_embedding_d2.transpose(-2,-1))
        news_embedding_dilations[:,:,1,:] = self.ReLU(news_embedding_d2)

        news_embedding_d3 = self.CNN_d3(news_embedding_set)
        news_embedding_d3 = self.LayerNorm(news_embedding_d3.transpose(-2,-1))
        news_embedding_dilations[:,:,2,:] = self.ReLU(news_embedding_d3)

        return news_embedding_dilations

    def forward(self, news_batch, **kwargs):
        """ encode set of news to news representation

        Args:
            news_batch: batch of news tokens, of size [batch_size, *, signal_length]

        Returns:
            news_embedding: hidden vector of each token in news, of size [batch_size, *, signal_length, level, hidden_dim]
            news_repr: hidden vector of each news, of size [batch_size, *, hidden_dim]
        """
        news_embedding_pretrained = self.DropOut(
            self.embedding(news_batch)).view(-1, news_batch.shape[2], self.embedding_dim)
        news_embedding = self._HDC(news_embedding_pretrained).view(
            news_batch.shape + (self.level-1, self.hidden_dim))
        news_embedding_attn = Attention.ScaledDpAttention(
            self.query_levels, news_embedding, news_embedding).squeeze(dim=-2)
        news_repr = Attention.ScaledDpAttention(self.query_words, news_embedding_attn, news_embedding_attn).squeeze(
            dim=-2).view(news_batch.shape[0], news_batch.shape[1], self.hidden_dim)
        news_repr_selection = self.selectionProject(news_repr)
        news_embedding_selection,_ = self.lstm(news_embedding_pretrained.view(-1, news_batch.shape[-1], self.embedding_dim),(news_repr_selection.view(1,-1,self.hidden_dim), torch.zeros((1,news_batch.size(0)*news_batch.size(1),self.hidden_dim),device=self.device)))
        news_embedding = torch.cat([news_embedding_selection.view(news_batch.shape + (1,self.hidden_dim)),news_embedding],dim=-2)

        return news_embedding, news_repr, news_repr_selection

In [7]:
encoder = FIM_Encoder(hparams, vocab)
# interactor = KNRM_Interactor()

hparams['name'] = '-'.join([name,encoder.name,hparams['select']])

# sfi = SFI_gating_MultiView(hparams, encoder, interactor).to(hparams['device'])
sfi = SFI_gating(hparams, encoder).to(hparams['device'])

# sfi.load_state_dict(torch.load('/home/peitian_zhang/Codes/News-Recommendation/data/model_params/sfi-fim-fim-gating/large_epoch4_step33832_[hs=50,topk=30,attrs=title].model', map_location=hparams['device'])['model'])

In [8]:
a = sfi.encoder.selectionProject.weight.clone().detach()

In [11]:
a

tensor([[-0.0928,  0.0415, -0.1030,  ..., -0.1363, -0.0342,  0.0186],
        [-0.0662,  0.0276,  0.1351,  ...,  0.0953,  0.0177, -0.1659],
        [ 0.0164, -0.0031,  0.0452,  ...,  0.0770,  0.0047,  0.0582],
        ...,
        [-0.0755, -0.1510,  0.0390,  ..., -0.1298,  0.0647, -0.0166],
        [-0.1098,  0.0861,  0.0469,  ..., -0.0462, -0.0411,  0.0489],
        [ 0.1958, -0.1618, -0.1387,  ...,  0.0561, -0.0147, -0.0092]],
       device='cuda:0')

In [47]:
sfi(record)

tensor([[-1.2239, -1.4873, -1.5518, -1.9387, -2.0860],
        [-1.8369, -2.2684, -1.4035, -1.4035, -1.4035],
        [-2.0791, -1.3006, -1.6752, -1.1655, -2.2678],
        [-2.4069, -1.2226, -1.8529, -1.1013, -2.0698],
        [-2.4358, -1.1874, -1.9216, -2.0400, -1.1055],
        [-1.3759, -2.6017, -1.3875, -1.7756, -1.3697],
        [-1.9593, -1.8125, -1.4701, -2.0441, -1.0895],
        [-1.1525, -1.3465, -1.8123, -1.7633, -2.4164],
        [-1.2692, -1.9558, -1.0408, -2.2987, -2.0881],
        [-1.9164, -0.8806, -1.8325, -3.0058, -1.4749]], device='cuda:0',
       grad_fn=<LogSoftmaxBackward>)

In [9]:
hparams['epochs'] = 1
train(sfi, hparams, loaders)

[2021-04-21 16:26:37,578] INFO (root) training...
epoch 1 , step 290 , loss: 1.5594: 100%|██████████| 295/295 [00:20<00:00, 14.45it/s]
[2021-04-21 16:26:58,968] INFO (root) saved model of step 0, epoch 1 at data/model_params/sfi-fim-gating/demo_epoch1_step0_[hs=50,topk=30,attrs=title].model
[2021-04-21 16:26:59,055] INFO (root) evaluating...
100%|██████████| 1812/1812 [00:46<00:00, 39.21it/s]
[2021-04-21 16:27:45,786] INFO (root) evaluation results:{'auc': 0.5432, 'mean_mrr': 0.2457, 'ndcg@5': 0.2554, 'ndcg@10': 0.3265, 'epoch': 1, 'step': 0}


SFI_gating(
  (encoder): FIM_Encoder(
    (embedding): Embedding(54316, 300, sparse=True)
    (ReLU): ReLU()
    (LayerNorm): LayerNorm((150,), eps=1e-05, elementwise_affine=True)
    (DropOut): Dropout(p=0.2, inplace=False)
    (CNN_d1): Conv1d(300, 150, kernel_size=(3,), stride=(1,), padding=(1,))
    (CNN_d2): Conv1d(300, 150, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
    (CNN_d3): Conv1d(300, 150, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
    (lstm): LSTM(300, 150, batch_first=True)
    (selectionProject): Linear(in_features=150, out_features=150, bias=True)
  )
  (softmax): Softmax(dim=-1)
  (interactor): FIM_Interactor(
    (SeqCNN3D): Sequential(
      (0): Conv3d(4, 32, kernel_size=[3, 3, 3], stride=(1, 1, 1), padding=(1, 1, 1))
      (1): ReLU()
      (2): MaxPool3d(kernel_size=[3, 3, 3], stride=[3, 3, 3], padding=0, dilation=1, ceil_mode=False)
      (3): Conv3d(32, 16, kernel_size=[3, 3, 3], stride=(1, 1, 1), padding=(1, 1, 1))
      (4): 

In [10]:
c= a==sfi.encoder.selectionProject.weight

In [12]:
sfi.encoder.selectionProject.weight

Parameter containing:
tensor([[-0.0823,  0.0512, -0.0861,  ..., -0.1184, -0.0415,  0.0027],
        [-0.0967,  0.0019,  0.0927,  ...,  0.0536, -0.0065, -0.1693],
        [-0.0074, -0.0223,  0.0086,  ...,  0.0305,  0.0033,  0.0437],
        ...,
        [-0.0722, -0.1471,  0.0411,  ..., -0.1110,  0.0561, -0.0364],
        [-0.0979,  0.0891,  0.0523,  ..., -0.0213, -0.0411,  0.0343],
        [ 0.2000, -0.1572, -0.1392,  ...,  0.0425, -0.0041,  0.0112]],
       device='cuda:0', requires_grad=True)

In [13]:
a[c==False]

tensor([-0.0928,  0.0415, -0.1030,  ...,  0.0561, -0.0147, -0.0092],
       device='cuda:0')