In [1]:
# from .peptides_functional.pyg.peptides_functional import *
# from experiments.peptides_structural.pyg.peptides_structural import *

In [1]:
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

from graph_hscn.loader.dataset.peptides_functional import PeptidesFunctionalDataset
from graph_hscn.loader.dataset.peptides_structural import PeptidesStructuralDataset

In [2]:
"""Spectral Clustering GNN layer definition."""
import os.path as osp
import torch
from torch.nn import Linear

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GraphConv, dense_mincut_pool
from torch_geometric import utils
from torch_geometric.nn import Sequential
from torch_geometric.nn.conv.gcn_conv import gcn_norm

from sklearn.metrics import normalized_mutual_info_score as NMI


import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter


class Net(torch.nn.Module):
    def __init__(self, 
                 mp_units,
                 mp_act,
                 in_channels, 
                 n_clusters, 
                 mlp_units=[],
                 mlp_act="Identity"):
        super().__init__()
        
        mp_act = getattr(torch.nn, mp_act)(inplace=True)
        mlp_act = getattr(torch.nn, mlp_act)(inplace=True)
        
        # Message passing layers
        mp = [
            (GraphConv(in_channels, mp_units[0]), 'x, edge_index, edge_weight -> x'),
            mp_act
        ]
        for i in range(len(mp_units)-1):
            mp.append((GraphConv(mp_units[i], mp_units[i+1]), 'x, edge_index, edge_weight -> x'))
            mp.append(mp_act)
        self.mp = Sequential('x, edge_index, edge_weight', mp)
        out_chan = mp_units[-1]
        
        # MLP layers
        self.mlp = torch.nn.Sequential()
        for units in mlp_units:
            self.mlp.append(Linear(out_chan, units))
            out_chan = units
            self.mlp.append(mlp_act)
        self.mlp.append(Linear(out_chan, n_clusters))
        

    def forward(self, x, edge_index, edge_weight):
        
        # Propagate node feats
        x = self.mp(x, edge_index, edge_weight) 
        
        # Cluster assignments (logits)
        s = self.mlp(x) 
        
        # Obtain MinCutPool losses
        adj = utils.to_dense_adj(edge_index)
        #return x, adj, s
        _, _, mc_loss, o_loss = dense_mincut_pool(x, adj, s)
        
        return torch.softmax(s, dim=-1), mc_loss, o_loss, adj

In [3]:
import logging
import time
from typing import Literal

from torch_geometric.data import Data
from torch_geometric.graphgym.checkpoint import clean_ckpt, save_ckpt
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loss import compute_loss
from torch_geometric.graphgym.model_builder import GraphGymModule
from torch_geometric.graphgym.register import register_train
from torch_geometric.graphgym.utils.epoch import is_ckpt_epoch, is_eval_epoch

from sklearn.metrics import normalized_mutual_info_score as NMI

In [4]:
from torch_geometric.data import HeteroData
from torch_geometric.nn import GCNConv, GATConv, GINConv, global_mean_pool, global_add_pool
import torch.nn.functional as F
from tqdm import tqdm
from torch_geometric.data import InMemoryDataset

In [5]:
# device = torch.device('cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [6]:
dataset = PeptidesFunctionalDataset("../datasets")

In [7]:
dataset.get_idx_split()

{'train': tensor([    1,     3,     4,  ..., 15531, 15533, 15534]),
 'val': tensor([    6,    11,    17,  ..., 15496, 15499, 15527]),
 'test': tensor([    0,     2,     5,  ..., 15521, 15528, 15532])}

In [8]:
from yacs.config import CfgNode as CN
from graph_hscn.config.posenc_config import set_cfg_posenc

cfg = CN()
set_cfg_posenc(cfg)
cfg.posenc_SignNet.model = "DeepSet"
cfg.posenc_SignNet.post_layers = 1
cfg.posenc_SignNet.eigen.max_freqs = 20
cfg.posenc_SignNet.enable = True

"""Functions for precomputing positional encoding stats."""
from typing import Literal

import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import (  # noqa
    get_laplacian,
    to_scipy_sparse_matrix,
    to_undirected,
)
from yacs.config import CfgNode
from functools import partial
from typing import Callable

Normalization = Literal["L1", "L2", "abs-max"]

def pre_transform_in_memory(
    dataset: Data, transform_func: Callable, show_progress: bool = False
) -> Data | None:
    """Apply a transform function to InMemoryDataset in pre_transform stage.

    Parameters
    ----------
    dataset : Data
        Dataset to pre-transform.
    transform_func : Callable
        Transform function to apply.
    show_progress : bool
        Whether to show progress in tqdm.

    Returns
    -------
    Data | None
        Dataset if no transform_func specified, None otherwise.
    """
    if transform_func is None:
        return dataset

    data_list = [
        transform_func(dataset.get(i))
        for i in tqdm(
            range(len(dataset)),
            disable=not show_progress,
            mininterval=10,
            miniters=len(dataset) // 20,
        )
    ]
    data_list = list(filter(None, data_list))
    dataset._indices = None
    dataset._data_list = data_list
    dataset.data, dataset.slices = dataset.collate(data_list)


def compute_posenc_stats(
    data: Data, pe_types: list[str], is_undirected: bool, cfg: CfgNode
):
    """Precompute positional encodings for the given graph.

    Parameters
    ----------
    data : Data
        PyG graph object.
    pe_types : list[str]
        Positional encoding types to precompute statistics for.
    is_undirected : bool
        Whether the graph is expected to be undirected.
    cfg : CfgNode
        Configuration node for experiment.

    Returns
    -------
    Data
        Extended PyG graph object.
    """
    for t in pe_types:
        if t not in [
            "LapPE",
            "EquivStableLapPE",
            "SignNet",
        ]:
            raise ValueError(
                f"Unexpected PE stats selection {t} in {pe_types}"
            )

    if hasattr(data, "num_nodes"):
        N = data.num_nodes
    else:
        N = data.x.shape[0]

    laplacian_norm_type = cfg.posenc_LapPE.eigen.laplacian_norm.lower()

    if laplacian_norm_type == "none":
        laplacian_norm_type = None

    if is_undirected:
        undir_edge_index = data.edge_index
    else:
        undir_edge_index = to_undirected(data.edge_index)

    # Eigenvalues and eigenvectors
    evals, evects = None, None

    if "LaPE" in pe_types or "EquivStableLapPE" in pe_types:
        L = to_scipy_sparse_matrix(
            *get_laplacian(
                undir_edge_index,
                normalization=laplacian_norm_type,
                num_nodes=N,
            )
        )
        evals, evects = np.linalg.eigh(L.toarray())

        max_freqs, eigvec_norm = None, None

        if "LapPE" in pe_types:
            max_freqs = cfg.posenc_LapPE.eigen.max_freqs
            eigvec_norm = cfg.posenc_LapPE.eigen.eigvec_norm
        elif "EquivStableLapPE" in pe_types:
            max_freqs = cfg.posenc_EquivStableLapPE.eigen.max_freqs
            eigvec_norm = cfg.posenc_EquivStableLapPE.eigen.eigvec_norm

        data.eig_vals, data.eig_vecs = get_lap_decomp_stats(
            evals=evals,
            evects=evects,
            max_freqs=max_freqs,
            eigvec_norm=eigvec_norm,
        )

    if "SignNet" in pe_types:
        norm_type = cfg.posenc_SignNet.eigen.laplacian_norm.lower()

        if norm_type == "none":
            norm_type = None

        L = to_scipy_sparse_matrix(
            *get_laplacian(
                undir_edge_index, normalization=norm_type, num_nodes=N
            )
        )
        evals_sn, evects_sn = np.linalg.eigh(L.toarray())
        data.eigvals_sn, data.eigvecs_sn = get_lap_decomp_stats(
            evals=evals_sn,
            evects=evects_sn,
            max_freqs=cfg.posenc_SignNet.eigen.max_freqs,
            eigvec_norm=cfg.posenc_SignNet.eigen.eigvec_norm,
        )

    return data


def get_lap_decomp_stats(
    evals: torch.Tensor,
    evects: torch.Tensor,
    max_freqs: int,
    eigvec_norm: Normalization = "L2",
) -> tuple[torch.Tensor, torch.Tensor]:
    """Compute Laplacian eigen-decomposition-based PE stats of a graph.

    Parameters
    ----------
    evals : torch.Tensor
        Precomputed eigenvalues.
    evects : torch.Tensor
        Precomputed eigenvectors.
    max_freqs : int
        Maximum number of top smallest frequencies/eigenvectors to use.
    eigvec_norm : Normalization
        Normalization for the eigenvectors of the Laplacian.

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor]
        Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node.
        Tensor (num_nodes, max_freqs) of eigenvector values per node.
    """
    N = len(evals)  # Number of nodes, including disconnected nodes

    idx = evals.argsort()[:max_freqs]
    evals, evects = evals[idx], np.real(evects[:, idx])
    evals = torch.from_numpy(np.real(evals)).clamp_min(0)

    # Normalize and pad eigenvectors.
    evects = torch.from_numpy(evects).float()
    evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm)

    if N < max_freqs:
        eig_vecs = F.pad(evects, (0, max_freqs - N), value=float("nan"))
    else:
        eig_vecs = evects

    # Pad and save eigenvalues.
    if N < max_freqs:
        eig_vals = F.pad(
            evals, (0, max_freqs - N), value=float("nan")
        ).unsqueeze(0)
    else:
        eig_vals = evals.unsqueeze(0)

    eig_vals = eig_vals.repeat(N, 1).unsqueeze(2)

    return eig_vals, eig_vecs


def eigvec_normalizer(
    eig_vecs: torch.Tensor,
    eig_vals: torch.Tensor,
    normalization: Normalization = "L2",
    eps: float = 1e-12,
):
    """Implement different eigenvector normalizations.

    Parameters
    ----------
    eig_vecs : torch.Tensor
        Eigenvectors of data.
    eig_vals : torch.Tensor
        Eigenvalues of data.
    normalization : Normalization
        Normalization scheme.
    eps: float
        Epsilon for clamping.

    Returns
    -------
    torch.Tensor
        Normalized eigenvectors.
    """
    match normalization:
        case "L1":
            # eigvec / sum(abs(eigvec))
            denom = eig_vecs.norm(p=1, dim=0, keepdim=True)
        case "L2":
            # eigvec / sqrt(sum(eigvec^2))
            denom = eig_vecs.norm(p=2, dim=0, keepdim=True)
        case "abs-max":
            # eigvec / max(|eigvec|)
            denom = torch.max(eig_vecs.abs(), dim=0, keepdim=True).values
        case other:
            raise ValueError(f"Unsupported normalization `{normalization}`")

    denom = denom.clamp_min(eps).expand_as(eig_vecs)
    eig_vecs = eig_vecs / denom

    return eig_vecs


pe_enabled_list = []
for key, pecfg in cfg.items():
    if key.startswith('posenc_') and pecfg.enable:
        pe_name = key.split('_', 1)[1]
        pe_enabled_list.append(pe_name)
        if hasattr(pecfg, 'kernel'):
            # Generate kernel times if functional snippet is set.
            if pecfg.kernel.times_func:
                pecfg.kernel.times = list(eval(pecfg.kernel.times_func))
            logging.info(f"Parsed {pe_name} PE kernel times / steps: "
                         f"{pecfg.kernel.times}")
if pe_enabled_list:
    start = time.perf_counter()
    logging.info(f"Precomputing Positional Encoding statistics: "
                 f"{pe_enabled_list} for all graphs...")
    # Estimate directedness based on 10 graphs to save time.
    is_undirected = all(d.is_undirected() for d in dataset[:10])
    logging.info(f"  ...estimated to be undirected: {is_undirected}")
    pre_transform_in_memory(dataset,
                            partial(compute_posenc_stats,
                                    pe_types=pe_enabled_list,
                                    is_undirected=is_undirected,
                                    cfg=cfg),
                            show_progress=True
                            )
    elapsed = time.perf_counter() - start
    timestr = time.strftime('%H:%M:%S', time.gmtime(elapsed)) \
              + f'{elapsed:.2f}'[-3:]
    logging.info(f"Done! Took {timestr}")


100%|██████████| 15535/15535 [02:23<00:00, 107.92it/s]


In [None]:
enc = SignNetNodeEncoder(cfg, dataset.num_features, 32)
new_dataset = []
for data in tqdm(dataset):
    new_dataset += enc(data)

  7%|▋         | 1063/15535 [00:12<03:46, 64.02it/s]

In [10]:
model = Net([16], "ELU", dataset.num_features, 5).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
model.train()

all_loss = []
for epoch in range(1, 11):
    for data in tqdm(dataset):
        data.edge_index, data.edge_weight = gcn_norm(  
                 data.edge_index, data.edge_weight, data.num_nodes,
                 add_self_loops=True)
        data = data.to(device)

        optimizer.zero_grad()
        _, mc_loss, o_loss, adj = model(data.x.float(), data.edge_index, data.edge_weight)
        loss = mc_loss + o_loss
        all_loss.append(loss.item())
        loss.backward()
        optimizer.step()

100%|██████████| 15535/15535 [00:30<00:00, 512.08it/s]
100%|██████████| 15535/15535 [00:35<00:00, 435.86it/s]
100%|██████████| 15535/15535 [00:49<00:00, 317.03it/s]
100%|██████████| 15535/15535 [00:24<00:00, 627.18it/s]
100%|██████████| 15535/15535 [00:24<00:00, 640.76it/s]
100%|██████████| 15535/15535 [00:24<00:00, 645.15it/s]
 87%|████████▋ | 13490/15535 [00:24<00:04, 469.43it/s]

In [11]:
color_all_lst = []
for data in tqdm(dataset):
    data.edge_index, data.edge_weight = gcn_norm(  
             data.edge_index, data.edge_weight, data.num_nodes,
             add_self_loops=True)
    data = data.to(device)
    clust, _, _, adj = model(data.x.float(), data.edge_index, data.edge_weight)
    colors = clust.max(1)[1].cpu().numpy()
    color_all_lst.append(colors)

In [40]:
class hg_dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['peptides_func_hetero.npy']

    @property
    def processed_file_names(self):
        return ['peptides_func_hetero.pt']

    def download(self):
        pass

    def process(self):
        # Read data into huge `Data` list.
        # dataset = PeptidesFunctionalDataset("/home/zluo/")
        
        color_all_lst = np.load("color_all_lst.npy", allow_pickle=True)
        h_data_lst = []
        
        if self.pre_filter is not None:
            h_data_lst = [data for data in h_data_lst if self.pre_filter(data)]

        if self.pre_transform is not None:
            h_data_lst = [self.pre_transform(data) for data in h_data_lst]

        for idx in tqdm(range(len(dataset))):
            data = dataset[idx].to(device)
            clust_node = [[] for idx in range(5)]
            colors = color_all_lst[idx]
            unique_colors = np.unique(colors)
            clus_map = {unique_colors[idx]: idx for idx in range(len(unique_colors))}
            colors = [clus_map[val] for val in colors]
            for idx in range(data.num_nodes):
                clust_num = colors[idx] - 1
                clust_node[clust_num].append(data.x[idx].tolist())

            clust_node = [lst for lst in clust_node if len(lst) != 0]
            clust_mean = [np.mean(clust_lst, axis=0) for clust_lst in clust_node]
            clust_mean = np.array(clust_mean)

            num_clust = len(clust_mean)
            h_data = HeteroData()
            h_data['local'].x = data.x.float()
            h_data['local'].y = data.y
            h_data['virtual'].x = torch.FloatTensor(clust_mean)
            h_data['local', 'to', 'local'].edge_index = data.edge_index
            col = np.concatenate([[idx] * (num_clust - idx) for idx in range(num_clust)])
            row = np.concatenate([[idx for idx in range(num_clust - index)] for index in range(num_clust)])
            h_data['virtual', 'to', 'virtual'].edge_index = torch.LongTensor([list(col), list(row)])

            edge_lst = []
            for idx in range(len(colors)):
                clust_num = colors[idx]
                edge_lst.append([idx, clust_num])

            h_data['local', 'to', 'virtual'].edge_index = torch.LongTensor(edge_lst).T

            h_data_lst.append(h_data)

        data, slices = self.collate(h_data_lst)
        torch.save((data, slices), self.processed_paths[0])

In [41]:
h_dataset = hg_dataset("../datasets/peptides_func_hetero")
h_dataset.process()

In [37]:
print(h_dataset)

In [None]:
for data in tqdm(dataset):
    data.edge_index, data.edge_weight = gcn_norm(  
                data.edge_index, data.edge_weight, data.num_nodes,
                add_self_loops=False)
    data = data.to(device)

    colors = [0]
    num_clust = 1
    while num_clust < 5:
        colors, clust, x = run()
        num_clust = len(np.unique(colors))
        #print(num_clust)

    G = to_networkx(data, node_attrs=["x"])
    G = G.to_undirected()


    nodelist = G.nodes()
    clust_node = [[] for idx in range(5)]
    for idx in range(len(nodelist)):
        clust_num = colors[idx] - 1
        clust_node[clust_num].append(nodelist[idx]['x'])

    clust_node = [lst for lst in clust_node if len(lst) != 0]

    clust_mean = [np.mean(clust_lst, axis=0) for clust_lst in clust_node]
    clust_mean = np.array(clust_mean)

    h_data = HeteroData()
    h_data['local'].x = data.x.float()
    h_data['local'].y = data.y
    h_data['virtual'].x = torch.FloatTensor(clust_mean)
    h_data['local', 'to', 'local'].edge_index = data.edge_index
    col = np.concatenate([[idx] * (num_clust - idx) for idx in range(num_clust)])
    row = np.concatenate([[idx for idx in range(num_clust - index)] for index in range(num_clust)])
    h_data['virtual', 'to', 'virtual'].edge_index = torch.LongTensor([list(col), list(row)])

    edge_lst = []
    for idx in range(len(colors)):
        clust_num = colors[idx]
        edge_lst.append([idx, clust_num])

    h_data['local', 'to', 'virtual'].edge_index = torch.LongTensor(edge_lst).T

    h_data_lst.append(h_data)    
    #print("RUN: " + str(len(h_data_lst)))

  1%|▉                                                                                                                                                                 | 89/15535 [02:22<7:56:22,  1.85s/it]

In [63]:
class HeteroGNN(nn.Module):
    def __init__(self, lv_conv, ll_conv, vv_conv, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.convs = nn.ModuleList()
        

        for _ in range(num_layers):
            conv = HeteroConv({
                ("local", "to", "virtual"): lv_conv((-1, -1), hidden_channels, add_self_loops=False, cached=False),
                ("local", "to", "local"): ll_conv(-1, hidden_channels, add_self_loops=False, cached=False),
                ("virtual", "to", "virtual"): vv_conv(-1, hidden_channels, add_self_loops=False, cached=False),
            }, aggr="sum")
            self.convs.append(conv)
        
        self.lin1 = Linear(hidden_channels, hidden_channels)
        self.lin2 = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict, batch):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
        
        x = global_mean_pool(x_dict["local"], batch)
        
        x = F.relu(self.lin1(x))
        x = self.lin2(x)
        return x

In [64]:
from torch_geometric.nn import HeteroConv, GCNConv, SAGEConv, GATConv, Linear

In [65]:
def mse(output, target):
    loss = torch.square(output - target)
    return loss

In [66]:
model = HeteroGNN(GATConv, GCNConv, GCNConv, 32, 11, 3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [67]:
for epoch in range(1, 10):
    model.train()
    all_loss = []
    for data in tqdm(h_dataset):
        data = data.to(device)
        optimizer.zero_grad()
        output = model(data.x_dict, data.edge_index_dict, None)
        loss = mse(output, data['local'].y)
        all_loss.append(loss.view(-1).tolist())
        loss.sum().backward()
        optimizer.step()
    print(np.mean(all_loss, axis=0))

100%|████████████████████████████████████████| 15535/15535 [01:00<00:00, 256.01it/s]


[0.99976401 0.99930188 0.99046993 0.99911508 1.00095701 0.9963645
 0.99832419 0.99449364 0.99664531 0.99440601 0.99661492]


100%|████████████████████████████████████████| 15535/15535 [00:59<00:00, 262.13it/s]


[0.98978915 0.98806335 0.93322838 0.99114505 0.99336086 0.96013113
 0.9901561  0.96133322 0.96821076 0.97264736 0.95505515]


100%|████████████████████████████████████████| 15535/15535 [00:59<00:00, 259.32it/s]


[0.95802834 0.95903133 0.84499626 0.96558172 0.96958096 0.89081626
 0.9496801  0.89185419 0.9028089  0.96494372 0.88447096]


  7%|██▉                                      | 1109/15535 [00:04<00:55, 259.96it/s]


KeyboardInterrupt: 

In [None]:
PeptidesFunctionalDataset("../../datasets")

NameError: name 'PeptidesFunctionalDataset' is not defined

In [69]:
data['local'].y

tensor([[-0.3878, -0.3690, -0.5128, -0.4187, -0.3985, -0.5848, -0.1651, -0.5372,
         -0.4686, -0.1788, -0.0827]])