In [1]:
from relbench.datasets import dataset_names, get_dataset

In [2]:
dataset_names

['rel-amazon', 'rel-stackex', 'rel-fake']

In [2]:
dataset = get_dataset(name="rel-stackex", process=True)
dataset

making Database object from raw files...
done in 43.73 seconds.
reindexing pkeys and fkeys...
done in 7.87 seconds.


StackExDataset()

In [4]:
dataset.db.table_dict.keys()

dict_keys(['badges', 'comments', 'postHistory', 'postLinks', 'posts', 'users', 'votes'])

In [35]:
dataset.db.table_dict["users"].df.columns


Index(['Id', 'AccountId', 'DisplayName', 'Location', 'ProfileImageUrl',
       'WebsiteUrl', 'AboutMe', 'CreationDate'],
      dtype='object')

In [36]:
dataset.db.table_dict["votes"].df.columns

Index(['Id', 'UserId', 'PostId', 'VoteTypeId', 'CreationDate'], dtype='object')

In [37]:
dataset.db.table_dict["posts"].df.columns

Index(['Id', 'OwnerUserId', 'LastEditorUserId', 'PostTypeId',
       'AcceptedAnswerId', 'ParentId', 'OwnerDisplayName',
       'LastEditorDisplayName', 'Title', 'Tags', 'ContentLicense', 'Body',
       'CreationDate'],
      dtype='object')

In [39]:
len(dataset.db.table_dict["users"].df)

255360

In [7]:
dataset.task_names

['rel-stackex-engage', 'rel-stackex-votes']

In [17]:
task = dataset.get_task("rel-stackex-votes")

In [18]:
task.train_table

Table(df=
        PostId  timestamp  popularity
0       152675 2018-07-05           0
1       152676 2018-07-05           0
2       152677 2018-07-05           0
3       152679 2018-07-05           0
4       152681 2018-07-05           0
...        ...        ...         ...
389884   14465 2010-08-16           0
389885   14602 2010-08-16           1
389886   14602 2010-02-17           0
389887   14602 2009-08-21           0
389888   14602 2009-02-22           0

[389889 rows x 3 columns],
  fkey_col_to_pkey_table={'PostId': 'posts'},
  pkey_col=None,
  time_col=timestamp)

In [30]:
task.train_table.df.columns

Index(['PostId', 'timestamp', 'popularity'], dtype='object')

In [27]:
len(task.train_table.df)

389889

In [31]:
task.val_table.df.columns

Index(['PostId', 'timestamp', 'popularity'], dtype='object')

In [26]:
len(task.val_table.df)

40725

In [34]:
task.test_table.df.columns

Index(['timestamp', 'PostId'], dtype='object')

In [28]:
len(task.test_table.df)

40063

Making the graph

In [3]:
from torch_frame.config.text_embedder import TextEmbedderConfig
from torch_frame.testing.text_embedder import HashTextEmbedder

from relbench.external.graph import get_stype_proposal, make_pkey_fkey_graph


data, col_stats_dict = make_pkey_fkey_graph(
    dataset.db,
    get_stype_proposal(dataset.db),
    text_embedder_cfg=TextEmbedderConfig(
        text_embedder=HashTextEmbedder(8), batch_size=None
    ),
)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
print(data)

HeteroData(
  comments={
    tf=TensorFrame([623967, 4]),
    time=[623967],
  },
  badges={
    tf=TensorFrame([463463, 4]),
    time=[463463],
  },
  postLinks={
    tf=TensorFrame([77337, 2]),
    time=[77337],
  },
  postHistory={
    tf=TensorFrame([1175368, 7]),
    time=[1175368],
  },
  votes={
    tf=TensorFrame([1317876, 2]),
    time=[1317876],
  },
  users={
    tf=TensorFrame([255360, 6]),
    time=[255360],
  },
  posts={
    tf=TensorFrame([333893, 7]),
    time=[333893],
  },
  (comments, f2p_UserId, users)={ edge_index=[2, 612288] },
  (users, p2f_UserId, comments)={ edge_index=[2, 612288] },
  (comments, f2p_PostId, posts)={ edge_index=[2, 623962] },
  (posts, p2f_PostId, comments)={ edge_index=[2, 623962] },
  (badges, f2p_UserId, users)={ edge_index=[2, 463463] },
  (users, p2f_UserId, badges)={ edge_index=[2, 463463] },
  (postLinks, f2p_PostId, posts)={ edge_index=[2, 61171] },
  (posts, p2f_PostId, postLinks)={ edge_index=[2, 61171] },
  (postLinks, f2p_RelatedPo

In [8]:
data["votes"]


{'tf': TensorFrame(
  num_cols=2,
  num_rows=1317876,
  categorical (1): ['VoteTypeId'],
  timestamp (1): ['CreationDate'],
  has_target=False,
  device='cpu',
), 'time': tensor([1233532800, 1233532800, 1233532800,  ..., 1609459200, 1609459200,
        1609459200])}

In [10]:
col_stats_dict.keys()

dict_keys(['comments', 'badges', 'postLinks', 'postHistory', 'votes', 'users', 'posts'])

In [12]:
col_stats_dict["postLinks"]

{'LinkTypeId': {<StatType.COUNT: 'COUNT'>: ([1, 3], [66588, 10749])},
 'CreationDate': {<StatType.YEAR_RANGE: 'YEAR_RANGE'>: [2010, 2020],
  <StatType.NEWEST_TIME: 'NEWEST_TIME'>: tensor([2020,   11,   30,    3,   21,   25,   24]),
  <StatType.OLDEST_TIME: 'OLDEST_TIME'>: tensor([2010,    6,   20,    2,   14,   47,   33]),
  <StatType.MEDIAN_TIME: 'MEDIAN_TIME'>: tensor([2017,    5,    7,    3,    0,   59,   25])}}

Create model

In [None]:
from relbench.external.nn import HeteroEncoder, HeteroGraphSAGE, HeteroTemporalEncoder
from torch_geometric.nn import MLP

node_to_col_names_dict = {  
    node_type: data[node_type].tf.col_names_dict for node_type in data.node_types
}

In [None]:
encoder = HeteroEncoder(64, node_to_col_names_dict, col_stats_dict)
temporal_encoder =  HeteroTemporalEncoder(
            node_types=[
                node_type for node_type in data.node_types if "time" in data[node_type]
            ],
            channels=64,
        )
gnn = HeteroGraphSAGE(data.node_types, data.edge_types, 64)
head = MLP(64, out_channels=1, num_layers=1)

Old demo tutorial:

In [41]:
import numpy as np

pred = np.array([0] * len(task.test_table.df))
task.evaluate(pred)



{'mae': 0.09447619998502359, 'rmse': 0.4515018605279}