In [1]:
import numpy as np
import polars as pl

import torch_geometric
import torch

from constants import *

from torch_geometric.data import Data, InMemoryDataset, download_url

In [21]:
class Wikidata5m(InMemoryDataset):
    train = "wikidata5m_transductive_train.txt"
    validate = "wikidata5m_transductive_valid.txt"
    test = "wikidata5m_transductive_test.txt"
    corpus = "wikidata5m_text.txt"

    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        # Root is a path to a folder which contains the datasets
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return [self.train, self.validate, self.test, self.corpus]

    @property
    def processed_file_names(self):
        # return self.processed_paths[0]
        return ["saved_dataset.pt"]

    # def download(self):
    #     # Download to `self.raw_dir`.
    #     download_url(url, self.raw_dir)
    #     ...

    def read_csv_files(self) -> list[pl.DataFrame]:
        entity_columns =["entity1", "relation", "entity2"] 
        def read_entity_file(file_name: str, entity_columns=entity_columns) -> pl.DataFrame:
            return pl.read_csv(file_name, sep="\t", has_header=False, new_columns=entity_columns)

        train_data = read_entity_file(DatasetPaths.train.value)
        validate_data = read_entity_file(DatasetPaths.validate.value)
        test_data = read_entity_file(DatasetPaths.test.value)
        corpus_text = pl.read_csv(DatasetPaths.corpus.value, sep="\t", has_header=False, new_columns=["id", "description"], n_rows=1000)

        data = [train_data, validate_data, test_data, corpus_text]
        data = [self.convert_df_to_int_indexes(dato) for dato in data]
        return data

    @staticmethod
    def convert_df_to_int_indexes(df: pl.DataFrame) -> pl.DataFrame:
        int_columns = ["id", "entity1", "entity2", "relation"]
        for col in df.columns:
            if col in int_columns:
                # Remove first element and convert col to int
                df = df.with_columns(pl.col(col).str.lstrip("QP").cast(pl.UInt32))
        return df

    @staticmethod
    def get_edges_from_dataset(dataset: pl.DataFrame) -> np.array:
        # PyG requires format [2, num_edges]
        edges = dataset.select(["entity1", "entity2"]).to_numpy().T
        assert(edges.shape[0] == 2)
        return edges

    def process(self):
        train_data, validate_data, test_data, corpus_text = self.read_csv_files()

        train_edges = self.get_edges_from_dataset(train_data)
        test_edges = self.get_edges_from_dataset(test_data).T
        nodes = corpus_text.select("id").to_numpy()

        data_list = [Data(x=nodes, edge_index=train_edges, y=test_edges)]

        # if self.pre_filter is not None:
        #     data_list = [data for data in data_list if self.pre_filter(data)]
        # if self.pre_transform is not None:
        #     data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [22]:
dataset = Wikidata5m("datasets")
dataset[0]

Data(x=[1000, 1], edge_index=[2, 20614279], y=[5133, 2])