# Build homogeneous dataset

## Setting up environment

---



### Loading libraries

In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath(".."))

import pickle

from torch_geometric.data import Data
from box import Box

from util.postgres import create_sqlalchemy_engine
from util.homogeneous.dataset import DatasetEuCoHM, assert_bidirectional_edges

### Global variables

In [2]:
# -------------------- GLOBAL VARIABLES --------------------
PATH_TO_CONFIG_FILE = '../config.yaml'

# -------------------- LOAD CONFIGURATION --------------------
# Load the configuration file
config = Box.from_yaml(filename=PATH_TO_CONFIG_FILE)

num_train = 0.7             # Percentage of data used for training

pg_engine = create_sqlalchemy_engine(
    username=config.POSTGRES.USERNAME,
    password=config.POSTGRES.PASSWORD,
    host=config.POSTGRES.HOST,
    port=config.POSTGRES.PORT,
    database=config.POSTGRES.DATABASE,
    schema=config.POSTGRES.SCHEMA
)

## Data preparation

---



In [3]:
def unit_testing(data):
    # Test: check that the number of elements in the positive edge index equals to the number of elements in the negative edge index
    assert data.test_pos_edge_index.numel() == data.test_neg_edge_index.numel()
    
    # Test: check that all positive edges are bidirectional
    assert_bidirectional_edges(edges=data.train_pos_edge_index)
    assert_bidirectional_edges(edges=data.test_pos_edge_index)
    print('All tests passed')

def save_dataset(dataset):
    dataset_save_filepath = f'../data/{dataset.get_dataset_name()}.pkl'
    # Before saving the dataset, we need to close the engine to connect to Postgres DB.
    dataset.close_engine()
    # Save the dataset
    with open(dataset_save_filepath, 'wb') as output:
        pickle.dump(dataset, output, pickle.HIGHEST_PROTOCOL)
        print(f'Dataset saved to {dataset_save_filepath}')

def build_dataset(use_periodical_embedding_decay: bool,
                  use_top_keywords: bool,
                  num_train: float):
    # Build the homogeneous graph
    data: Data
    author_node_id_map: dict
    author_id_map: dict
    dataset: DatasetEuCoHM = DatasetEuCoHM(
        pg_engine=pg_engine,
        num_train=num_train,
        use_periodical_embedding_decay=use_periodical_embedding_decay,
        use_top_keywords=use_top_keywords
    )
    data, author_node_id_map, author_id_map = dataset.build_homogeneous_graph()
    return dataset, data, author_node_id_map, author_id_map

In [4]:
dataset_configurations = [
    # {'use_periodical_embedding_decay': False, 'use_top_keywords': False, 'num_train': num_train},
    {'use_periodical_embedding_decay': True, 'use_top_keywords': False, 'num_train': num_train},
    # {'use_periodical_embedding_decay': False, 'use_top_keywords': True, 'num_train': num_train},
    # {'use_periodical_embedding_decay': True, 'use_top_keywords': True, 'num_train': num_train},
    # {'use_periodical_embedding_decay': True, 'use_top_keywords': True, 'num_train': 1.0},
]

for conf in dataset_configurations:
    print(f'Processing dataset configuration {conf}...')
    # Building dataset
    dataset, data, author_node_id_map, author_id_map = build_dataset(
        use_periodical_embedding_decay=conf['use_periodical_embedding_decay'],
        use_top_keywords=conf['use_top_keywords'],
        num_train=conf['num_train']
    )
    # Unit testing
    unit_testing(data)

    # Save dataset
    save_dataset(dataset)

Processing dataset configuration {'use_periodical_embedding_decay': True, 'use_top_keywords': False, 'num_train': 0.7}...
Querying n-th time percentile...
Querying co-authorship edge data...
Querying author nodes...
All tests passed
Dataset saved to ../data/dataset_homogeneous_periodical_decay.pkl


In [5]:
dataset.data.x[1]

tensor([ 1.7237e-02,  1.5160e-01,  3.4230e-02,  2.1417e-01,  1.3798e-01,
        -2.2870e-01,  1.7449e-01,  8.2475e-01, -8.3116e-02,  9.4696e-01,
         8.8759e-02,  1.5662e-01, -5.6863e-01, -8.6023e-02,  5.7454e-01,
        -4.5274e-01,  1.9931e-03, -3.0011e-01, -8.8429e-01, -1.6260e+00,
         1.8703e-01,  3.0959e-01,  2.5564e-01, -6.7080e-01,  1.3150e+00,
        -4.2844e-01, -8.2218e-01,  1.3699e+00, -1.5035e+00, -2.7883e-02,
        -1.0746e+00,  1.3148e+00,  4.7388e-01, -1.7548e+00, -1.7506e+00,
        -3.7658e-01,  1.1853e+00, -2.3729e-01,  4.7733e-01,  4.1621e-01,
        -2.5128e+00, -5.4416e-01,  2.4020e-01, -1.5024e+00, -1.9300e+00,
        -1.2976e-01, -6.8205e-01, -5.6408e-01,  2.9860e-01, -1.0191e+00,
         6.9700e-01, -4.9185e-02,  3.0137e+00, -1.5486e+00,  3.0218e+00,
         4.3260e-01,  1.1412e+00,  5.4124e-01,  9.1090e-01, -5.6550e-01,
        -5.7027e-01, -6.1581e-01, -1.5972e-01,  9.8645e-01,  1.3603e+00,
        -8.1084e-01, -5.7289e-01, -3.3520e-01, -2.3