# Imports

In [76]:
import numpy as np
import torch
import gudhi as gd
import itertools
import os

# import graph_tool as gt
# import graph_tool.topology as top
import networkx as nx

from tqdm import tqdm
from tsp_sc.common.simp_complex import Cochain, SimplicialComplex
from tsp_sc.graph_classification.data.dataset import ComplexDataset
from typing import List, Dict, Optional, Union
from torch import Tensor
from torch_geometric.typing import Adj
from torch_scatter import scatter
from torch_geometric.utils import from_networkx
# from data.parallel import ProgressParallel
from joblib import delayed
from scipy.sparse import coo_matrix


# Input

In [20]:
def get_house_complex():
    """
    Returns the `house graph` below with dummy features.
    The `house graph` (3-2-4 is a filled triangle):
       4
      / \
     3---2
     |   |
     0---1

       .
      4 5
     . 2 .
     3   1
     . 0 .

       .
      /0\
     .---.
     |   |
     .---.
    """
    v_up_index = torch.tensor([[0, 1, 0, 3, 1, 2, 2, 3, 2, 4, 3, 4],
                               [1, 0, 3, 0, 2, 1, 3, 2, 4, 2, 4, 3]], dtype=torch.long)
    v_shared_coboundaries = torch.tensor([0, 0, 3, 3, 1, 1, 2, 2, 5, 5, 4, 4], dtype=torch.long)
    v_x = torch.tensor([[1], [2], [3], [4], [5]], dtype=torch.float)
    yv = torch.tensor([0, 0, 0, 0, 0], dtype=torch.long)
    v_cochain = Cochain(dim=0, x=v_x, upper_index=v_up_index, shared_coboundaries=v_shared_coboundaries, y=yv)

    e_boundaries = [[0, 1], [1, 2], [2, 3], [0, 3], [3, 4], [2, 4]]
    e_boundary_index = torch.stack([
        torch.LongTensor(e_boundaries).view(-1),
        torch.LongTensor([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5]).view(-1)], 0)

    e_up_index = torch.tensor([[2, 4, 2, 5, 4, 5],
                               [4, 2, 5, 2, 5, 4]], dtype=torch.long)
    e_shared_coboundaries = torch.tensor([0, 0, 0, 0, 0, 0], dtype=torch.long)
    e_down_index = torch.tensor([[0, 1, 0, 3, 1, 2, 1, 5, 2, 3, 2, 4, 2, 5, 3, 4, 4, 5],
                                 [1, 0, 3, 0, 2, 1, 5, 1, 3, 2, 4, 2, 5, 2, 4, 3, 5, 4]],
        dtype=torch.long)
    e_shared_boundaries = torch.tensor([1, 1, 0, 0, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 3, 3, 4, 4],
        dtype=torch.long)
    e_x = torch.tensor([[1], [2], [3], [4], [5], [6]], dtype=torch.float)
    ye = torch.tensor([1, 1, 1, 1, 1, 1], dtype=torch.long)
    e_cochain = Cochain(dim=1, x=e_x, upper_index=e_up_index, lower_index=e_down_index,
        shared_coboundaries=e_shared_coboundaries, shared_boundaries=e_shared_boundaries,
        boundary_index=e_boundary_index, y=ye)

    t_boundaries = [[2, 4, 5]]
    t_boundary_index = torch.stack([
        torch.LongTensor(t_boundaries).view(-1),
        torch.LongTensor([0, 0, 0]).view(-1)], 0)
    t_x = torch.tensor([[1]], dtype=torch.float)
    yt = torch.tensor([2], dtype=torch.long)
    t_cochain = Cochain(dim=2, x=t_x, y=yt, boundary_index=t_boundary_index)
    
    y = torch.LongTensor([v_x.shape[0]])
    
    return Complex(v_cochain, e_cochain, t_cochain, y=y)

In [25]:
G = nx.house_graph()
data = from_networkx(G)
edge_index = data['edge_index']
num_nodes = data['num_nodes']

# Simplicial tree creation

In [22]:
def pyg_to_simplex_tree(edge_index: Tensor, size: int):
    """Constructs a simplex tree from a PyG graph.

    Args:
        edge_index: The edge_index of the graph (a tensor of shape [2, num_edges])
        size: The number of nodes in the graph.
    """
    st = gd.SimplexTree()
    # Add vertices to the simplex.
    for v in range(size):
        st.insert([v])

    # Add the edges to the simplex.
    edges = edge_index.numpy()
    for e in range(edges.shape[1]):
        edge = [edges[0][e], edges[1][e]]
        st.insert(edge)

    return st

In [33]:
simplex_tree = pyg_to_simplex_tree(edge_index, num_nodes)

In [44]:
expansion_dim = 2
simplex_tree.expansion(expansion_dim)  # Computes the clique complex up to the desired dim.
complex_dim = simplex_tree.dimension()

# Tables

In [88]:
def build_tables(simplex_tree, size):
    complex_dim = simplex_tree.dimension()
    # Each of these data structures has a separate entry per dimension.
    id_maps = [{} for _ in range(complex_dim+1)] # simplex -> id
    simplex_tables = [[] for _ in range(complex_dim+1)] # matrix of simplices
    boundaries_tables = [[] for _ in range(complex_dim+1)]

    simplex_tables[0] = [[v] for v in range(size)]
    id_maps[0] = {frozenset([v]): v for v in range(size)}

    for simplex, _ in simplex_tree.get_simplices():
        dim = len(simplex) - 1
        if dim == 0:
            continue

        # Assign this simplex the next unused ID
        next_id = len(simplex_tables[dim])
        id_maps[dim][frozenset(simplex)] = next_id
        simplex_tables[dim].append(simplex)

    return simplex_tables, id_maps

In [89]:
tables, id_maps = build_tables(simplex_tree, num_nodes)

In [90]:
print(tables)
print(id_maps)

[[[0], [1], [2], [3], [4]], [[0, 1], [0, 2], [1, 3], [2, 3], [2, 4], [3, 4]], [[2, 3, 4]]]
[{frozenset({0}): 0, frozenset({1}): 1, frozenset({2}): 2, frozenset({3}): 3, frozenset({4}): 4}, {frozenset({0, 1}): 0, frozenset({0, 2}): 1, frozenset({1, 3}): 2, frozenset({2, 3}): 3, frozenset({2, 4}): 4, frozenset({3, 4}): 5}, {frozenset({2, 3, 4}): 0}]


# Boundaries and coboundaries

In [91]:
def get_simplex_boundaries(simplex):
    boundaries = itertools.combinations(simplex, len(simplex) - 1)
    return [tuple(boundary) for boundary in boundaries]

In [94]:

def extract_boundaries_and_coboundaries_from_simplex_tree(
    simplex_tree, id_maps, complex_dim: int
):
    """Build two maps simplex -> its coboundaries and simplex -> its boundaries"""
    # The extra dimension is added just for convenience to avoid treating it as a special case.
    boundaries = [{} for _ in range(complex_dim + 2)]  # simplex -> boundaries
    coboundaries = [{} for _ in range(complex_dim + 2)]  # simplex -> coboundaries
    boundaries_tables = [[] for _ in range(complex_dim + 1)]

    for simplex, _ in simplex_tree.get_simplices():
        print(simplex)
        # Extract the relevant boundary and coboundary maps
        simplex_dim = len(simplex) - 1
        level_coboundaries = coboundaries[simplex_dim]
        level_boundaries = boundaries[simplex_dim + 1]

        # Add the boundaries of the simplex to the boundaries table
        # (2, 3, 4) --> [(2, 3), (2, 4), (3, 4)] --> [id[(2, 3)], id[(2, 4)], id[(3, 4)]
        if simplex_dim > 0:
            boundaries_ids = [
                id_maps[simplex_dim - 1][frozenset(boundary)]
                for boundary in get_simplex_boundaries(simplex)
            ]
            boundaries_tables[simplex_dim].append(boundaries_ids)

        simplex_coboundaries = simplex_tree.get_cofaces(simplex, codimension=1)
        for coboundary, _ in simplex_coboundaries:
            assert len(coboundary) == len(simplex) + 1

            if tuple(simplex) not in level_coboundaries:
                level_coboundaries[tuple(simplex)] = list()
            level_coboundaries[tuple(simplex)].append(tuple(coboundary))

            if tuple(coboundary) not in level_boundaries:
                level_boundaries[tuple(coboundary)] = list()
            level_boundaries[tuple(coboundary)].append(tuple(simplex))

    return boundaries_tables, boundaries, coboundaries

In [95]:
# Extracts the boundaries and coboundaries of each simplex in the complex
boundaries_tables, boundaries, co_boundaries = (
    extract_boundaries_and_coboundaries_from_simplex_tree(simplex_tree, id_maps, complex_dim))

[0, 1]
[0, 2]
[0]
[1, 3]
[1]
[2, 3, 4]
[2, 3]
[2, 4]
[2]
[3, 4]
[3]
[4]


In [96]:
print(boundaries_tables)

[[], [[0, 1], [0, 2], [1, 3], [2, 3], [2, 4], [3, 4]], [[3, 4, 5]]]


In [97]:
print(boundaries)

[{}, {(0, 1): [(0,), (1,)], (0, 2): [(0,), (2,)], (1, 3): [(1,), (3,)], (2, 3): [(2,), (3,)], (2, 4): [(2,), (4,)], (3, 4): [(3,), (4,)]}, {(2, 3, 4): [(2, 3), (2, 4), (3, 4)]}, {}]


In [98]:
print(co_boundaries)

[{(0,): [(0, 1), (0, 2)], (1,): [(0, 1), (1, 3)], (2,): [(0, 2), (2, 3), (2, 4)], (3,): [(1, 3), (2, 3), (3, 4)], (4,): [(2, 4), (3, 4)]}, {(2, 3): [(2, 3, 4)], (2, 4): [(2, 3, 4)], (3, 4): [(2, 3, 4)]}, {}, {}]


In [99]:

def build_boundaries(id_maps):
    """
    Build the boundary operators from a list of simplices.

    Parameters
    ----------
    simplices:
                List of dictionaries, one per dimension d.
                The size of the dictionary is the number of d-simplices.
                The dictionary's keys are sets (of size d+1) of the vertices that constitute the d-simplices.
                The dictionary's values are the indexes of the simplices in the boundary and Laplacian matrices.
    Returns
    -------
    boundaries:
                List of boundary operators, one per dimension: i-th boundary is in (i-1)-th position
    """
    boundaries = list()

    for dim in range(1, len(id_maps)):
        idx_simplices, idx_faces, values = [], [], []

        # simplex is a frozenset of vertices, idx_simplex is the integer progressive id of the simplex
        for simplex, idx_simplex in id_maps[dim].items():
            simplices_list_sorted = np.sort(list(simplex))

            for i, left_out in enumerate(simplices_list_sorted):
                # linear combination of the face of the simplex obtained by removing
                # the i-th vertex
                idx_simplices.append(idx_simplex)
                values.append((-1) ** i)
                face = simplex.difference({left_out})
                idx_faces.append(id_maps[dim - 1][face])

        assert len(values) == (dim + 1) * len(id_maps[dim])
        boundary = coo_matrix(
            (values, (idx_faces, idx_simplices)),
            dtype=np.float32,
            shape=(len(id_maps[dim - 1]), len(id_maps[dim])),
        )
        boundaries.append(boundary)
    return boundaries

In [100]:
boundaries = build_boundaries(id_maps)

# Features

In [103]:

def construct_features(vx: Tensor, cell_tables, init_method: str) -> List:
    """Combines the features of the component vertices to initialise the cell features"""
    features = [vx]
    for dim in range(1, len(cell_tables)):
        aux_1 = []
        aux_0 = []
        for c, cell in enumerate(cell_tables[dim]):
            aux_1 += [c for _ in range(len(cell))]
            aux_0 += cell
        node_cell_index = torch.LongTensor([aux_0, aux_1])
        in_features = vx.index_select(0, node_cell_index[0])
        features.append(
            scatter(
                in_features,
                node_cell_index[1],
                dim=0,
                dim_size=len(cell_tables[dim]),
                reduce=init_method,
            )
        )

    return features

In [104]:
xs = construct_features(x, simplex_tables, init_method)

# Initialise the node / complex labels
v_y, complex_y = extract_labels(y, size)

cochains = []
for i in range(complex_dim+1):
    y = v_y if i == 0 else None
    cochain = generate_cochain(i, xs[i], upper_idx, lower_idx, shared_boundaries, shared_coboundaries,
                           simplex_tables, boundaries_tables, complex_dim=complex_dim, y=y)
    cochains.append(cochain)

return Complex(*cochains, y=complex_y, dimension=complex_dim)

NameError: name 'x' is not defined