In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import networkx as nx
import random
import matplotlib.pyplot as plt

# Parameters
num_nodes = 200  # Number of nodes
prob_edge = 0.002  # Probability of edge creation in ER graph
extra_edge_prob = prob_edge

# Step 1: Create an undirected Erdős–Rényi (ER) graph
# G = nx.erdos_renyi_graph(num_nodes, prob_edge)
G = nx.random_labeled_tree(num_nodes)

# Step 2: Add extra random edges while ensuring the graph remains simple
for u in range(num_nodes):
    for v in range(u + 1, num_nodes):
        if not G.has_edge(u, v) and random.random() < extra_edge_prob:
            G.add_edge(u, v)


# Step 2: Convert to DAG by directing edges based on node numbering


def dagify(G: nx.Graph) -> nx.DiGraph:
    DAG = nx.DiGraph()
    DAG.add_nodes_from(G.nodes())

    for u, v in G.edges():
        if u < v:
            DAG.add_edge(u, v)
        else:
            DAG.add_edge(v, u)
    return DAG


DAG = dagify(G)

RevDAG = nx.DiGraph()
RevDAG.add_nodes_from(G.nodes())

for u, v in G.edges():
    if u < v:
        RevDAG.add_edge(v, u)
    else:
        RevDAG.add_edge(u, v)


print(DAG)
# Step 3: Visualize the DAG using a spring layout
plt.figure(figsize=(6, 6))
pos = {node: (idx, max(min(random.random(), 1.0), -1.0)) for idx, node in enumerate(DAG.nodes())}
pos = nx.spring_layout(DAG, iterations=0, pos=pos)  # Spring layout for better visualization
nx.draw(DAG, pos, with_labels=True, node_color="lightblue", edge_color="gray", arrows=True)
plt.title("Directed Acyclic Graph (DAG) from ER Graph")
plt.show()

In [None]:
max(dict(nx.degree(DAG)).values())

In [None]:
rDAG = nx.DiGraph()
rDAG.add_nodes_from({*range(16)})
rDAG.add_edges_from(tuple((i, i + 1) for i in range(16) if ((i + 1) % 4 != 0)) + tuple((i, i + 4) for i in range(12)))
print(rDAG)

plt.figure(figsize=(6, 6))
pos = {node: (idx % 4, idx // 4) for idx, node in enumerate(rDAG.nodes())}
pos = nx.spring_layout(rDAG, iterations=0, pos=pos)  # Spring layout for better visualization
nx.draw(rDAG, pos, with_labels=True, node_color="lightblue", edge_color="gray", arrows=True)
plt.title("Directed Acyclic Graph (DAG) from ER Graph")
plt.show()

In [None]:
from collections import defaultdict


def prune_dag_to_multitree(dag: nx.DiGraph):
    assert nx.is_directed_acyclic_graph(dag), "Input must be a DAG."

    topo_order = list(nx.topological_sort(dag))
    multitree = nx.DiGraph()
    multitree.add_nodes_from(dag.nodes)

    # For fast access:
    # ancestor_map[b] = set(a)  means a -> b has been visited (a is an ancestor of b)
    # descendant_map[a] = set(b) means a -> b has been visited (b is a descendant of a)
    ancestor_map = defaultdict(set)
    descendant_map = defaultdict(set)

    for current in topo_order:
        predecessors = list(dag.predecessors(current))

        # Sort predecessors based on topological order
        predecessors.sort(key=topo_order.index)

        for pred in predecessors:
            prior_ancestors = ancestor_map[pred]
            if any(current in descendant_map[ancestor] for ancestor in prior_ancestors):
                continue  # Skip edge

            # Otherwise, add the edge
            multitree.add_edge(pred, current)
            ancestor_map[current].add(pred)
            descendant_map[pred].add(current)
            for ancestor in prior_ancestors:
                ancestor_map[current].add(ancestor)
                descendant_map[ancestor].add(current)

    return multitree

In [None]:
def check_multitree(multitree: nx.DiGraph) -> bool:
    if not nx.is_directed_acyclic_graph(multitree):
        return False

    for node1 in multitree.nodes:
        for node2 in multitree.nodes:
            if len(list(nx.all_simple_paths(multitree, node1, node2))) > 1:
                return False
    return True

In [None]:
pDAG = prune_dag_to_multitree(DAG)
plt.figure(figsize=(6, 6))
print(pDAG)
pos = {node: (idx, max(min(random.random(), 1.0), -1.0)) for idx, node in enumerate(pDAG.nodes())}
pos = nx.spring_layout(pDAG, iterations=0, pos=pos)  # Spring layout for better visualization
nx.draw(pDAG, pos, with_labels=True, node_color="lightblue", edge_color="gray", arrows=True)
plt.title("Directed Acyclic Graph (DAG) from ER Graph")
plt.show()

In [None]:
prDAG = prune_dag_to_multitree(rDAG)

plt.figure(figsize=(6, 6))
pos = {node: (idx % 4, idx // 4) for idx, node in enumerate(prDAG.nodes())}
pos = nx.spring_layout(prDAG, iterations=0, pos=pos)  # Spring layout for better visualization
nx.draw(prDAG, pos, with_labels=True, node_color="lightblue", edge_color="gray", arrows=True)
plt.title("Directed Acyclic Graph (DAG) from ER Graph")
plt.show()

In [None]:
def create_line_graph(dag: nx.DiGraph):
    linegraph = nx.DiGraph()
    linegraph.add_nodes_from(dag.edges)
    for n in dag.nodes:
        linegraph.add_edges_from(((p, n), (n, s)) for p in dag.predecessors(n) for s in dag.successors(n))
    return linegraph


linegraph = create_line_graph(DAG)
linegraph

In [None]:
lDAG = linegraph
plt.figure(figsize=(6, 6))
print(lDAG)
pos = {node: (idx, max(min(random.random(), 1.0), -1.0)) for idx, node in enumerate(lDAG.nodes())}
pos = nx.spring_layout(lDAG, iterations=0, pos=pos)  # Spring layout for better visualization
nx.draw(lDAG, pos, with_labels=True, node_color="lightblue", edge_color="gray", arrows=True)
plt.title("Directed Acyclic Graph (DAG) from ER Graph")
plt.show()

In [None]:
linegraph = create_line_graph(DAG)
plDAG = prune_dag_to_multitree(linegraph)
plt.figure(figsize=(6, 6))
print(plDAG)
pos = {node: (idx, max(min(random.random(), 1.0), -1.0)) for idx, node in enumerate(plDAG.nodes())}
pos = nx.spring_layout(plDAG, iterations=0, pos=pos)  # Spring layout for better visualization
nx.draw(plDAG, pos, with_labels=True, node_color="lightblue", edge_color="gray", arrows=True)
plt.title("Directed Acyclic Graph (DAG) from ER Graph")
plt.show()

In [None]:
linegraph = create_line_graph(rDAG)
plrDAG = prune_dag_to_multitree(linegraph)

In [None]:
linegraph = create_line_graph(rDAG)
plrDAG = prune_dag_to_multitree(linegraph)
plt.figure(figsize=(6, 6))
print(plrDAG)
pos = {
    node: ((node[0] % 4 + node[1] % 4) / 2, (node[0] // 4 + node[1] // 4) / 2)
    for idx, node in enumerate(plrDAG.nodes())
}
pos = nx.spring_layout(plrDAG, iterations=0, pos=pos)  # Spring layout for better visualization
nx.draw(plrDAG, pos, with_labels=True, node_color="lightblue", edge_color="gray", arrows=True)
plt.title("Directed Acyclic Graph (DAG) from ER Graph")
plt.show()

In [None]:
check_multitree(plDAG)

## pLSTM Layer construction in torch

In [None]:
import torch

In [None]:
DAG.edges

In [None]:
DAG.nodes

In [None]:
num_nodes = len(DAG.nodes)
num_edges = len(DAG.edges)
max_edges = max(
    max((d for n, d in DAG.in_degree())), max((d for n, d in DAG.out_degree()))
)  # limit to both the in-coming and out-going edges
max_edges

In [None]:
adjacency_backward_array = -torch.ones([num_nodes, max_edges])
adjacency_forward_array = -torch.ones([num_nodes, max_edges])
incoming_edge_nums = torch.zeros([num_nodes])
outgoing_edge_nums = torch.zeros([num_nodes])

adjacency_forward_edgemap = {}
adjacency_backward_edgemap = {}


for node in DAG.nodes:
    pred = list(DAG.predecessors(node))
    adjacency_backward_array[node, : len(pred)] = torch.tensor(pred)
    for idx, pn in enumerate(pred):
        adjacency_backward_edgemap[(pn, node)] = idx
    incoming_edge_nums[node] = len(pred)

for node in DAG.nodes:
    succ = list(DAG.successors(node))
    adjacency_forward_array[node, : len(succ)] = torch.tensor(succ)
    for idx, sn in enumerate(succ):
        adjacency_forward_edgemap[(node, sn)] = idx
    outgoing_edge_nums[node] = len(succ)

In [None]:
adjacency_backward_array, incoming_edge_nums

In [None]:
adjacency_forward_array, outgoing_edge_nums

In [None]:
## all inputs
num_heads = 4
head_dim = 32
qk_head_dim = 32
v_head_dim = 32

adjacency_backward_array, incoming_edge_nums
adjacency_forward_array, outgoing_edge_nums

inp = torch.zeros([num_heads, num_nodes, head_dim])
query = torch.zeros([num_heads, num_nodes, qk_head_dim], requires_grad=True)
key = torch.zeros([num_heads, num_nodes, qk_head_dim], requires_grad=True)
value = torch.zeros([num_heads, num_nodes, v_head_dim], requires_grad=True)

source = torch.zeros([num_heads, num_nodes, max_edges], requires_grad=True)
transition = torch.zeros([num_heads, num_nodes, max_edges, max_edges], requires_grad=True)
transition_mask = torch.ones([num_heads, num_nodes, max_edges, max_edges], requires_grad=True)
mark = torch.zeros([num_heads, num_nodes, max_edges], requires_grad=True)
direct = torch.zeros([num_heads, num_nodes], requires_grad=True)

In [None]:
"""
Description of a pLSTM-Graph layer
Vector-Valued inputs at nodes are split into head vectors
Source, Transition and Mark have different scaling "angle" depending on head and number of predecessors / successors
Source doesn't have to scale, Mark doesn't have to scale
Transition should be limited to one in row / column
Transition should distribute differently for different heads in bias -> bias term is not constant but adaptive to number of pred/succ
Example: 
4 heads
node: 2 pred, 2 succ -> "orientation bias" according to succ: 1 angle 
node: 2 pred, 4 succ -> "orientation bias" according to succ: 3 angles

"""

In [None]:
"""
Approach: No parallelization for now -> sequential processing. 
Problem: All edges need a C state potentially. Need to store the state as well for backprop. C-States have size: qkdim x vdim.
"""

In [None]:
# Naive implementation, C state for every edge -> potentially recompute for backward to save memory.
# 1. source terms given
# 2. compute edge states for every edge in DAG order of linegraph, use transitions
# 3. compute outputs via marks

cell_states = torch.zeros([num_heads, num_edges, qk_head_dim, v_head_dim])
outputs = torch.zeros([num_heads, num_nodes, v_head_dim])

In [None]:
lDAG = nx.line_graph(DAG)
lDAG_edges = list(nx.topological_sort(lDAG))

idx_map = {edge: idx for idx, edge in enumerate(lDAG_edges)}

In [None]:
edge_out_map = {}
edge_in_map = {}
for idx_edge, edge in enumerate(lDAG_edges):
    in_node = edge[0]
    out_node = edge[1]
    for pred_edge in lDAG.predecessors(edge):
        cell_states[:, idx_edge] += (
            transition[:, in_node, adjacency_forward_edgemap[pred_edge], adjacency_backward_edgemap[edge], None, None]
            * cell_states[:, idx_map[pred_edge]]
        )
    cell_states[:, idx_edge] += source[:, in_node, adjacency_backward_edgemap[edge], None, None] * torch.einsum(
        "ha,hb->hab", key[:, in_node], value[:, in_node]
    )
    outputs[:, out_node] += mark[:, out_node, adjacency_forward_edgemap[edge], None] * torch.einsum(
        "ha,hab->hb", query[:, out_node], cell_states[:, idx_edge]
    )

outputs += direct[:, :, None] * torch.sum(key * query, dim=-1, keepdim=True) * value

In [None]:
def plstm_graph(
    query,
    key,
    value,
    source,
    transition,
    mark,
    direct,
    adjancency_forward_edgemap,
    adjacency_backward_edgemap,
    lDAG,
    lDAG_sorted,
    recompute_cell_states: bool = True,
):
    class pLSTMGraph(torch.autograd.Function):
        @staticmethod
        def forward(
            ctx,
            query,
            key,
            value,
            source,
            transition,
            mark,
            direct,
            adjancency_forward_edgemap,
            adjacency_backward_edgemap,
            lDAG,
            lDAG_sorted,
        ):
            num_heads, num_nodes, qk_head_dim, v_head_dim = *query.shape, value.shape[-1]
            cell_states = torch.zeros([num_heads, num_edges, qk_head_dim, v_head_dim])
            outputs = torch.zeros([num_heads, num_nodes, v_head_dim])

            for idx_edge, edge in enumerate(lDAG_sorted):
                in_node = edge[0]
                out_node = edge[1]
                for pred_edge in lDAG.predecessors(edge):
                    cell_states[:, idx_edge] += (
                        transition[
                            :,
                            in_node,
                            adjacency_forward_edgemap[pred_edge],
                            adjacency_backward_edgemap[edge],
                            None,
                            None,
                        ]
                        * cell_states[:, idx_map[pred_edge]]
                    )
                cell_states[:, idx_edge] += source[
                    :, in_node, adjacency_backward_edgemap[edge], None, None
                ] * torch.einsum("hk,hv->hkv", key[:, in_node], value[:, in_node])
                outputs[:, out_node] += mark[:, out_node, adjacency_forward_edgemap[edge], None] * torch.einsum(
                    "hk,hkv->hv", query[:, out_node], cell_states[:, idx_edge]
                )

            outputs += direct[:, :, None] * torch.sum(key * query, dim=-1, keepdim=True) * value

            ctx.save_for_backward(
                query, key, value, source, transition, mark, direct, cell_states if not recompute_cell_states else None
            )
            ctx.lDAG = lDAG
            ctx.lDAG_sorted = lDAG_sorted
            ctx.adjacency_forward_edgemap = adjacency_forward_edgemap
            ctx.adjacency_backward_edgemap = adjacency_backward_edgemap

            return outputs

        @staticmethod
        def backward(
            ctx, doutputs
        ) -> tuple[
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            None,
            None,
            None,
            None,
        ]:
            query, key, value, source, transition, mark, direct, cell_states = ctx.saved_tensors
            num_heads, _, qk_head_dim, v_head_dim = *query.shape, value.shape[-1]

            dquery = torch.zeros_like(query)
            dkey = torch.zeros_like(key)
            dvalue = torch.zeros_like(value)
            dsource = torch.zeros_like(source)
            dtransition = torch.zeros_like(transition)
            dmark = torch.zeros_like(mark)
            ddirect = torch.zeros_like(direct)

            if not cell_states:
                cell_states = torch.zeros([num_heads, num_edges, qk_head_dim, v_head_dim])
                for idx_edge, edge in enumerate(lDAG_sorted):
                    in_node = edge[0]
                    out_node = edge[1]
                    for pred_edge in lDAG.predecessors(edge):
                        cell_states[:, idx_edge] += (
                            transition[
                                :,
                                in_node,
                                adjacency_forward_edgemap[pred_edge],
                                adjacency_backward_edgemap[edge],
                                None,
                                None,
                            ]
                            * cell_states[:, idx_map[pred_edge]]
                        )
                    cell_states[:, idx_edge] += source[
                        :, in_node, adjacency_backward_edgemap[edge], None, None
                    ] * torch.einsum("ha,hb->hab", key[:, in_node], value[:, in_node])

            dcell_states = torch.zeros_like(cell_states)

            for revidx_edge, edge in enumerate(reversed(ctx.lDAG_sorted)):
                idx_edge = num_edges - revidx_edge - 1
                in_node = edge[0]
                out_node = edge[1]
                for succ_edge in lDAG.successors(edge):
                    dcell_states[:, idx_edge] += (
                        transition[
                            :,
                            out_node,
                            adjacency_forward_edgemap[edge],
                            adjacency_backward_edgemap[succ_edge],
                            None,
                            None,
                        ]
                        * dcell_states[:, idx_map[succ_edge]]
                    )
                    dtransition[
                        :, out_node, adjacency_forward_edgemap[edge], adjacency_backward_edgemap[succ_edge]
                    ] += torch.einsum("hkv,hkv->h", cell_states[:, idx_edge], dcell_states[:, idx_map[succ_edge]])
                dcell_states[:, idx_edge] += mark[
                    :, out_node, adjacency_forward_edgemap[edge], None, None
                ] * torch.einsum("hk,hv->hkv", query[:, out_node], doutputs[:, out_node])

                dquery[:, out_node] += mark[:, out_node, adjacency_forward_edgemap[edge], None] * torch.einsum(
                    "hkv,hv->hk", cell_states[:, idx_edge], doutputs[:, out_node]
                )
                dmark[:, out_node, adjacency_forward_edgemap[edge]] += torch.einsum(
                    "hk,hkv,hv->h", query[:, out_node], cell_states[:, idx_edge], doutputs[:, out_node]
                )

                dkey[:, in_node] += source[:, in_node, adjacency_backward_edgemap[edge], None] * torch.einsum(
                    "hkv,hv->hk", dcell_states[:, idx_edge], value[:, in_node]
                )
                dvalue[:, in_node] += source[:, in_node, adjacency_backward_edgemap[edge], None] * torch.einsum(
                    "hkv,hk->hv", dcell_states[:, idx_edge], key[:, in_node]
                )
                dsource[:, in_node, adjacency_backward_edgemap[edge]] += torch.einsum(
                    "hkv,hk,hv", dcell_states[:, idx_edge], key[:, in_node], value[:, in_node]
                )

            dquery += torch.einsum("hn,hnk,hnv,hnv->hnk", direct, key, value, doutputs)
            dkey += torch.einsum("hn,hnk,hnv,hnv->hnk", direct, query, value, doutputs)
            dvalue += torch.einsum("hn,hnk,hnk,hnv->hnk", direct, query, key, doutputs)
            ddirect += torch.einsum("hnk,hnk,hnv,hnv->hn", query, key, value, doutputs)

            return dquery, dkey, dvalue, dsource, dtransition, dmark, ddirect, None, None, None, None

    return pLSTMGraph.apply(
        query,
        key,
        value,
        source,
        transition,
        mark,
        direct,
        adjancency_forward_edgemap,
        adjacency_backward_edgemap,
        lDAG,
        lDAG_sorted,
    )

In [None]:
out = plstm_graph(
    query,
    key,
    value,
    source,
    transition,
    mark,
    direct,
    adjacency_forward_edgemap,
    adjacency_backward_edgemap,
    lDAG,
    lDAG_edges,
)

In [None]:
out = plstm_graph(
    query,
    key,
    value,
    source,
    transition,
    mark,
    direct,
    adjacency_forward_edgemap,
    adjacency_backward_edgemap,
    lDAG,
    lDAG_edges,
)
loss = torch.sum(out)
loss.backward()

### plstm graph transition normalization P mode


In [None]:
"""
Given a list of transitions t: [H, N, E, E], with H heads, N nodes, and E max edges per node, they have to be normalized,
such that torch.sum(torch.abs(t), dim=3) <= 1. 
Also they should be normalized such that the transitions can be between minus one and one, in all cases. 
Given values: t1, t2, t3, t4... , max edges E, real edges e, arbitrary.
Out values: n1, n2, n3, n4, ..., s.t. Sum |n1| + |n2| +... <= 1

Use L1 norm right away with:
ni = ti / (1 + alpha * l1)
"""

In [None]:
"""
Implementation of this.

given: 
vector [H, N, E, E]
actual incoming_edge_nums: [N]

Set non-existant edges to zero:
edge_mask: [N, E] s.t. em = (incoming_edge_nums + 0.5 - arange(E)) > 0

"""

# Test pLSTMGraphLayer

In [None]:
from plstm.torch.plstm_graph_layer import pLSTMGraphLayerConfig, pLSTMGraphLayer
import torch

In [None]:
cfg = pLSTMGraphLayerConfig(input_dim=64, num_heads=4, max_edges=8, mode="P")

In [None]:
layer = pLSTMGraphLayer(cfg)

In [None]:
inp = torch.randn((G.number_of_nodes(), cfg.input_dim))

out = layer(inp, graph=G)

In [None]:
out.sum().backward()

# Test pLSTMGraphEdgeLayer

In [None]:
from plstm.torch.plstm_graph_layer import pLSTMGraphEdgeLayerConfig, pLSTMGraphEdgeLayer, PreparedGraph

In [None]:
g = PreparedGraph.create(G, mode="P")

In [None]:
len(G.edges)

In [None]:
from plstm.graph import dagify

dag2 = dagify(G)
print(len(dag2.edges))

In [None]:
print(len(g.dag.edges))

In [None]:
cfg = pLSTMGraphEdgeLayerConfig(
    input_dim=64,
    num_heads=4,
    edge_input_dim=32,
    max_edges=100,  # not actually used
    mode="P",
)

In [None]:
pge_layer = pLSTMGraphEdgeLayer(cfg)

In [None]:
inp = torch.randn((G.number_of_nodes(), cfg.input_dim))
edge_inp = torch.randn((G.number_of_edges(), cfg.edge_input_dim))

out = pge_layer(inp, edge_inp, graph=g)

In [None]:
out.sum().backward()

In [None]:
# Important TODO:
"""
- pLSTMGraphEdgeLayer: P Mode normalization!!! -> works?
- check if indexing is aligned -> fixed?
- transition biases!!!
- thorough testing!

"""

## pLSTM Graph Block Stack

In [None]:
from plstm.config.graph_block import pLSTMGraphBlockConfig, pLSTMGraphEdgeBlockConfig
from plstm.torch.interfaces import ResidualModule

cfg = pLSTMGraphBlockConfig(input_dim=192, num_heads=12, block_mode="DP", block_type="post_up", max_edges=6)

In [None]:
from compoconf import Registry

Registry._registries

In [None]:
graph_block = cfg.instantiate(ResidualModule)

In [None]:
graph_block

In [None]:
inp = torch.randn((G.number_of_nodes(), cfg.input_dim))
# edge_inp = torch.randn((G.number_of_edges(), cfg.edge_input_dim))

In [None]:
graph_block(inp, graph=g)

In [None]:
cfg = pLSTMGraphEdgeBlockConfig(input_dim=192, num_heads=12, block_mode="DP", block_type="post_up")

In [None]:
cfg

In [None]:
graph_block = cfg.instantiate(ResidualModule)
inp = torch.randn((G.number_of_nodes(), cfg.input_dim))
edge_inp = torch.randn((G.number_of_edges(), cfg.input_dim))

In [None]:
graph_block(inp, edge_features=edge_inp, graph=g)