In [8]:
import sys

sys.path.append("../..")

In [9]:
from typing import Dict

import numpy as np
import pandas as pd
import torch
from scipy.stats import mode
from torch_geometric.seed import seed_everything

from torch_frame.data.stats import StatType
 
from relbench.base import Dataset, EntityTask, Table, TaskType
from relbench.datasets import get_dataset, get_dataset_names
from relbench.tasks import get_task, get_task_names

import ctu_relational
from ctu_relational.tasks import CTUEntityTask

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 200)
%reload_ext autoreload
%autoreload 2

In [10]:
device = torch.device("cpu")
seed_everything(42)

In [11]:
def evaluate(
    task: CTUEntityTask, train_table: Table, pred_table: Table, name: str
) -> Dict[str, float]:
    is_test = task.target_col not in pred_table.df
    if name == "global_zero":
        pred = np.zeros(len(pred_table))
    elif name == "global_mean":
        mean = train_table.df[task.target_col].astype(float).values.mean()
        pred = np.ones(len(pred_table)) * mean
    elif name == "global_median":
        median = np.median(train_table.df[task.target_col].astype(float).values)
        pred = np.ones(len(pred_table)) * median

    elif name == "random":
        pred = np.random.rand(len(pred_table))
    elif name == "majority":
        past_target = train_table.df[task.target_col].astype(int)
        majority_label = int(past_target.mode().iloc[0])
        pred = np.ones(len(pred_table), float) * majority_label

    elif name == "majority_multiclass":
        past_target = train_table.df[task.target_col]
        majority = int(past_target.mode().iloc[0])
        num_labels = len(task.stats[StatType.COUNT][0])
        pred = np.zeros((len(pred_table.df), num_labels), float)
        pred[:, majority] = 1
    elif name == "random_multiclass":
        num_labels = len(task.stats[StatType.COUNT][0])
        pred = np.random.rand(len(pred_table), num_labels)

    elif name == "majority_multilabel":
        past_target = train_table.df[task.target_col]
        majority = mode(np.stack(past_target.values), axis=0).mode[0]
        pred = np.stack([majority] * len(pred_table.df))
    elif name == "random_multilabel":
        num_labels = train_table.df[task.target_col].values[0].shape[0]
        pred = np.random.rand(len(pred_table), num_labels)

    else:
        raise ValueError("Unknown eval name called {name}.")
    return task.evaluate(pred, None if is_test else pred_table)

In [12]:
ctu_datasets = list(filter(lambda x: x.startswith("ctu"), get_dataset_names()))

In [13]:
regression_info = {}
classification_info = {}

for dataset_name in ctu_datasets:
    print(f"Processing {dataset_name}")
    for task_name in get_task_names(dataset_name):
        task: CTUEntityTask = get_task(dataset_name, task_name)

        if task.task_type in [TaskType.LINK_PREDICTION, TaskType.MULTILABEL_CLASSIFICATION]:
            continue

        stats = {}

        train_table = task.get_table("train")
        val_table = task.get_table("val")
        test_table = task.get_table("test", mask_input_cols=False)

        trainval_table_df = pd.concat([train_table.df, val_table.df], axis=0)
        trainval_table = Table(
            df=trainval_table_df,
            fkey_col_to_pkey_table=train_table.fkey_col_to_pkey_table,
            pkey_col=train_table.pkey_col,
            time_col=train_table.time_col,
        )

        if task.task_type == TaskType.REGRESSION:
            eval_name_list = ["global_zero", "global_mean", "global_median"]
            info = regression_info
            stats = {k.name: v for k, v in task.stats.items()}

        elif task.task_type in [TaskType.BINARY_CLASSIFICATION]:
            eval_name_list = ["random", "majority"]
            info = classification_info
            stats = {"class_count": len(task.stats[StatType.COUNT][0])}

        elif task.task_type in [TaskType.MULTICLASS_CLASSIFICATION]:
            eval_name_list = ["random_multiclass", "majority_multiclass"]
            info = classification_info
            stats = {"class_count": len(task.stats[StatType.COUNT][0])}

        info[task_name] = {
            "dataset_name": dataset_name,
            "task_name": task_name,
            "task_type": task.task_type.name,
            **stats,
            "num_train": len(train_table),
            "num_val": len(val_table),
            "num_test": len(test_table),
        }

        for name in eval_name_list:
            for m, v in evaluate(task, train_table, train_table, name=name).items():
                info[task_name][f"train_{name}_{m}"] = v

            for m, v in evaluate(task, train_table, val_table, name=name).items():
                info[task_name][f"val_{name}_{m}"] = v

            for m, v in evaluate(task, trainval_table, test_table, name=name).items():
                info[task_name][f"test_{name}_{m}"] = v

        regression_info_df = pd.DataFrame(regression_info).T
        regression_info_df.to_csv("./regression_info.csv", index=False)
        classification_info_df = pd.DataFrame(classification_info).T
        classification_info_df.to_csv("./classification_info.csv", index=False)

Processing ctu-accidents
Loading Database object from /home/jakub/.cache/relbench/ctu-accidents/db...
Done in 0.60 seconds.
Processing ctu-adventureworks
Making task table for train split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Loading Database object from /home/jakub/.cache/relbench/ctu-adventureworks/db...
Done in 0.30 seconds.
Done in 0.34 seconds.
Making task table for val split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Done in 0.02 seconds.
Making task table for test split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Done in 0.02 seconds.
Making task table for train split from scratch...
(You can also use `get_task(..., download=True)` for tasks prepared by the RelBench team.)
Done in 0.02 seconds.
Making task table for val split from scratch...
(You can also use `get_task(..., download=Tru