In [1]:
import argparse
import numpy as np
from torch import nn
from src.config import TrainConfig 
from src.ultis import *
from src.data_helper import prepare_preprocessed_data
from src.data_load import *
from src.metrics import *

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Admin\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
cfg = TrainConfig()

In [16]:
cfg.use_graph = True
cfg.reprocess = True

In [17]:
def prepare_news_graph(cfg, mode='train'):
    data_dir = {"train": cfg.data_dir + '_train', "val": cfg.data_dir + '_val', "test": cfg.data_dir + '_test'}
    nltk_target_path = Path(os.path.join(cfg.data_dir + f'_{mode}', "nltk_news_graph.pt"))
    reprocess_flag = False
    if nltk_target_path.exists() is False:
        reprocess_flag = True
        
    if (reprocess_flag == False) and (cfg.reprocess == False):
        print(f"[{mode}] All graphs exist !")
        return
    # -----------------------------------------News Graph------------------------------------------------
    behavior_path = Path(data_dir['train']) / "behaviors.tsv"
    origin_graph_path = Path(data_dir['train']) / "nltk_news_graph.pt"

    news_dict = pickle.load(open(Path(data_dir[mode]) / "news_dict.bin", "rb"))
    nltk_token_news = pickle.load(open(Path(data_dir[mode]) / "nltk_token_news.bin", "rb"))
    
    # ------------------- Build Graph -------------------------------
    if mode == 'train':
        edge_list, user_set = [], set()
        num_line = len(open(behavior_path, encoding='utf-8').readlines())
        with open(behavior_path, 'r', encoding='utf-8') as f:
            for line in tqdm(f, total=num_line, desc=f"[{mode}] Processing behaviors news to News Graph"):
                line = line.strip().split('\t')

                # check duplicate user
                used_id = line[1]
                if used_id in user_set:
                    continue
                else:
                    user_set.add(used_id)

                history = line[3].split()
                if len(history) > 1:
                    long_edge = [news_dict[news_id] for news_id in history]
                    edge_list.append(long_edge)

        # edge count
        node_feat = nltk_token_news[:,:22] # keep only feature of title and category, exclude entity indices.
        target_path = nltk_target_path
        num_nodes = len(news_dict) + 1

        short_edges = []
        for edge in tqdm(edge_list, total=len(edge_list), desc=f"Processing news edge list"):
            if cfg.use_graph_type == 0:
                for i in range(len(edge) - 1):
                    short_edges.append((edge[i], edge[i + 1]))
            elif cfg.use_graph_type == 1:
                for i in range(len(edge) - 1):
                    for j in range(i+1, len(edge)):
                        short_edges.append((edge[i], edge[j]))
                        short_edges.append((edge[j], edge[i]))
            else:
                assert False, "Wrong"

        edge_weights = Counter(short_edges)
        unique_edges = list(edge_weights.keys())

        edge_index = torch.tensor(list(zip(*unique_edges)), dtype=torch.long)
        edge_attr = torch.tensor([edge_weights[edge] for edge in unique_edges], dtype=torch.long)

        data = Data(x=torch.from_numpy(node_feat),
                edge_index=edge_index, edge_attr=edge_attr,
                num_nodes=num_nodes)
    
        torch.save(data, target_path)
        print(data)
        print(f"[{mode}] Finish News Graph Construction, \nGraph Path: {target_path} \nGraph Info: {data}")
    
    elif mode in ['test', 'val']:
        origin_graph = torch.load(origin_graph_path, weights_only = False)
        edge_index = origin_graph.edge_index
        edge_attr = origin_graph.edge_attr
        node_feat = nltk_token_news

        data = Data(x=torch.from_numpy(node_feat),
                    edge_index=edge_index, edge_attr=edge_attr,
                    num_nodes=len(news_dict) + 1)
        
        torch.save(data, nltk_target_path)
        print(f"[{mode}] Finish nltk News Graph Construction, \nGraph Path: {nltk_target_path}\nGraph Info: {data}")
    return data

In [19]:
from collections import Counter

In [20]:
g1 = prepare_news_graph(cfg, 'train')
g2 = prepare_news_graph(cfg, 'val')

[train] Processing behaviors news to News Graph: 100%|█████████████████████| 156965/156965 [00:00<00:00, 304196.56it/s]
Processing news edge list: 100%|█████████████████████████████████████████████| 48304/48304 [00:00<00:00, 333132.72it/s]


Data(x=[51283, 22], edge_index=[2, 632044], edge_attr=[632044], num_nodes=51283)
[train] Finish News Graph Construction, 
Graph Path: data\MINDsmall_train\nltk_news_graph.pt 
Graph Info: Data(x=[51283, 22], edge_index=[2, 632044], edge_attr=[632044], num_nodes=51283)
[val] Finish nltk News Graph Construction, 
Graph Path: data\MINDsmall_val\nltk_news_graph.pt
Graph Info: Data(x=[65239, 22], edge_index=[2, 632044], edge_attr=[632044], num_nodes=65239)


In [None]:
prepare_neighbor_list(cfg, 'train', 'news')
prepare_neighbor_list(cfg, 'val', 'news')
# prepare_neighbor_list(cfg, 'test', 'news')