In [1]:
import torch

# Utils

In [2]:
from functools import singledispatch
import torch
import numpy as np
from typing import Type

@singledispatch
def to_one_hot(arr: torch.tensor, num_classes: int):
    return torch.nn.functional.one_hot(arr, num_classes)
    
@to_one_hot.register
def _(arr: np.ndarray, num_classes: int):
    oh = np.zeros((arr.shape[0], num_classes), dtype=np.long)
    for i, a in enumerate(arr):
        oh[i, a] = 1.0
    return oh 

In [105]:
def dict_join(d1, d2, d3=None, join_fn=None, default1=..., default2=..., keys=None, mode: str = 'union'):
    if d3 is None:
        d3 = dict()
        
    if join_fn is None:
        def join_fn(v1, v2):
            if v2 is not ...:
                return v2
            if v1 is not ...:
                return v1
            
    if keys is None:
        if mode == 'union':
            keys = set(d1).union(set(d2))
        elif mode == 'intersection':
            keys = set(d1).intersection(set(d2))
        elif mode == 'left':
            keys = set(d1)
        elif mode == 'right':
            keys = set(d2)
        else:
            raise ValueError("mode '{}' not recognized.".format(mode))
    
    def get(d, k, default):
        v = d.get(k, default)
        return v
    
    for k in keys:
        v1 = get(d1, k, default1)
        v2 = get(d2, k, default2)
        if v1 is ...:
            v3 = v2
        elif v2 is ...:
            v3 = v1
        else:
            v3 = join_fn(v1, v2)
        d3[k] = v3
    return d3


d1 = {1: 2, 3: 6}
d2 = {1: 3, 2: 5}
dict_join(d1, d2, join_fn=lambda a, b: a + b, default1=1, default2=1, mode='union')

{1: 5, 2: 6, 3: 7}

In [106]:
import networkx as nx
import random

def random_graph(n, e=None, d=None):
    """Create a random graph"""
    n = random.randint(*n)
    if e is None:
        d = random.random() * d[1] + d[0]
        e = int((d * n * (n - 1)) / 2)
    else:
        e = random.randint(*e)
    e = max(1, e)
    return nx.generators.dense_gnm_random_graph(n, e)


def annotate_shortest_path(g):
    source, target = np.random.choice(list(g.nodes), size=(2,))

    g.nodes[source]['source'] = True
    g.nodes[target]['target'] = True
    try:
        path = nx.shortest_path(g, source=nodes[0], target=nodes[1])
    except nx.NetworkXNoPath:
        path = []

    for n in path:
        g.nodes[n]['shortest_path'] = False
    for n, ndata in g.nodes(data=True):
        ndata['shortest_path'] = target

    for n1, n2, edata in g.edges(data=True):
        edata['shortest_path'] = False

    for n1, n2 in nx.utils.pairwise(path):
        g[n1][n2]['shortest_path'] = True


def cat_property(from_key, to_key):
    pass


def from_property_graph_to_graph_data():
    pass


import networkx as nx
from typing import Tuple

def collect_nx_features(g, to_key: str, keys: Tuple[str, ...] = None):
    pass


def create_random_graph():
    g = nx.generators.dense_gnm_random_graph(30, 30)
    for n, ndata in g.nodes(data=True):
        ndata['features'] = True
        ndata['target'] = False
    for n1, n2, edata in g.edges(data=True):
        pass
    
    return g

def extract_one_hot(datalist, key, num_classes: int, default=..., nodelist=None):
    """Convert a list of dictionaries to one_hot array"""
    feat = []
    feat_to_i = {}
    i = 0
    for ndata in datalist:
        if default is not ...:
            v = ndata.get(key, default)
        else:
            v = ndata[key]
        if isinstance(v, np.ndarray):
            v = v.item()
        if v not in feat_to_i:
            feat_to_i[v] = i
            i += 1
        feat.append(feat_to_i[v])
    return to_one_hot(np.array(feat), num_classes)

# Testing

In [171]:
g = create_random_graph()

datalist = [ndata for _, ndata in g.nodes(data=True)]
data = extract_one_hot(datalist, 'features', 10)

d2 = {n: {'_target': d} for n, d in zip(g.nodes, data)}

d1 = dict(g.nodes(data=True))

def collate(d1, d2, join_fn):
    return dict_join(d1, d2, join_fn=lambda a, b: join_fn(a, b))


import functools
def curry(f):
    @functools.wraps(f)
    def wrapped(*args, **kwargs):
        try:
            return f(*args, **kwargs)
        except:
            return curry(functools.partial(f, *args, **kwargs))
    return wrapped

def pipe(*funcs):
    def wrapped(data):
        result = data
        for f in funcs:
            result = f(result)
        return result
    return wrapped

def fn_partial(*args, **kwargs):
    def wrapped(f):
        return f(*args, **kwargs)
    return wrapped

def fn_curry_partial(*args, **kwargs):
    def wrapped(f):
        return curry(f)(*args, **kwargs)
    return wrapped



# fn = curry_partial(b=3)
# d3 = collate(d1, d2, lambda a, b: dict_join(a, b, mode='right', join_fn=lambda a, b: np.hstack([a, b])))
# d4 = collate(d3, d3, lambda a, b: dict_join(a, b, mode='right', join_fn=lambda a, b: np.hstack([a, b])))
# d4

In [125]:
a = {'_target': np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
  'target': False,
  'features': True}

b =  {'_target': np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
  'target': False,
  'features': True}

dict_join(a, b, join_fn=lambda a, b: np.hstack([a, b]))

{'_target': array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
 'target': array([False, False]),
 'features': array([ True,  True])}

In [None]:
np.concatenate(np.array([0])