In [None]:
import os
import sys
os.chdir('../')
sys.path.append('../')

import torch
from utils.utils import train,prepare,evaluate,tune
from models.Interactors import FIM_Interactor, KNRM_Interactor
from models.Encoders.FIM import FIM_Encoder
from models.SFI import SFI_gating, SFI_gating_MultiView
from configs.ManualConfig import hparams

In [None]:
name='sfi'
hparams['k'] = 30
hparams['his_size'] = 50
hparams['select'] = 'gating'
hparams['onehot'] = True
hparams['device'] = 'cuda:0'
# hparams['threshold'] = 0.5

In [None]:
# hparams['validate'] = True
vocab, loaders = prepare(hparams, pin_memory=False)

In [None]:
record = next(iter(loaders[0]))

In [None]:
import logging
import math,random
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.Interactors import FIM_Interactor

class SFI_gating(nn.Module):
    def __init__(self, hparams, encoder, interactor=None):
        super().__init__()

        self.cdd_size = (hparams['npratio'] +
                         1) if hparams['npratio'] > 0 else 1
        self.batch_size = hparams['batch_size']
        self.his_size = hparams['his_size']
        self.signal_length = hparams['title_size']

        self.k = hparams['k']

        # contrasive learning deprecated
        self.contra_num = 0

        self.encoder = encoder
        self.level = encoder.level
        self.hidden_dim = encoder.hidden_dim

        # concatenate category embedding and subcategory embedding

        self.device = hparams['device']
        # elements in the slice along dim will sum up to 1
        self.softmax = nn.Softmax(dim=-1)

        if not interactor:
            self.interactor = FIM_Interactor(self.level)
        else:
            self.interactor = interactor

        final_dim = int(int(self.k / 3) /3) * int(int(self.signal_length / 3) / 3)**2 * 16
        self.learningToRank = nn.Linear(final_dim, 1)

        self.name = '-'.join(['sfi-gating', self.encoder.name, self.interactor.name])

        self.selectionProject = nn.Linear(self.hidden_dim, self.hidden_dim)

        self.lstm = nn.LSTM(self.hidden_dim, self.hidden_dim//2, batch_first=True, bidirectional=True)

        if hasattr(self,'selectionProject'):
            if isinstance(self.selectionProject, nn.Linear):
                    nn.init.xavier_normal_(self.selectionProject.weight)
            else:
                for param in self.selectionProject:
                    if isinstance(param, nn.Linear):
                        nn.init.xavier_normal_(param.weight)

        nn.init.xavier_normal_(self.learningToRank.weight)

        if 'threshold' in hparams and hparams['threshold']:
            threshold = torch.tensor([hparams['threshold']])
            self.register_buffer('threshold', threshold)
            if self.k != self.his_size:
                raise ValueError("K value not matched!")

    def _news_attention(self, cdd_repr, his_repr, his_embedding, his_mask):
        """ apply news-level attention

        Args:
            cdd_repr: tensor of [batch_size, cdd_size, hidden_dim]
            his_repr: tensor of [batch_size, his_size, hidden_dim]
            his_embedding: tensor of [batch_size, his_size, signal_length, level, *]
            his_mask: tensor of [batch_size, his_size, 1]

        Returns:
            his_activated: tensor of [batch_size, cdd_size, k, signal_length, *]
            his_focus: tensor of [batch_size, cdd_size, k, his_size]
            pos_repr: tensor of [batch_size, cdd_size, contra_num, hidden_dim]
            neg_repr: tensor of [batch_size, cdd_size, contra_num, hidden_dim]
        """
        # [bs, cs, hs]
        if hasattr(self, 'threshold'):
            attn_weights = F.normalize(self.selectionProject(cdd_repr), dim=-1).matmul(F.normalize(self.selectionProject(his_repr),dim=-1).transpose(-2,-1))
            # attn_weights = F.normalize(cdd_repr, dim=-1).matmul(F.normalize(his_repr, dim=-1).transpose(-1, -2))
            print(attn_weights[0][1])

            his_activated = his_embedding.unsqueeze(dim=1) * (attn_weights.masked_fill(attn_weights<self.threshold, 0).view(self.batch_size, self.cdd_size, self.k, 1, 1, 1))

            output = (his_activated, None)

        else:
            cdd_repr = self.selectionProject(cdd_repr)
            his_repr = self.selectionProject(his_repr)

            attn_weights = cdd_repr.matmul(his_repr.transpose(-1, -2))

            # Masking off these 0s will force the gumbel_softmax to attend to only non-zero histories.
            # Masking in candidate also cause such a problem, however we donot need to fix it
            
            attn_weights = self.softmax(attn_weights.masked_fill(his_mask.transpose(-1, -2), -float("inf")))
            # attn_weights = self.softmax(attn_weights)

            _, attn_weights_sorted = attn_weights.detach().sort(dim=-1, descending=True)

            # use scatter to map the index tensor to one-hot encoding, this is faster than F.one_hot
            attn_focus = torch.zeros(self.batch_size, self.cdd_size, self.k, self.his_size,device=self.device)
            src = torch.ones(self.batch_size, self.cdd_size, self.k, self.his_size,device=self.device)
            his_focus = attn_focus.scatter(-1, attn_weights_sorted[:, :, :self.k].unsqueeze(dim=-1), src)

            # [bs, cs, k, sl, level, fn]
            his_activated = torch.matmul(his_focus, his_embedding.reshape(
                self.batch_size, 1, self.his_size, -1)).view(self.batch_size, self.cdd_size, self.k, -1, self.level, self.hidden_dim)

            output = (his_activated, his_focus)
        return output

    def _click_predictor(self, fusion_tensors):
        """ calculate batch of click probabolity

        Args:
            fusion_tensors: tensor of [batch_size, cdd_size, *]

        Returns:
            score: tensor of [batch_size, cdd_size], which is normalized click probabilty
        """
        score = self.learningToRank(fusion_tensors).squeeze(dim=-1)
        return score

    def forward_(self, x):
        if x['candidate_title'].shape[0] != self.batch_size:
            self.batch_size = x['candidate_title'].shape[0]

        cdd_news = x['candidate_title'].long().to(self.device)
        cdd_news_embedding, cdd_news_repr = self.encoder(
            cdd_news,
            user_index=x['user_index'].long().to(self.device),
            news_id=x['cdd_id'].long().to(self.device),
            attn_mask=x['candidate_title_pad'].to(self.device))

        his_news = x['clicked_title'].long().to(self.device)
        his_news_embedding, his_news_repr = self.encoder(
            his_news,
            user_index=x['user_index'].long().to(self.device),
            news_id=x['his_id'].long().to(self.device),
            attn_mask=x['clicked_title_pad'].to(self.device))

        output = self._news_attention(
            cdd_news_repr, his_news_repr, his_news_embedding, x['his_mask'].to(self.device))

        if self.interactor.name == 'knrm':
            cdd_pad = x['candidate_title_pad'].float().to(self.device).view(self.batch_size, self.cdd_size, 1, 1, -1, 1)
            if output[1] is not None:
                his_pad = torch.matmul(output[1], x['clicked_title_pad'].float().to(self.device).reshape(
                    self.batch_size, 1, self.his_size, -1)).view(self.batch_size, self.cdd_size, self.k, 1, 1, -1, 1)
            else:
                his_pad = x['clicked_title_pad'].float().to(self.device).view(self.batch_size, 1, self.k, 1, 1, self.signal_length, 1).expand(self.batch_size, self.cdd_size, self.k, 1, 1, self.signal_length, 1)

            fusion_tensors = self.interactor(cdd_news_embedding, output[0], cdd_pad=cdd_pad, his_pad=his_pad)

        elif self.interactor.name == 'fim':
            fusion_tensors = self.interactor(cdd_news_embedding, output[0])

        return self._click_predictor(fusion_tensors)

    def forward(self, x):
        score = self.forward_(x)
        if self.cdd_size > 1:
            score = nn.functional.log_softmax(score, dim=1)
        else:
            score = torch.sigmoid(score)
        return score

In [None]:
encoder = FIM_Encoder(hparams, vocab)
# interactor = KNRM_Interactor()
hparams['name'] = '-'.join([name,encoder.name,hparams['select']])
hparams['k'] = hparams['his_size']
hparams['threshold'] = 0.3
# sfi = SFI_gating_MultiView(hparams, encoder, interactor).to(hparams['device'])
sfi = SFI_gating(hparams, encoder).to(hparams['device'])

sfi.load_state_dict(torch.load('/home/peitian_zhang/Codes/News-Recommendation/data/model_params/sfi-fim-fim-gating/large_epoch5_step8458_[hs=50,topk=50,attrs=title].model', map_location=hparams['device'])['model'])

In [None]:
sfi(record)

In [None]:
a = sfi.selectionProject.weight.clone()

In [None]:
hparams['epochs'] = 1
train(sfi, hparams, loaders)