In [1]:
import argparse
import numpy as np
from torch import nn
from src.config import TrainConfig 
from src.ultis import *
from src.data_helper import prepare_preprocessed_data
from src.data_load import *
from src.metrics import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch.nn.functional as F
import torch.optim as optim

In [3]:
from torch_geometric.data import Data, Batch
from torch_geometric.utils import subgraph

from torch.utils.data import IterableDataset
from torch_geometric.loader import DataLoader as GraphDataLoader

In [5]:
cfg = TrainConfig

In [6]:
# device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device: torch.device = torch.device("cpu")

In [7]:
set_random_seed(cfg.random_seed)

In [8]:
logging.info("Start")
"""
0. Definite Parameters & Functions

"""
logging.info("Prepare the dataset")
prepare_preprocessed_data(cfg)

2024-10-18 18:42:00,271 INFO Start
2024-10-18 18:42:00,272 INFO Prepare the dataset


Target_file is not exist. New behavior file in ./data/MINDsmall_train\behaviors_np4_0.tsv
./data/MINDsmall_train\behaviors_np4_0.tsv ./data/MINDsmall_train\behaviors.tsv


156965it [00:02, 53810.46it/s]


[train]Writing files...
Target_file is not exist. New behavior file in ./data/MINDsmall_val\behaviors_np4_0.tsv
./data/MINDsmall_val\behaviors_np4_0.tsv ./data/MINDsmall_val\behaviors.tsv


73152it [00:00, 645371.19it/s]

[val]Writing files...



[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Admin\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[train]Processing raw news: 100%|██████████████████████████████████████████████| 51282/51282 [00:05<00:00, 8549.81it/s]
Processing parsed news: 100%|████████████████████████████████████████████████| 51282/51282 [00:00<00:00, 229964.32it/s]
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Admin\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Glove token preprocess finish.


[val]Processing raw news: 100%|████████████████████████████████████████████████| 42416/42416 [00:04<00:00, 9529.32it/s]
Processing parsed news: 100%|████████████████████████████████████████████████| 65238/65238 [00:00<00:00, 227297.69it/s]


Glove token preprocess finish.


[train] Processing behaviors news to News Graph: 100%|█████████████████████| 156965/156965 [00:00<00:00, 257329.37it/s]
Processing news edge list: 100%|█████████████████████████████████████████████| 48304/48304 [00:00<00:00, 284088.86it/s]


Data(x=[51283, 22], edge_index=[2, 632044], edge_attr=[632044], num_nodes=51283)
[train] Finish News Graph Construction, 
Graph Path: data\MINDsmall_train\nltk_news_graph.pt 
Graph Info: Data(x=[51283, 22], edge_index=[2, 632044], edge_attr=[632044], num_nodes=51283)
[val] Finish nltk News Graph Construction, 
Graph Path: data\MINDsmall_val\nltk_news_graph.pt
Graph Info: Data(x=[65239, 22], edge_index=[2, 632044], edge_attr=[632044], num_nodes=65239)
[train] Start to process neighbors list
[train] Finish news Neighbor dict 
Dict Path: data\MINDsmall_train\news_neighbor_dict.bin, 
Weight Dict: data\MINDsmall_train\news_weights_dict.bin
[val] Start to process neighbors list
[val] Finish news Neighbor dict 
Dict Path: data\MINDsmall_val\news_neighbor_dict.bin, 
Weight Dict: data\MINDsmall_val\news_weights_dict.bin
news_graph, Data(x=[51283, 22], edge_index=[2, 632044], edge_attr=[632044], num_nodes=51283)
entity_indices,  (51283, 5)
[train] Finish Entity Graph Construction, 
 Graph Path: 

In [9]:
mode='train'
data_dir = {"train": cfg.data_dir + '_train', "val": cfg.data_dir + '_val', "test": cfg.data_dir}

# ------------- load news.tsv-------------
news_index = pickle.load(open(Path(data_dir[mode]) / "news_dict.bin", "rb"))

news_input = pickle.load(open(Path(data_dir[mode]) / "nltk_token_news.bin", "rb"))

In [10]:
target_file = Path(data_dir[mode]) / f"behaviors_np{cfg.npratio}_0.tsv"
news_graph = torch.load(Path(data_dir[mode]) / "nltk_news_graph.pt")

if cfg.directed is False:
    news_graph.edge_index, news_graph.edge_attr = to_undirected(news_graph.edge_index, news_graph.edge_attr)
print(f"[{mode}] News Graph Info: {news_graph}")

news_neighbors_dict = pickle.load(open(Path(data_dir[mode]) / "news_neighbor_dict.bin", "rb"))

[train] News Graph Info: Data(x=[51283, 22], edge_index=[2, 1171966], edge_attr=[1171966], num_nodes=51283)


In [11]:
class Dataset_PANEL_1(Dataset):
    def __init__(self, filename, news_index, news_combined, cfg, neighbor_dict, news_graph):
        super(Dataset_PANEL_1).__init__()
        self.filename = filename
        self.news_index = news_index
        self.news_combined = news_combined
        self.user_log_length = cfg.his_size
        self.npratio = cfg.npratio
        self.cfg = cfg
        self.neighbor_dict = neighbor_dict
        self.news_graph = news_graph
        self.news_graph.x = self.news_graph.x.float()
        self.prepare()

    def trans_to_nindex(self, nids):
        return [self.news_index[i] if i in self.news_index else 0 for i in nids]

    def pad_to_fix_len(self, x, fix_length, padding_front=True, padding_value=0):
        if padding_front:
            pad_x = [padding_value] * (fix_length - len(x)) + x[-fix_length:]
            mask = [0] * (fix_length - len(x)) + [1] * min(fix_length, len(x))
        else:
            pad_x = x[-fix_length:] + [padding_value] * (fix_length - len(x))
            mask = [1] * min(fix_length, len(x)) + [0] * (fix_length - len(x))
        return pad_x, np.array(mask, dtype='float32')

    def build_k_hop(self, click_doc):
        click_idx = [x for x in click_doc]
        source_idx = [x for x in click_doc]
        for _ in range(self.cfg.k_hops) :
            current_hop_idx = []
            for news_idx in source_idx:
                current_hop_idx.extend(self.neighbor_dict[news_idx][:self.cfg.num_neighbors])
            source_idx = current_hop_idx
            click_idx.extend(current_hop_idx)
        return list(set(click_idx))
        
    def prepare(self):
        self.preprocessDT = []
        with open(self.filename) as f:
            for line in tqdm(f):
                g, dt = self.line_mapper(line)
                if len(g) == 0:
                    continue
                self.preprocessDT.append([g,dt])
                if len(self.preprocessDT) > 10000:
                    break
    
    def line_mapper(self, line):
        line = line.strip().split('\t')
        click_docs = line[3].split()
        sess_pos = line[4].split()
        sess_neg = line[5].split()
        click_docs = self.trans_to_nindex(click_docs)

        # build sub-graph
        k_hops_click = self.build_k_hop(click_docs)
        
        # subemb = self.news_graph.x[k_hops_click]
        # sub_edge_index, sub_edge_attr = subgraph(k_hops_click, self.news_graph.edge_index, self.news_graph.edge_attr, \
                                                 # relabel_nodes=True, num_nodes=self.news_graph.num_nodes)
        # sub_news_graph = Data(x=subemb, edge_index=sub_edge_index, edge_attr=sub_edge_attr)

        
        click_docs, log_mask = self.pad_to_fix_len(click_docs, self.user_log_length)
        user_feature = self.news_combined[click_docs]

        pos = self.trans_to_nindex(sess_pos)
        neg = self.trans_to_nindex(sess_neg)

        label = random.randint(0, self.npratio)
        sample_news = neg[:label] + pos + neg[label:]
        news_feature = self.news_combined[sample_news]
        return k_hops_click, [torch.from_numpy(user_feature), torch.from_numpy(log_mask), \
        torch.from_numpy(news_feature), torch.tensor(label)]

    # def __iter__(self):
    #     file_iter = open(self.filename)
    #     return map(self.line_mapper, file_iter)

    def __getitem__(self, idx):
        k_hops_click, dt =  self.preprocessDT[idx]
        subemb = self.news_graph.x[k_hops_click]
        sub_edge_index, sub_edge_attr = subgraph(k_hops_click, self.news_graph.edge_index, self.news_graph.edge_attr, \
                                                 relabel_nodes=True, num_nodes=self.news_graph.num_nodes)
        sub_news_graph = Data(x=subemb, edge_index=sub_edge_index, edge_attr=sub_edge_attr).cuda()
        return sub_news_graph, dt

    def __len__(self):
        return len(self.preprocessDT)

In [12]:
dataset = Dataset_PANEL_1(
                filename=target_file,
                news_index=news_index,
                news_combined=news_input,
                cfg=cfg,
                neighbor_dict=news_neighbors_dict,
                news_graph=news_graph
)
dataloader = GraphDataLoader(dataset, batch_size=128)

10172it [00:01, 7344.24it/s]


In [13]:
iterator = iter(dataloader)
data_batch = next(iterator)
sub_news_graph, [user_feature, log_mask, news_feature, label] = data_batch
user_feature.shape

torch.Size([128, 100, 22])

In [14]:
max(sub_news_graph.batch)

tensor(127, device='cuda:0')

In [14]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

In [72]:
# x, edge_index, batch = sub_news_graph.x, sub_news_graph.edge_index, sub_news_graph.batch
# x = x.float()
# x = tmpGCN(x, edge_index)
# x = global_mean_pool(x, batch)
# x.shape


In [15]:
class AttentionPooling(nn.Module):
    def __init__(self, emb_size, hidden_size):
        super(AttentionPooling, self).__init__()
        self.att_fc1 = nn.Linear(emb_size, hidden_size)
        self.att_fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x, attn_mask=None):
        """
        Args:
            x: batch_size, candidate_size, emb_dim
            attn_mask: batch_size, candidate_size
        Returns:
            (shape) batch_size, emb_dim
        """
        e = self.att_fc1(x)
        e = nn.Tanh()(e)
        alpha = self.att_fc2(e)
        alpha = torch.exp(alpha)

        if attn_mask is not None:
            alpha = alpha * attn_mask.unsqueeze(2)

        alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)
        if len(x.shape) == 3:
            x = torch.bmm(x.permute(0, 2, 1), alpha).squeeze(dim=-1)
        else:
            x = torch.bmm(x.unsqueeze(-1), alpha.unsqueeze(-1)).squeeze(dim=-1)
        return x


class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k

    def forward(self, Q, K, V, attn_mask=None):
        '''
            Q: batch_size, n_head, candidate_num, d_k
            K: batch_size, n_head, candidate_num, d_k
            V: batch_size, n_head, candidate_num, d_v
            attn_mask: batch_size, n_head, candidate_num
            Return: batch_size, n_head, candidate_num, d_v
        '''
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(self.d_k)
        scores = torch.exp(scores)

        if attn_mask is not None:
            scores = scores * attn_mask.unsqueeze(dim=-2)

        attn = scores / (torch.sum(scores, dim=-1, keepdim=True) + 1e-8)
        context = torch.matmul(attn, V)
        return context


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads, d_k, d_v):
        super(MultiHeadSelfAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v

        self.W_Q = nn.Linear(d_model, d_k * n_heads)
        self.W_K = nn.Linear(d_model, d_k * n_heads)
        self.W_V = nn.Linear(d_model, d_v * n_heads)

        self.scaled_dot_product_attn = ScaledDotProductAttention(self.d_k)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=1)

    def forward(self, Q, K, V, mask=None):
        '''
            Q: batch_size, candidate_num, d_model
            K: batch_size, candidate_num, d_model
            V: batch_size, candidate_num, d_model
            mask: batch_size, candidate_num
        '''
        batch_size = Q.shape[0]
        if mask is not None:
            mask = mask.unsqueeze(dim=1).expand(-1, self.n_heads, -1)

        q_s = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        k_s = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        v_s = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)

        context = self.scaled_dot_product_attn(q_s, k_s, v_s, mask)
        output = context.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_v)
        return output

In [16]:
class NewsEncoder(nn.Module):
    def __init__(self, embedding_matrix, num_category, num_subcategory):
        super(NewsEncoder, self).__init__()
        self.embedding_matrix = embedding_matrix
        self.drop_rate = 0.2
        self.num_words_title = 20
        self.use_category = True
        self.use_subcategory = True
        category_emb_dim = 100
        news_dim = 400
        news_query_vector_dim = 200
        word_embedding_dim = 300
        self.category_emb = nn.Embedding(num_category + 1, category_emb_dim, padding_idx=0)
        self.category_dense = nn.Linear(category_emb_dim, news_dim)
        self.subcategory_emb = nn.Embedding(num_subcategory + 1, category_emb_dim, padding_idx=0)
        self.subcategory_dense = nn.Linear(category_emb_dim, news_dim)
        self.final_attn = AttentionPooling(news_dim, news_query_vector_dim)
        self.cnn = nn.Conv1d(
            in_channels=word_embedding_dim,
            out_channels=news_dim,
            kernel_size=3,
            padding=1
        )
        self.attn = AttentionPooling(news_dim, news_query_vector_dim)

    def forward(self, x, mask=None):
        '''
            x: batch_size, word_num
            mask: batch_size, word_num
        '''
        title = torch.narrow(x, -1, 0, self.num_words_title).long()
        word_vecs = F.dropout(self.embedding_matrix(title),
                              p=self.drop_rate,
                              training=self.training)
        context_word_vecs = self.cnn(word_vecs.transpose(1, 2)).transpose(1, 2)
        
        title_vecs = self.attn(context_word_vecs, mask)
        all_vecs = [title_vecs]

        start = self.num_words_title
        if self.use_category:
            category = torch.narrow(x, -1, start, 1).squeeze(dim=-1).long()
            category_vecs = self.category_dense(self.category_emb(category))
            all_vecs.append(category_vecs)
            start += 1
        if self.use_subcategory:
            subcategory = torch.narrow(x, -1, start, 1).squeeze(dim=-1).long()
            subcategory_vecs = self.subcategory_dense(self.subcategory_emb(subcategory))
            all_vecs.append(subcategory_vecs)

        if len(all_vecs) == 1:
            news_vecs = all_vecs[0]
        else:
            all_vecs = torch.stack(all_vecs, dim=1)
            
            news_vecs = self.final_attn(all_vecs)
        return news_vecs

In [17]:
class UserEncoder(nn.Module):
    def __init__(self):
        super(UserEncoder, self).__init__()
        news_dim = 400
        user_query_vector_dim = 200
        self.user_log_length = 50
        self.user_log_mask = False
        self.attn = AttentionPooling(news_dim, user_query_vector_dim)
        self.pad_doc = nn.Parameter(torch.empty(1, news_dim).uniform_(-1, 1)).type(torch.FloatTensor)

    def forward(self, news_vecs, log_mask=None):
        '''
            news_vecs: batch_size, history_num, news_dim
            log_mask: batch_size, history_num
        '''
        bz = news_vecs.shape[0]
        if self.user_log_mask:
            user_vec = self.attn(news_vecs, log_mask)
        else:
            padding_doc = self.pad_doc.unsqueeze(dim=0).expand(bz, self.user_log_length, -1)
            news_vecs = news_vecs * log_mask.unsqueeze(dim=-1) + padding_doc * (1 - log_mask.unsqueeze(dim=-1))
            user_vec = self.attn(news_vecs)
        return user_vec

In [18]:
class NAML(torch.nn.Module):
    def __init__(self, embedding_matrix, num_category, num_subcategory, **kwargs):
        super(NAML, self).__init__()
        pretrained_word_embedding = torch.from_numpy(embedding_matrix).float()
        word_embedding = nn.Embedding.from_pretrained(pretrained_word_embedding,
                                                      freeze=False,
                                                      padding_idx=0)

        self.news_encoder = NewsEncoder( word_embedding, num_category, num_subcategory)
        self.user_encoder = UserEncoder()
        self.outG = 32
        self.gcn = GCNConv(22,self.outG)
        self.news_dim = 400
        self.attn = AttentionPooling(self.outG+self.news_dim, 128)
        self.ln = nn.Linear(self.outG+self.news_dim, self.news_dim)
        self.gln = nn.Linear(self.outG, self.outG)
        self.loss_fn = nn.CrossEntropyLoss()
        self.npratio = 4
        self.user_log_length = 50

    def forward(self, graph_batch, history, history_mask, candidate, label):
        '''
            history: batch_size, history_length, num_word_title
            history_mask: batch_size, history_length
            candidate: batch_size, 1+K, num_word_title
            label: batch_size, 1+K
        '''
        graph_vec, edge_index, batch = graph_batch.x, graph_batch.edge_index, graph_batch.batch
        print(graph_vec.shape)
        print(history.shape)
        stop
        graph_vec = self.gcn(graph_vec, edge_index)
        graph_vec = graph_vec.relu()
        graph_vec = global_mean_pool(graph_vec, batch)
        graph_vec = F.dropout(graph_vec, p=0.2, training=self.training)
        graph_vec = self.gln(graph_vec)
        num_words = history.shape[-1]
        candidate_news = candidate.reshape(-1, num_words)
        candidate_news_vecs = self.news_encoder(candidate_news).reshape(-1, 1 + self.npratio, self.news_dim)
        history_news = history.reshape(-1, num_words)
        history_news_vecs = self.news_encoder(history_news).reshape(-1, self.user_log_length, self.news_dim)
        user_vec = self.user_encoder(history_news_vecs, history_mask)
        uservec = torch.cat((user_vec, graph_vec), dim=1)
        uservec = self.attn(uservec)
        uservec = self.ln(uservec)
        score = torch.bmm(candidate_news_vecs, user_vec.unsqueeze(dim=-1)).squeeze(dim=-1)
        loss = self.loss_fn(score, label)
        return loss, score

In [19]:
def acc(y_true, y_hat):
    y_hat = torch.argmax(y_hat, dim=-1)
    tot = y_true.shape[0]
    hit = torch.sum(y_true == y_hat)
    return hit.data.float() * 1.0 / tot

In [20]:
category_dict = pickle.load(open(os.path.join(cfg.data_dir + '_train', "category_dict.bin"), "rb"))
subcategory_dict = pickle.load(open(os.path.join(cfg.data_dir + '_train', "subcategory_dict.bin"), "rb"))
word_dict = pickle.load(open(os.path.join(cfg.data_dir + '_train', "word_dict.bin"), "rb"))
glove_emb = load_pretrain_emb(cfg.glove_path, word_dict, cfg.word_emb_dim)
len(word_dict)

-----------------------------------------------------
Dict length: 12506
Have words: 11947
Missing rate: 0.0446985446985447


12506

In [21]:
model = NAML(glove_emb, len(category_dict), len(subcategory_dict))
optimizer = optim.Adam(model.parameters(), lr=0.0003)
# model = model.to("cuda")
torch.set_grad_enabled(True)
model.train()

NAML(
  (news_encoder): NewsEncoder(
    (embedding_matrix): Embedding(12507, 300, padding_idx=0)
    (category_emb): Embedding(18, 100, padding_idx=0)
    (category_dense): Linear(in_features=100, out_features=400, bias=True)
    (subcategory_emb): Embedding(265, 100, padding_idx=0)
    (subcategory_dense): Linear(in_features=100, out_features=400, bias=True)
    (final_attn): AttentionPooling(
      (att_fc1): Linear(in_features=400, out_features=200, bias=True)
      (att_fc2): Linear(in_features=200, out_features=1, bias=True)
    )
    (cnn): Conv1d(300, 400, kernel_size=(3,), stride=(1,), padding=(1,))
    (attn): AttentionPooling(
      (att_fc1): Linear(in_features=400, out_features=200, bias=True)
      (att_fc2): Linear(in_features=200, out_features=1, bias=True)
    )
  )
  (user_encoder): UserEncoder(
    (attn): AttentionPooling(
      (att_fc1): Linear(in_features=400, out_features=200, bias=True)
      (att_fc2): Linear(in_features=200, out_features=1, bias=True)
    )
 

In [22]:
for ep in range(6):
    loss = 0.0
    accuary = 0.0
    print("EPOCH: " + str(ep))
    for cnt, (g, [log_ids, log_mask, input_ids, targets]) in tqdm(enumerate(dataloader)):
        # log_ids = log_ids.cuda()
        # log_mask = log_mask.cuda()
        # input_ids = input_ids.cuda()
        # targets = targets.cuda()

        bz_loss, y_hat = model(g, log_ids, log_mask, input_ids, targets)
        loss += bz_loss.data.float()
        accuary += acc(targets, y_hat)
        optimizer.zero_grad()
        bz_loss.backward()
        optimizer.step()
        # stop
    print(loss, accuary)

EPOCH: 0


0it [00:00, ?it/s]

torch.Size([29297, 22])
torch.Size([128, 100, 22])





NameError: name 'stop' is not defined

In [20]:
# torch.save(model.state_dict(), 'Graph_naml_model.pth')

In [21]:
model.load_state_dict(torch.load('Graph_naml_model.pth'))

<All keys matched successfully>

# Validate

In [22]:
from torch.utils.data import DataLoader

In [39]:
model.eval()
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x14e82388be0>

In [24]:
mode = "val"
data_dir = {"train": cfg.data_dir + '_train', "val": cfg.data_dir + '_val', "test": cfg.data_dir}

In [25]:
news_index = pickle.load(open(Path(data_dir[mode]) / "news_dict.bin", "rb"))
news_input = pickle.load(open(Path(data_dir[mode]) / "nltk_token_news.bin", "rb"))
data_dir[mode]

'./data/MINDsmall_val'

In [26]:
news_dataset = NewsDataset(news_input)
news_dataloader = DataLoader(news_dataset, batch_size=128)

## news -> scoring

In [27]:
news_scoring = []
with torch.no_grad():
    for input_ids in tqdm(news_dataloader):
        input_ids = input_ids.cuda()
        news_vec = model.news_encoder(input_ids)
        news_vec = news_vec.to(torch.device("cpu")).detach().numpy()
        news_scoring.extend(news_vec)

news_scoring = np.array(news_scoring)

100%|████████████████████████████████████████████████████████████████████████████████| 510/510 [00:06<00:00, 82.39it/s]


## Val loader and compute score

In [28]:
class ValidDataset_PANEL_1(Dataset_PANEL_1):
    def __init__(self, filename, news_index, news_score, cfg, neighbor_dict, news_graph):
        super(Dataset_PANEL_1).__init__()
        self.filename = filename
        self.news_index = news_index
        self.news_score = news_score
        self.user_log_length = cfg.his_size
        self.npratio = cfg.npratio
        self.cfg = cfg
        self.neighbor_dict = neighbor_dict
        self.news_graph = news_graph
        self.news_graph.x = self.news_graph.x.float()
        self.prepare()
        

    def line_mapper(self, line):
        line = line.strip().split('\t')
        click_docs = line[3].split()
        
        candidate_news = self.trans_to_nindex([i.split('-')[0] for i in line[4].split()])
        label = np.array([int(i.split('-')[1]) for i in line[4].split()])
        
        click_docs = self.trans_to_nindex(click_docs)

        # build sub-graph
        k_hops_click = self.build_k_hop(click_docs)
        
        click_docs, log_mask = self.pad_to_fix_len(click_docs, self.user_log_length)
        user_feature = self.news_score[click_docs]

        news_feature = self.news_score[candidate_news]
        
        return k_hops_click,  [torch.from_numpy(user_feature), torch.from_numpy(log_mask), \
        torch.from_numpy(news_feature), torch.tensor(label)]


In [29]:
valid_target_file = Path(data_dir[mode]) / f"behaviors.tsv"
valid_target_file

WindowsPath('data/MINDsmall_val/behaviors.tsv')

In [41]:
news_graph = torch.load(Path(data_dir[mode]) / "nltk_news_graph.pt")
news_neighbors_dict = pickle.load(open(Path(data_dir[mode]) / "news_neighbor_dict.bin", "rb"))

In [42]:
valid_dataset = ValidDataset_PANEL_1(
                filename=valid_target_file,
                news_index=news_index,
                news_score=news_scoring,
                cfg=cfg,
                neighbor_dict=news_neighbors_dict,
                news_graph=news_graph
)
valid_dataloader = GraphDataLoader(valid_dataset, batch_size=1)

73152it [00:40, 1800.33it/s]


In [32]:
iterator = iter(valid_dataloader)
data_batch = next(iterator)

In [33]:
g,[ uf, lm, nf, l] = data_batch


In [34]:
AUC = []
MRR = []
nDCG5 = []
nDCG10 = []

def print_metrics(cnt, x):
    print(cnt, x)

def get_mean(arr):
    return [np.array(i).mean() for i in arr]

def get_sum(arr):
    return [np.array(i).sum() for i in arr]

In [43]:
for cnt, (g, [log_vecs, log_mask, news_vecs, labels]) in enumerate(valid_dataloader):
    log_vecs = log_vecs.cuda()
    log_mask = log_mask.cuda()
    graph_vec, edge_index, batch = g.x, g.edge_index, g.batch
    graph_vec = model.gcn(graph_vec, edge_index)
    graph_vec = graph_vec.relu()
    graph_vec = global_mean_pool(graph_vec, batch)
    graph_vec = model.gln(graph_vec)

    user_vecs = model.user_encoder(log_vecs, log_mask)
    user_vecs = torch.cat((user_vecs, graph_vec), dim=1)
    user_vecs = model.attn(user_vecs)
    user_vecs = model.ln(user_vecs).to(torch.device("cpu")).detach().numpy()
    news_vecs = news_vecs.to(torch.device("cpu")).detach().numpy()
    labels = labels.to(torch.device("cpu")).detach().numpy()
    
    for user_vec, news_vec, label in zip(user_vecs, news_vecs, labels):
        tmp = np.mean(label)
        if tmp == 0 or tmp == 1:
            continue

        score = np.dot(news_vec, user_vec)
        auc = roc_auc_score(label, score)
        mrr = mrr_score(label, score)
        ndcg5 = ndcg_score(label, score, k=5)
        ndcg10 = ndcg_score(label, score, k=10)

        AUC.append(auc)
        MRR.append(mrr)
        nDCG5.append(ndcg5)
        nDCG10.append(ndcg10)

    if cnt % 10000 == 0:
        print_metrics(cnt, get_mean([AUC, MRR, nDCG5, nDCG10]))

print_metrics(cnt, get_mean([AUC, MRR, nDCG5, nDCG10]))

0 [0.5332264532074023, 0.23396181814502737, 0.2510158487004132, 0.3100420405560214]


KeyboardInterrupt: 