In [2]:
import sys
sys.path.insert(0, '..')
import caldera

In [87]:
# flake8: noqa

from torch import nn
from caldera import gnn
import torch
from collections import OrderedDict
from caldera.data import GraphBatch
from typing import Mapping, Callable
from dataclasses import dataclass
        

@dataclass
class Connection(object):
    
    src: nn.Module
    dest: nn.Module
    src_map: Callable
    dest_map: Callable
    idx_map: Callable
        
    def __str__(self):
        return "Conection"
    
    def __repr__(self):
        return "Connection"
        
class ConnectionMapping(object):
    
    def __init__(self):
        self.connections = []
    
    def add(self, connection: Connection):
        self.connections.append(connection)
        
    def successors(self, src):
        carr = []
        for c in self.connections:
            if src is c.src:
                carr.append(c)
        return carr
    
    def predecessors(self, dest):
        carr = []
        for c in self.connections:
            if dest is c.dest:
                carr.append(c)
        return carr
    
class MessagePassing(nn.Module):

    def __init__(self):
        super().__init__()
        self.connections = ConnectionMapping()

    def register_connection(self, source, destination, mapping):
        self.connections.add(source, destination, mapping)


    def propogate(self, mod):
        """Go through the graph of messages and start """
        print('propogate')
        def wrapped(x):
            results = [dest_mapping(x)]
            connections = self.connections.predecessors[mod]
            print(connections)
            for source, (src_mapping, dest_mapping, index_mapping) in connections:
                out = source(src_mapping(x))
                print(out.shape)
                reduced = out[index_mapping(x)]
                print(reduced.shape)
                results.append(reduced)
            return mod(torch.cat(results, 1))
        return wrapped
    
class GraphCore(MessagePassing):

    def __init__(self):
        super().__init__()
        self.node = gnn.Flex(nn.Linear)(..., 1)
        self.edge = gnn.Flex(nn.Linear)(..., 1)
        
        self.connections = ConnectionMapping()
        self.connections.add(
            Connection(
                src=self.node,
                dest=self.edge,
                src_map=lambda data: data.x,
                dest_map=lambda data: data.e,
                idx_map=lambda data: data.edges[0]
            )
        )
        self.connections.add(
            Connection(
                lambda data: data.g,
                self.edge,
                lambda data: data,
                lambda data: data.e,
                lambda data: data.edge_idx
            )
        )
        
#         self.mapping('node', 'edge', )
#         self.emit(self.node, 'node')
#         self.listen(self.edge, 'node')
#         self.emit(self.edge, 'edge')
#         self.register_connection('node', 'edge')
#         self.register_connection(self.node, self.edge, (lambda d: d.x, lambda d: d.e, lambda d: d.edges[0]))
# #         self.register_connection(lambda d: self.net1(d.x), self.net2, lambda x: x.node_idx)
# #         self.register_connection(lambda d: self.net2(d.e), self.net2, lambda x: x.edge_idx)
    
    def propogate(self, mod, x):
        messages = []
        for c in self.connections.predecessors(mod):
            print(c)
            # emit message
            src_out = c.src(c.src_map(x))

            # receive message
            src_to_dest_map = c.idx_map(x)
            msg = src_out[src_to_dest_map]
            messages.append(msg)
            
        # collect and apply
        cat = torch.cat([c.dest_map(x)] + messages, 1)
        return c.dest(cat)
    
    def forward(self, data):
        return self.propogate(self.edge, data)
#         # we apply a callable to the data and pass to source module
#         node_out = self.node(data.x)
        
#         # we collect some kind of mapping from source_out to dest_in
#         mapping = data.edges[0]
        
#         # we apply the mapping
#         reduced_node_out = node_out[mapping]
#         print(reduced_node_out.shape)
#         print(data.e.shape)
        
#         # we apply a callable to the data to pass to the dest module
#         # we concat the values
#         cat = torch.cat([data.e, reduced_node_out], 1)
#         print(cat.shape)
        
#         # we apply dest module
#         edge_out = self.edge(cat)
 
b = GraphBatch.random_batch(10, 5, 10, 5)

core = GraphCore()
core(b)

Conection
Conection


tensor([[0.5315],
        [0.5863]], grad_fn=<IndexBackward>)

In [157]:
b.x.data_ptr()

94406940553728

In [155]:
class ConnectionMod(nn.Module):
    
    def __init__(self, src, dest):
        super().__init__()
        self.src = src
        self.dest = dest

class Adapter(nn.Module):
    
    def __init__(self, mod: nn.Module, func: Callable):
        super().__init__()
        self.mod = mod
        self.func = func
        
    def forward(self, x):
        return self.mod(self.func(x))
        
class Foo(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.node = gnn.Flex(nn.Linear)(..., 1)
        self.edge = gnn.Flex(nn.Linear)(..., 1)
        self.connections = [
            (lambda data: data.x, self.node),
            (lambda data: data.e, self.edge),
            (self.node, self.edge, lambda data: data.edges[0]),
            (lambda data: data.g, self.edge, lambda data: data.edge_idx)
        ]
        
    def propogate(self, mod, data):
        connections = [c for c in self.connections if c[1] is mod]
        if not connections:
            out = mod(data)
            return out
        else:
            results = []
            for c in connections:
                a = self.propogate(c[0], data)
                if len(c) == 3:
                    a = a[c[2](data)]
                results.append(a)
            b = c[1](torch.cat(results, 1))
            return b
#                 if len(c) == 2:
#                     out = c[1](c[0](data))
#                     resutls.append(out)
#                 out = c[0](data)
#                 print(out)
                
    def forward(self, data):
        return self.propogate(self.edge, data)
        
Foo()(b)

tensor([[-0.5630],
        [-0.3908],
        [-1.0178],
        [-0.0021],
        [-0.1794],
        [-0.4257],
        [-0.5536],
        [-0.1457],
        [-1.0274],
        [ 0.1126],
        [ 0.0405],
        [-1.1641],
        [-1.0706],
        [-0.1394],
        [-0.9175],
        [-0.3514],
        [-0.5424],
        [ 0.4762],
        [-1.1267],
        [-0.7587],
        [ 0.0159],
        [ 0.3151],
        [ 0.2521],
        [-0.9224],
        [-0.2634],
        [-0.1773],
        [-0.1731],
        [-0.8779],
        [-0.1679],
        [-0.3020],
        [-1.0015],
        [ 0.0305],
        [-0.1856],
        [-0.2791],
        [ 0.0666],
        [-0.1582],
        [-0.2379],
        [-0.2196],
        [-0.0162],
        [-0.3696],
        [-0.0938],
        [-0.2692],
        [-0.5411],
        [-0.1139],
        [ 0.3954],
        [-0.2293],
        [-0.0981],
        [ 0.3287],
        [-0.6210],
        [-1.1494],
        [-0.6490],
        [ 0.1185],
        [-0.

In [105]:
mapping = (
    lambda data: data.x, 
    lambda data: data.e, 
    lambda data: data.edges[0]
)

b.x is mapping[0](b)

True