<a href="https://colab.research.google.com/github/eisbetterthanpi/hypergraph/blob/main/hgnn_list.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# @title dgl data
!pip install dgl

# "co-cite" relationship: hyperedge includes all the other papers it cited, as well as the paper itself
# then incidence mat = incidence mat + id
import torch
from dgl.data import CoraGraphDataset # https://github.com/dmlc/dgl/blob/master/python/dgl/data/citation_graph.py

def load_data():
    dataset = CoraGraphDataset()
    graph = dataset[0]
    indices = torch.stack(graph.edges())
    H = torch.sparse_coo_tensor(indices=indices, values=torch.ones(indices.shape[1]),).coalesce()
    id = torch.sparse.spdiags(torch.ones(H.shape[0]),torch.tensor(0),H.shape) # torch.eye(H.shape[0])
    H = (id + H).coalesce() # each vert got its hyperedge, contain all cited and itself, [2708, 2708], incedence matrix, |V| hyperedges

    X = graph.ndata["feat"] #[2708, 1433] num papers, len bag of words
    Y = graph.ndata["label"] # [2708], classiifcation 0-6
    train_mask = graph.ndata["train_mask"]
    val_mask = graph.ndata["val_mask"]
    test_mask = graph.ndata["test_mask"]
    return H, X, Y, dataset.num_classes, train_mask, val_mask, test_mask

H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()
# print(H.shape, X.shape, Y.shape) # [2708, 2708], [2708, 1433], [2708]

n_v, n_e = H.shape
elst = [[] for id in range(n_e)] # edge list H(E)={e1,e2,e3}={{A,D},{D,E},{A,B,C}}
ilst = [[] for id in range(n_v)] # incidence list {A:{e1,e3}, B:{e3}, C:{e3}, D:{e1,e2}, E:{e2}}
for a,b in H.indices().T.tolist():
    elst[a].append(b)
    ilst[b].append(a)


In [None]:
# @title models
import torch
import torch.nn as nn
import torch.nn.functional as F

@torch.no_grad
def hypergraph_laplacian(H):
    N,M = H.shape
    d_V = H.sum(1).to_dense() # node deg
    d_E = H.sum(0).to_dense() # edge deg
    D_v_invsqrt = torch.sparse.spdiags(d_V**-0.5,torch.tensor(0),(N,N)) # torch.diag(d_V**-0.5)
    D_e_inv = torch.sparse.spdiags(d_E**-1,torch.tensor(0),(M,M)) # torch.diag(d_E**-1)
    B = torch.sparse.spdiags(torch.ones(M),torch.tensor(0),(M,M)) # torch.eye(M) # B is id, dim n_edges
    return D_v_invsqrt @ H @ B @ D_e_inv @ H.T @ D_v_invsqrt # Laplacian

class HGNN(nn.Module): # https://github.com/dmlc/dgl/blob/master/notebooks/sparse/hgnn.ipynb
    def __init__(self, H, in_size, out_size, hidden_dims=16):
        super().__init__()
        self.W1 = nn.Linear(in_size, hidden_dims)
        self.W2 = nn.Linear(hidden_dims, out_size)
        self.dropout = nn.Dropout(0.5)
        self.L = hypergraph_laplacian(H)

    def forward(self, H, X):
        X = self.L @ self.W1(self.dropout(X)) # like emb then weighted sum
        X = F.relu(X)
        X = self.L @ self.W2(self.dropout(X))
        return X

# Hypergraph Convolution and Hypergraph Attention https://arxiv.org/pdf/1901.08150.pdf
class HypergraphAttention(nn.Module): # https://github.com/dmlc/dgl/blob/master/examples/sparse/hypergraphatt.py
    def __init__(self, in_size, out_size):
        super().__init__()
        self.P = nn.Linear(in_size, out_size)
        self.a = nn.Linear(2 * out_size, 1)

    def forward(self, H, X, X_edges): # H [2708, 2708] n_vert,n_edge ; X n_vert,vembdim
        Z = self.P(X) # emb verts [n_vert,out_size]
        Z_edges = self.P(X_edges) # emb edges
        sim = self.a(torch.cat([Z[H.indices()[0]], Z_edges[H.indices()[1]]], 1)) #  vertemb,edgeemb(=vertemb)
        sim = F.leaky_relu(sim, 0.2).squeeze(1) # og[13264]
        H_att = torch.sparse_coo_tensor(indices=H.indices(), values=sim,).coalesce()
        H_att = torch.sparse.softmax(H_att,1) # [2708, 2708]
        return hypergraph_laplacian(H_att) @ Z # [2708, 2708], [2708, hidden_size/out_size]

class Net(nn.Module):
    def __init__(self, in_size, out_size, hidden_size=16):
        super().__init__()
        self.layer1 = HypergraphAttention(in_size, hidden_size)
        self.layer2 = HypergraphAttention(hidden_size, out_size)

    def forward(self, H, X):
        Z = self.layer1(H, X, X) # [n_vert, hidden_size]
        Z = F.elu(Z)
        Z = self.layer2(H, Z, Z) # [n_vert, out_size]
        return Z



In [None]:
# @title models down
import torch
import torch.nn as nn
import torch.nn.functional as F

@torch.no_grad
def hypergraph_laplacian(H):
    N,M = H.shape # num_verts, num_edges
    d_V = H.sum(1).to_dense() # node deg
    d_E = H.sum(0).to_dense() # edge deg
    D_v_invsqrt = torch.sparse.spdiags(d_V**-0.5,torch.tensor(0),(N,N)) # torch.diag(d_V**-0.5)
    D_e_inv = torch.sparse.spdiags(d_E**-1,torch.tensor(0),(M,M)) # torch.diag(d_E**-1)
    B = torch.sparse.spdiags(torch.ones(M),torch.tensor(0),(M,M)) # torch.eye(M) # B is id, dim n_edges
    return D_v_invsqrt @ H @ B @ D_e_inv @ H.T @ D_v_invsqrt # Laplacian

class HGNN(nn.Module): # https://github.com/dmlc/dgl/blob/master/notebooks/sparse/hgnn.ipynb
    def __init__(self, H, in_size, out_size, hidden_dims=16):
        super().__init__()
        self.W1 = nn.Linear(in_size, hidden_dims)
        self.W2 = nn.Linear(hidden_dims, out_size)
        self.dropout = nn.Dropout(0.) # og 0.5
        self.L = hypergraph_laplacian(H)

    def forward(self, H, X):
        X = self.L @ self.W1(self.dropout(X)) # like emb then weighted sum
        X = F.relu(X)
        X = self.L @ self.W2(self.dropout(X))
        return X

# Hypergraph Convolution and Hypergraph Attention https://arxiv.org/pdf/1901.08150.pdf
class HypergraphAttention(nn.Module): # https://github.com/dmlc/dgl/blob/master/examples/sparse/hypergraphatt.py
    def __init__(self, in_size, out_size):
        super().__init__()
        self.P = nn.Linear(in_size, out_size)
        self.a = nn.Linear(2 * out_size, 1) # og

    def forward(self, H, X, X_edges): # H [2708, 2708] n_vert,n_edge ; X n_vert,vembdim
        Z = self.P(X) # emb verts [n_vert,out_size]
        # Z_edges = self.P(X_edges) # emb edges
        # sim = self.a(torch.cat([Z[H.indices()[0]], Z_edges[H.indices()[1]]], 1)) #  vertemb,edgeemb(=vertemb)
        sim = self.a(torch.cat([Z[H.indices()[0]], Z[H.indices()[1]]], 1)) #  vertemb,edgeemb(=vertemb)
        # sim = F.leaky_relu(sim, 0.2).squeeze(1) # og[13264]
        sim = F.relu(sim).squeeze(1) # me
        H_att = torch.sparse_coo_tensor(indices=H.indices(), values=sim,).coalesce()
        H_att = torch.sparse.softmax(H_att,1) # [2708, 2708]
        return hypergraph_laplacian(H_att) @ Z # [2708, 2708], [2708, hidden_size/out_size]

class Net(nn.Module):
    def __init__(self, in_size, out_size, hidden_size=16):
        super().__init__()
        self.layer1 = HypergraphAttention(in_size, hidden_size)
        self.layer2 = HypergraphAttention(hidden_size, out_size)

    def forward(self, H, X):
        Z = self.layer1(H, X, X) # [n_vert, hidden_size]
        # Z = F.elu(Z) # og
        Z = F.relu(Z)
        Z = self.layer2(H, Z, Z) # [n_vert, out_size]
        return Z



In [None]:
# @title theory
@torch.no_grad
def hypergraph_laplacian(H):
    N,M = H.shape # num_verts, num_edges
    d_V = H.sum(1).to_dense() # node deg
    d_E = H.sum(0).to_dense() # edge deg
    D_v_invsqrt = torch.sparse.spdiags(d_V**-0.5,torch.tensor(0),(N,N)) # torch.diag(d_V**-0.5)
    D_e_inv = torch.sparse.spdiags(d_E**-1,torch.tensor(0),(M,M)) # torch.diag(d_E**-1)
    B = torch.sparse.spdiags(torch.ones(M),torch.tensor(0),(M,M)) # torch.eye(M) # B is id, dim n_edges
    return D_v_invsqrt @ H @ B @ D_e_inv @ H.T @ D_v_invsqrt # Laplacian


# @torch.no_grad
# def hypergraph_laplacian1(H, B):
N,M = H.shape # num_verts, num_edges
d_V = H.sum(1).to_dense() # node deg
d_E = H.sum(0).to_dense() # edge deg
D_v_invsqrt = torch.sparse.spdiags(d_V**-0.5,torch.tensor(0),(N,N)) # torch.diag(d_V**-0.5)
D_v_inv = torch.sparse.spdiags(d_V**-1,torch.tensor(0),(N,N)) # torch.diag(d_V**-0.5)
D_e_inv = torch.sparse.spdiags(d_E**-1,torch.tensor(0),(M,M)) # torch.diag(d_E**-1)
D_e_invsqrt = torch.sparse.spdiags(d_E**-0.5,torch.tensor(0),(M,M)) # torch.diag(d_E**-1)
B = torch.sparse.spdiags(torch.ones(M),torch.tensor(0),(M,M)) # torch.eye(M) # B is id, dim n_edges
# return D_v_invsqrt @ H @ B @ D_e_inv @ H.T @ D_v_invsqrt # Laplacian
# return H @ B @ D_e_inv @ H.T @ D_v_inv

# hl=hypergraph_laplacian(H)
# hl1=hypergraph_laplacian1(H)
# print(hl.to_dense()[:5,:5])
# print(hl1.to_dense()[:5,:5])

n_v, n_e = H.shape
d_model=16
vemb=torch.rand(n_v,d_model)
eemb=torch.rand(n_e,d_model)
# B = torch.sparse.spdiags(eemb,torch.tensor(0),(M,M,d_model)) # torch.diag(d_E**-1)


                    # [n_vert,n_edge] @ [num_edge, d_model]
# vemb1 = (D_v_invsqrt @ H @ B @ D_e_inv @ H.T @ D_v_invsqrt @ self.fv(vemb))
out = D_v_invsqrt @ H @ B @ D_e_inv @ H.T @ D_v_invsqrt @ vemb
print(out.to_dense()[:5,:5])

# # out = D_v_invsqrt @ H @ eemb @ D_e_inv @ H.T @ D_v_invsqrt @ vemb
                    # [n_vert,n_edge] @ [num_edge, d_model]
vemb1 = D_v_invsqrt @ H @ D_e_invsqrt @ B @ D_e_invsqrt @ H.T @ D_v_invsqrt @ vemb
                    # [num_edge,n_edge] @ [n_vert, d_model]
eemb1 = D_e_invsqrt @ H.T @ D_v_invsqrt @ V @ D_v_invsqrt @ H @ D_e_invsqrt @ eemb
print(out.to_dense()[:5,:5])


vmsg = self.fv(vemb) # vmsg = self.fv(torch.cat((vemb, semsg), 1))
svmsg = D_e_invsqrt @ H.T @ D_v_invsqrt @ vmsg # [num_edge, d_model]
emsg = self.fw(torch.cat((eemb, svmsg), 1))
semsg = D_v_invsqrt @ H @ D_e_invsqrt @ emsg

vemb1 = self.gv(torch.cat((vemb, semsg), 1))
eemb1 = self.gw(torch.cat((eemb, svmsg), 1))


vmsg = self.fv(torch.cat((vemb, semsg), 1)) # vmsg = self.fv(vemb)
svmsg # H.T @ D_v_invsqrt
emsg = self.fw(torch.cat((eemb, svmsg), 1))
semsg # D_v_invsqrt @ H @ D_e_inv


In [108]:
# @title adj, inc drop, transform
# https://arxiv.org/pdf/2203.16995.pdf
import torch
import torch.nn as nn
device = "cuda" if torch.cuda.is_available() else "cpu"

class AdjDropout(nn.Module): # adjacency dropout
    def __init__(self, p=0.7):
        super(AdjDropout, self).__init__()
        self.p=p
    def forward(self, H): # randomly remove hyperedges
        if self.training: # apply AdjDropout only during training
            n_v, n_e = H.shape
            mask = (torch.rand(n_e) >= self.p).float().expand(n_v, n_e) # 1->keep, throw p
            return H*mask # randomly zero out cols(aka hyperedges)
        else: return H

class IncDropout(nn.Module):
    def __init__(self, p=0.7):
        super(IncDropout, self).__init__()
        self.p=p
    def forward(self, H): # randomly set incidence to 0
        if self.training:
            Hval = H.values()
            mask = (torch.rand(len(Hval)) >= self.p).float() # 1->keep, throw p
            return  torch.sparse_coo_tensor(indices=H.indices(), values=Hval*mask,).coalesce() # randomly zero out values(ie remove verts from hyperedges)
        else: return H


class TrainTransform(object):
    def __init__(self):
        self.transform = nn.Sequential(IncDropout(0.5), AdjDropout(0.5),)
        self.transform_prime = nn.Sequential(IncDropout(0.5), AdjDropout(0.5),)
    def __call__(self, H):
        return self.transform(H), self.transform_prime(H)
trs=TrainTransform()

# print(H.to_dense()[:5,:5])
# print(trs(H)[1].to_dense()[:5,:5])


In [None]:
# @title HMPNN me H attn vic
# https://arxiv.org/pdf/2203.16995.pdf
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Vert msg = fv(vert ebd) , Sum edge msgs
# Edge msg = fw(edge emb, Sum Vert msgs)
# Vert emb1 = gv(vert emb, Sum edge msgs)
# Edge emb1 = gw(edge emb, Sum Vert msgs)

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

class MultiHeadAttention(nn.Module):
    # def __init__(self, d_model, n_heads, dropout=0):
    def __init__(self, d_model, n_heads, dropout=0):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.head_dim = d_model // n_heads
        self.q = nn.Linear(d_model, d_model, bias=False)
        self.k = nn.Linear(d_model, d_model, bias=False)
        self.v = nn.Linear(d_model, d_model, bias=False)
        self.lin = nn.Linear(d_model, d_model)
        # self.lin = nn.Linear(d_model, out_dim)
        self.drop = nn.Dropout(dropout)
        self.scale = torch.sqrt(torch.tensor((self.head_dim,), dtype=torch.float, device=device))

    def forward(self, query, key, value, mask=None):
        batch_size = query.shape[0]
        Q = self.q(query).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        K = self.k(key).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        V = self.v(value).view(batch_size, -1, self.n_heads, self.head_dim).transpose(1, 2)
        attn = Q @ K.transpose(2, 3) / self.scale # attn = torch.matmul(Q, K.transpose(2, 3)) / self.scale
        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)
        attention = torch.softmax(attn, dim=-1)
        x = self.drop(attention) @ V # x = torch.matmul(self.drop(attention), V)
        x = x.transpose(1, 2).reshape(batch_size, -1, self.d_model)
        x = self.lin(x)
        return x#, attention

from torch.nn.utils.rnn import pad_sequence
def get_idx(H): # get index of non zero entries for each row
    csr=H.to_sparse_csr()
    ss=torch.split(csr.col_indices(), tuple(torch.diff(csr.crow_indices()))) # https://stackoverflow.com/a/44536294/13359815
    sidx=pad_sequence(ss, batch_first=True, padding_value=-1)
    mask=sidx<0
    return sidx, mask # [n_rows, num_idx]

class ff(nn.Module):
    # def __init__(self, in_dim, hid_dim, out_dim):
    def __init__(self, in_dim, out_dim):
        super(ff, self).__init__()
        h_dim=16
        self.lin = nn.Sequential(
            # nn.Linear(in_dim, h_dim), nn.ReLU(), # ReLU Sigmoid
            # nn.Linear(h_dim, h_dim), nn.ReLU(),
            # nn.Linear(h_dim, out_dim),

            # nn.Linear(in_dim, out_dim), nn.BatchNorm1d(out_dim),
            # # nn.Sigmoid(), nn.Dropout(p=0.5) # google
            # nn.Dropout(p=0.5), nn.Sigmoid() # intra-layer

            # nn.BatchNorm1d(in_dim), nn.Linear(in_dim, out_dim), nn.Sigmoid(), nn.Dropout(p=0.1) # me
            nn.Linear(in_dim, out_dim), nn.Sigmoid(), nn.Dropout(p=0.5)
            )
    def forward(self, x):
        return self.lin(x)

class MsgPass(nn.Module):
    def __init__(self, d_model):
        super(MsgPass, self).__init__()
        self.ffv = ff(d_model, d_model)
        self.ffw = ff(2*d_model, d_model)
        self.fgv = ff(2*d_model, d_model)
        self.fgw = ff(2*d_model, d_model)
        drop=0.
        self.fv = MultiHeadAttention(d_model, n_heads=1, dropout=drop)
        self.fw = MultiHeadAttention(d_model, n_heads=1, dropout=drop)
        self.gv = MultiHeadAttention(d_model, n_heads=1, dropout=drop)
        self.gw = MultiHeadAttention(d_model, n_heads=1, dropout=drop)
        # self.adjdrop = AdjDropout(0.7) # 0.7 "Adjacency dropout must be applied in neighborhood creation steps of Equations 3 through 5"
        self.adjdrop = AdjDropout(0)
        self.drop = nn.Dropout(0.5)

    def forward(self, H, vemb, eemb, emsg=None):
        # vmsg = self.ffv(vemb)
        # H = AdjDropout(0.7)(H)
        # svmsg = self.adjdrop(H).T @ vmsg # sum aggregate
        # emsg = self.ffw(torch.cat((eemb, svmsg), 1))
        # semsg = self.adjdrop(H) @ emsg
        # vemb1 = self.fgv(torch.cat((vemb, semsg), 1))
        # svmsg = self.adjdrop(H).T @ vmsg
        # eemb1 = self.fgw(torch.cat((eemb, svmsg), 1))

        # vemb = self.drop(vemb)

        if emsg==None:
            vmsg = self.ffv(vemb)
        else:
            ridx, mask = get_idx(H) # [n_rows, num_idx]
            semsg=emsg[ridx] # [n_rows, num_idx, d_model]
            mask=mask.unsqueeze(1).unsqueeze(2) # [n_rows, 1, 1, num_idx]
            vmsg = self.fv(vemb, semsg, semsg, mask) # [n_rows, 1, d_model]

        # vmsg = self.drop(vmsg)
        cidx, mask = get_idx(H.T) # [n_cols, num_idx]
        svmsg=vmsg[cidx] # [n_cols, num_idx, d_model]
        mask=mask.unsqueeze(1).unsqueeze(2) # [n_cols, 1, 1, num_idx]
        emsg = self.fw(eemb, svmsg, svmsg, mask) # [n_cols, 1, d_model]

        # emsg = self.drop(emsg)
        ridx, mask = get_idx(H) # [n_rows, num_idx]
        semsg=emsg[ridx] # [n_rows, num_idx, d_model]
        mask=mask.unsqueeze(1).unsqueeze(2) # [n_rows, 1, 1, num_idx]
        vemb1 = self.gv(vemb, semsg, semsg, mask) # [n_rows, 1, d_model]

        # vmsg = self.drop(vmsg)
        cidx, mask = get_idx(H.T) # [n_cols, num_idx]
        svmsg=vmsg[cidx] # [n_cols, num_idx, d_model]
        mask=mask.unsqueeze(1).unsqueeze(2) # [n_cols, 1, 1, num_idx]
        eemb1 = self.gw(eemb, svmsg, svmsg, mask) # [n_cols, 1, d_model]
        eemb1=eemb1.squeeze()
        vemb1=vemb1.squeeze()

        return vemb1, eemb1, emsg

class HMPNN(nn.Module):
    def __init__(self, in_dim, d_model, out_dim):
        super(HMPNN, self).__init__()
        self.venc = nn.Linear(in_dim, d_model, bias=False)
        self.eenc = nn.Linear(in_dim, d_model, bias=False)
        self.msgpass = MsgPass(d_model)
        self.msgpass2 = MsgPass(d_model)

        # f=[d_model,256,256,256]
        f=[d_model,32,32,32]
        self.exp = nn.Sequential(
            nn.Linear(f[0], f[1]), nn.BatchNorm1d(f[1]), nn.ReLU(),
            nn.Linear(f[1], f[2]), nn.BatchNorm1d(f[2]), nn.ReLU(),
            nn.Linear(f[-2], f[-1], bias=False)
            )
        self.classifier = nn.Linear(d_model, out_dim)

    def forward(self, H, vemb):
        # vemb = self.venc(vemb)
        vemb, eemb = self.venc(vemb), self.eenc(vemb)
        # eemb = torch.zeros(len(elst),self.eembdim)
        # eemb = vemb
        vemb, eemb, emsg = self.msgpass(H, vemb, eemb)
        # vemb, eemb = vemb+vemb1, eemb+eemb1
        vemb, eemb, emsg = self.msgpass2(H, vemb, eemb, emsg=emsg)
        # vemb, eemb = vemb+vemb1, eemb+eemb1
        # return vemb
        return self.classifier(vemb)

    # https://arxiv.org/pdf/2105.04906.pdf
    def vicreg(self, x, y): # https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py
        # invariance loss
        repr_loss = F.mse_loss(x, y) # s(Z, Z')

        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        # variance loss
        std_x = torch.sqrt(x.var(dim=0) + 0.0001) #ϵ=0.0001
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2

        batch_size, num_features = x.shape
        sim_coeff=10.0 # 25.0 # λ
        std_coeff=10.0 # 25.0 # µ
        cov_coeff=1.0 # 1.0 # ν

        if x.dim() == 1: x = x.unsqueeze(0)
        if y.dim() == 1: y = y.unsqueeze(0)

        # # covariance loss
        cov_x = (x.T @ x) / (batch_size - 1) #C(Z)
        cov_y = (y.T @ y) / (batch_size - 1)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(num_features)\
         + off_diagonal(cov_y).pow_(2).sum().div(num_features) #c(Z)
        loss = (sim_coeff * repr_loss + std_coeff * std_loss + cov_coeff * cov_loss)
        print("in vicreg ",(sim_coeff * repr_loss).item() , (std_coeff * std_loss).item() , (cov_coeff * cov_loss).item())
        return loss

    def loss(self, H1, H2, vemb):
        sx, sy = self.forward(H1, vemb), self.forward(H2, vemb)
        vx, vy = self.exp(sx), self.exp(sy)
        loss = self.vicreg(vx,vy)
        return loss

    def classify(self, x):
        return self.classifier(x)


num_v,vdim=X.shape
# print("num_v,vembdim",num_v,vembdim) # 2708, 1433
num_classes=7

model=HMPNN(X.shape[1],16,num_classes)


In [115]:
# @title HMPNN me H vicreg
# https://arxiv.org/pdf/2203.16995.pdf
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Vert msg = fv(vert ebd) , Sum edge msgs
# Edge msg = fw(edge emb, Sum Vert msgs)
# Vert emb1 = gv(vert emb, Sum edge msgs)
# Edge emb1 = gw(edge emb, Sum Vert msgs)

def off_diagonal(x):
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

class ff(nn.Module):
    # def __init__(self, in_dim, hid_dim, out_dim):
    def __init__(self, in_dim, out_dim):
        super(ff, self).__init__()
        h_dim=out_dim
        self.lin = nn.Sequential(
            # nn.Linear(in_dim, h_dim), nn.Sigmoid(), #nn.Dropout(p=0.5), # ReLU GELU Sigmoid
            # nn.Linear(h_dim, h_dim), nn.Sigmoid(), #nn.Dropout(p=0.5),
            # nn.Linear(h_dim, out_dim), nn.Sigmoid()

            # nn.BatchNorm1d(in_dim), nn.Linear(in_dim, out_dim), nn.Sigmoid(), #nn.Dropout(p=0.1) # me
            # nn.Linear(in_dim, out_dim), nn.Sigmoid(), nn.Dropout(p=0.5) # best nah
            nn.Linear(in_dim, out_dim), nn.Sigmoid()#, nn.Dropout(p=0.5) #
            )
    def forward(self, x):
        return self.lin(x)

class MsgPass(nn.Module):
    def __init__(self, d_model):
        super(MsgPass, self).__init__()
        self.fv = ff(2*d_model, d_model)
        self.fw = ff(2*d_model, d_model)
        self.gv = ff(2*d_model, d_model)
        self.gw = ff(2*d_model, d_model)
        # self.adjdrop = AdjDropout(0.3) # 0.7 "Adjacency dropout must be applied in neighborhood creation steps of Equations 3 through 5"
        self.adjdrop = AdjDropout(0)

    def forward(self, H, vemb, eemb, semsg=None):
        N,M = H.shape
        d_V = H.sum(1).to_dense() # node deg
        d_E = H.sum(0).to_dense() # edge deg
        d_V[d_V==0] = float('inf')
        d_E[d_E==0] = float('inf')
        D_v_invsqrt = torch.sparse.spdiags(d_V**-0.5,torch.tensor(0),(N,N))
        D_e_invsqrt = torch.sparse.spdiags(d_E**-0.5,torch.tensor(0),(M,M))

        # H = AdjDropout(0.7)(H)

        if semsg == None: semsg = vemb
        vmsg = vemb + self.fv(torch.cat((vemb, semsg), 1))
        svmsg = D_e_invsqrt @ self.adjdrop(H).T @ D_v_invsqrt @ vmsg # [num_edge, d_model]
        emsg = svmsg + self.fw(torch.cat((eemb, svmsg), 1))
        semsg = D_v_invsqrt @ self.adjdrop(H) @ D_e_invsqrt @ emsg

        vemb1 = semsg + self.gv(torch.cat((vemb, semsg), 1))
        eemb1 = svmsg + self.gw(torch.cat((eemb, svmsg), 1))
        return vemb1, eemb1, semsg

class HMPNN(nn.Module):
    def __init__(self, in_dim, d_model, out_dim):
        super(HMPNN, self).__init__()
        self.venc = nn.Linear(in_dim, d_model)
        self.eenc = nn.Linear(in_dim, d_model)
        self.msgpass = MsgPass(d_model)
        self.msgpass2 = MsgPass(d_model)
        self.msgpass3 = MsgPass(d_model)
        f=[d_model,256,256,256]
        # f=[d_model,32,32,32]
        self.exp = nn.Sequential(
            nn.Linear(f[0], f[1]), nn.BatchNorm1d(f[1]), nn.ReLU(),
            nn.Linear(f[1], f[2]), nn.BatchNorm1d(f[2]), nn.ReLU(),
            nn.Linear(f[-2], f[-1], bias=False)
            )
        self.classifier = nn.Linear(d_model, out_dim)
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p) # xavier_uniform_ xavier_normal_

    def forward(self, H, X, classify=True):
        vemb, eemb = self.venc(X), self.eenc(X)
        vemb1, eemb1, semsg = self.msgpass(H, vemb, eemb)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        vemb1, eemb1, semsg = self.msgpass2(H, vemb, eemb, semsg=semsg)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        vemb1, eemb1, semsg = self.msgpass3(H, vemb, eemb, semsg=semsg)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        if classify==False: return vemb
        else: return self.classifier(vemb)

    # https://arxiv.org/pdf/2105.04906.pdf
    def vicreg(self, x, y): # https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py
        # invariance loss
        repr_loss = F.mse_loss(x, y) # s(Z, Z')

        x = x - x.mean(dim=0)
        y = y - y.mean(dim=0)

        # variance loss
        std_x = torch.sqrt(x.var(dim=0) + 0.0001) #ϵ=0.0001
        std_y = torch.sqrt(y.var(dim=0) + 0.0001)
        std_loss = torch.mean(F.relu(1 - std_x)) / 2 + torch.mean(F.relu(1 - std_y)) / 2

        batch_size, num_features = x.shape
        sim_coeff=5.0 # 25.0 # λ
        std_coeff=10.0 # 25.0 # µ
        cov_coeff=1.0 # 1.0 # ν

        if x.dim() == 1: x = x.unsqueeze(0)
        if y.dim() == 1: y = y.unsqueeze(0)

        # # covariance loss
        cov_x = (x.T @ x) / (batch_size - 1) #C(Z)
        cov_y = (y.T @ y) / (batch_size - 1)
        cov_loss = off_diagonal(cov_x).pow_(2).sum().div(num_features)\
         + off_diagonal(cov_y).pow_(2).sum().div(num_features) #c(Z)
        loss = (sim_coeff * repr_loss + std_coeff * std_loss + cov_coeff * cov_loss)
        print("in vicreg ",(sim_coeff * repr_loss).item() , (std_coeff * std_loss).item() , (cov_coeff * cov_loss).item())
        return loss

    def loss(self, H1, H2, X):
        sx, sy = self.forward(H1, X, classify=False), self.forward(H2, X, classify=False)
        vx, vy = self.exp(sx), self.exp(sy)
        loss = self.vicreg(vx,vy)
        return loss

    def classify(self, x):
        return self.classifier(x)

num_v,in_dim=X.shape # 2708, 1433
num_classes=7
model=HMPNN(X.shape[1],16,num_classes)


In [59]:
# @title HMPNN me H
# https://arxiv.org/pdf/2203.16995.pdf
import torch
import torch.nn as nn
import torch.nn.functional as F
device = "cuda" if torch.cuda.is_available() else "cpu"
# Vert msg = fv(vert ebd) , Sum edge msgs
# Edge msg = fw(edge emb, Sum Vert msgs)
# Vert emb1 = gv(vert emb, Sum edge msgs)
# Edge emb1 = gw(edge emb, Sum Vert msgs)

class ff(nn.Module):
    # def __init__(self, in_dim, hid_dim, out_dim):
    def __init__(self, in_dim, out_dim):
        super(ff, self).__init__()
        h_dim=out_dim
        self.lin = nn.Sequential(
            # nn.Linear(in_dim, h_dim), nn.Sigmoid(), #nn.Dropout(p=0.5), # ReLU GELU Sigmoid
            # nn.Linear(h_dim, h_dim), nn.Sigmoid(), #nn.Dropout(p=0.5),
            # nn.Linear(h_dim, out_dim), nn.Sigmoid()

            # nn.BatchNorm1d(in_dim), nn.Linear(in_dim, out_dim), nn.Sigmoid(), #nn.Dropout(p=0.1) # me
            # nn.Linear(in_dim, out_dim), nn.Sigmoid(), nn.Dropout(p=0.5) # best nah
            nn.Linear(in_dim, out_dim), nn.Sigmoid()#, nn.Dropout(p=0.5) #
            )
    def forward(self, x):
        return self.lin(x)

class MsgPass(nn.Module):
    def __init__(self, d_model):
        super(MsgPass, self).__init__()
        # self.ffv = ff(d_model, d_model)
        self.fv = ff(2*d_model, d_model)
        self.fw = ff(2*d_model, d_model)
        self.gv = ff(2*d_model, d_model)
        self.gw = ff(2*d_model, d_model)
        # self.adjdrop = AdjDropout(0.3) # 0.7 "Adjacency dropout must be applied in neighborhood creation steps of Equations 3 through 5"
        self.adjdrop = AdjDropout(0)

    def forward(self, H, vemb, eemb, semsg=None):
        N,M = H.shape
        d_V = H.sum(1).to_dense() # node deg
        d_E = H.sum(0).to_dense() # edge deg
        D_v_invsqrt = torch.sparse.spdiags(d_V**-0.5,torch.tensor(0),(N,N))
        D_e_invsqrt = torch.sparse.spdiags(d_E**-0.5,torch.tensor(0),(M,M))
        # D_v_inv = torch.sparse.spdiags(d_V**-1,torch.tensor(0),(N,N))
        # D_e_inv = torch.sparse.spdiags(d_E**-1,torch.tensor(0),(M,M))

        # H = AdjDropout(0.7)(H)

        # if semsg != None: vmsg = vemb + self.fv(torch.cat((vemb, semsg), 1))
        # else: vmsg = vemb + self.ffv(vemb)

        # if semsg == None: semsg = torch.zeros(eemb.shape)
        if semsg == None: semsg = vemb
        vmsg = vemb + self.fv(torch.cat((vemb, semsg), 1))
        svmsg = D_e_invsqrt @ self.adjdrop(H).T @ D_v_invsqrt @ vmsg # [num_edge, d_model]
        emsg = svmsg + self.fw(torch.cat((eemb, svmsg), 1))
        semsg = D_v_invsqrt @ self.adjdrop(H) @ D_e_invsqrt @ emsg

        vemb1 = semsg + self.gv(torch.cat((vemb, semsg), 1))
        eemb1 = svmsg + self.gw(torch.cat((eemb, svmsg), 1))
        return vemb1, eemb1, semsg

class HMPNN(nn.Module):
    def __init__(self, in_dim, d_model, out_dim):
        super(HMPNN, self).__init__()
        self.venc = nn.Linear(in_dim, d_model)
        self.eenc = nn.Linear(in_dim, d_model)
        self.msgpass = MsgPass(d_model)
        self.msgpass2 = MsgPass(d_model)
        self.msgpass3 = MsgPass(d_model)
        self.classifier = nn.Linear(d_model, out_dim)
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p) # xavier_uniform_ xavier_normal_

    def forward(self, H, vemb):
        vemb, eemb = self.venc(vemb), self.eenc(vemb)
        # eemb = torch.zeros(len(elst),self.eembdim)
        vemb1, eemb1, semsg = self.msgpass(H, vemb, eemb)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        vemb1, eemb1, semsg = self.msgpass2(H, vemb, eemb, semsg=semsg)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        # vemb1, eemb1, semsg = self.msgpass3(H, vemb, eemb, semsg=semsg)
        # vemb, eemb = vemb+vemb1, eemb+eemb1
        return self.classifier(vemb)


num_v,in_dim=X.shape # 2708, 1433
num_classes=7
model=HMPNN(X.shape[1],16,num_classes)


In [86]:
# @title train/ eval
import torch
import torch.nn.functional as F

def train(model, optimizer, H, X, Y, train_mask):
    model.train()
    Y_hat = model(H, X)
    loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask]) # loss_fn = nn.CrossEntropyLoss()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
#     print("test acc: ",accuracy(Y_hat[train_mask].argmax(1), Y[train_mask]))
    return loss.item()

def victrain(model, optimizer, H, X):
    model.train()
    # H1, H2 = trs(H)
    H1, H2 = H,H
    loss = model.loss(H1, H2, X)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def accuracy(yhat, y): return 100*(yhat == y).type(torch.float).sum().item()/y.shape[0]

def evaluate(model, H, X, Y, val_mask, test_mask):
    model.eval()
    with torch.no_grad():
        Y_hat = model(H, X) # model(X)
    # print(Y_hat[val_mask].shape, Y[val_mask].shape)
    # print(Y_hat[val_mask], Y[val_mask])
    val_acc = accuracy(Y_hat[val_mask].argmax(1), Y[val_mask])
    test_acc = accuracy(Y_hat[test_mask].argmax(1), Y[test_mask])
    return val_acc, test_acc



In [None]:
# @title run


# model = HGNN(H, X.shape[1], num_classes) # hg conv
# model = Net(X.shape[1], num_classes) # hg att
# model = HMPNN(num_classes, vembdim, eembdim) # hg msg pass
# model = Network(in_channels=X.shape[1], hidden_channels=8, out_channels=num_classes, n_layers=2, task_level="node")
# model=HMPNN(num_classes, vembdim, eembdim, vmsgdim, emsgdim)
# model=HMPNN(X.shape[1],16,num_classes)

import time
start = time.time()

# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2, betas=(0.9, 0.999), eps=1e-08, weight_decay=3e-6) # vicreg1e-4
# coptimizer = torch.optim.AdamW(model.parameters(), lr=3e-2) # vicreg1e-3
# coptimizer = torch.optim.AdamW(model.classifier.parameters(), lr=3e-2) # vicreg1e-3
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # 0.001 # og
# optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) # 0.001 # og
# optimizer.param_groups[0]['lr']=0.01

for epoch in range(200):
    loss = train(model, optimizer, H, X, Y, train_mask)
    for _ in range(7):
        loss = victrain(model, optimizer, H, X)
    val_acc, test_acc = evaluate(model, H, X, Y, val_mask, test_mask)

    # loss = ctrain(model, optimizer, H, X, Y, train_mask)
    # val_acc, test_acc = vicevaluate(model, H, X, Y, val_mask, test_mask)
    print(f"{epoch+1} test loss: {loss:.5f}, Val acc: {val_acc:.2f}, Test acc: {test_acc:.2f}")

end = time.time()
print("time: ",end - start)


# dhg
# HGNN 200epoch 14sec test loss: 0.15217, Val acc: 0.79600, Test acc: 0.79600
# HGNN drop0 200 test loss: 0.06214, Val acc: 78.00, Test acc: 79.50

# attn lr=0.01 200epoch 27 sec test loss: 0.02389, Val acc: 0.77400, Test acc: 0.78900
# attn Ponce 200 test loss: 0.04812, Val acc: 77.20, Test acc: 76.70

# hmpnn relu 200 epoch test loss: 0.00001, Val acc: 0.53600, Test acc: 0.51900
# hmpnn sigmoid 200 epoch test loss: 0.00885, Val acc: 0.55000, Test acc: 0.52200
# 2hmpnn sigmoid 200 epoch test loss: 1.94591, Val acc: 0.31600, Test acc: 0.31900  Val acc: 0.11400, Test acc: 0.10300
# 2hmpnn 2lin sigmoid 200 epoch test loss: 1.94591, Val acc: 0.11400, Test acc: 0.10300
# 2hmpnn 2lin sigmoid noadjdrop 200 epoch test loss: 1.94591, Val acc: 0.16200, Test acc: 0.14900 Val acc: 0.05800, Test acc: 0.06400
# 2hmpnn 2lin sigmoid noadjdrop res 200 epoch test loss: 0.01424, Val acc: 0.37400, Test acc: 0.39700
# 2hmpnn 2lin sigmoid noadjdrop nodrop res 200 epoch 5m47s test loss: 0.08539, Val acc: 0.31000, Test acc: 0.36100
# 2hmpnn sigmoid noadjdrop nodrop res 200 epoch 11m6s test loss: 0.00635, Val acc: 0.55200, Test acc: 0.53700
# 3hmpnn sigmoid noadjdrop nodrop res 200 epoch test loss: 0.00638, Val acc: 0.54200, Test acc: 0.52900

# request
# HGNN 200epoch 12sec test loss: 0.15115, Val acc: 0.26200, Test acc: 0.23200
# attn 200epoch test loss: 0.00675, Val acc: 0.25000, Test acc: 0.22800
# hmpnn relu 7 epoch test loss: 0.00779, Val acc: 0.53800, Test acc: 0.47500
# hmpnn relu 200 epoch test loss: 0.00000, Val acc: 0.43400, Test acc: 0.41000
# hmpnn 2lin relu 200 epoch test loss: 0.02837, Val acc: 0.44000, Test acc: 0.38000
# hmpnn sigmoid test loss: 0.00064, Val acc: 0.53000, Test acc: 0.45100
# 2hmpnn sigmoid test loss: 1.81603, Val acc: 0.35000, Test acc: 0.29500
# 2hmpnn relu test loss: 0.00000, Val acc: 0.42800, Test acc: 0.45100
# 2hmpnn 2lin relu 200 epoch test loss: 1.81581, Val acc: 0.35000, Test acc: 0.29500
# 2hmpnn 2lin sigmoid 200 epoch test loss: 1.81585, Val acc: 0.35000, Test acc: 0.29500

# @ 3hpnn 1lin relu 400 test loss: 0.04027, Val acc: 0.28, Test acc: 0.28
# @ 3hpnn 2lin relu 200 test loss: 0.09888, Val acc: 0.34, Test acc: 0.36 400 test loss: 0.00260, Val acc: 0.42, Test acc: 0.40
# @ 3hpnn 2lin sigmoid 200 test loss: 0.45359, Val acc: 31.40, Test acc: 28.40
# @ 2hpnn 2lin sigmoid 200 test loss: 0.20277, Val acc: 32.20, Test acc: 32.00
# @ hpnn 2lin sigmoid 200 test loss: 0.09099, Val acc: 42.20, Test acc: 41.50
# @ hpnn 2lin relu 200 test loss: 0.00447, Val acc: 33.60, Test acc: 38.80
# @ hpnn lin relu 200 test loss: 0.15811, Val acc: 33.80, Test acc: 33.50; 400 test loss: 0.01975, Val acc: 36.40, Test acc: 35.00
# @ hpnn lin sigmoid 200 test loss: 0.47640, Val acc: 33.80, Test acc: 33.70; 400test loss: 0.09173, Val acc: 33.60, Test acc: 34.50
# @ hpnn 2lin relu hdim8 200 test loss: 0.00264, Val acc: 46.80, Test acc: 46.70
# @ 3hpnn 2lin relu hdim8 200 test loss: 0.00477, Val acc: 28.00, Test acc: 27.10; 400 test loss: 0.00045, Val acc: 31.00, Test acc: 29.80
# @ hpnn3 2lin relu hdim8 200 test loss: 0.00310, Val acc: 49.20, Test acc: 50.00
# @ hpnn5 2lin relu hdim8 200 test loss: 0.00337, Val acc: 46.20, Test acc: 46.70

# @ hpnn2 2lin relu hdim2 drop0.5 adjdrop0.7 200 test loss: 0.53219, Val acc: 36.00, Test acc: 38.80 ; test loss: 0.21715, Val acc: 38.20, Test acc: 40.40 ;
# @ hpnn2 2lin relu hdim2 adjdrop 200 test loss: 0.01526, Val acc: 33.20, Test acc: 32.70
# @ hpnn2 2lin relu hdim2 adjdrop drop 200 test loss: 0.36543, Val acc: 30.40, Test acc: 35.20 ; test loss: 0.18227, Val acc: 32.80, Test acc: 36.20
# @ hpnn2 2lin sig hdim2 adjdrop drop 200 test loss: 0.28160, Val acc: 31.60, Test acc: 34.60 ; test loss: 0.10077, Val acc: 33.80, Test acc: 38.00
# @ hpnn2 lin sig hdim2 adjdrop drop 200 test loss: 0.77846, Val acc: 28.20, Test acc: 30.00 ; test loss: 0.17251, Val acc: 32.20, Test acc: 33.20
# @ hpnn2 lin sig hdim16 adjdrop drop 200 test loss: 0.00632, Val acc: 50.00, Test acc: 50.70
# @ hpnn2 lin sig hdim16 drop adjdrop 200 test loss: 0.00569, Val acc: 55.20, Test acc: 57.70
# hpnn2 lin sig hdim16 drop0.5 noadjdrop eemb=eenc(vemb) res12 100 test loss: 0.77957, Val acc: 64.40 starts decreasing, Test acc: 63.20
# hpnn2 lin sig hdim16 drop0.5 adjdrop0.7123 eemb=eenc(vemb) res12 200 test loss: 0.38732, Val acc: 64.00 starts decreasing, Test acc: 62.50
# hpnn2 lin sig hdim16 drop0.5 adjdrop0.7 eemb=eenc(vemb) res12 200 test loss: 0.39054, Val acc: 60.40, Test acc: 57.50
# follow noadjdrop 200 test loss: 0.00675, Val acc: 63.60, Test acc: 62.60
# follow noadjdrop normalise 200 test loss: 0.00582, Val acc: 63.40, Test acc: 62.60
# follow linbiasF noadjdrop normalise 200 test loss: 0.00923, Val acc: 61.80, Test acc: 58.70
# res1 nope
# hpnn3 res123 follow noadjdrop normalise 800? test loss: 0.00241, Val acc: 65.20, Test acc: 63.20
# hpnn2 res12 follow noadjdrop 200 test loss: 0.00748, Val acc: 65.60, Test acc: 65.20
# hpnn2 almostcopy noadjdrop 200 test loss: 0.00376, Val acc: 72.40, Test acc: 74.80

# @ hpnn2 lin aggsig hdim16 drop adjdrop 200 test loss: 0.00634, Val acc: 54.80, Test acc: 56.70
# nores test loss: 0.10421, Val acc: 44.00, Test acc: 42.50
# @ hpnn2 lin sigagg hdim16 drop adjdrop 200 test loss: 0.00952, Val acc: 48.00, Test acc: 47.20

# @ hpnn2 lin sigagg hdim16 drop adjdrop relu vembX 100 test loss: 0.00001, Val acc: 50.00, Test acc: 50.20
# @ hpnn2 lin sigagg hdim16 drop adjdrop sig vembX 100 test loss: 0.00071, Val acc: 52.20, Test acc: 50.40

#
# @ hpnn2 lin ggBDLS ffLBDS hdim16 drop adjdrop345 100 test loss: 0.01240, Val acc: 45.60, Test acc: 44.60
# @ hpnn2 lin ggBDLS ffLBDS hdim16 drop adjdrop 100 test loss: 0.03899, Val acc: 40.60, Test acc: 42.50
# @ hpnn2 lin ggffBLSD hdim16 drop adjdrop 100 test loss: 0.00002, Val acc: 28.80, Test acc: 31.10
# @ hpnn2 lin ggffBLSD hdim16 drop0.1 adjdrop0 100 test loss: 0.00000, Val acc: 30.60, Test acc: 31.30
# @ hpnn2 lin hdim16 drop0.1 adjdrop0 100 test loss: 1.87673, Val acc: 12.80, Test acc: 12.10
# @ hpnn2 lin hdim16 drop0.1 adjdrop0.7 100 test loss: 0.92645, Val acc: 36.60, Test acc: 37.60
# @ hpnn2 lin hdim16 drop0.5 adjdrop0.7 100 test loss: 1.09457, Val acc: 42.00, Test acc: 37.10
# adj train~40

# 100 test loss: 0.00412, Val acc: 36.60, Test acc: 39.30

# 10,10,1 noadjdrop 1000 test loss: 6.19388, Val acc: 24.20, Test acc: 25.90
# 15,15,1 trs 0.3 1000 test loss: 0.00289, Val acc: 43.40, Test acc: 39.30
# 15,15,1 trs 0.5 137 test loss: 0.06801, Val acc: 35.00, Test acc: 41.60
# 15,15,1 trs 0.7 357 test loss: 0.00001, Val acc: 28.60, Test acc: 31.00
# 15,15,1 trs 0.1 154 test loss: 0.08117, Val acc: 34.40, Test acc: 35.60

# eemb=vemb 1000 test loss: 0.07657, Val acc: 50.80, Test acc: 54.90

# mha d_model2 828 test loss: 0.00944, Val acc: 27.20, Test acc: 24.90
# mha d_model2 vicreg 26 test loss: 1.94425, Val acc: 7.40, Test acc: 9.90
# mha d_model16 vicreg exp32 54 test loss: 0.00315, Val acc: 48.40, Test acc: 47.30 in vicreg  0.0 0.06320717930793762 1.4348915815353394
# mha d_model16 96 test loss: 0.03089, Val acc: 48.80, Test acc: 48.70 # 276 test loss: 0.00349, Val acc: 47.60, Test acc: 49.20

# mha d_model16 vicreg10,10,1 exp32 2708, 169, 1433
# mha d_model16 nores2 158 test loss: 0.00906, Val acc: 50.20, Test acc: 50.70
# mha d_model16 nores 219 test loss: 0.00544, Val acc: 49.60, Test acc: 49.60
# mha d_model16 nores embbiasT 112 test loss: 0.02248, Val acc: 50.20, Test acc: 48.70

# mha d_model128 nores 18 test loss: 0.45603, Val acc: 51.80, Test acc: 51.20 37 test loss: 0.00388, Val acc: 50.80, Test acc: 51.40

# mha drop0.5 adam1e-2 101 test loss: 0.02396, Val acc: 51.80, Test acc: 50.00
# mhadrop0.5 178 test loss: 1.94971, Val acc: 31.60, Test acc: 31.90

# 2lin drop0.5 adjdrop0.7123 200 test loss: 1.05092, Val acc: 17.80, Test acc: 20.70
# 2lin drop0.5 adjdrop0.7 200 test loss: 0.95048, Val acc: 33.80, Test acc: 28.60
# 2lin gelu drop0.5 adjdrop0.7 200 test loss: 1.03655, Val acc: 19.60, Test acc: 21.10
# 2lin sig drop0.5 adjdrop0.7 200 test loss: 1.43882, Val acc: 16.40 huge variation, Test acc: 17.10

# lin sig 200 drop0.5 adjdrop0.7123 test loss: 1.86198, Val acc: 56.60, Test acc: 60.90
# better than batchnorm ,adjdrop0.7, 2lin
# 3lin sig nodrop noadjdrop 600? test loss: 0.08037, Val acc: 26.00, Test acc: 27.30

# # 3lin sig nodrop noadjdrop in vicreg  0.08287973701953888 0.6470489501953125 1.871908187866211
# test acc:  100.0
# 591 test loss: 0.00912, Val acc: 15.80, Test acc: 17.20
# 5,5,1in vicreg  0.007232982665300369 0.007199370302259922 0.4727526307106018
# test acc:  100.0
# 100 test loss: 0.00160, Val acc: 20.00, Test acc: 19.40

# lin D_e_inv@emsg 100 test loss: 2.38527, Val acc: 36.80, Test acc: 33.40

# res 200 test loss: 0.00703, Val acc: 57.60, Test acc: 54.30
# resfgvw 200 test loss: 0.00358, Val acc: 72.60, Test acc: 74.80
# resfgvw encbiasT 200 test loss: 0.00470, Val acc: 76.40, Test acc: 76.10
# resfgvw encbiasT nodrop 200 test loss: 0.00243, Val acc: 77.40, Test acc: 75.40
# resfgvw semsg=zeros encbiasT nodrop 200 test loss: 0.00329, Val acc: 76.80, Test acc: 76.50
# resfgvw semsg=zeros encbiasT nodrop adjdrop0.7 nope 50+?
# resfgvw semsg=zeros encbiasT nodrop adjdrop0.712 200 test loss: 0.00292, Val acc: 71.20, Test acc: 74.30
# resfgvw semsg=zeros encbiasT nodrop adjdrop0.312 200 test loss: 0.00302, Val acc: 74.60, Test acc: 77.30
# resfgvw 2lin semsg=zeros encbiasT nodrop noadjdrop 200 test loss: 0.00020, Val acc: 74.20, Test acc: 74.80
# resfgvw hmpnn3 1lin semsg=zeros encbiasT nodrop noadjdrop 200 test loss: 0.00543, Val acc: 76.80, Test acc: 77.70 79at49epochs
# inv 200 test loss: 0.00271, Val acc: 54.00, Test acc: 54.40

# resfgvw ls semsg=vemb 200 test loss: 0.00239, Val acc: 78.00, Test acc: 75.80 again 200 test loss: 0.00183, Val acc: 76.60, Test acc: 77.10
# lbsd 200 test loss: 0.00846, Val acc: 67.00, Test acc: 65.60
# lbds 200 test loss: 0.00780, Val acc: 67.00, Test acc: 68.30
# lds 200 test loss: 0.00686, Val acc: 74.00, Test acc: 73.60
# dls 200 test loss: 0.00442, Val acc: 76.60, Test acc: 77.90
# lsd 200 test loss: 0.00501, Val acc: 76.20, Test acc: 75.00
# bls 200 test loss: 0.00167, Val acc: 67.20, Test acc: 68.10
# lsl 200 test loss: 0.00024, Val acc: 73.40, Test acc: 74.30
# lsl hdim*2 200 test loss: 0.00059, Val acc: 74.20, Test acc: 75.90
# lsls hdim*1 200 test loss: 0.00447, Val acc: 76.80, Test acc: 77.80

# xavier_uniform 200 test loss: 0.00451, Val acc: 74.00, Test acc: 74.10 200 test loss: 0.00697, Val acc: 75.80, Test acc: 77.00 200 test loss: 0.00655, Val acc: 77.20, Test acc: 77.10
# xavier_normal 200 test loss: 0.00638, Val acc: 77.20, Test acc: 77.40 200 test loss: 0.00653, Val acc: 77.60, Test acc: 78.90 200 test loss: 0.00394, Val acc: 75.80, Test acc: 76.10
# xavier_normal semsg=zeros 200 test loss: 0.00778, Val acc: 77.40, Test acc: 77.90 200 test loss: 0.00777, Val acc: 74.20, Test acc: 75.90
# xavier_normal semsg=vemb 200 test loss: 0.00669, Val acc: 78.00, Test acc: 77.30 200 test loss: 0.00502, Val acc: 75.20, Test acc: 75.70
# lr
# lg
# le
# ll-r


# 1vic10,10,1 200 test loss: 0.42472, Val acc: 70.80, Test acc: 74.00
# 1vic5,10,1 200 test loss: 0.40446, Val acc: 74.40, Test acc: 75.50 # in vicreg  0.0 0.018776636570692062 0.3856853246688843
# 7vic5,10,1 trs0.1 153 test loss: 0.00000, Val acc: 75.20, Test acc: 75.40 #in vicreg  0.0 0.0 1.8084348596403288e-07
# 7vic5,10,1 trs0.3 200 test loss: 0.00000, Val acc: 72.20, Test acc: 77.00 # in vicreg  0.0 0.0 1.730809202626915e-07
# 7vic5,10,1 trs0.5 200 test loss: 0.00000, Val acc: 71.40, Test acc: 74.10 # in vicreg  0.0 0.0 5.946303360815364e-08
# hpnn3 7vic5,10,1 trs0.5 200 test loss: 0.00000, Val acc: 78.60, Test acc: 79.90 # in vicreg  0.0 0.0 1.24092821351951e-07 # in vicreg  0.0 0.0 6.852277e-08 200 test loss: 0.00000, Val acc: 74.60, Test acc: 76.10
# hpnn3 7vic5,10,1 trs0.5 exp256 200 test loss: 6.25101, Val acc: 72.00, Test acc: 72.60 # in vicreg  0.0 5.1014814376831055 1.1495264768600464








### trash


## Hypergraph Neural Network (HGNN) Layer

The [HGNN layer](https://arxiv.org/pdf/1809.09401.pdf) is defined as:

$$f(X^{(l)}, H; W^{(l)}) = \sigma(L X^{(l)} W^{(l)})$$$$L = D_v^{-1/2} H B D_e^{-1} H^\top D_v^{-1/2}$$

where

* $H \in \mathbb{R}^{N \times M}$ is the incidence matrix of hypergraph with $N$ nodes and $M$ hyperedges.
* $D_v \in \mathbb{R}^{N \times N}$ is a diagonal matrix representing node degrees, whose $i$-th diagonal element is $\sum_{j=1}^M H_{ij}$.
* $D_e \in \mathbb{R}^{M \times M}$ is a diagonal matrix representing hyperedge degrees, whose $j$-th diagonal element is $\sum_{i=1}^N H_{ij}$.
* $B \in \mathbb{R}^{M \times M}$ is a diagonal matrix representing the hyperedge weights, whose $j$-th diagonal element is the weight of $j$-th hyperedge.  In our example, $B$ is an identity matrix.

The following code builds a two-layer HGNN.

In [None]:
# @title test
# https://colab.research.google.com/github/dmlc/dgl/blob/master/notebooks/sparse/hgnn.ipynb
# https://github.com/dmlc/dgl/blob/master/notebooks/sparse/hgnn.ipynb
# https://github.com/dmlc/dgl/blob/master/examples/sparse/hgnn.py
import torch

cite=torch.Tensor([[0, 1, 2, 2, 2, 2, 3, 4, 5, 5, 5, 5, 6, 7, 7, 8, 8, 9, 9, 10],
                    [0, 0, 0, 1, 3, 4, 2, 1, 0, 2, 3, 4, 2, 1, 3, 1, 3, 2, 4, 4]])
H = torch.sparse_coo_tensor(indices=cite, values=torch.ones(cite.shape[1]),).coalesce()
# uncoalesced tensors, may be duplicate coords in the indices; in this case, the interpretation is that the value at that index is the sum of all duplicate value entries
# vert _ is in hyperedge _
print(H.to_dense()) # cols: hyperedges ; rows: verts

# print(H) # indices = [[x1,x2,...], [y1,y2,y3,...]]
# print(H.to_sparse_csr()) # crow_indices=[row1 got ? elements, row2... , ... ] , col_indices= col idx # https://stackoverflow.com/questions/52299420/scipy-csr-matrix-understand-indptr
# print(H.to_sparse_csc()) # ccol_indices = [start count num elements in col], row_indices = row ind
# print(H.to_dense().to_sparse_bsr())
# print(H.to_sparse_bsc())

csr=H.to_sparse_csr()
# csr.crow_indices
# csr.col_indices
# import numpy as np
# ss=np.split(csr.col_indices(), csr.crow_indices())[1:-1]
ss=torch.split(csr.col_indices(), tuple(torch.diff(csr.crow_indices())))
# ss=torch.split(csr.col_indices(), torch.diff(csr.crow_indices()))
print(ss)

from torch.nn.utils.rnn import pad_sequence
pp=pad_sequence(ss, batch_first=True, padding_value=-1)
print(pp)
mask=pp<0
print(mask)
# node_degrees = H.sum(1)
# print("Node degrees", node_degrees)
# hyperedge_degrees = H.sum(0)
# print("Hyperedge degrees", hyperedge_degrees.values())


# vmsg=torch.rand(11,2)
# svmsg=torch.stack([torch.sum(vmsg[v.to_dense().to(torch.bool)],0) for v in H.T]) # given e, get all vmsgs then aggregate
# # print(svmsg)



In [None]:
# @title requests data
import requests
url = 'https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz'
# response = requests.get(url)
open("cora.tgz", "wb").write(response.content)

import tarfile # os, sys,
tar = tarfile.open('cora.tgz', 'r')
tar.extractall('/content')

import torch

content = open("cora/cora.content", "r")
# print(content.read(10000))
# paper id, bag of words bool, category 0-6 # all str
rlst = content.read().split('\n')[:-1] # bec last row is ''
pid = [] # paper id
bow = [] # bag of words
cls = [] # classes
# category: Case_Based, Genetic_Algorithms, Neural_Networks, Probabilistic_Methods, Reinforcement_Learning, Rule_Learning, Theory
category = {'Case_Based':0, 'Genetic_Algorithms':1, 'Neural_Networks':2, 'Probabilistic_Methods':3, 'Reinforcement_Learning':4, 'Rule_Learning':5, 'Theory':6} # cora
for r in rlst:
    rr=r.split('\t')
    pid.append(int(rr[0]))
    bow.append(list(map(float, rr[1:-1]))) # must be float
    cls.append(category[rr[-1]])
pid=torch.tensor(pid)
X=torch.tensor(bow)
Y=torch.tensor(cls)
num_classes=7

# https://stellargraph.readthedocs.io/en/v1.0.0rc1/demos/node-classification/gcn/gcn-cora-node-classification-example.html
# The Cora dataset consists of 2708 scientific publications
# classified into one of seven classes.
# The citation network consists of 5429 links
# Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary.
# The dictionary consists of 1433 unique words

cites = open("cora/cora.cites", "r") # cite relation
clst = cites.read().split('\n')[:-1] # bec last row is ''
cite = [] #
for c in clst:
    cc=c.split('\t')
    cite.append([int(cc[0]),int(cc[1])])
cite=torch.tensor(cite) # [5429]

ukeys = torch.unique(pid)
uvals = torch.arange(len(ukeys))
udict = dict(zip(ukeys.tolist(), uvals.tolist())) # assign new id to each paper
pid = pid.apply_(udict.get)
cite = cite.apply_(udict.get)

num_v = len(pid)
H = torch.sparse_coo_tensor(indices=cite.T, values=torch.ones(cite.shape[0]), size=(num_v, num_v)).coalesce() # size=(2708, 2708), nnz=5429, layout=torch.sparse_coo
id = torch.sparse.spdiags(torch.ones(H.shape[0]),torch.tensor(0),H.shape)
H = (id + H).coalesce() # each vert got its hyperedge, contain all cited and itself, [2708, 2708], incedence matrix, |V| hyperedges


train_mask, val_mask, test_mask = torch.zeros(3, num_v, dtype=torch.bool)
train_mask[:140], val_mask[140:640], test_mask[-1000:] = True, True, True # cora mask
# print(len(train_mask))
# print(train_mask)
# H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()

# print(train_mask, val_mask, test_mask)
# print(sum(train_mask), sum(val_mask), sum(test_mask)) # 140), (500), (1000)
# print(sum(test_mask[-1000:]))
# print(len(test_mask)) # 2708
# print(train_mask[140])
# [:140], [140:640], [-1000:]

# print(H.shape, X.shape, Y.shape) # [2708, 2708], [2708, 1433], [2708]

# @title edge/ incidence list

# edic = dict((id, [id]) for id in pid.tolist()) # edge list H(E)={e1,e2,e3}={{A,D},{D,E},{A,B,C}}
# idic = dict((id, [id]) for id in pid.tolist()) # incidence list {A:{e1,e3}, B:{e3}, C:{e3}, D:{e1,e2}, E:{e2}}
elst = [[id] for id in pid.tolist()] # edge list H(E)={e1,e2,e3}={{A,D},{D,E},{A,B,C}}
ilst = [[id] for id in pid.tolist()] # incidence list {A:{e1,e3}, B:{e3}, C:{e3}, D:{e1,e2}, E:{e2}}
for a,b in cite.tolist():
    elst[a].append(b)
    ilst[b].append(a)
elst = torch.tensor(elst)
ilst = torch.tensor(ilst)
# print(elst)
# print(ilst)


In [None]:
# @title gpt HMPNN
import torch.nn.functional as F

class HMPNNLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(HMPNNLayer, self).__init__()
        self.fc_v = nn.Linear(input_dim, output_dim)
        self.fc_w = nn.Linear(input_dim, output_dim)
        self.batch_norm = nn.BatchNorm1d(output_dim)

    def forward(self, X_v, W_e, M_v):
        # Vertex-to-Hyperedge Message Passing
        M_v = F.relu(self.fc_v(M_v))
        M_v = F.dropout(M_v, p=0.5, training=self.training)  # Adjust dropout as needed
        W_e = W_e * M_v  # Element-wise multiplication with adjacency matrix dropout
        W_e = W_e.sum(dim=1)

        # Hyperedge-to-Vertex Message Passing
        W_e = F.relu(self.fc_w(W_e))
        W_e = F.dropout(W_e, p=0.5, training=self.training)
        M_e = W_e.unsqueeze(2).repeat(1, 1, X_v.size(1))  # Repeat for all vertices in hyperedge
        M_e = M_e * X_v  # Element-wise multiplication
        M_e = M_e.sum(dim=1)

        # Aggregation and Batch Normalization
        M_e = self.batch_norm(M_e)

        return M_e

class HMPNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(HMPNN, self).__init__()
        self.layer1 = HMPNNLayer(input_dim, hidden_dim)
        self.layer2 = HMPNNLayer(hidden_dim, output_dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, X_v, W_e, M_v):
        # Forward pass through layers
        M_e = self.layer1(X_v, W_e, M_v)
        M_e = self.layer2(X_v, W_e, M_e)

        # Final activation
        output = self.sigmoid(M_e)
        return output

# Example usage
input_dim = 64  # Adjust based on your input data
hidden_dim = 32
# output_dim = 1  # Assuming binary classification
model = HMPNN(input_dim, hidden_dim, output_dim)
model=HMPNN(num_classes, vembdim, eembdim, vmsgdim, emsgdim)



import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BatchNorm1d, Dropout

class HMPNNLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout_rate):
        super(HMPNNLayer, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.batch_norm = BatchNorm1d(out_features)
        self.dropout = Dropout(p=dropout_rate)

    def forward(self, x):
        x = self.linear(x)
        x = self.batch_norm(x)
        x = F.relu(x)
        x = self.dropout(x)
        return x

class HMPNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout_rate_v, dropout_rate_e):
        super(HMPNN, self).__init__()
        self.fv = HMPNNLayer(input_size, hidden_size, dropout_rate_v)
        self.fw = HMPNNLayer(hidden_size, output_size, dropout_rate_e)

    def forward(self, vemb, H):
        vmsg = self.fv(vemb)
        aH = F.dropout(H, p=0.7, training=self.training)
        eemb = torch.matmul(aH, vmsg)
        emsg = self.fw(eemb)
        vmsg = torch.matmul(H.T, emsg)
        vemb = vemb + vmsg
        return F.sigmoid(vemb)

# Example Usage:
input_size = 64  # Input feature size for vertices
hidden_size = 32  # Hidden layer size
output_size = 1  # Output size (for binary classification, for example)
dropout_rate_v = 0.5  # Dropout rate for vertices
dropout_rate_e = 0.5  # Dropout rate for hyperedges

# Instantiate the HMPNN model
hmpnn_model = HMPNN(input_size, hidden_size, output_size, dropout_rate_v, dropout_rate_e)

# Dummy data
vemb = torch.randn((batch_size, input_size))
hyperedge_adjacency_matrix = torch.randn((batch_size, batch_size))

# Forward pass
output = hmpnn_model(vemb, hyperedge_adjacency_matrix)




In [None]:
# @title pyt-team/TopoModelX data
!pip install torch_geometric

import torch
import torch_geometric.datasets as geom_datasets
from sklearn.metrics import accuracy_score
import numpy as np


torch.manual_seed(0)
np.random.seed(0)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = geom_datasets.Planetoid(root="tmp/", name="cora")[0]

incidence_1 = torch.sparse_coo_tensor(dataset["edge_index"], torch.ones(dataset["edge_index"].shape[1]), dtype=torch.long)
dataset = dataset.to(device)

x_0s = dataset["x"]
y = dataset["y"]
# print(incidence_1.shape, x_0s.shape, y.shape) # [2708, 2708], [2708, 1433], [2708]


In [None]:
# @title pyt-team/TopoModelX hmpnn
# https://arxiv.org/pdf/2203.16995.pdf
# https://github.com/pyt-team/TopoModelX/tree/main/topomodelx/nn/hypergraph
# https://github.com/pyt-team/TopoModelX/blob/main/topomodelx/nn/hypergraph/hmpnn.py
# https://github.com/pyt-team/TopoModelX/blob/main/tutorials/hypergraph/hmpnn_train.ipynb
import torch
from torch import nn
from torch.nn import functional as F

# https://github.com/pyt-team/TopoModelX/blob/main/topomodelx/utils/scatter.py

def broadcast(src, other, dim):
    """Broadcasts `src` to the shape of `other`."""
    if dim < 0:
        dim = other.dim() + dim
    if src.dim() == 1:
        for _ in range(0, dim):
            src = src.unsqueeze(0)
    for _ in range(src.dim(), other.dim()):
        src = src.unsqueeze(-1)
    src = src.expand(other.size())
    return src


def scatter_sum(src, index, dim = -1, out = None, dim_size = None,):
    """Add all values from the `src` tensor into `out` at the indices."""
    index = broadcast(index, src, dim)
    if out is None:
        size = list(src.size())
        if dim_size is not None:
            size[dim] = dim_size
        elif index.numel() == 0:
            size[dim] = 0
        else:
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
        return out.scatter_add_(dim, index, src)
    else:
        return out.scatter_add_(dim, index, src)


def scatter_add(src, index, dim = -1, out = None, dim_size = None,):
    """Add all values from the `src` tensor into `out` at the indices."""
    return scatter_sum(src, index, dim, out, dim_size)

def scatter_mean(src, index, dim = -1, out = None, dim_size = None,):
    """Compute the mean value of all values from the `src` tensor into `out`."""
    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
    if index.dim() <= index_dim:
        index_dim = index.dim() - 1
    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
    count[count < 1] = 1
    count = broadcast(count, out, dim)
    if out.is_floating_point():
        out.true_divide_(count)
    else:
        out.div_(count, rounding_mode="floor")
    return out


SCATTER_DICT = {"sum": scatter_sum, "mean": scatter_mean, "add": scatter_sum}


def scatter(scatter: str):
    if isinstance(scatter, str) and scatter in SCATTER_DICT:
        return SCATTER_DICT[scatter]
    else:
        raise ValueError(f"scatter must be callable or string: {list(SCATTER_DICT.keys())}")


import math

# https://github.com/pyt-team/TopoModelX/blob/main/topomodelx/base/message_passing.py
class MessagePassing(torch.nn.Module):
    def __init__(self, aggr_func = "sum", att = False, initialization = "xavier_uniform", initialization_gain = 1.414,):
        # aggr_func: ["sum", "mean", "add"] = "sum",
        # initialization: ["uniform", "xavier_uniform", "xavier_normal"] = "xavier_uniform",
        super().__init__()
        self.aggr_func = aggr_func
        self.att = att
        self.initialization = initialization
        self.initialization_gain = initialization_gain

    def reset_parameters(self):
        match self.initialization:
            case "uniform":
                if self.weight is not None:
                    stdv = 1.0 / math.sqrt(self.weight.size(1))
                    self.weight.data.uniform_(-stdv, stdv)
                if self.att:
                    stdv = 1.0 / math.sqrt(self.att_weight.size(1))
                    self.att_weight.data.uniform_(-stdv, stdv)
            case "xavier_uniform":
                if self.weight is not None: torch.nn.init.xavier_uniform_(self.weight, gain=self.initialization_gain)
                if self.att: torch.nn.init.xavier_uniform_(self.att_weight.view(-1, 1), gain=self.initialization_gain)
            case "xavier_normal":
                if self.weight is not None: torch.nn.init.xavier_normal_(self.weight, gain=self.initialization_gain)
                if self.att: torch.nn.init.xavier_normal_(self.att_weight.view(-1, 1), gain=self.initialization_gain)
            case _: raise ValueError(f"Initialization {self.initialization} not recognized.")

    def message(self, x_source, x_target=None):
        return x_source

    def attention(self, x_source, x_target=None):
        x_source_per_message = x_source[self.source_index_j]
        x_target_per_message = (x_source[self.target_index_i] if x_target is None else x_target[self.target_index_i])
        x_source_target_per_message = torch.cat([x_source_per_message, x_target_per_message], dim=1)
        return torch.nn.functional.elu(torch.matmul(x_source_target_per_message, self.att_weight))

    def aggregate(self, x_message):
        aggr = scatter(self.aggr_func)
        return aggr(x_message, self.target_index_i, 0)

    def forward(self, x_source, neighborhood, x_target=None):
        neighborhood = neighborhood.coalesce()
        self.target_index_i, self.source_index_j = neighborhood.indices()
        neighborhood_values = neighborhood.values()

        x_message = self.message(x_source=x_source, x_target=x_target)
        x_message = x_message.index_select(-2, self.source_index_j)

        if self.att:
            attention_values = self.attention(x_source=x_source, x_target=x_target)
            neighborhood_values = torch.multiply(neighborhood_values, attention_values)

        x_message = neighborhood_values.view(-1, 1) * x_message
        return self.aggregate(x_message)



class _AdjacencyDropoutMixin:
    def apply_dropout(self, neighborhood, dropout_rate):
        neighborhood = neighborhood.coalesce()
        return torch.sparse_coo_tensor(neighborhood.indices(), F.dropout(neighborhood.values().to(torch.float), dropout_rate), neighborhood.size(),).coalesce()


class _NodeToHyperedgeMessenger(MessagePassing, _AdjacencyDropoutMixin):
    def __init__(self, messaging_func, adjacency_dropout = 0.7, aggr_func = "sum",):
        super().__init__(aggr_func)
        self.messaging_func = messaging_func
        self.adjacency_dropout = adjacency_dropout

    def message(self, x_source):
        return self.messaging_func(x_source)

    def forward(self, x_source, neighborhood):
        neighborhood = self.apply_dropout(neighborhood, self.adjacency_dropout)
        source_index_j, self.target_index_i = neighborhood.indices()
        x_message = self.message(x_source)
        x_message_aggregated = self.aggregate(x_message.index_select(-2, source_index_j))
        return x_message_aggregated, x_message


class _HyperedgeToNodeMessenger(MessagePassing, _AdjacencyDropoutMixin):
    def __init__(self, messaging_func, adjacency_dropout = 0.7, aggr_func = "sum",):
        super().__init__(aggr_func)
        self.messaging_func = messaging_func
        self.adjacency_dropout = adjacency_dropout

    def message(self, x_source, neighborhood, node_messages):
        hyperedge_neighborhood = self.apply_dropout(neighborhood, self.adjacency_dropout)
        source_index_j, target_index_i = hyperedge_neighborhood.indices()
        node_messages_aggregated = scatter(self.aggr_func)(node_messages.index_select(-2, source_index_j), target_index_i, 0)
        return self.messaging_func(x_source, node_messages_aggregated)

    def forward(self, x_source, neighborhood, node_messages):
        x_message = self.message(x_source, neighborhood, node_messages)
        neighborhood = self.apply_dropout(neighborhood, self.adjacency_dropout)
        self.target_index_i, source_index_j = neighborhood.indices()
        x_message_aggregated = self.aggregate(x_message.index_select(-2, source_index_j))
        return x_message_aggregated

class _DefaultHyperedgeToNodeMessagingFunc(nn.Module):
    def __init__(self, in_channels):
        super().__init__() self.linear = nn.Linear(2 * in_channels, in_channels)
    def forward(self, x_1, m_0): return F.sigmoid(self.linear(torch.cat((x_1, m_0), dim=1)))

class _DefaultUpdatingFunc(nn.Module):
    def __init__(self, in_channels): super().__init__()
    def forward(self, x, m): return F.sigmoid(x + m)


class HMPNNLayer(nn.Module):
    def __init__(self, in_channels, node_to_hyperedge_messaging_func=None, hyperedge_to_node_messaging_func=None, adjacency_dropout = 0.7, aggr_func = "sum", updating_dropout = 0.5, updating_func=None,):
        super().__init__()
        if node_to_hyperedge_messaging_func is None:
            node_to_hyperedge_messaging_func = nn.Sequential(nn.Linear(in_channels, in_channels), nn.Sigmoid())
        self.node_to_hyperedge_messenger = _NodeToHyperedgeMessenger(node_to_hyperedge_messaging_func, adjacency_dropout, aggr_func)
        if hyperedge_to_node_messaging_func is None:
            hyperedge_to_node_messaging_func = _DefaultHyperedgeToNodeMessagingFunc(in_channels)
        self.hyperedge_to_node_messenger = _HyperedgeToNodeMessenger(hyperedge_to_node_messaging_func, adjacency_dropout, aggr_func)
        self.node_batchnorm = nn.BatchNorm1d(in_channels)
        self.hyperedge_batchnorm = nn.BatchNorm1d(in_channels)
        self.dropout = torch.distributions.Bernoulli(updating_dropout)

        if updating_func is None:
            updating_func = _DefaultUpdatingFunc(in_channels)
        self.updating_func = updating_func

    def apply_regular_dropout(self, x):
        if self.training:
            mask = self.dropout.sample(x.shape).to(dtype=torch.float, device=x.device)
            d = x.size(0)
            x *= mask * (2 * d - mask.sum(dim=1)).view(-1, 1) / d
        return x

    def forward(self, x_0, x_1, incidence_1):
        node_messages_aggregated, node_messages = self.node_to_hyperedge_messenger( x_0, incidence_1)
        hyperedge_messages_aggregated = self.hyperedge_to_node_messenger(x_1, incidence_1, node_messages)
        x_0 = self.updating_func(self.apply_regular_dropout(self.node_batchnorm(x_0)), hyperedge_messages_aggregated,)
        x_1 = self.updating_func(self.apply_regular_dropout(self.hyperedge_batchnorm(x_1)), node_messages_aggregated,)
        return x_0, x_1


class HMPNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, n_layers=2, adjacency_dropout_rate=0.7, regular_dropout_rate=0.5,):
        super().__init__()
        self.linear_node = torch.nn.Linear(in_channels, hidden_channels)
        self.linear_edge = torch.nn.Linear(in_channels, hidden_channels)
        self.layers = torch.nn.ModuleList([HMPNNLayer(hidden_channels, adjacency_dropout=adjacency_dropout_rate, updating_dropout=regular_dropout_rate,) for _ in range(n_layers)])

    def forward(self, x_0, x_1, incidence_1):
        x_0 = self.linear_node(x_0)
        x_1 = self.linear_edge(x_1)
        for layer in self.layers:
            x_0, x_1 = layer(x_0, x_1, incidence_1)
        return x_0, x_1


class Network(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, task_level="graph", **kwargs): # task_level: "graph" or "node".
        super().__init__()
        self.base_model = HMPNN(in_channels=in_channels, hidden_channels=hidden_channels, **kwargs)
        self.linear = torch.nn.Linear(hidden_channels, out_channels)
        self.out_pool = True if task_level == "graph" else False

    # def forward(self, x_0, x_1, incidence_1):
    def forward(self, incidence_1, x_0):
        x_1 = torch.zeros_like(x_0)

        x_0, x_1 = self.base_model(x_0, x_1, incidence_1)
        if self.out_pool is True: x = torch.max(x_0, dim=0)[0]
        else: x = x_0
        return self.linear(x)

# Base model hyperparameters
in_channels = x_0s.shape[1]
hidden_channels = 128
n_layers = 1

# Readout hyperparameters
out_channels = torch.unique(y).shape[0]
task_level = "graph" if out_channels == 1 else "node"

# model = Network(in_channels=in_channels, hidden_channels=hidden_channels, out_channels=out_channels, n_layers=n_layers, task_level=task_level,).to(device)
# print(in_channels, hidden_channels, out_channels, n_layers, task_level) # 1433, 128, 7, 1, node



In [None]:
# @title pyt-team/TopoModelX run

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

train_mask = dataset["train_mask"]
val_mask = dataset["val_mask"]
test_mask = dataset["test_mask"]

torch.manual_seed(0)


initial_x_1 = torch.zeros_like(x_0s)
for epoch in range(2):
    model.train()
    optimizer.zero_grad()
    # print(x_0s.shape, initial_x_1.shape, incidence_1.shape)
    # print(x_0s, initial_x_1, incidence_1) # 0s? [2708, 1433], 0s [2708, 1433], sparse coo [2708, 2708]
    # y_hat = model(x_0s, initial_x_1, incidence_1)
    y_hat = model(incidence_1, x_0s)
    loss = loss_fn(y_hat[train_mask], y[train_mask])
    loss.backward()
    optimizer.step()

    train_loss = loss.item()
    y_pred = y_hat.argmax(dim=-1)
    train_acc = accuracy_score(y[train_mask].cpu(), y_pred[train_mask].cpu())
    # print((y[train_mask]==y_pred[train_mask]).sum()/len(y[train_mask]))
    # train_acc = accuracy(y[train_mask], y_pred[train_mask])

    model.eval()
    # y_hat = model(x_0s, initial_x_1, incidence_1)
    y_hat = model(incidence_1, x_0s)
    val_loss = loss_fn(y_hat[val_mask], y[val_mask]).item()
    y_pred = y_hat.argmax(dim=-1)
    # val_acc = accuracy_score(y[val_mask].cpu(), y_pred[val_mask].cpu())

    test_loss = loss_fn(y_hat[test_mask], y[test_mask]).item()
    y_pred = y_hat.argmax(dim=-1)
    test_acc = accuracy_score(y[test_mask].cpu(), y_pred[test_mask].cpu())
    # test_acc = accuracy(y[test_mask], y_pred[test_mask])
    print(f"{epoch + 1} train loss: {train_loss:.4f} test loss: {test_loss:.4f} test acc: {test_acc:.2f}") # val loss: {val_loss:.4f} val acc: {val_acc:.2f}


In [None]:
# @title HMPNN me H
# https://arxiv.org/pdf/2203.16995.pdf
import torch
import torch.nn as nn
import torch.nn.functional as F
# Vert msg = fv(vert ebd) , Sum edge msgs
# Edge msg = fw(edge emb, Sum Vert msgs)
# Vert emb1 = gv(vert emb, Sum edge msgs)
# Edge emb1 = gw(edge emb, Sum Vert msgs)

class MsgPass(nn.Module):
    def __init__(self, vembdim, eembdim, vmsgdim, emsgdim):
    def __init__(self, in_dim, hid_dim, out_dim):
        super(MsgPass, self).__init__()
        self.h_dim = 16
        self.fv = nn.Sequential(
            # nn.Linear(vembdim, self.h_dim), nn.ReLU(),
            # nn.Linear(vembdim, self.h_dim), nn.Sigmoid(),
            # nn.Linear(self.h_dim, vmsgdim),
            # nn.Linear(vembdim, vmsgdim), nn.ReLU(),
            nn.Linear(vembdim, vmsgdim), #nn.Sigmoid(),
            )
        self.fw = nn.Sequential(
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.ReLU(),
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.Sigmoid(),
            # nn.Linear(self.h_dim, emsgdim),
            # nn.Linear(eembdim+vmsgdim, emsgdim), nn.ReLU(),
            nn.Linear(eembdim+vmsgdim, emsgdim), #nn.Sigmoid(),
            )
        self.gv = nn.Sequential(
            # nn.Linear(vembdim+emsgdim, self.h_dim), nn.ReLU(),
            # nn.Linear(vembdim+emsgdim, self.h_dim), nn.Sigmoid(),
            # nn.Linear(self.h_dim, vembdim),
            # nn.Linear(vembdim+emsgdim, vembdim), nn.ReLU(),
            nn.Linear(vembdim+emsgdim, vembdim), #nn.Sigmoid(),
            )
        self.gw = nn.Sequential(
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.ReLU(),
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.Sigmoid(),
            # nn.Linear(self.h_dim, eembdim),
            # nn.Linear(eembdim+vmsgdim, eembdim), nn.ReLU(),
            nn.Linear(eembdim+vmsgdim, eembdim), #nn.Sigmoid(),
            )
        # self.vmsgdim = vmsgdim
        # self.emsgdim = emsgdim
        self.drop = nn.Dropout(0.5)
        self.adjdrop = AdjDropout(0.7)
        self.sig = nn.Sigmoid()

    def forward(self, H, vemb, eemb, emsg=None):
        vemb, eemb = self.drop(vemb), self.drop(eemb)
        vmsg = self.fv(vemb)
        # vmsg = self.drop(vmsg)
        # print("vemb, eemb",vemb.shape, eemb.shape) # [2708, 2], [2708, 2]
        H = self.adjdrop(H)
        HT = H.T
        vmsg = self.sig(vmsg)
        svmsg = HT @ vmsg # sum aggregate
        # svmsg = vmsg @ H
        # svmsg = self.sig(svmsg)
        # print("vmsg, svmsg",vmsg.shape, svmsg.shape) # [2708, 2], [2708, 2]
        emsg = self.fw(torch.cat((eemb, svmsg), 1))
        # emsg = self.drop(emsg)
        emsg = self.sig(emsg)
        semsg = H @ emsg
        # semsg = self.sig(semsg)
        vemb1 = self.gv(torch.cat((vemb, semsg), 1))
        # vemb1 = self.drop(vemb1)
        vmsg = self.sig(vmsg)
        svmsg = HT @ vmsg
        # svmsg = self.sig(svmsg)
        eemb1 = self.gw(torch.cat((eemb, svmsg), 1))
        # eemb1 = self.drop(eemb1)


        # vmsg = self.fv(vemb)
        # vmsg = D_v_invsqrt @ vmsg # outgoing node msg is node features mul by inv sqrt of their deg, i.e. Dv^-1/2 X(l)
        # svmsg = self.adjdrop(H).T @ vmsg # sum aggregate
        # # svmsg = self.adjdrop(H).T @ D_v_invsqrt @ vmsg # node aggregation function is sum of input multiplied by the inverse square root of their degree, i.e. Dv^-1/2 H
        # # emsg = self.fw(torch.cat((eemb, svmsg), 1))
        # emsg=eemb
        # semsg = self.adjdrop(H) @ emsg
        # semsg = D_e_inv @ semsg # hyperedge aggregation is the average, i.e. De^-1 HT
        # vemb1 = self.gv(torch.cat((vemb, semsg), 1)) # node updating function is σ(XΘ(l))
        # vemb1 = D_v_invsqrt @ vemb1 #
        # svmsg = self.adjdrop(H).T @ vmsg
        # eemb1 = self.gw(torch.cat((eemb, svmsg), 1))

        return vemb1, eemb1, emsg

class AdjDropout(nn.Module):
    def __init__(self, p=0.7):
        super(AdjDropout, self).__init__()
    def forward(self, H):
        mask = (torch.rand(n_e) >= p).float().expand(n_v,n_e) # 1->keep, throw p
        return H*mask

class HMPNN(nn.Module):
    def __init__(self, outdim, vembdim, eembdim, vmsgdim, emsgdim):
        super(HMPNN, self).__init__()
        # self.msgpass = MsgPass(vembdim, eembdim, vmsgdim, emsgdim)
        self.msgpass = MsgPass(X.size(1), eembdim, vmsgdim, emsgdim)
        self.msgpass2 = MsgPass(vembdim, eembdim, vmsgdim, emsgdim)
        self.msgpass3 = MsgPass(vembdim, eembdim, vmsgdim, emsgdim)
        self.lin = nn.Linear(vembdim, outdim)
        # vert 1/0 emb provided
        # self.ve = nn.Embedding(vdim, vembdim) # turn vects from 1/0 vect to vect emb
        # self.ee = nn.Embedding(edim, eembdim)
        self.eemb = None
        # create edge vect emb
        self.eembdim = eembdim
        # self.ve = nn.Linear(X.size(-1), vembdim, bias=False)

    def forward(self, H, X):
        # vemb = self.ve(X)
        # print("vemb",vemb.shape)
        eemb = torch.zeros(len(elst),self.eembdim)
        # eemb = self.ee(eemb)
        vemb1, eemb1, emsg = self.msgpass(H, vemb, eemb)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        vemb1, eemb1, emsg = self.msgpass2(H, vemb, eemb, emsg=emsg)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        # vemb1, eemb1, emsg = self.msgpass3(H, vemb, eemb, emsg=emsg)
        # vemb, eemb = vemb+vemb1, eemb+eemb1
        x = self.lin(vemb)
        return x

num_v,vdim=X.shape
# print("num_v,vembdim",num_v,vembdim) # 2708, 1433
# vembdim, eembdim = 2, 2
# vmsgdim, emsgdim = 2, 2
vembdim=eembdim=vmsgdim=emsgdim=16

num_classes=7
model=HMPNN(num_classes, vembdim, eembdim, vmsgdim, emsgdim)
# yhat = model(X, elst, ilst)
# print(H.shape)
yhat = model(H, X)
print(yhat.shape) # [2708, 7]

# Implementation Details Our model uses two layers of HMPNN with sigmoid
# activation and a hidden representation of size 2. We use sum as the message
# aggregation functions, with adjacency matrix dropout with rate 0.7, as well as
# dropout with rate 0.5 for vertex and hyperedge representation.

# print(len(X[0]))

torch.Size([2708, 7])


In [None]:
# @title HMPNN elst, ilst
import torch
import torch.nn as nn
import torch.nn.functional as F
# Vert msg = fv(vert ebd)
# Edge msg = fw(edge emb, Sum Vert msgs)
# Vert emb1 = gv(vert emb, Sum edge msgs)
# Edge emb1 = gw(edge emb, Sum Vert msgs)

class MsgPass(nn.Module):
    def __init__(self, vembdim, eembdim, vmsgdim, emsgdim):
        super(MsgPass, self).__init__()
        self.h_dim = 4
        self.fv = nn.Sequential(
            # nn.Linear(vembdim, self.h_dim), nn.ReLU(),
            # nn.Linear(vembdim, self.h_dim), nn.Sigmoid(),
            # nn.Linear(self.h_dim, vmsgdim),
            # nn.Linear(vembdim, vmsgdim), nn.ReLU(),
            nn.Linear(vembdim, vmsgdim), nn.Sigmoid(),
            )
        self.fw = nn.Sequential(
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.ReLU(),
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.Sigmoid(),
            # nn.Linear(self.h_dim, emsgdim),
            # nn.Linear(eembdim+vmsgdim, emsgdim), nn.ReLU(),
            nn.Linear(eembdim+vmsgdim, emsgdim), nn.Sigmoid(),
            )
        self.gv = nn.Sequential(
            # nn.Linear(vembdim+emsgdim, self.h_dim), nn.ReLU(),
            # nn.Linear(vembdim+emsgdim, self.h_dim), nn.Sigmoid(),
            # nn.Linear(self.h_dim, vembdim),
            # nn.Linear(vembdim+emsgdim, vembdim), nn.ReLU(),
            nn.Linear(vembdim+emsgdim, vembdim), nn.Sigmoid(),
            )
        self.gw = nn.Sequential(
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.ReLU(),
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.Sigmoid(),
            # nn.Linear(self.h_dim, eembdim),
            # nn.Linear(eembdim+vmsgdim, eembdim), nn.ReLU(),
            nn.Linear(eembdim+vmsgdim, eembdim), nn.Sigmoid(),
            )
        self.vmsgdim = vmsgdim
        self.emsgdim = emsgdim

    def forward(self, vemb, eemb, elst, ilst):
        # semsg=torch.stack([torch.mean(emsg[e],0) for e in ilst]) # given v, get all emsgs then aggregate
        vmsg = self.fv(vemb)
        # vmsg = self.fv(torch.cat((vemb, semsg), 1))
        # vmsg = F.dropout(F.batch_norm(vmsg,torch.zeros(self.vmsgdim),torch.ones(self.vmsgdim)),p=0.5)

        mvemb, meemb, melst, milst = vemb, eemb, elst, ilst
        # _, meemb, melst, milst = adjdrop(vemb, eemb, elst, ilst, p=0.7)

        svmsg=torch.stack([torch.sum(vmsg[v],0) for v in melst]) # given e, get all vmsgs then aggregate
        memsg = self.fw(torch.cat((meemb, svmsg), 1))
        # memsg = F.dropout(F.batch_norm(memsg,torch.zeros(self.emsgdim),torch.ones(self.emsgdim)),p=0.5)

        semsg=torch.stack([torch.sum(memsg[e],0) for e in milst]) # given v, get all emsgs then aggregate # cannot be mean bec vert in ilst may be isolated, divide by 0 hyperedges
        vemb1 = self.gv(torch.cat((vemb, semsg), 1))

        svmsg=torch.stack([torch.sum(vmsg[v],0) for v in elst]) # given e, get all vmsgs then aggregate
        eemb1 = self.gw(torch.cat((eemb, svmsg), 1))
        return vemb1, eemb1

# Vert msg = fv(vert ebd)
# Edge msg = fw(edge emb, Sum Vert msgs)
# Vert emb1 = gv(vert emb, Sum edge msgs)
# Edge emb1 = gw(edge emb, Sum Vert msgs)

    # def forward(self, vemb, eemb, elst, ilst):
    #     vmsg = self.fv(vemb)
    #     return vmsg

    # def forward(self, vemb, eemb, elst, ilst):
    #     svmsg=torch.stack([torch.sum(vmsg[v],0) for v in elst]) # given e, get all vmsgs then aggregate
    #     eemb1 = self.gw(torch.cat((eemb, svmsg), 1))

    #     svmsg=torch.stack([torch.sum(vmsg[v],0) for v in elst]) # given e, get all vmsgs then aggregate
    #     emsg = self.fw(torch.cat((meemb, svmsg), 1))

    #     semsg=torch.stack([torch.sum(emsg[e],0) for e in ilst]) # given v, get all emsgs then aggregate # cannot be mean bec vert in ilst may be isolated, divide by 0 hyperedges
    #     vemb1 = self.gv(torch.cat((vemb, semsg), 1))

    #     return vemb1, eemb1

def ilst_from_elst(elst, n_v=len(ilst)): # generate incidence list from edge list
    ilst = [[] for id in range(n_v)]
    for e,vs in enumerate(elst):
        [ilst[v].append(e) for v in vs]
    return ilst

def adjdrop(vemb, eemb, elst, ilst, p=0.7): # adjacency dropout, maybe can replace with slicing of sparse tensors if pytorch implements it
    mask = torch.rand(len(elst)) >= p # True->keep, throw p
    melst = [e for e, m in zip(elst, mask) if m]
    meemb = eemb[mask==True]
    milst = ilst_from_elst(melst)
    return vemb, meemb, melst, milst


class HMPNN(nn.Module):
    def __init__(self, outdim, vembdim, eembdim):
        super(HMPNN, self).__init__()
        self.msgpass = MsgPass(vembdim, eembdim, vmsgdim=2, emsgdim=2)
        self.lin = nn.Linear(vembdim, outdim)

    def forward(self, x, elst=elst, ilst=ilst):
        vemb = x
        eemb = torch.zeros(len(elst),eembdim)
        # vemb, eemb = self.msgpass(vemb, eemb)
        # vemb, eemb = self.msgpass(vemb, eemb, elst=elst, ilst=ilst)
        vemb1, eemb1 = self.msgpass(vemb, eemb, elst=elst, ilst=ilst)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        vemb1, eemb1 = self.msgpass(vemb, eemb, elst=elst, ilst=ilst)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        vemb1, eemb1 = self.msgpass(vemb, eemb, elst=elst, ilst=ilst)
        vemb, eemb = vemb+vemb1, eemb+eemb1
        x = self.lin(vemb)
        return x


def trainl(model, optimizer, elst, ilst, X, Y, train_mask):
    model.train()
    Y_hat = model(X,elst, ilst)
    l,r=0,4
    # print(Y_hat[train_mask][l:r], Y[train_mask][l:r])
    loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask]) # loss_fn = nn.CrossEntropyLoss()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluatel(model, elst, ilst, X, Y, val_mask, test_mask):
    model.eval()
    Y_hat = model(X,elst, ilst) # model(X)
    val_acc = accuracy(Y_hat[val_mask].argmax(1), Y[val_mask])
    test_acc = accuracy(Y_hat[test_mask].argmax(1), Y[test_mask])
    return val_acc, test_acc

# val_acc, test_acc = evaluatel(model, elst, ilst, X, Y, val_mask, test_mask)
# val_acc, test_acc = evaluatel(model, elst, ilst, X, Y, val_mask, train_mask)
# print(val_acc, test_acc)



num_v,vembdim=X.shape
eembdim=2
num_classes=7
model=HMPNN(num_classes, vembdim, eembdim)
yhat = model(X, elst, ilst)
# print(yhat.shape) # [2708, 7]

# print(len(X[0]))

In [None]:
# @title HMPNN H
import torch
import torch.nn as nn
# Vert msg = fv(vert ebd)
# Edge msg = fw(edge emb, Sum Vert msgs)
# Vert emb1 = gv(vert emb, Sum edge msgs)
# Edge emb1 = gw(edge emb, Sum Vert msgs)

class MsgPass(nn.Module):
    def __init__(self, vembdim, eembdim, vmsgdim, emsgdim):
        super(MsgPass, self).__init__()
        self.h_dim = 4
        self.fv = nn.Sequential(
            # nn.Linear(vembdim, self.h_dim), nn.ReLU(),
            nn.Linear(vembdim, self.h_dim), nn.Sigmoid(),
            nn.Linear(self.h_dim, vmsgdim),
            # nn.Linear(vembdim, vmsgdim), nn.ReLU(),
            # nn.Linear(vembdim, vmsgdim), nn.Sigmoid(),
            )
        self.fw = nn.Sequential(
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.ReLU(),
            nn.Linear(eembdim+vmsgdim, self.h_dim), nn.Sigmoid(),
            nn.Linear(self.h_dim, emsgdim),
            # nn.Linear(eembdim+vmsgdim, emsgdim), nn.ReLU(),
            # nn.Linear(eembdim+vmsgdim, emsgdim), nn.Sigmoid(),
            )
        self.gv = nn.Sequential(
            # nn.Linear(vembdim+emsgdim, self.h_dim), nn.ReLU(),
            nn.Linear(vembdim+emsgdim, self.h_dim), nn.Sigmoid(),
            nn.Linear(self.h_dim, vembdim),
            # nn.Linear(vembdim+emsgdim, vembdim), nn.ReLU(),
            # nn.Linear(vembdim+emsgdim, vembdim), nn.Sigmoid(),
            )
        self.gw = nn.Sequential(
            # nn.Linear(eembdim+vmsgdim, self.h_dim), nn.ReLU(),
            nn.Linear(eembdim+vmsgdim, self.h_dim), nn.Sigmoid(),
            nn.Linear(self.h_dim, eembdim),
            # nn.Linear(eembdim+vmsgdim, eembdim), nn.ReLU(),
            # nn.Linear(eembdim+vmsgdim, eembdim), nn.Sigmoid(),
            )
        self.vmsgdim = vmsgdim
        self.emsgdim = emsgdim

    def forward(self, H, vemb, eemb):
        # semsg=torch.stack([torch.mean(emsg[e],0) for e in ilst]) # given v, get all emsgs then aggregate
        vmsg = self.fv(vemb)
        # vmsg = self.fv(torch.cat((vemb, semsg), 1))
        # vmsg = F.dropout(F.batch_norm(vmsg,torch.zeros(self.vmsgdim),torch.ones(self.vmsgdim)),p=0.5)

        # mvemb, meemb, melst, milst = vemb, eemb, elst, ilst
        # _, meemb, melst, milst = adjdrop(vemb, eemb, elst, ilst, p=0.7)


        # svmsg=torch.stack([torch.sum(vmsg[v],0) for v in melst]) # given e, get all vmsgs then aggregate
        svmsg=torch.stack([torch.sum(vmsg[v.to_dense().to(torch.bool)],0) for v in H.T]) # given e, get all vmsgs then aggregate
        memsg = self.fw(torch.cat((meemb, svmsg), 1))
        memsg = F.dropout(F.batch_norm(memsg,torch.zeros(self.emsgdim),torch.ones(self.emsgdim)),p=0.5)

        # cannot be mean bec vert in ilst may be isolated, divide by 0 hyperedges
        # semsg=torch.stack([torch.sum(memsg[e],0) for e in milst]) # given v, get all emsgs then aggregate
        semsg=torch.stack([torch.sum(memsg[e.to_dense().to(torch.bool)],0) for e in H]) # given e, get all vmsgs then aggregate
        vemb1 = self.gv(torch.cat((vemb, semsg), 1))

        # svmsg=torch.stack([torch.sum(vmsg[v],0) for v in elst]) # given e, get all vmsgs then aggregate
        svmsg=torch.stack([torch.sum(vmsg[v.to_dense().to(torch.bool)],0) for v in H.T]) # given e, get all vmsgs then aggregate
        eemb1 = self.gw(torch.cat((eemb, svmsg), 1))
        return vemb1, eemb1

# def ilst_from_elst(elst, n_v=len(ilst)): # generate incidence list from edge list
#     ilst = [[] for id in range(n_v)]
#     for e,vs in enumerate(elst):
#         [ilst[v].append(e) for v in vs]
#     return ilst

def adjdrop(vemb, eemb, elst, ilst, p=0.7): # adjacency dropout, maybe can replace with slicing of sparse tensors if pytorch implements it
    mask = torch.rand(len(elst)) >= p # True->keep, throw p
    melst = [e for e, m in zip(elst, mask) if m]
    meemb = eemb[mask==True]
    milst = ilst_from_elst(melst)
    return vemb, meemb, melst, milst


class HMPNN(nn.Module):
    def __init__(self, outdim, vembdim, eembdim):
        super(HMPNN, self).__init__()
        self.msgpass = MsgPass(vembdim, eembdim, vmsgdim=2, emsgdim=2)
        self.lin = nn.Linear(vembdim, outdim)

    def forward(self, H, X):
        vemb = x
        eemb = torch.zeros(len(elst),eembdim)
        # vemb, eemb = self.msgpass(vemb, eemb)
        vemb, eemb = self.msgpass(vemb, eemb, elst=elst, ilst=ilst)
        vemb, eemb = self.msgpass(vemb, eemb, elst=elst, ilst=ilst)
        x = self.lin(vemb)
        return x

num_v,vembdim=X.shape
eembdim=2
num_classes=7
model=HMPNN(num_classes, vembdim, eembdim)
# yhat = model(X, elst, ilst)
Y_hat = model(H, X)
# print(yhat.shape) # [2708, 7]

# print(len(X[0]))

In [None]:
# @title old og hg attn
Hypergraph Convolution and Hypergraph Attention
(https://arxiv.org/pdf/1901.08150.pdf).
import argparse
import dgl.sparse as dglsp
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from dgl.data import CoraGraphDataset
def accuracy(yhat, y): return (yhat.argmax(1) == y).type(torch.float).sum().item()/y.shape[0]


def hypergraph_laplacian(H):
    ###########################################################
    # (HIGHLIGHT) Compute the Laplacian with Sparse Matrix API
    ###########################################################
    d_V = H.sum(1)  # node degree
    d_E = H.sum(0)  # edge degree
    n_edges = d_E.shape[0]
    D_V_invsqrt = dglsp.diag(d_V**-0.5)  # D_V ** (-1/2)
    D_E_inv = dglsp.diag(d_E**-1)  # D_E ** (-1)
    W = dglsp.identity((n_edges, n_edges))
    return D_V_invsqrt @ H @ W @ D_E_inv @ H.T @ D_V_invsqrt


class HypergraphAttention(nn.Module):
    """Hypergraph Attention module as in the paper
    `Hypergraph Convolution and Hypergraph Attention
    <https://arxiv.org/pdf/1901.08150.pdf>`_.
    """

    def __init__(self, in_size, out_size):
        super().__init__()

        self.P = nn.Linear(in_size, out_size)
        self.a = nn.Linear(2 * out_size, 1)

    def forward(self, H, X, X_edges):
        Z = self.P(X)
        Z_edges = self.P(X_edges)
        # print("H",H.shape) # 2708, 2708
        # print("H.row,H.col",H.row.shape,H.col.shape) # H.row,H.col tensor([   0,    0,    0,  ..., 2707, 2707, 2707]) tensor([   0,  633, 1862,  ..., 1473, 2706, 2707]) # [13264], [13264]
        # print("Z[H.row], Z_edges[H.col]",Z[H.row], Z_edges[H.col].shape) # [13264, 16], [13264, 16]
        print(Z,Z_edges.shape) # [2708, 16], [2708, 16]

        sim = self.a(torch.cat([Z[H.row], Z_edges[H.col]], 1))
        sim = F.leaky_relu(sim, 0.2).squeeze(1)
        # Reassign the hypergraph new weights.
        H_att = dglsp.val_like(H, sim)
        H_att = H_att.softmax()
        return hypergraph_laplacian(H_att) @ Z


class Net(nn.Module):
    def __init__(self, in_size, out_size, hidden_size=16):
        super().__init__()
        self.layer1 = HypergraphAttention(in_size, hidden_size)
        self.layer2 = HypergraphAttention(hidden_size, out_size)

    def forward(self, H, X):
        Z = self.layer1(H, X, X)
        Z = F.elu(Z)
        Z = self.layer2(H, Z, Z)
        return Z


def train(model, optimizer, H, X, Y, train_mask):
    model.train()
    Y_hat = model(H, X)
    loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()


def evaluate(model, H, X, Y, val_mask, test_mask, num_classes):
    model.eval()
    Y_hat = model(H, X)
    val_acc = accuracy(Y_hat[val_mask], Y[val_mask])
    test_acc = accuracy(Y_hat[test_mask], Y[test_mask])
    return val_acc, test_acc


def load_data():
    dataset = CoraGraphDataset()
    graph = dataset[0]
    indices = torch.stack(graph.edges())
    H = dglsp.spmatrix(indices)
    H = H + dglsp.identity(H.shape)
    X = graph.ndata["feat"]
    Y = graph.ndata["label"]
    train_mask = graph.ndata["train_mask"]
    val_mask = graph.ndata["val_mask"]
    test_mask = graph.ndata["test_mask"]
    return H, X, Y, dataset.num_classes, train_mask, val_mask, test_mask


H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()
model = Net(X.shape[1], num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

with tqdm.trange(2) as tq:
    for epoch in tq:
        loss = train(model, optimizer, H, X, Y, train_mask)
        val_acc, test_acc = evaluate(
            model, H, X, Y, val_mask, test_mask, num_classes
        )
        tq.set_postfix(
            {
                "Loss": f"{loss:.5f}",
                "Val acc": f"{val_acc:.5f}",
                "Test acc": f"{test_acc:.5f}",
            },
            refresh=False,
        )

print(f"Test acc: {test_acc:.3f}")

