In [6]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import dgl
import pandas

# Pandas Graph Builder
Graph builder from Pandas dataframes

In [7]:
from collections import namedtuple

from pandas.api.types import(
    is_categorical,
    is_categorical_dtype,
    is_numeric_dtype
)

def _series_to_tensor(series):
    if is_categorical(series):
        # Make categorical values to int64
        # cat.codes -> Make categorical values to single int type(cat1 -> 1, cat2 -> 2 , ... ,catn -> n)
        return torch.LongTensor(series.cat.codes.values.astype("int64")) 
    
    else: #Numeric
        return torch.FloatTensor(series.values)


In [None]:
class PandasGraphBuilder(object):
    """
    Creates a heterogeneous graph from multiple pandas dataframes
    Examples
    --------
    Let's say we have the following three pandas dataframes:
    User table ``users``:
    ===========  ===========  =======
    ``user_id``  ``country``  ``age``
    ===========  ===========  =======
    XYZZY        U.S.         25
    FOO          China        24
    BAR          China        23
    ===========  ===========  =======
    Game table ``games``:
    ===========  =========  ==============  ==================
    ``game_id``  ``title``  ``is_sandbox``  ``is_multiplayer``
    ===========  =========  ==============  ==================
    1            Minecraft  True            True
    2            Tetris 99  False           True
    ===========  =========  ==============  ==================
    Play relationship table ``plays``:
    ===========  ===========  =========
    ``user_id``  ``game_id``  ``hours``
    ===========  ===========  =========
    XYZZY        1            24
    FOO          1            20
    FOO          2            16
    BAR          2            28
    ===========  ===========  =========
    Usage example to make bidirectional bipartite graph:
    >>> builder = PandasGraphBuilder()
    >>> builder.add_entities(users, 'user_id', 'user')
    >>> builder.add_entities(games, 'game_id', 'game')
    >>> builder.add_binary_relations(plays, 'user_id', 'game_id', 'plays')
    >>> builder.add_binary_relations(plays, 'game_id', 'user_id', 'played-by')
    >>> g = builder.build()
    >>> g.num_nodes('user')
    3
    >>> g.num_edges('plays')
    4
    """
    def __init__(self):
        """ 
        init func
        """
        # Store entitys
        self.entity_table = {}
        # Store relations
        self.relation_tables = {}
        # Mapping from primary key name to entity name
        self.entity_pk_to_name = (
            {}
        )
        # Mapping from entity name to primary key
        self.entity_pk = {} 
        # Mapping from entity names to primary key values
        self.entity_key_map=(
            {}
        ) 
        self.num_nodes_per_type = {}
        self.edges_per_relation = {}
        self.relation_name_to_etype = {}

        # Mapping from relation name to source key 
        self.relation_src_key = {}
        # Mapping from relation name to destination key
        self.relation_dst_key = (
            {}
        )

    def add_entities(self, entity_table, primary_key, name):
        """ 
        Add new entity to entity table
        Parameter =>
        entity_table: Pandas dataframe
        primary_key: str -> colunm name of dataframe
        name: str -> table name
        """
        entities = entity_table[primary_key].astype("category")
        if not (entities.value_counts() == 1).all():
            raise ValueError(
                "Different entity with the same primary key detected."
            )
        # Preserve the category order in the original entity table
        entities = entities.cat.reorder_categories(
            entity_table[primary_key].values
        )
        # pk to name mapping
        self.entity_pk_to_name[primary_key] = name
        # name to pk mapping
        self.entity_pk[name] = primary_key
        # 
        self.num_nodes_per_type[name] = entity_table.shape[0]
        self.entity_key_map[name] = entities
        self.entity_tables[name] = entity_table

    def add_binary_relations(self, source_key, destination_key, name):
        
