In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import networkx as nx
import numpy as np

# ---------- 1. Grafo bipartito y labels ----------
def make_bipartite_graph(n_per_class=50, p_edge=0.1, seed=0):
    rng = np.random.RandomState(seed)
    
    # Creamos un grafo bipartito A-B
    A_nodes = list(range(n_per_class))
    B_nodes = list(range(n_per_class, 2 * n_per_class))
    
    G = nx.Graph()
    G.add_nodes_from(A_nodes, bipartite=0, label=0) 
    G.add_nodes_from(B_nodes, bipartite=1, label=1)
    
    # Añadimos aristas aleatorias entre A y B
    for a in A_nodes:
        for b in B_nodes:
            if rng.rand() < p_edge:
                G.add_edge(a, b)
    
    # Nos aseguramos de que el grafo sea conectado
    if not nx.is_connected(G):
        # Conectar componentes a lo bruto
        comps = list(nx.connected_components(G))
        for i in range(len(comps) - 1):
            u = next(iter(comps[i]))
            v = next(iter(comps[i+1]))
            G.add_edge(u, v)
    
    labels = np.array([G.nodes[i]['label'] for i in G.nodes()])
    
    # Lista de aristas como pares (u, v)
    edges = np.array(list(G.edges()), dtype=np.int64)
    return G, edges, labels

# ---------- 2. Features Gaussianos solapados ----------
def make_gaussian_features(labels, mu=1.0, sigma=1.0, seed=0):
    rng = np.random.RandomState(seed)
    n = len(labels)
    fdim = 2  # 2D para poder visualizar
    X = np.zeros((n, fdim), dtype=np.float32)
    
    # Clase 0: N((-mu, 0), sigma^2 I)
    # Clase 1: N((+mu, 0), sigma^2 I)
    for i, y in enumerate(labels):
        mean = np.array([-mu, 0.0]) if y == 0 else np.array([+mu, 0.0])
        X[i] = rng.normal(loc=mean, scale=sigma, size=(fdim,))
    return X.astype(np.float32)

# Ejemplo de uso:
G, edges, labels = make_bipartite_graph(n_per_class=100, p_edge=0.05, seed=0)
X0 = make_gaussian_features(labels, mu=1.0, sigma=1.5, seed=0)

X0 = torch.from_numpy(X0)          # [n, f]
y  = torch.from_numpy(labels)      # [n]
edges_torch = torch.from_numpy(edges)  # [m, 2]


In [2]:
def build_sheaf_laplacian_scalar(n_nodes, edges, F_ve, F_ue):
    """
    n_nodes: número de nodos
    edges: [m, 2] tensor con (v, u)
    F_ve, F_ue: [m] tensores con los mapas F_{v->e}, F_{u->e}
    
    Devuelve ∆_F (nd x nd) con d=1, así que es [n, n]
    """
    m = edges.shape[0]
    L = torch.zeros((n_nodes, n_nodes), dtype=torch.float32)
    
    # Para cada arista e = (v, u):
    for e_idx in range(m):
        v = edges[e_idx, 0].item()
        u = edges[e_idx, 1].item()
        fv = F_ve[e_idx]
        fu = F_ue[e_idx]
        
        # Según Definition 2 del paper:
        # L_F(x)_v = sum_{v,u~e} F_v^T(F_v x_v - F_u x_u)
        # Lo que induce:
        # L[v, v] += fv^2
        # L[v, u] -= fv * fu
        # L[u, u] += fu^2
        # L[u, v] -= fu * fv
        L[v, v] += fv * fv
        L[u, u] += fu * fu
        L[v, u] -= fv * fu
        L[u, v] -= fu * fv
    
    # Normalización por D^{-1/2} (como en el paper)
    d = torch.diag(L)
    # Evitar división por cero
    d_clamped = torch.clamp(d, min=1e-6)
    D_inv_sqrt = torch.diag(1.0 / torch.sqrt(d_clamped))
    Delta = D_inv_sqrt @ L @ D_inv_sqrt
    return Delta


In [3]:
class PhiGeneral(nn.Module):
    """
    Φ general: produce F_{v->e} y F_{u->e} distintos (modelo no simétrico).
    """
    def __init__(self, in_dim, hidden_dim=32):
        super().__init__()
        self.mlp_v = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.mlp_u = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x_v, x_u):
        """
        x_v, x_u: [m, f]
        devuelve: F_ve, F_ue: [m]
        """
        inp = torch.cat([x_v, x_u], dim=-1)  # [m, 2f]
        F_ve = self.mlp_v(inp).squeeze(-1)
        F_ue = self.mlp_u(inp).squeeze(-1)
        
        # Para evitar mapas degenerados, puedes forzar que no sean ~0
        return F_ve, F_ue


class PhiSymmetric(nn.Module):
    """
    Φ simétrica: F_{v->e} = F_{u->e} (modelo tipo weighted graph Laplacian).
    """
    def __init__(self, in_dim, hidden_dim=32):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )
    
    def forward(self, x_v, x_u):
        inp = torch.cat([x_v, x_u], dim=-1)  # [m, 2f]
        F_e = self.mlp(inp).squeeze(-1)
        return F_e, F_e  # F_v = F_u


In [4]:
class SheafDiffusionModel(nn.Module):
    def __init__(self, in_features, hidden_phi=32, symmetric=False, T=20, n_classes=2):
        super().__init__()
        self.T = T
        in_dim_phi = 2 * in_features  # [x_v || x_u]
        if symmetric:
            self.phi = PhiSymmetric(in_dim_phi, hidden_dim=hidden_phi)
        else:
            self.phi = PhiGeneral(in_dim_phi, hidden_dim=hidden_phi)
        
        # Clasificador lineal sobre X^T
        self.classifier = nn.Linear(in_features, n_classes)
    
    def forward(self, X0, edges):
        """
        X0: [n, f]
        edges: [m, 2]
        """
        n_nodes = X0.shape[0]
        X = X0
        
        # Construimos F a partir de X0 (como en el paper: se aprende a t=0)
        v_idx = edges[:, 0]
        u_idx = edges[:, 1]
        x_v = X0[v_idx]
        x_u = X0[u_idx]
        F_ve, F_ue = self.phi(x_v, x_u)  # [m]
        
        # Laplaciano de sheaf ∆_F (scalar, d=1)
        Delta = build_sheaf_laplacian_scalar(n_nodes, edges, F_ve, F_ue)  # [n, n]
        
        # Difusión "vanilla": X^{t+1} = X^t - ∆_F X^t
        for _ in range(self.T):
            X = X - Delta @ X  # [n, f]
        
        # Clasificador lineal
        logits = self.classifier(X)  # [n, n_classes]
        return logits


In [5]:
# Parámetros
n_per_class = 100
G, edges, labels = make_bipartite_graph(n_per_class=n_per_class, p_edge=0.05, seed=0)
X0 = make_gaussian_features(labels, mu=1.0, sigma=1.5, seed=0)

X0 = torch.from_numpy(X0)
y  = torch.from_numpy(labels)
edges_torch = torch.from_numpy(edges)

# Split train/test
n = X0.shape[0]
perm = torch.randperm(n)
train_size = int(0.8 * n)
train_idx = perm[:train_size]
test_idx  = perm[train_size:]

def train_model(symmetric=False):
    model = SheafDiffusionModel(
        in_features=X0.shape[1],
        hidden_phi=32,
        symmetric=symmetric,
        T=20,
        n_classes=2
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
    
    for epoch in range(200):
        model.train()
        optimizer.zero_grad()
        
        logits = model(X0, edges_torch)
        loss = F.cross_entropy(logits[train_idx], y[train_idx])
        loss.backward()
        optimizer.step()
        
        if (epoch + 1) % 50 == 0:
            model.eval()
            with torch.no_grad():
                pred = logits.argmax(dim=-1)
                train_acc = (pred[train_idx] == y[train_idx]).float().mean().item()
                test_acc  = (pred[test_idx] == y[test_idx]).float().mean().item()
            print(f"Epoch {epoch+1} | loss={loss.item():.4f} | train_acc={train_acc:.3f} | test_acc={test_acc:.3f}")
    
    return model

print("=== Modelo simétrico (weighted Laplacian) ===")
sym_model = train_model(symmetric=True)

print("\n=== Modelo general de sheaf (no simétrico) ===")
gen_model = train_model(symmetric=False)


=== Modelo simétrico (weighted Laplacian) ===
Epoch 50 | loss=0.4862 | train_acc=0.781 | test_acc=0.750
Epoch 100 | loss=0.3327 | train_acc=0.994 | test_acc=0.925
Epoch 150 | loss=0.2229 | train_acc=1.000 | test_acc=0.975
Epoch 200 | loss=0.1499 | train_acc=1.000 | test_acc=1.000

=== Modelo general de sheaf (no simétrico) ===
Epoch 50 | loss=0.3186 | train_acc=1.000 | test_acc=1.000
Epoch 100 | loss=0.1671 | train_acc=1.000 | test_acc=1.000
Epoch 150 | loss=0.1047 | train_acc=1.000 | test_acc=1.000
Epoch 200 | loss=0.0720 | train_acc=1.000 | test_acc=1.000
