## Requirements ##

In [None]:
# @title
!pip install torch==1.11.0
!pip install -U fvcore

!pip install tensorly
# !pip install https://data.pyg.org/whl/torch-1.11.0%2Bcu102/pyg_lib-0.1.0%2Bpt111cu102-cp310-cp310-linux_x86_64.whl
# !pip install https://data.pyg.org/whl/torch-1.11.0%2Bcu102/torch_scatter-2.0.9-cp310-cp310-linux_x86_64.whl
# !pip install torchviz

Collecting torch==1.11.0
  Downloading torch-1.11.0-cp310-cp310-manylinux1_x86_64.whl (750.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m750.6/750.6 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 2.2.1+cu121
    Uninstalling torch-2.2.1+cu121:
      Successfully uninstalled torch-2.2.1+cu121
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchaudio 2.2.1+cu121 requires torch==2.2.1, but you have torch 1.11.0 which is incompatible.
torchdata 0.7.1 requires torch>=2, but you have torch 1.11.0 which is incompatible.
torchtext 0.17.1 requires torch==2.2.1, but you have torch 1.11.0 which is incompatible.
torchvision 0.17.1+cu121 requires torch==2.2.1, but you have torch 1.11.0 which is incompatible.[0m[31m
[0mSuccessfully install


###Manifolds###

Poincare, Lorentz and Euclidean Manifolds

In [None]:
# @title
#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

##########Manifold#######################

# @title
#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

##########Manifold#######################



from abc import abstractmethod
from torch.nn import Embedding


class Manifold(object):
    def allocate_lt(self, N, dim, sparse):
        return Embedding(N, dim, sparse=sparse)

    def normalize(self, u):
        return u

    @abstractmethod
    def distance(self, u, v):
        """
        Distance function
        """
        raise NotImplementedError

    def init_weights(self, w, scale=1e-4):
        w.weight.data.uniform_(-scale, scale)

    @abstractmethod
    def expm(self, p, d_p, lr=None, out=None):
        """
        Exponential map
        """
        raise NotImplementedError

    @abstractmethod
    def logm(self, x, y):
        """
        Logarithmic map
        """
        raise NotImplementedError

    @abstractmethod
    def ptransp(self, x, y, v, ix=None, out=None):
        """
        Parallel transport
        """
        raise NotImplementedError

    def norm(self, u, **kwargs):
        if isinstance(u, Embedding):
            u = u.weight
        return u.pow(2).sum(dim=-1).sqrt()

    @abstractmethod
    def half_aperture(self, u):
        """
        Compute the half aperture of an entailment cone.
        As in: https://arxiv.org/pdf/1804.01882.pdf
        """
        raise NotImplementedError

    @abstractmethod
    def angle_at_u(self, u, v):
        """
        Compute the angle between the two half lines (0u and uv
        """
        raise NotImplementedError

#############Euclidean Manifold####################
#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch as th
import numpy as np


class EuclideanManifold(Manifold):
    __slots__ = ["max_norm"]

    def __init__(self, max_norm=None, K=None, **kwargs):
        self.max_norm = max_norm
        self.K = K
        if K is not None:
            self.inner_radius = 2 * self.K / (1 + np.sqrt(1 + 4 * self.K * self.K))

    def normalize(self, u):
        d = u.size(-1)
        if self.max_norm:
            u.view(-1, d).renorm_(2, 0, self.max_norm)
        return u

    def distance(self, u, v):
        return (u - v).pow(2).sum(dim=-1)

    def rgrad(self, p, d_p):
        return d_p

    def half_aperture(self, u):
        sqnu = u.pow(2).sum(dim=-1)
        return th.asin(self.inner_radius / sqnu.sqrt())

    def angle_at_u(self, u, v):
        norm_u = self.norm(u)
        norm_v = self.norm(v)
        dist = self.distance(v, u)
        num = norm_u.pow(2) - norm_v.pow(2) - dist.pow(2)
        denom = 2 * norm_v * dist
        return (num / denom).acos()

    def expm(self, p, d_p, normalize=False, lr=None, out=None):
        if lr is not None:
            d_p.mul_(-lr)
        if out is None:
            out = p
        out.add_(d_p)
        if normalize:
            self.normalize(out)
        return out

    def logm(self, p, d_p, out=None):
        return p - d_p

    def ptransp(self, p, x, y, v):
        ix, v_ = v._indices().squeeze(), v._values()
        return p.index_copy_(0, ix, v_)

########################Poincare Manifold############################
#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch as th
from torch.autograd import Function
import numpy as np


class PoincareManifold(EuclideanManifold):
    def __init__(self, eps=1e-5, K=None, **kwargs):
        self.eps = eps
        super(PoincareManifold, self).__init__(max_norm=1 - eps)
        self.K = K
        if K is not None:
            self.inner_radius = 2 * K / (1 + np.sqrt(1 + 4 * K * self.K))

    def distance(self, u, v):
        return Distance.apply(u, v, self.eps)

    def half_aperture(self, u):
        eps = self.eps
        sqnu = u.pow(2).sum(dim=-1)
        sqnu.clamp_(min=0, max=1 - eps)
        return th.asin((self.inner_radius * (1 - sqnu) / th.sqrt(sqnu))
            .clamp(min=-1 + eps, max=1 - eps))

    def angle_at_u(self, u, v):
        norm_u = u.norm(2, dim=-1)
        norm_v = v.norm(2, dim=-1)
        dot_prod = (u * v).sum(dim=-1)
        edist = (u - v).norm(2, dim=-1)  # euclidean distance
        num = (dot_prod * (1 + norm_v ** 2) - norm_v ** 2 * (1 + norm_u ** 2))
        denom = (norm_v * edist * (1 + norm_v**2 * norm_u**2 - 2 * dot_prod).sqrt())
        return (num / denom).clamp_(min=-1 + self.eps, max=1 - self.eps).acos()

    def rgrad(self, p, d_p):
        if d_p.is_sparse:
            p_sqnorm = th.sum(
                p[d_p._indices()[0].squeeze()] ** 2, dim=1,
                keepdim=True
            ).expand_as(d_p._values())
            n_vals = d_p._values() * ((1 - p_sqnorm) ** 2) / 4
            n_vals.renorm_(2, 0, 5)
            d_p = th.sparse.DoubleTensor(d_p._indices(), n_vals, d_p.size())
        else:
            p_sqnorm = th.sum(p ** 2, dim=-1, keepdim=True)
            d_p = d_p * ((1 - p_sqnorm) ** 2 / 4).expand_as(d_p)
        return d_p


class Distance(Function):
    @staticmethod
    def grad(x, v, sqnormx, sqnormv, sqdist, eps):
        alpha = (1 - sqnormx)
        beta = (1 - sqnormv)
        z = 1 + 2 * sqdist / (alpha * beta)
        a = ((sqnormv - 2 * th.sum(x * v, dim=-1) + 1) / th.pow(alpha, 2))\
            .unsqueeze(-1).expand_as(x)
        a = a * x - v / alpha.unsqueeze(-1).expand_as(v)
        z = th.sqrt(th.pow(z, 2) - 1)
        z = th.clamp(z * beta, min=eps).unsqueeze(-1)
        return 4 * a / z.expand_as(x)

    @staticmethod
    def forward(ctx, u, v, eps):
        squnorm = th.clamp(th.sum(u * u, dim=-1), 0, 1 - eps)
        sqvnorm = th.clamp(th.sum(v * v, dim=-1), 0, 1 - eps)
        sqdist = th.sum(th.pow(u - v, 2), dim=-1)
        ctx.eps = eps
        ctx.save_for_backward(u, v, squnorm, sqvnorm, sqdist)
        x = sqdist / ((1 - squnorm) * (1 - sqvnorm)) * 2 + 1
        # arcosh
        z = th.sqrt(th.pow(x, 2) - 1)
        return th.log(x + z)

    @staticmethod
    def backward(ctx, g):
        u, v, squnorm, sqvnorm, sqdist = ctx.saved_tensors
        g = g.unsqueeze(-1)
        gu = Distance.grad(u, v, squnorm, sqvnorm, sqdist, ctx.eps)
        gv = Distance.grad(v, u, sqvnorm, squnorm, sqdist, ctx.eps)
        return g.expand_as(gu) * gu, g.expand_as(gv) * gv, None


###########################LorentzManifold################################


#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch as th
from torch.autograd import Function
import numpy as np
from torch.nn import Embedding


class LorentzManifold(Manifold):
    __slots__ = ["eps", "_eps", "norm_clip", "max_norm", "debug"]

    def __init__(self, eps=1e-12, _eps=1e-5, norm_clip=1, max_norm=1e6,
            debug=False, K=None, **kwargs):
        self.eps = eps
        self._eps = _eps
        self.norm_clip = norm_clip
        self.max_norm = max_norm
        self.debug = debug
        self.K = K
        if K is not None:
            self.inner_radius = 2 * self.K / (1 + np.sqrt(1 + 4 * self.K * self.K))

    def allocate_lt(self, N, dim, sparse):
        return Embedding(N, dim + 1, sparse=sparse)

    def init_weights(self, w, irange=1e-5):
        w.weight.data.uniform_(-irange, irange)
        self.normalize(w.weight.data)

    @staticmethod
    def ldot(u, v, keepdim=False):
        """Lorentzian Scalar Product"""
        uv = u * v
        uv.narrow(-1, 0, 1).mul_(-1)
        return th.sum(uv, dim=-1, keepdim=keepdim)

    def to_poincare_ball(self, u):
        x = u.clone()
        d = x.size(-1) - 1
        return x.narrow(-1, 1, d) / (x.narrow(-1, 0, 1) + 1)

    def distance(self, u, v):
        d = -LorentzDot.apply(u, v)
        d.data.clamp_(min=1)
        return acosh(d, self._eps)

    def norm(self, u):
        return th.sqrt(th.sum(th.pow(self.to_poincare_ball(u), 2), dim=-1))

    def normalize(self, w):
        """Normalize vector such that it is located on the hyperboloid"""
        d = w.size(-1) - 1
        narrowed = w.narrow(-1, 1, d)
        if self.max_norm:
            narrowed.view(-1, d).renorm_(p=2, dim=0, maxnorm=self.max_norm)

        if self.K is not None:
            # Push embeddings outside of `inner_radius`
            w0 = w.narrow(-1, 0, 1).squeeze()
            wnrm = th.sqrt(th.sum(th.pow(narrowed, 2), dim=-1)) / (1 + w0)
            scal = th.ones_like(wnrm)
            ix = wnrm < (self.inner_radius + self._eps)
            scal[ix] = (self.inner_radius + self._eps) / wnrm[ix]
            narrowed.mul_(scal.unsqueeze(-1))

        tmp = 1 + th.sum(th.pow(narrowed, 2), dim=-1, keepdim=True)
        tmp.sqrt_()
        w.narrow(-1, 0, 1).copy_(tmp)
        return w

    def normalize_tan(self, x_all, v_all):
        d = v_all.size(1) - 1
        x = x_all.narrow(1, 1, d)
        xv = th.sum(x * v_all.narrow(1, 1, d), dim=1, keepdim=True)
        tmp = 1 + th.sum(th.pow(x_all.narrow(1, 1, d), 2), dim=1, keepdim=True)
        tmp.sqrt_().clamp_(min=self._eps)
        v_all.narrow(1, 0, 1).copy_(xv / tmp)
        return v_all

    def rgrad(self, p, d_p):
        """Riemannian gradient for hyperboloid"""
        if d_p.is_sparse:
            u = d_p._values()
            x = p.index_select(0, d_p._indices().squeeze())
        else:
            u = d_p
            x = p
        u.narrow(-1, 0, 1).mul_(-1)
        u.addcmul_(self.ldot(x, u, keepdim=True).expand_as(x), x)
        return d_p

    def expm(self, p, d_p, lr=None, out=None, normalize=False):
        """Exponential map for hyperboloid"""
        if out is None:
            out = p
        if d_p.is_sparse:
            ix, d_val = d_p._indices().squeeze(), d_p._values()
            # This pulls `ix` out of the original embedding table, which could
            # be in a corrupted state.  normalize it to fix it back to the
            # surface of the hyperboloid...
            # TODO: we should only do the normalize if we know that we are
            # training with multiple threads, otherwise this is a bit wasteful
            p_val = self.normalize(p.index_select(0, ix))
            ldv = self.ldot(d_val, d_val, keepdim=True)
            if self.debug:
                assert all(ldv > 0), "Tangent norm must be greater 0"
                assert all(ldv == ldv), "Tangent norm includes NaNs"
            nd_p = ldv.clamp_(min=0).sqrt_()
            t = th.clamp(nd_p, max=self.norm_clip)
            nd_p.clamp_(min=self.eps)
            newp = (th.cosh(t) * p_val).addcdiv_(th.sinh(t) * d_val, nd_p)
            if normalize:
                newp = self.normalize(newp)
            p.index_copy_(0, ix, newp)
        else:
            if lr is not None:
                d_p.narrow(-1, 0, 1).mul_(-1)
                d_p.addcmul_((self.ldot(p, d_p, keepdim=True)).expand_as(p), p)
                d_p.mul_(-lr)
            ldv = self.ldot(d_p, d_p, keepdim=True)
            if self.debug:
                assert all(ldv > 0), "Tangent norm must be greater 0"
                assert all(ldv == ldv), "Tangent norm includes NaNs"
            nd_p = ldv.clamp_(min=0).sqrt_()
            t = th.clamp(nd_p, max=self.norm_clip)
            nd_p.clamp_(min=self.eps)
            newp = (th.cosh(t) * p).addcdiv_(th.sinh(t) * d_p, nd_p)
            if normalize:
                newp = self.normalize(newp)
            p.copy_(newp)

    def logm(self, x, y):
        """Logarithmic map on the Lorenz Manifold"""
        xy = th.clamp(self.ldot(x, y).unsqueeze(-1), max=-1)
        v = acosh(-xy, self.eps).div_(
            th.clamp(th.sqrt(xy * xy - 1), min=self._eps)
        ) * th.addcmul(y, xy, x)
        return self.normalize_tan(x, v)

    def ptransp(self, x, y, v, ix=None, out=None):
        """Parallel transport for hyperboloid"""
        if ix is not None:
            v_ = v
            x_ = x.index_select(0, ix)
            y_ = y.index_select(0, ix)
        elif v.is_sparse:
            ix, v_ = v._indices().squeeze(), v._values()
            x_ = x.index_select(0, ix)
            y_ = y.index_select(0, ix)
        else:
            raise NotImplementedError
        xy = self.ldot(x_, y_, keepdim=True).expand_as(x_)
        vy = self.ldot(v_, y_, keepdim=True).expand_as(x_)
        vnew = v_ + vy / (1 - xy) * (x_ + y_)
        if out is None:
            return vnew
        else:
            out.index_copy_(0, ix, vnew)

    def half_aperture(self, u):
        eps = self.eps
        d = u.size(-1) - 1
        sqnu = th.sum(u.narrow(-1, 1, d) ** 2, dim=-1) / (1 + u.narrow(-1, 0, 1)
            .squeeze(-1)) ** 2
        sqnu.clamp_(min=0, max=1 - eps)
        return th.asin((self.inner_radius * (1 - sqnu) / th.sqrt(sqnu))
            .clamp(min=-1 + eps, max=1 - eps))

    def angle_at_u(self, u, v):
        uvldot = LorentzDot.apply(u, v)
        u0 = u.narrow(-1, 0, 1).squeeze(-1)
        num = th.add(v.narrow(-1, 0, 1).squeeze(-1), th.mul(u0, uvldot))
        tmp = th.pow(uvldot, 2) - 1.
        den = th.sqrt(th.pow(u0, 2) - 1.) * th.sqrt(tmp.clamp_(min=self.eps))
        frac = th.div(num, den)
        if self.debug and (frac != frac).any():
            import ipdb; ipdb.set_trace()
        frac.data.clamp_(min=-1 + self.eps, max=1 - self.eps)
        ksi = frac.acos()
        return ksi

    def norm(self, u):
        if isinstance(u, Embedding):
            u = u.weight
        d = u.size(-1) - 1
        sqnu = th.sum(u.narrow(-1, 1, d) ** 2, dim=-1)
        sqnu = sqnu / (1 + u.narrow(-1, 0, 1).squeeze(-1)) ** 2
        return sqnu.sqrt()


class LorentzDot(Function):
    @staticmethod
    def forward(ctx, u, v):
        ctx.save_for_backward(u, v)
        return LorentzManifold.ldot(u, v)

    @staticmethod
    def backward(ctx, g):
        u, v = ctx.saved_tensors
        g = g.unsqueeze(-1).expand_as(u).clone()
        g.narrow(-1, 0, 1).mul_(-1)
        return g * v, g * u

###Hype###

Contains checkpoint code and all other util files (including data pre-processing, data splitting etc), as well as some mathematical operations (poisson kernel) etc

In [None]:
# @title
#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


####################checkpoint.py#######################
import os
import time
import torch
import warnings
from fvcore.common.file_io import PathManager


def upgrade_state_dict(state):
    '''
    Used to upgrade old checkpoints to deal with breaking changes
    '''
    conf = state['conf']

    # Previously we only had `-manifold`.  if this is an old checkpoint, then
    # update the `conf` to use the same manifold and "distance" model...
    if 'model' not in conf:
        warnings.warn(
            'Missing `model` field in checkpoint config.'
            '  Assuming `distance`.'
        )
        conf['model'] = 'distance'
    return state


class LocalCheckpoint(object):
    def __init__(self, path, include_in_all=None, start_fresh=False):
        self.path = path
        self.start_fresh = start_fresh
        self.include_in_all = {} if include_in_all is None else include_in_all

    def initialize(self, params):
        if not self.start_fresh and os.path.isfile(self.path):
            print(f'Loading checkpoint from {self.path}')
            with PathManager.open(self.path, 'rb') as fin:
                return torch.load(fin)
        else:
            return params

    def load(self):
        if os.path.isfile(self.path):
            print(f'Loading checkpoint from {self.path}')
            with PathManager.open(self.path, 'rb') as fin:
                return torch.load(fin)
        else:
            print('not a valid path to load from')
            raise NotImplemented

    def save(self, params, tries=10):
        try:
            with PathManager.open(self.path, 'wb') as fout:
                torch.save({**self.include_in_all, **params}, fout)
        except Exception as err:
            if tries > 0:
                print(f'Exception while saving ({err})\nRetrying ({tries})')
                time.sleep(60)
                self.save(params, tries=(tries - 1))
            else:
                print("Giving up on saving...")



###############common.py######################
import torch as th
from torch.autograd import Function


class Acosh(Function):
    @staticmethod
    def forward(ctx, x, eps):
        z = th.sqrt(x * x - 1)
        ctx.save_for_backward(z)
        ctx.eps = eps
        return th.log(x + z)

    @staticmethod
    def backward(ctx, g):
        z, = ctx.saved_tensors
        z = th.clamp(z, min=ctx.eps)
        z = g / z
        return z, None


acosh = Acosh.apply

#####################hyla_utils.py##################################

import torch
import numpy as np
import scipy.sparse as sp
from sklearn.metrics import average_precision_score, accuracy_score, f1_score
import os
import pickle as pkl
import sys
import networkx as nx
import torch.nn.functional as F
import json
from networkx.readwrite import json_graph
import pdb
from scipy.sparse.linalg import eigsh
#from scipy.sparse.linalg.eigen.arpack import eigsh
from scipy.sparse import csr_matrix
import re
from time import perf_counter
import tabulate
sys.setrecursionlimit(99999)


def sample_boundary(n_Bs, d, cls):
    if cls =='RandomUniform' or d>2:
        pre_b = torch.randn(n_Bs, d)
        b = pre_b/torch.norm(pre_b,dim=-1,keepdim=True)
    elif cls == 'FixedUniform':
        theta = torch.arange(0,2 * np.pi, 2*np.pi/n_Bs)
        b = torch.stack([torch.cos(theta), torch.sin(theta)],1)
    elif cls == 'RandomDisk':
        theta = 2 * np.pi * torch.rand(n_Bs)
        b = torch.stack([torch.cos(theta), torch.sin(theta)],1)
    else:
        raise NotImplementedError
    return b

def PoissonKernel(X, b):
    X = X.view(X.size(0), 1, X.size(-1))
    return (1 - torch.norm(X, 2, dim=-1)**2)/(torch.norm(X-b, 2, dim=-1)**2)
#     return (1 - torch.sum(X * X, dim=-1))/torch.sum((X-b)**2,dim=-1)

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float64)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.DoubleTensor(indices, values, shape)


def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

def aug_normalized_adjacency(adj):
    adj = adj + sp.eye(adj.shape[0])
    adj = sp.coo_matrix(adj)
    row_sum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(row_sum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo()

def add_self_loop(adj):
    return adj + sp.eye(adj.shape[0])

def sgc_precompute(adj, features, degree):
    nonzero_perc = []
#     assert degree > 0, 'invalid degree as 0'
    if degree==0:
        number_nonzero = (features != 0).sum().item()
        percentage = number_nonzero*1.0/features.size(0)/features.size(1)*100.0
        nonzero_perc.append("%.2f" % percentage)
        print('input order 0, return raw feature')
        return features, nonzero_perc
    for i in range(degree):
        features = torch.spmm(adj, features)
        number_nonzero = (features != 0).sum().item()
        percentage = number_nonzero*1.0/features.size(0)/features.size(1)*100.0
        nonzero_perc.append("%.2f" % percentage)
    return features, nonzero_perc


def adj_compute(adj, degree):
    new_adj = adj.to_dense().clone()
    result = adj.to_dense().clone()
    for i in range(degree):
        result = torch.spmm(new_adj, result)
    return result

def acc_f1(output, labels, average='micro'):
    preds = output.max(1)[1].type_as(labels)
    if preds.is_cuda:
        preds = preds.cpu()
        labels = labels.cpu()
    accuracy = accuracy_score(preds, labels)
    f1 = f1_score(preds, labels, average=average)
    return accuracy, f1

def measure_tensor_size(a):
    # return # MB
    return a.element_size() * a.nelement() * 0.000001

# ###################################################
# data loading

def load_data(args, datapath):
    data = load_data_nc(args.dataset, args.use_feats, datapath, args.split_seed)
    adj_n = aug_normalized_adjacency(data['adj_train'])
    data['adj_train'] = sparse_mx_to_torch_sparse_tensor(adj_n)
    data['features'] = sparse_mx_to_torch_sparse_tensor(data['features'])
    return data

# ############### FEATURES PROCESSING ####################################


def process(adj, features, normalize_adj, normalize_feats):
    if sp.isspmatrix(features):
        features = np.array(features.todense())
    if normalize_feats:
        features = normalize(features)
    features = torch.Tensor(features)
    if normalize_adj:
        adj = normalize(adj + sp.eye(adj.shape[0]))
    adj = sparse_mx_to_torch_sparse_tensor(adj)
    return adj, features


def normalize(mx):
    """Row-normalize sparse matrix."""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx


def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo()
    indices = torch.from_numpy(
            np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
    )
    values = torch.Tensor(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)


def augment(adj, features, normalize_feats=True):
    deg = np.squeeze(np.sum(adj, axis=0).astype(int))
    deg[deg > 5] = 5
    deg_onehot = torch.tensor(np.eye(6)[deg], dtype=torch.float).squeeze()
    const_f = torch.ones(features.size(0), 1)
    features = torch.cat((features, deg_onehot, const_f), dim=1)
    return features


# ############### DATA SPLITS #####################################################


def mask_edges(adj, val_prop, test_prop, seed):
    np.random.seed(seed)  # get tp edges
    x, y = sp.triu(adj).nonzero()
    pos_edges = np.array(list(zip(x, y)))
    np.random.shuffle(pos_edges)
    # get tn edges
    x, y = sp.triu(sp.csr_matrix(1. - adj.toarray())).nonzero()
    neg_edges = np.array(list(zip(x, y)))
    np.random.shuffle(neg_edges)

    m_pos = len(pos_edges)
    n_val = int(m_pos * val_prop)
    n_test = int(m_pos * test_prop)
    val_edges, test_edges, train_edges = pos_edges[:n_val], pos_edges[n_val:n_test + n_val], pos_edges[n_test + n_val:]
    val_edges_false, test_edges_false = neg_edges[:n_val], neg_edges[n_val:n_test + n_val]
    train_edges_false = np.concatenate([neg_edges, val_edges, test_edges], axis=0)
    adj_train = sp.csr_matrix((np.ones(train_edges.shape[0]), (train_edges[:, 0], train_edges[:, 1])), shape=adj.shape)
    adj_train = adj_train + adj_train.T
    return adj_train, torch.LongTensor(train_edges), torch.LongTensor(train_edges_false), torch.LongTensor(val_edges), \
           torch.LongTensor(val_edges_false), torch.LongTensor(test_edges), torch.LongTensor(
            test_edges_false)


def split_data(labels, val_prop, test_prop, seed):
    np.random.seed(seed)
    nb_nodes = labels.shape[0]
    all_idx = np.arange(nb_nodes)
    pos_idx = labels.nonzero()[0]
    neg_idx = (1. - labels).nonzero()[0]
    np.random.shuffle(pos_idx)
    np.random.shuffle(neg_idx)
    pos_idx = pos_idx.tolist()
    neg_idx = neg_idx.tolist()
    nb_pos_neg = min(len(pos_idx), len(neg_idx))
    nb_val = round(val_prop * nb_pos_neg)
    nb_test = round(test_prop * nb_pos_neg)
    idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[nb_val + nb_test:]
    idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[nb_val + nb_test:]
    return idx_val_pos + idx_val_neg, idx_test_pos + idx_test_neg, idx_train_pos + idx_train_neg



from collections import Counter

def split_data2(labels, val_prop, test_prop, seed):
    np.random.seed(seed)
    nb_nodes = labels.shape[0]
    # print("nb_nodes ",labels.shape)
    # print(np.unique(np.array(labels)))
    all_idx = np.arange(nb_nodes)
    # pos_idx = labels.nonzero()[0]
    pos_idx = np.array([index for index, value in enumerate(labels) if value == 1])
    count_dict = Counter(labels)
    neg_idx = (1 - labels).nonzero()[0]
    np.random.shuffle(pos_idx)
    np.random.shuffle(neg_idx)
    pos_idx = pos_idx.tolist()
    neg_idx = neg_idx.tolist()
    nb_pos_neg = min(len(pos_idx), len(neg_idx))
    nb_val = round(val_prop * nb_pos_neg)
    nb_test = round(test_prop * nb_pos_neg)
    idx_val_pos, idx_test_pos, idx_train_pos = pos_idx[:nb_val], pos_idx[nb_val:nb_val + nb_test], pos_idx[
                                                                                                   nb_val + nb_test:]
    idx_val_neg, idx_test_neg, idx_train_neg = neg_idx[:nb_val], neg_idx[nb_val:nb_val + nb_test], neg_idx[
                                                                                                    nb_val + nb_test:]
    return idx_val_pos + idx_val_neg, idx_test_pos + idx_test_neg, idx_train_pos + idx_train_neg


def bin_feat(feat, bins):
    digitized = np.digitize(feat, bins)
    return digitized - digitized.min()


# ############### LINK PREDICTION DATA LOADERS ####################################


def load_data_lp(dataset, use_feats, data_path):
    if dataset in ['cora', 'pubmed']:
        adj, features = load_citation_data(dataset, use_feats, data_path)[:2]
    elif dataset == 'disease_lp':
        adj, features = load_synthetic_data(dataset, use_feats, data_path)[:2]
    elif dataset == 'airport':
        adj, features = load_data_airport(dataset, data_path, return_label=False)
    else:
        raise FileNotFoundError('Dataset {} is not supported.'.format(dataset))
    data = {'adj_train': adj, 'features': features}
    return data


# ############### NODE CLASSIFICATION DATA LOADERS ####################################


def load_data_nc(dataset, use_feats, data_path, split_seed):
    if dataset in ['cora', 'pubmed', 'citeseer']:
        adj, features, labels, idx_train, idx_val, idx_test = load_citation_data(
            dataset, use_feats, data_path, split_seed
        )
    else:
        if dataset == 'disease_nc':
            adj, features, labels = load_synthetic_data(dataset, use_feats, data_path)
            val_prop, test_prop = 0.10, 0.60
            idx_val, idx_test, idx_train = split_data(labels, val_prop, test_prop, seed=split_seed)
        elif dataset == 'airport':
            adj, features, labels = load_data_airport(dataset, data_path, return_label=True)
            val_prop, test_prop = 0.05, 0.10
            idx_val, idx_test, idx_train = split_data2(labels, val_prop, test_prop, seed=split_seed)
            print(len(idx_train))
            print(len(idx_val))
            print(len(idx_test))
        else:
            raise FileNotFoundError('Dataset {} is not supported.'.format(dataset))
    labels = torch.LongTensor(labels)
    data = {'adj_train': adj, 'features': features, 'labels': labels, 'idx_train': idx_train, 'idx_val': idx_val, 'idx_test': idx_test}
    return data


# ############### DATASETS ####################################

def loadRedditFromNPZ(dataset_dir):
    adj = sp.load_npz(dataset_dir+"reddit_adj.npz")
    data = np.load(dataset_dir+"reddit.npz")

    return adj, data['feats'], data['y_train'], data['y_val'], data['y_test'], data['train_index'], data['val_index'], data['test_index']

def load_reddit_data(data_path):
    adj, features, y_train, y_val, y_test, train_index, val_index, test_index = loadRedditFromNPZ(data_path)
    labels = np.zeros(adj.shape[0])
    labels[train_index]  = y_train
    labels[val_index]  = y_val
    labels[test_index]  = y_test
    adj = adj + adj.T # remove maybe?
    train_adj = adj[train_index, :][:, train_index]
    features = torch.tensor(np.array(features))
    features = (features-features.mean(dim=0))/features.std(dim=0)

    adj = aug_normalized_adjacency(adj)
    adj = sparse_mx_to_torch_sparse_tensor(adj)#.float()
    train_adj = aug_normalized_adjacency(train_adj)
    train_adj = sparse_mx_to_torch_sparse_tensor(train_adj)#.float()
    labels = torch.LongTensor(labels)

    data = {'adj_all': adj, 'adj_train': train_adj, 'features': features, 'labels': labels, 'idx_train': train_index, 'idx_val': val_index, 'idx_test': test_index}

    return data


def load_citation_data(dataset_str, use_feats, data_path, split_seed=None):
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        with open(os.path.join(data_path, "ind.{}.{}".format(dataset_str, names[i])), 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file(os.path.join(data_path, "ind.{}.test.index".format(dataset_str)))
    test_idx_range = np.sort(test_idx_reorder)

    if dataset_str == 'citeseer':
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range-min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range-min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]
    labels = np.argmax(labels, 1)

    idx_test = test_idx_range.tolist()
    idx_train = list(range(len(y)))
    idx_val = range(len(y), len(y) + 500)

    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph))
    if not use_feats:
        features = sp.eye(adj.shape[0])
    print(len(idx_train))
    print(len(idx_val))
    print(len(idx_test))
    return adj, features, labels, idx_train, idx_val, idx_test


def parse_index_file(filename):
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index


def load_synthetic_data(dataset_str, use_feats, data_path):
    object_to_idx = {}
    idx_counter = 0
    edges = []
    with open(os.path.join(data_path, "{}.edges.csv".format(dataset_str)), 'r') as f:
        all_edges = f.readlines()
    for line in all_edges:
        n1, n2 = line.rstrip().split(',')
        if n1 in object_to_idx:
            i = object_to_idx[n1]
        else:
            i = idx_counter
            object_to_idx[n1] = i
            idx_counter += 1
        if n2 in object_to_idx:
            j = object_to_idx[n2]
        else:
            j = idx_counter
            object_to_idx[n2] = j
            idx_counter += 1
        edges.append((i, j))
    adj = np.zeros((len(object_to_idx), len(object_to_idx)))
    for i, j in edges:
        adj[i, j] = 1.  # comment this line for directed adjacency matrix
        adj[j, i] = 1.
    if use_feats:
        features = sp.load_npz(os.path.join(data_path, "{}.feats.npz".format(dataset_str)))
    else:
        features = sp.eye(adj.shape[0])
    labels = np.load(os.path.join(data_path, "{}.labels.npy".format(dataset_str)))
    return sp.csr_matrix(adj), features, labels


def load_data_airport(dataset_str, data_path, return_label=False,num_nodes=3188):
    graph = pkl.load(open(os.path.join(data_path, dataset_str + '.p'), 'rb'))
    #graph = pkl.load(open(data_path, 'rb'))
    selected_nodes=list(graph.nodes)[:num_nodes]
    subgraph=graph.subgraph(selected_nodes)
    adj = nx.adjacency_matrix(subgraph)
    #features = np.array([graph.node[u]['feat'] for u in graph.nodes()])
    features = np.array([subgraph.nodes[u]['feat'] for u in subgraph.nodes()])
    if return_label:
        label_idx = 4
        labels = features[:, label_idx]
        features = features[:, :label_idx]
        labels = bin_feat(labels, bins=[7.0/7, 8.0/7, 9.0/7])
        return sp.csr_matrix(adj), sp.csr_matrix(features), labels
    else:
        return sp.csr_matrix(adj), sp.csr_matrix(features)

# ############### Loading ppi ####################################
# adapted from PetarV/GAT
def run_dfs(adj, msk, u, ind, nb_nodes):
    if msk[u] == -1:
        msk[u] = ind
        #for v in range(nb_nodes):
        for v in adj[u,:].nonzero()[1]:
            #if adj[u,v]== 1:
            run_dfs(adj, msk, v, ind, nb_nodes)

def dfs_split(adj):
    # Assume adj is of shape [nb_nodes, nb_nodes]
    nb_nodes = adj.shape[0]
    ret = np.full(nb_nodes, -1, dtype=np.int32)

    graph_id = 0

    for i in range(nb_nodes):
        if ret[i] == -1:
            run_dfs(adj, ret, i, graph_id, nb_nodes)
            graph_id += 1

    return ret

def test(adj, mapping):
    nb_nodes = adj.shape[0]
    for i in range(nb_nodes):
        #for j in range(nb_nodes):
        for j in adj[i, :].nonzero()[1]:
            if mapping[i] != mapping[j]:
              #  if adj[i,j] == 1:
                 return False
    return True

def find_split(adj, mapping, ds_label):
    nb_nodes = adj.shape[0]
    dict_splits={}
    for i in range(nb_nodes):
        #for j in range(nb_nodes):
        for j in adj[i, :].nonzero()[1]:
            if mapping[i]==0 or mapping[j]==0:
                dict_splits[0]=None
            elif mapping[i] == mapping[j]:
                if ds_label[i]['val'] == ds_label[j]['val'] and ds_label[i]['test'] == ds_label[j]['test']:

                    if mapping[i] not in dict_splits.keys():
                        if ds_label[i]['val']:
                            dict_splits[mapping[i]] = 'val'

                        elif ds_label[i]['test']:
                            dict_splits[mapping[i]]='test'

                        else:
                            dict_splits[mapping[i]] = 'train'

                    else:
                        if ds_label[i]['test']:
                            ind_label='test'
                        elif ds_label[i]['val']:
                            ind_label='val'
                        else:
                            ind_label='train'
                        if dict_splits[mapping[i]]!= ind_label:
                            print ('inconsistent labels within a graph exiting!!!')
                            return None
                else:
                    print ('label of both nodes different, exiting!!')
                    return None
    return dict_splits

def load_ppi(data_path):

    print ('Loading G...')
    with open(data_path + 'ppi-G.json') as jsonfile:
        g_data = json.load(jsonfile)
    # print (len(g_data))
    G = json_graph.node_link_graph(g_data)

    #Extracting adjacency matrix
    adj=nx.adjacency_matrix(G)

    prev_key=''
    for key, value in g_data.items():
        if prev_key!=key:
            # print (key)
            prev_key=key

    # print ('Loading id_map...')
    with open(data_path + 'ppi-id_map.json') as jsonfile:
        id_map = json.load(jsonfile)
    # print (len(id_map))

    id_map = {int(k):int(v) for k,v in id_map.items()}
    for key, value in id_map.items():
        id_map[key]=[value]
    # print (len(id_map))

    print ('Loading features...')
    features_=np.load(data_path + 'ppi-feats.npy')
    # print (features_.shape)

    #standarizing features
    from sklearn.preprocessing import StandardScaler

    train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']])
    train_feats = features_[train_ids[:,0]]
    scaler = StandardScaler()
    scaler.fit(train_feats)
    features_ = scaler.transform(features_)

    features = sp.csr_matrix(features_).tolil()


    print ('Loading class_map...')
    class_map = {}
    with open(data_path + 'ppi-class_map.json') as jsonfile:
        class_map = json.load(jsonfile)
    # print (len(class_map))

    #pdb.set_trace()
    #Split graph into sub-graphs
    # print ('Splitting graph...')
    splits=dfs_split(adj)

    #Rearrange sub-graph index and append sub-graphs with 1 or 2 nodes to bigger sub-graphs
    # print ('Re-arranging sub-graph IDs...')
    list_splits=splits.tolist()
    group_inc=1

    for i in range(np.max(list_splits)+1):
        if list_splits.count(i)>=3:
            splits[np.array(list_splits) == i] =group_inc
            group_inc+=1
        else:
            #splits[np.array(list_splits) == i] = 0
            ind_nodes=np.argwhere(np.array(list_splits) == i)
            ind_nodes=ind_nodes[:,0].tolist()
            split=None

            for ind_node in ind_nodes:
                if g_data['nodes'][ind_node]['val']:
                    if split is None or split=='val':
                        splits[np.array(list_splits) == i] = 21
                        split='val'
                    else:
                        raise ValueError('new node is VAL but previously was {}'.format(split))
                elif g_data['nodes'][ind_node]['test']:
                    if split is None or split=='test':
                        splits[np.array(list_splits) == i] = 23
                        split='test'
                    else:
                        raise ValueError('new node is TEST but previously was {}'.format(split))
                else:
                    if split is None or split == 'train':
                        splits[np.array(list_splits) == i] = 1
                        split='train'
                    else:
                        pdb.set_trace()
                        raise ValueError('new node is TRAIN but previously was {}'.format(split))

    #counting number of nodes per sub-graph
    list_splits=splits.tolist()
    nodes_per_graph=[]
    for i in range(1,np.max(list_splits) + 1):
        nodes_per_graph.append(list_splits.count(i))

    #Splitting adj matrix into sub-graphs
    subgraph_nodes=np.max(nodes_per_graph)
    adj_sub=np.empty((len(nodes_per_graph), subgraph_nodes, subgraph_nodes))
    feat_sub = np.empty((len(nodes_per_graph), subgraph_nodes, features.shape[1]))
    labels_sub = np.empty((len(nodes_per_graph), subgraph_nodes, 121))

    for i in range(1, np.max(list_splits) + 1):
        #Creating same size sub-graphs
        indexes = np.where(splits == i)[0]
        subgraph_=adj[indexes,:][:,indexes]

        if subgraph_.shape[0]<subgraph_nodes or subgraph_.shape[1]<subgraph_nodes:
            subgraph=np.identity(subgraph_nodes)
            feats=np.zeros([subgraph_nodes, features.shape[1]])
            labels=np.zeros([subgraph_nodes,121])
            #adj
            subgraph = sp.csr_matrix(subgraph).tolil()
            subgraph[0:subgraph_.shape[0],0:subgraph_.shape[1]]=subgraph_
            adj_sub[i-1,:,:]=subgraph.todense()
            #feats
            feats[0:len(indexes)]=features[indexes,:].todense()
            feat_sub[i-1,:,:]=feats
            #labels
            for j,node in enumerate(indexes):
                labels[j,:]=np.array(class_map[str(node)])
            labels[indexes.shape[0]:subgraph_nodes,:]=np.zeros([121])
            labels_sub[i - 1, :, :] = labels

        else:
            adj_sub[i - 1, :, :] = subgraph_.todense()
            feat_sub[i - 1, :, :]=features[indexes,:].todense()
            for j,node in enumerate(indexes):
                labels[j,:]=np.array(class_map[str(node)])
            labels_sub[i-1, :, :] = labels

    # Get relation between id sub-graph and tran,val or test set
    dict_splits = find_split(adj, splits, g_data['nodes'])

    # Testing if sub graphs are isolated
    # print ('Are sub-graphs isolated?')
    # print (test(adj, splits))

    #Splitting tensors into train,val and test
    train_split=[]
    val_split=[]
    test_split=[]
    for key, value in dict_splits.items():
        if dict_splits[key]=='train':
            train_split.append(int(key)-1)
        elif dict_splits[key] == 'val':
            val_split.append(int(key)-1)
        elif dict_splits[key] == 'test':
            test_split.append(int(key)-1)

    train_adj=adj_sub[train_split,:,:]
    val_adj=adj_sub[val_split,:,:]
    test_adj=adj_sub[test_split,:,:]

    train_feat=feat_sub[train_split,:,:]
    val_feat = feat_sub[val_split, :, :]
    test_feat = feat_sub[test_split, :, :]

    train_labels = labels_sub[train_split, :, :]
    val_labels = labels_sub[val_split, :, :]
    test_labels = labels_sub[test_split, :, :]

    train_nodes=np.array(nodes_per_graph[train_split[0]:train_split[-1]+1])
    val_nodes = np.array(nodes_per_graph[val_split[0]:val_split[-1]+1])
    test_nodes = np.array(nodes_per_graph[test_split[0]:test_split[-1]+1])


    #Masks with ones

    tr_msk = np.zeros((len(nodes_per_graph[train_split[0]:train_split[-1]+1]), subgraph_nodes))
    vl_msk = np.zeros((len(nodes_per_graph[val_split[0]:val_split[-1] + 1]), subgraph_nodes))
    ts_msk = np.zeros((len(nodes_per_graph[test_split[0]:test_split[-1]+1]), subgraph_nodes))

    for i in range(len(train_nodes)):
        for j in range(train_nodes[i]):
            tr_msk[i][j] = 1

    for i in range(len(val_nodes)):
        for j in range(val_nodes[i]):
            vl_msk[i][j] = 1

    for i in range(len(test_nodes)):
        for j in range(test_nodes[i]):
            ts_msk[i][j] = 1

    train_adj_list = []
    val_adj_list = []
    test_adj_list = []
    for i in range(train_adj.shape[0]):
        adj = sp.coo_matrix(train_adj[i])
        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
        tmp = aug_normalized_adjacency(adj)
        train_adj_list.append(sparse_mx_to_torch_sparse_tensor(tmp))
    for i in range(val_adj.shape[0]):
        adj = sp.coo_matrix(val_adj[i])
        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
        tmp = aug_normalized_adjacency(adj)
        val_adj_list.append(sparse_mx_to_torch_sparse_tensor(tmp))
        adj = sp.coo_matrix(test_adj[i])
        adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
        tmp = aug_normalized_adjacency(adj)
        test_adj_list.append(sparse_mx_to_torch_sparse_tensor(tmp))

    train_feat = torch.tensor(train_feat)
    val_feat = torch.tensor(val_feat)
    test_feat = torch.tensor(test_feat)

    train_labels = torch.tensor(train_labels)
    val_labels = torch.tensor(val_labels)
    test_labels = torch.tensor(test_labels)

    tr_msk = torch.LongTensor(tr_msk)
    vl_msk = torch.LongTensor(vl_msk)
    ts_msk = torch.LongTensor(ts_msk)

    return train_adj_list,val_adj_list,test_adj_list,train_feat,val_feat,test_feat,train_labels,val_labels, test_labels, train_nodes, val_nodes, test_nodes

# ############### Loading for TextHyLa ####################################
# adapted from Tiiiger/SGC
def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

def load_corpus(data_dir, dataset_str, inductive=False):
    """
    Loads input corpus from text/data directory

    ind.dataset_str.x => the feature vectors of the training docs as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.tx => the feature vectors of the test docs as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.allx => the feature vectors of both labeled and unlabeled training docs/words
        (a superset of ind.dataset_str.x) as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.y => the one-hot labels of the labeled training docs as numpy.ndarray object;
    ind.dataset_str.ty => the one-hot labels of the test docs as numpy.ndarray object;
    ind.dataset_str.ally => the labels for instances in ind.dataset_str.allx as numpy.ndarray object;
    ind.dataset_str.adj => adjacency matrix of word/doc nodes as scipy.sparse.csr.csr_matrix object;
    ind.dataset_str.train.index => the indices of training docs in original doc list.

    All objects above must be saved using python pickle module.

    :param dataset_str: Dataset name
    :return: All data input files loaded (as well the training/test data).
    """
    index_dict = {}
    label_dict = {}
    phases = ["train", "val", "test"]
    objects = []
    def load_pkl(path):
        with open(path.format(dataset_str, p), 'rb') as f:
            if sys.version_info > (3, 0):
                return pkl.load(f, encoding='latin1')
            else:
                return pkl.load(f)

    for p in phases:
        index_dict[p] = load_pkl("{}/ind.{}.{}.x".format(data_dir, dataset_str, p))
        label_dict[p] = load_pkl("{}/ind.{}.{}.y".format(data_dir, dataset_str, p))

    if inductive:
        adj = load_pkl("{}/ind.{}.B.adj".format(data_dir, dataset_str))
        adj = adj.astype(np.float32)
    else:
        adj = load_pkl("{}/ind.{}.BCD.adj".format(data_dir, dataset_str))
        adj = adj.astype(np.float32)
        adj = aug_normalized_adjacency(adj)

    return adj, index_dict, label_dict

def aug_normalized_adjacency(adj):
    adj = adj + sp.eye(adj.shape[0])
    adj = sp.coo_matrix(adj)
    row_sum = np.array(adj.sum(1))
    d_inv_sqrt = np.power(row_sum, -0.5).flatten()
    d_inv_sqrt[np.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = sp.diags(d_inv_sqrt)
    return d_mat_inv_sqrt.dot(adj).dot(d_mat_inv_sqrt).tocoo()

def loadWord2Vec(filename):
    """Read Word Vectors"""
    vocab = []
    embd = []
    word_vector_map = {}
    file = open(filename, 'r')
    for line in file.readlines():
        row = line.strip().split(' ')
        if(len(row) > 2):
            vocab.append(row[0])
            vector = row[1:]
            length = len(vector)
            for i in range(length):
                vector[i] = float(vector[i])
            embd.append(vector)
            word_vector_map[row[0]] = vector
    print('Loaded Word Vectors!')
    file.close()
    return vocab, embd, word_vector_map

def clean_str(string):
    string = re.sub(r'[?|$|.|!]',r'',string)
    string = re.sub(r'[^a-zA-Z0-9 ]',r'',string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip().lower()

def sparse_to_torch_sparse(sparse_mx, device='cuda'):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    if device == 'cuda':
        indices = indices.cuda()
        values = torch.from_numpy(sparse_mx.data).cuda()
        shape = torch.Size(sparse_mx.shape)
        adj = torch.cuda.sparse.FloatTensor(indices, values, shape)
    elif device == 'cpu':
        values = torch.from_numpy(sparse_mx.data)
        shape = torch.Size(sparse_mx.shape)
        adj = torch.sparse.FloatTensor(indices, values, shape)
    return adj

def sparse_to_torch_dense(sparse, device='cuda'):
    dense = sparse.todense().astype(np.float32)
    torch_dense = torch.from_numpy(dense).to(device=device)
    return torch_dense

def sgc_precompute_text(adj, features, degree, index_dict):
#     assert degree==1, "Only supporting degree 2 now"
    assert degree > 0, 'invalid degree as 0'
    feat_dict = {}
    start = perf_counter()
    train_feats = features[:, index_dict["train"]]#.cuda()
    #     nonzero_perc = []
    for i in range(degree):
        train_feats = torch.spmm(adj, train_feats)
#         number_nonzero = (features != 0).sum().item()
#         percentage = number_nonzero*1.0/features.size(0)/features.size(1)*100.0
#         nonzero_perc.append("%.2f" % percentage)
    train_feats = train_feats.t()
    train_feats_max, _ = train_feats.max(dim=0, keepdim=True)
    train_feats_min, _ = train_feats.min(dim=0, keepdim=True)
    train_feats_range = train_feats_max-train_feats_min
    useful_features_dim = train_feats_range.squeeze().gt(0).nonzero().squeeze()
    train_feats = train_feats[:, useful_features_dim]
    train_feats_range = train_feats_range[:, useful_features_dim]
    train_feats_min = train_feats_min[:, useful_features_dim]
    train_feats = (train_feats-train_feats_min)/train_feats_range
    feat_dict["train"] = train_feats.double()
    for phase in ["test", "val"]:
        feats = features[:, index_dict[phase]]#.cuda()
        feats = torch.spmm(adj, feats).t()
        feats = feats[:, useful_features_dim]
        feat_dict[phase] = ((feats-train_feats_min)/train_feats_range).cpu().double() # adj is symmetric!
    precompute_time = perf_counter()-start
    return feat_dict, precompute_time

def sgc_precompute_text_v1(adj, features, degree, index_dict):
    assert degree > 0, 'invalid degree as 0'
    feat_dict = {}
    start = perf_counter()
    for i in range(degree):
        features = torch.spmm(adj, features)
    train_feats = features[index_dict["train"], :].double()
    train_feats_max, _ = train_feats.max(dim=0, keepdim=True)
    train_feats_min, _ = train_feats.min(dim=0, keepdim=True)
    train_feats_range = train_feats_max-train_feats_min
    useful_features_dim = train_feats_range.squeeze().gt(0).nonzero().squeeze()
    train_feats = train_feats[:, useful_features_dim]
    train_feats_range = train_feats_range[:, useful_features_dim]
    train_feats_min = train_feats_min[:, useful_features_dim]
    feat_dict["train"] = (train_feats-train_feats_min)/train_feats_range
    for phase in ["test", "val"]:
        feats = features[index_dict[phase], :].double()
        feats = feats[:, useful_features_dim]
        feat_dict[phase] = ((feats-train_feats_min)/train_feats_range).cpu() # adj is symmetric!
    precompute_time = perf_counter()-start
    return feat_dict, precompute_time

def set_seed(seed, cuda):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if cuda: torch.cuda.manual_seed(seed)

def print_table(values, columns, epoch):
    table = tabulate.tabulate([values], columns, tablefmt='simple', floatfmt='8.4f')
    if epoch % 40 == 0:
        table = table.split('\n')
        table = '\n'.join([table[1]] + table)
    else:
        table = table.split('\n')[2]
    print(table)





###RSGD/SGD/PGD optimizer###

In [None]:
# @title
#!/usr/bin/env python3
# Copyright (c) 2018-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from torch.optim.optimizer import Optimizer, required
import tensorly as tl

class RiemannianSGD(Optimizer):
    r"""Riemannian stochastic gradient descent.

    Args:
        rgrad (Function): Function to compute the Riemannian gradient
           from the Euclidean gradient
        retraction (Function): Function to update the retraction
           of the Riemannian gradient
    """

    def __init__(
            self,
            params,
            lr=required,
            rgrad=required,
            expm=required,
    ):
        defaults = {
            'lr': lr,
            'rgrad': rgrad,
            'expm': expm,
        }
        super(RiemannianSGD, self).__init__(params, defaults)

    def step(self, lr=None, counts=None, **kwargs):
        """Performs a single optimization step.

        Arguments:
            lr (float, optional): learning rate for the current update.
        """
        loss = None

        for group in self.param_groups:
            for p in group['params']:
                lr = lr or group['lr']
                rgrad = group['rgrad']
                expm = group['expm']

                if p.grad is None:
                    continue
                d_p = p.grad.data
                # make sure we have no duplicates in sparse tensor
                if d_p.is_sparse:
                    d_p = d_p.coalesce()
                d_p = rgrad(p.data, d_p)
                d_p.mul_(-lr)
                expm(p.data, d_p)

        return loss


from torch.optim.sgd import SGD
from torch.optim.optimizer import required
from torch.optim import Optimizer
import torch
import sklearn
import numpy as np
import scipy.sparse as sp

class PGD(Optimizer):
    """Proximal gradient descent.

    Parameters
    ----------
    params : iterable
        iterable of parameters to optimize or dicts defining parameter groups
    proxs : iterable
        iterable of proximal operators
    alpha : iterable
        iterable of coefficients for proximal gradient descent
    lr : float
        learning rate
    momentum : float
        momentum factor (default: 0)
    weight_decay : float
        weight decay (L2 penalty) (default: 0)
    dampening : float
        dampening for momentum (default: 0)

    """

    def __init__(self, params, proxs, alphas, lr=required, momentum=0, dampening=0, weight_decay=0):
        defaults = dict(lr=lr, momentum=0, dampening=0,
                        weight_decay=0, nesterov=False)


        super(PGD, self).__init__(params, defaults)

        for group in self.param_groups:
            group.setdefault('proxs', proxs)
            group.setdefault('alphas', alphas)

    def __setstate__(self, state):
        super(PGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)
            group.setdefault('proxs', proxs)
            group.setdefault('alphas', alphas)

    def step(self, delta=0, closure=None):
         for group in self.param_groups:
            lr = group['lr']
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']
            proxs = group['proxs']
            alphas = group['alphas']

            # apply the proximal operator to each parameter in a group
            for param in group['params']:
                for prox_operator, alpha in zip(proxs, alphas):
                    # param.data.add_(lr, -param.grad.data)
                    # param.data.add_(delta)
                    param.data = prox_operator(param.data, alpha=alpha*lr)


class ProxOperators():
    """Proximal Operators.
    """

    def __init__(self):
        self.nuclear_norm = None

    def prox_l1(self, data, alpha):
        """Proximal operator for l1 norm.
        """
        data = torch.mul(torch.sign(data), torch.clamp(torch.abs(data)-alpha, min=0))
        return data

    def prox_nuclear(self, data, alpha):
        """Proximal operator for nuclear norm (trace norm).
        """
        device = data.device
        U, S, V = np.linalg.svd(data.cpu())
        U, S, V = torch.FloatTensor(U).to(device), torch.FloatTensor(S).to(device), torch.FloatTensor(V).to(device)
        self.nuclear_norm = S.sum()
        # print("nuclear norm: %.4f" % self.nuclear_norm)

        diag_S = torch.diag(torch.clamp(S-alpha, min=0))
        return torch.matmul(torch.matmul(U, diag_S), V)

    def prox_nuclear_truncated_2(self, data, alpha, k=50):
        device = data.device
        tl.set_backend('pytorch')
        U, S, V = tl.truncated_svd(data.cpu(), n_eigenvecs=k)
        U, S, V = torch.FloatTensor(U).to(device), torch.FloatTensor(S).to(device), torch.FloatTensor(V).to(device)
        self.nuclear_norm = S.sum()
        # print("nuclear norm: %.4f" % self.nuclear_norm)

        S = torch.clamp(S-alpha, min=0)

        # diag_S = torch.diag(torch.clamp(S-alpha, min=0))
        # U = torch.spmm(U, diag_S)
        # V = torch.matmul(U, V)

        # make diag_S sparse matrix
        indices = torch.tensor((range(0, len(S)), range(0, len(S)))).to(device)
        values = S
        diag_S = torch.sparse.FloatTensor(indices, values, torch.Size((len(S), len(S))))
        V = torch.spmm(diag_S, V)
        V = torch.matmul(U, V)
        return V

    def prox_nuclear_truncated(self, data, alpha, k=50):
        device = data.device
        indices = torch.nonzero(data).t()
        values = data[indices[0], indices[1]] # modify this based on dimensionality
        data_sparse = sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()))
        U, S, V = sp.linalg.svds(data_sparse, k=k)
        U, S, V = torch.FloatTensor(U).to(device), torch.FloatTensor(S).to(device), torch.FloatTensor(V).to(device)
        self.nuclear_norm = S.sum()
        diag_S = torch.diag(torch.clamp(S-alpha, min=0))
        return torch.matmul(torch.matmul(U, diag_S), V)

    def prox_nuclear_cuda(self, data, alpha):

        device = data.device
        U, S, V = torch.svd(data)
        # self.nuclear_norm = S.sum()
        # print(f"rank = {len(S.nonzero())}")
        self.nuclear_norm = S.sum()
        S = torch.clamp(S-alpha, min=0)
        indices = torch.tensor([range(0, U.shape[0]),range(0, U.shape[0])]).to(device)
        values = S
        diag_S = torch.sparse.FloatTensor(indices, values, torch.Size(U.shape))
        # diag_S = torch.diag(torch.clamp(S-alpha, min=0))
        # print(f"rank_after = {len(diag_S.nonzero())}")
        V = torch.spmm(diag_S, V.t_())
        V = torch.matmul(U, V)
        return V


class SGD(Optimizer):


    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError("Invalid weight_decay value: {}".format(weight_decay))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError("Nesterov momentum requires a momentum and zero dampening")
        super(SGD, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SGD, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                if weight_decay != 0:
                    d_p.add_(weight_decay, p.data)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(1 - dampening, d_p)
                    if nesterov:
                        d_p = d_p.add(momentum, buf)
                    else:
                        d_p = buf

                p.data.add_(-group['lr'], d_p)

        return loss

prox_operators = ProxOperators()

##Pro-GNN##

In [None]:
# @title
import time
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# from deeprobust.graph.utils import accuracy
# from deeprobust.graph.defense.pgd import PGD, prox_operators
import warnings

class ProGNN:
    """ ProGNN (Properties Graph Neural Network). See more details in Graph Structure Learning for Robust Graph Neural Networks, KDD 2020, https://arxiv.org/abs/2005.10203.

    Parameters
    ----------
    model:
        model: The backbone GNN model in ProGNN
    args:
        model configs
    device: str
        'cpu' or 'cuda'.

    Examples
    --------
    See details in https://github.com/ChandlerBang/Pro-GNN.

    """

    def __init__(self, model_f, model_c, args, device):
        self.device = device
        self.args = args
        self.best_val_acc = 0
        self.best_val_loss = 10
        self.best_graph = None
        self.weights_f = None
        self.weights_c = None
        self.estimator = None
        self.model_f = model_f.to(device)
        self.model_c = model_c.to(device)


    def fit(self, features, adj, labels, idx_train, idx_val, opt,is_airport=False, title="", **kwargs ): ########################added opt as an argument
        """Train Pro-GNN.

        Parameters
        ----------
        features :
            node features
        adj :
            the adjacency matrix. The format could be torch.tensor or scipy matrix
        labels :
            node labels
        idx_train :
            node training indices
        idx_val :
            node validation indices
        """
        args = self.args

        #self.optimizer = optim.Adam(self.model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        self.optimizer_f = RiemannianSGD(self.model_f.optim_params(), lr=opt.lr_e) ####################### new optimizer for model_f
        self.optimizer_c = torch.optim.Adam(self.model_c.parameters(), lr=opt.lr_c) ###################### new optimizer for model_c

        estimator = EstimateAdj(adj, symmetric=args.symmetric, device=self.device).to(self.device)
        self.estimator = estimator
        self.optimizer_adj = optim.SGD(estimator.parameters(),
                              momentum=0.9, lr=args.lr_adj)

        self.optimizer_l1 = PGD(estimator.parameters(),
                        proxs=[prox_operators.prox_l1],
                        lr=args.lr_adj, alphas=[args.alpha])

        # warnings.warn("If you find the nuclear proximal operator runs too slow on Pubmed, you can  uncomment line 67-71 and use prox_nuclear_cuda to perform the proximal on gpu.")
        # if args.dataset == "pubmed":
        #     self.optimizer_nuclear = PGD(estimator.parameters(),
        #               proxs=[prox_operators.prox_nuclear_cuda],
        #               lr=args.lr_adj, alphas=[args.beta])
        # else:
        warnings.warn("If you find the nuclear proximal operator runs too slow, you can modify line 77 to use prox_operators.prox_nuclear_cuda instead of prox_operators.prox_nuclear to perform the proximal on GPU. See details in https://github.com/ChandlerBang/Pro-GNN/issues/1")
        self.optimizer_nuclear = PGD(estimator.parameters(),
                  proxs=[prox_operators.prox_nuclear_cuda],
                  lr=args.lr_adj, alphas=[args.beta])

        # Train model
        t_total = time.time()
        for epoch in range(args.epochs):
            #self, features, labels, idx_test,adj, order
          """
            if args.only_gcn:
                self.train_gcn(epoch, features, estimator.estimated_adj,
                        labels, idx_train, idx_val)
            else:
          """
          for i in range(int(args.outer_steps)):
              self.train_adj(epoch, features, adj, labels,
                      idx_train, idx_val,opt, is_airport)

          for i in range(int(args.inner_steps)):
              self.train_gcn(epoch, features, estimator.estimated_adj,
                      labels, idx_train, idx_val,opt,is_airport)

        # filename = "cleaned_adj"+title+".npy"
        # np.save(filename, self.estimator.normalize().detach().cpu())
        self.save_hyla_features(title)
        print("Optimization Finished!")
        print("Total time elapsed: {:.4f}s".format(time.time() - t_total))
        print(args)

        # Testing
        # #print("picking the best model according to validation performance")
        # self.model_f.load_state_dict(self.weights_f)
        # self.model_c.load_state_dict(self.weights_c)

    def train_gcn(self, epoch, features, adj, labels, idx_train, idx_val, opt,is_airport): #added opt argument!
        self.hyla_features_list = []
        args = self.args
        estimator = self.estimator
        adj = estimator.normalize()
        t = time.time()
        self.model_f.train()
        self.model_c.train()
        self.optimizer_f.zero_grad()
        self.optimizer_c.zero_grad()
        HyLa_features = self.model_f()
        if (is_airport):
          features = adj
          new_features, nonzero_perc = sgc_precompute(adj, features, opt.order-1)
          features_train = new_features[idx_train]
        else:
          new_features, nonzero_perc = sgc_precompute(adj, features, opt.order)
          features_train = new_features[idx_train]

        ###############till here################################################
        HyLa_features =torch.mm(features_train.to(opt.device), HyLa_features)
        predictions = self.model_c(HyLa_features)  #predict using model_c (whose input is the output of hyla)
        del HyLa_features #delete intermediate hyla features to free up memory
        loss_train = F.cross_entropy(predictions, labels[idx_train].to(opt.device))
        # Backpropagate the gradients and perform a step of optimization for both models
        loss_train.backward()
        self.optimizer_f.step()
        self.optimizer_c.step()
        acc_train, f1_train = acc_f1(predictions, labels[idx_train].to(opt.device))

        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        self.model_f.eval()
        self.model_c.eval()
        # Obtain HyLa_features by applying model_f to the input features
        HyLa_features = self.model_f()
        self.hyla_features_list.append(HyLa_features.detach().cpu().numpy())

        features_val= new_features[idx_val]
        HyLa_features = torch.mm(features_val.to(HyLa_features.device), HyLa_features)
        predictions = self.model_c(HyLa_features) # Make predictions using model_c on the transformed features
        del HyLa_features
        acc, f1 = acc_f1(predictions, labels[idx_val].to(opt.device)) # Calculate accuracy and F1 score using the acc_f1 function
        loss_val = F.cross_entropy(predictions, labels[idx_val].to(opt.device))
        acc_val,f1_val = acc_f1(predictions, labels[idx_val].to(opt.device))

        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = adj.detach()
            self.weights_f = deepcopy(self.model_f.state_dict())
            self.weights_c = deepcopy(self.model_c.state_dict())
            if args.debug:
                print('\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item())

        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = adj.detach()
            self.weights_f = deepcopy(self.model_f.state_dict())
            self.weights_c = deepcopy(self.model_c.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item())

        #if args.debug:

        print('Epoch: {:04d}'.format(epoch+1),
              'loss_train: {:.4f}'.format(loss_train.item()),
              'acc_train: {:.4f}'.format(acc_train.item()),
              'loss_val: {:.4f}'.format(loss_val.item()),
              'acc_val: {:.4f}'.format(acc_val.item()),
              'time: {:.4f}s'.format(time.time() - t))

    def save_hyla_features(self, title=""):
        # Save the list containing all HyLa features as a numpy array
        hyla_features_array = np.concatenate(self.hyla_features_list, axis=0)
        filename = "hyla_features" + title + ".npy"
        np.save(filename, hyla_features_array)

    def train_adj(self, epoch, features, adj, labels, idx_train, idx_val, opt,is_airport): #added another argument (opt)
        estimator = self.estimator
        args = self.args
        if args.debug:
            print("\n=== train_adj ===")
        t = time.time()
        estimator.train()
        self.optimizer_adj.zero_grad()

        loss_l1 = torch.norm(estimator.estimated_adj, 1)
        loss_fro = torch.norm(estimator.estimated_adj - adj, p='fro')
        normalized_adj = estimator.normalize()

        if args.lambda_:
            loss_smooth_feat = self.feature_smoothing(estimator.estimated_adj, features)
        else:
            loss_smooth_feat = 0 * loss_l1

        self.model_f.eval()
        self.model_c.eval()
        # Obtain HyLa_features by applying model_f to the input features
        HyLa_features = self.model_f()
        if (is_airport):
          features = adj
          new_features, nonzero_perc = sgc_precompute(adj, features, opt.order-1)
          features_train = new_features[idx_train]
        else:
          new_features, nonzero_perc = sgc_precompute(adj, features, opt.order)
          features_train = new_features[idx_train]

        HyLa_features = torch.mm(new_features[idx_train].to(HyLa_features.device), HyLa_features)
        output = self.model_c(HyLa_features) # Make predictions using model_c on the transformed features
        del HyLa_features

        loss_gcn = F.cross_entropy(output, labels[idx_train].to(opt.device))
        acc_train, f1_train = acc_f1(output, labels[idx_train])

        loss_symmetric = torch.norm(estimator.estimated_adj \
                        - estimator.estimated_adj.t(), p="fro")

        loss_diffiential =  loss_fro + args.gamma * loss_gcn + args.lambda_ * loss_smooth_feat + args.phi * loss_symmetric

        loss_diffiential.backward()

        self.optimizer_adj.step()
        loss_nuclear =  0 * loss_fro
        if args.beta != 0:
            self.optimizer_nuclear.zero_grad()
            self.optimizer_nuclear.step()
            loss_nuclear = prox_operators.nuclear_norm

        self.optimizer_l1.zero_grad()
        self.optimizer_l1.step()

        total_loss = loss_fro \
                    + args.gamma * loss_gcn \
                    + args.alpha * loss_l1 \
                    + args.beta * loss_nuclear \
                    + args.phi * loss_symmetric

        estimator.estimated_adj.data.copy_(torch.clamp(
                  estimator.estimated_adj.data, min=0, max=1))

        # Evaluate validation set performance separately,
        # deactivates dropout during validation run.
        normalized_adj = estimator.normalize()
        if (is_airport):
          features = normalized_adj
          new_features, nonzero_perc = sgc_precompute(normalized_adj, features, opt.order-1)
          features_train = new_features[idx_train]
        else:
          new_features, nonzero_perc = sgc_precompute(normalized_adj, features, opt.order)
          features_train = new_features[idx_train]
        self.model_f.eval()
        self.model_c.eval()
        HyLa_features = self.model_f()
        HyLa_features = torch.mm(new_features[idx_val].to(HyLa_features.device), HyLa_features)
        output = self.model_c(HyLa_features)

        loss_val = F.cross_entropy(output, labels[idx_val].to(opt.device))
        acc_val,f1_val = acc_f1(output, labels[idx_val].to(opt.device))
        if (args.debug):
          print('Epoch: {:04d}'.format(epoch+1),
                'acc_train: {:.4f}'.format(acc_train.item()),
                'loss_val: {:.4f}'.format(loss_val.item()),
                'acc_val: {:.4f}'.format(acc_val.item()),
                'time: {:.4f}s'.format(time.time() - t))

        if acc_val > self.best_val_acc:
            self.best_val_acc = acc_val
            self.best_graph = normalized_adj.detach()
            self.weights_f = deepcopy(self.model_f.state_dict())
            self.weights_c = deepcopy(self.model_c.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_acc: %s' % self.best_val_acc.item())

        if loss_val < self.best_val_loss:
            self.best_val_loss = loss_val
            self.best_graph = normalized_adj.detach()
            self.weights_f = deepcopy(self.model_f.state_dict())
            self.weights_c = deepcopy(self.model_c.state_dict())
            if args.debug:
                print(f'\t=== saving current graph/gcn, best_val_loss: %s' % self.best_val_loss.item())

        if args.debug:
            if epoch % 1 == 0:
                print('Epoch: {:04d}'.format(epoch+1),
                      'loss_fro: {:.4f}'.format(loss_fro.item()),
                      'loss_gcn: {:.4f}'.format(loss_gcn.item()),
                      'loss_feat: {:.4f}'.format(loss_smooth_feat.item()),
                      'loss_symmetric: {:.4f}'.format(loss_symmetric.item()),
                      'delta_l1_norm: {:.4f}'.format(torch.norm(estimator.estimated_adj-adj, 1).item()),
                      'loss_l1: {:.4f}'.format(loss_l1.item()),
                      'loss_total: {:.4f}'.format(total_loss.item()),
                      'loss_nuclear: {:.4f}'.format(loss_nuclear.item()))


    def test(self, features, labels, idx_test,adj, order, is_airport=False, title=""):
        """Evaluate the performance of ProGNN on test set
        """
        print("\t=== testing ===")
        with torch.no_grad():
          # Set both model_f and model_c to evaluation mode
          self.model_f.eval()
          self.model_c.eval()
          # Obtain HyLa_features by applying model_f to the input features
          HyLa_features = self.model_f()
          normalized_adj = self.estimator.normalize()
          if (is_airport):
            features = normalized_adj
            new_features, nonzero_perc = sgc_precompute(normalized_adj, features, order-1)
          else:
            new_features, nonzero_perc = sgc_precompute(normalized_adj, features, order)
          HyLa_features = torch.mm(new_features[idx_test].to(HyLa_features.device), HyLa_features)
          predictions = self.model_c(HyLa_features) # Make predictions using model_c on the transformed features
          numpy_array = HyLa_features.cpu().numpy()
          # filename = "prohyla_features"+title+".npy"
          # np.save(filename, numpy_array)
          del HyLa_features
          acc, f1 = acc_f1(predictions, labels[idx_test]) # Calculate accuracy and F1 score using the acc_f1 function
        return acc,f1

    def feature_smoothing(self, adj, X):
        adj = (adj.t() + adj)/2
        rowsum = adj.sum(1)
        r_inv = rowsum.flatten()
        D = torch.diag(r_inv)
        L = D - adj

        L = L.to(torch.float64)
        D = D.to(torch.float64)
        r_inv = r_inv.to(torch.float64)


        r_inv = r_inv  + 1e-3
        r_inv = r_inv.pow(-1/2).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv)
        # L = r_mat_inv @ L

        L = r_mat_inv @ L @ r_mat_inv
        XLXT = torch.matmul(torch.matmul(X.t(), L), X)
        loss_smooth_feat = torch.trace(XLXT)
        return loss_smooth_feat


class EstimateAdj(nn.Module):
    """Provide a pytorch parameter matrix for estimated
    adjacency matrix and corresponding operations.
    """

    def __init__(self, adj, symmetric=False, device='cpu'):
        super(EstimateAdj, self).__init__()
        n = len(adj)
        self.estimated_adj = nn.Parameter(torch.FloatTensor(n, n))
        self._init_estimation(adj)
        self.symmetric = symmetric
        self.device = device

    def _init_estimation(self, adj):
        with torch.no_grad():
            n = len(adj)
            self.estimated_adj.data.copy_(adj)

    def forward(self):
        return self.estimated_adj

    def normalize(self):

        if self.symmetric:
            adj = (self.estimated_adj + self.estimated_adj.t())/2
        else:
            adj = self.estimated_adj

        normalized_adj = self._normalize(adj + torch.eye(adj.shape[0]).to(self.device))
        return normalized_adj

    def _normalize(self, mx):
        rowsum = mx.sum(1)
        r_inv = rowsum.pow(-1/2).flatten()
        r_inv[torch.isinf(r_inv)] = 0.
        r_mat_inv = torch.diag(r_inv)
        mx = r_mat_inv @ mx
        mx = mx @ r_mat_inv
        return mx

##########################see loss details, merge!

In [None]:
        # ###############pasting here###########################################
        # if (is_tsne):
        #   tsne_model = TSNE(n_components=2, random_state=42)
        #   # Fit and transform your data using t-SNE
        #   X_tsne = tsne_model.fit_transform(new_features)
        #   colors = plt.cm.tab10(labels)
        #   # Plot the results with colors based on labels
        #   plt.title('Raw Input Features visualized using t-SNE')
        #   plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=colors)
        #   plt.savefig(f'/content/raw_features.jpg', bbox_inches='tight', dpi=600)
        #   plt.show()

        # #print(HyLa_features.shape)
        # if (epoch==opt.epochs-1 and is_tsne):
        #   #related to t_sne
        #   HyLa_features2 = self.model_f()
        #   HyLa_features2 = torch.mm(new_features.to(opt.device), HyLa_features2)
        #   X_tsne = tsne_model.fit_transform(HyLa_features2.cpu().detach().numpy())
        #   plt.title('HyLa output Features visualized using t-SNE')
        #   plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=colors)
        #   plt.savefig(f'/content/hyla_features.jpg', bbox_inches='tight', dpi=600)
        #   plt.show()

### Models ###

In [None]:
# @title
import torch
import torch.nn as nn
from torch.nn import Module
import torch.nn.functional as F
import math
import numpy as np

class HyLa(nn.Module):
    def __init__(self, manifold, dim, size, HyLa_fdim, scale=0.1, sparse=False, **kwargs):
        super(HyLa, self).__init__()
        self.manifold = manifold
        self.lt = manifold.allocate_lt(size, dim, sparse)
        self.manifold.init_weights(self.lt)
        self.dim = dim
        self.Lambdas = scale * torch.randn(HyLa_fdim)
        self.boundary = sample_boundary(HyLa_fdim, self.dim, cls='RandomUniform')
        self.bias = 2 * np.pi * torch.rand(HyLa_fdim)

    def forward(self):
        with torch.no_grad():
            e_all = self.manifold.normalize(self.lt.weight)
        PsK = PoissonKernel(e_all, self.boundary.to(e_all.device))
        angles = self.Lambdas.to(e_all.device)/2.0 * torch.log(PsK)
        eigs = torch.cos(angles + self.bias.to(e_all.device)) * torch.sqrt(PsK)**(self.dim-1)
        return eigs

    def optim_params(self):
        return [{
            'params': self.lt.parameters(),
            'rgrad': self.manifold.rgrad,
            'expm': self.manifold.expm,
            'logm': self.manifold.logm,
            'ptransp': self.manifold.ptransp,
        }]


class RFF(nn.Module):
    def __init__(self, manifold, dim, size, HyLa_fdim, scale=0.1, sparse=False, **kwargs):
        super(RFF, self).__init__()
        self.manifold = manifold
        self.lt = manifold.allocate_lt(size, dim, sparse)
        self.manifold.init_weights(self.lt)
        self.norm = 1. / np.sqrt(dim)
        self.Lambdas = nn.Parameter(torch.from_numpy(np.random.normal(loc=0, scale=scale, size=(dim, HyLa_fdim))), requires_grad=False)
        self.bias = nn.Parameter(torch.from_numpy(np.random.uniform(0, 2 * np.pi, size=HyLa_fdim)),requires_grad=False)

    def forward(self):
        with torch.no_grad():
            e_all = self.manifold.normalize(self.lt.weight)
        features = self.norm * np.sqrt(2) * torch.cos(e_all @ self.Lambdas + self.bias)
        return features

    def optim_params(self):
        return [{
            'params': self.lt.parameters(),
            'rgrad': self.manifold.rgrad,
            'expm': self.manifold.expm,
            'logm': self.manifold.logm,
            'ptransp': self.manifold.ptransp,
        }]

class SGC(nn.Module):
    """
    A Simple PyTorch Implementation of Logistic Regression.
    Assuming the features have been preprocessed with k-step graph propagation.
    """
    def __init__(self, nfeat, nclass):
        super(SGC, self).__init__()

        self.W = nn.Linear(nfeat, nclass)

    def forward(self, x):
        return self.W(x)

def build_model(opt, N):
    if isinstance(opt, argparse.Namespace):
        opt = vars(opt)
    manifold = MANIFOLDS[opt['manifold']](K=None)
    return MODELS[opt['model']](
        manifold,
        dim=opt['he_dim'],
        size=N,
        HyLa_fdim=opt['hyla_dim'],
        scale=opt['lambda_scale'],
        sparse=opt['sparse'],
    )

def get_model(model_opt, nfeat, nclass, adj=None, dropout=0.0):
    if model_opt == "SGC":
        model = SGC(nfeat=nfeat,
                    nclass=nclass)
    else:
        raise NotImplementedError('model:{} is not implemented!'.format(model_opt))
    return model

MANIFOLDS = {
    'lorentz': LorentzManifold,
    'poincare': PoincareManifold,
    'euclidean': EuclideanManifold,
}

MODELS = {
    'hyla': HyLa,
    'rff': RFF,
}

##Mounting google drive for ease of accesing datasets##

In [None]:
# @title
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


**Training loop**



*   **Generate_ckpt:** This function is designed to create or load a checkpoint for a machine learning model. It initializes or loads the checkpoint, updates the model's state, and adjusts the starting epoch based on the checkpoint information. The specific behavior of the checkpointing process depends on the implementation details of the LocalCheckpoint class, which is not provided in the code snippet.
*   **Test_regression(model_f, model_c, features, test_labels, test_index=None, metric='acc'):** Perform predictions on features (after converting raw features to hyla output) then uses model_c to to perform classification and returns accuracy/F1 score.
* **train(model_f,model_c,optimizer_f,optimizer_c,data,opt,log,progress=False,ckps=None,is_tsne=False):** model_f is the hyla model and model_c is the classification model (here SGC), data is a dictionary that contains node features, adjacency matrix, training, val and test sets. is_tsne by default is false because it is computationally expensive, when set true, gives the tsne visualizations of hyla features and raw input features



In [None]:
# @title
import os
import sys
import inspect
import numpy as np
import torch
import logging
import argparse
import json
import torch.nn.functional as F
import timeit
import gc
from sklearn.metrics import roc_auc_score, average_precision_score
import matplotlib
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE

def generate_ckpt(opt, model, path):
    checkpoint = LocalCheckpoint(
            path,
            include_in_all={'conf' : vars(opt)},
            start_fresh=opt.fresh
        )
    # get state from checkpoint
    state = checkpoint.initialize({'epoch': 0, 'model': model.state_dict()})
    model.load_state_dict(state['model'])
    opt.epoch_start = state['epoch']
    return checkpoint

def test_regression(model_f, model_c, features, test_labels, test_index=None, metric='acc', title=""):
    with torch.no_grad():
        # Set both model_f and model_c to evaluation mode
        model_f.eval()
        model_c.eval()
        # Obtain HyLa_features by applying model_f to the input features
        HyLa_features = model_f()
        HyLa_features = torch.mm(features.to(HyLa_features.device), HyLa_features)
        predictions = model_c(HyLa_features) # Make predictions using model_c on the transformed features
        numpy_array = HyLa_features.cpu().numpy()
        s = "hyla_features"+title+".npy"
        np.save(s, numpy_array)
        del HyLa_features
        acc, f1 = acc_f1(predictions, test_labels) # Calculate accuracy and F1 score using the acc_f1 function
    return acc,f1

def train(model_f,
          model_c,
          optimizer_f,
          optimizer_c,
          data,
          opt,
          log,
          progress=False,
          ckps=None,
          is_tsne=True,
):

    model_f.train()
    model_c.train()
    val_acc_best = 0.0
    train_acc_best = 0.0
    aux_var = False
    epochs = []
    training_curve = []
    val_curve = []
    # Create a t-SNE model
    if (is_tsne):
      tsne_model = TSNE(n_components=2, random_state=42)
      # Fit and transform your data using t-SNE
      X_tsne = tsne_model.fit_transform(data['features'])
      colors = plt.cm.tab10(data['labels'])
      # Plot the results with colors based on labels
      plt.title('Raw Input Features visualized using t-SNE')
      plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=colors)
      plt.savefig(f'/content/raw_features.jpg', bbox_inches='tight', dpi=600)
      plt.show()
    for epoch in range(opt.epoch_start, opt.epochs):
        t_start = timeit.default_timer()
        # Zero out gradients for both optimizer_f and optimizer_c
        optimizer_f.zero_grad()
        optimizer_c.zero_grad()
        HyLa_features = model_f()
        HyLa_features = torch.mm(data['features_train'].to(opt.device), HyLa_features)
        #print(HyLa_features.shape)
        if (epoch==opt.epochs-1 and is_tsne):
          #related to t_sne
          HyLa_features2 = model_f()
          HyLa_features2 = torch.mm(data['features'].to(opt.device), HyLa_features2)
          X_tsne = tsne_model.fit_transform(HyLa_features2.cpu().detach().numpy())
          plt.title('HyLa output Features visualized using t-SNE')
          plt.scatter(X_tsne[:, 0], X_tsne[:, 1], c=colors)
          plt.savefig(f'/content/hyla_features.jpg', bbox_inches='tight', dpi=600)
          plt.show()
          HyLa_features = model_f()
          HyLa_features = torch.mm(data['features_train'].to(opt.device), HyLa_features)
        predictions = model_c(HyLa_features)  #predict using model_c (whose input is the output of hyla)
        del HyLa_features #delete intermediate hyla features to free up memory
        loss = F.cross_entropy(predictions, data['labels'][data['idx_train']].to(opt.device))
        # Backpropagate the gradients and perform a step of optimization for both models
        loss.backward()
        optimizer_f.step()
        optimizer_c.step()
        #update the metrics!
        train_acc,train_f1 = test_regression(
            model_f, model_c, data['features_train'], data['labels'][data['idx_train']].to(opt.device), metric = opt.metric)
        val_acc,val_f1 = test_regression(model_f, model_c, data['features'][data['idx_val']],
                                  data['labels'][data['idx_val']].to(opt.device), metric = opt.metric)
        epochs.append(epoch)
        training_curve.append(train_acc*100.0)
        val_curve.append(val_acc*100.0)
        if val_acc>val_acc_best:
            val_acc_best = val_acc
            if ckps is not None:
                ckps[0].save({
                'model': model_f.state_dict(),
                'epoch': epoch,
                'val_acc_best': val_acc_best,
                })
                ckps[1].save({
                'model': model_c.state_dict(),
                'epoch': epoch,
                'val_acc_best': val_acc_best,
                })
        if train_acc>train_acc_best:
            train_acc_best = train_acc
        if progress:
            log.info(
                'running stats: {'
                f'"epoch": {epoch}, '
                f'"elapsed": {timeit.default_timer()-t_start:.2f}, '
                f'"train_acc": {train_acc*100.0:.2f}%, '
                f'"val_acc": {val_acc*100.0:.2f}%, '
                f'"loss_c": {loss.cpu().item():.4f}, '
                '}'
            )
        gc.collect()
        torch.cuda.empty_cache()
    """
    plt.figure(figsize=(8, 5))
    plt.plot(epochs, training_curve, label='Training Accuracy', marker='o')
    plt.plot(epochs, val_curve, label='Validation Accuracy', marker='o')

    # Adding labels and title
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training Curve')
    plt.legend()  # Display legend

    # Display the plot
    plt.grid(True)
    plt.show()
    print(model_f)
    print(model_c)
    """
    return train_acc, train_acc_best, val_acc, val_acc_best

##Main HyLa Function##
  * Add all the hyperparameters using new_args

  * want_tsne (2nd argument of main) (default False), when set true, shows tsne visulization of raw input features and hyla output features

In [None]:
# @title
#!/usr/bin/env python3

def are_tensors_equal(tensor1, tensor2):
    # Convert tensors to NumPy arrays
    array1 = np.array(tensor1)
    array2 = np.array(tensor2)

    # Check if the arrays are equal
    are_equal = np.array_equal(array1, array2)

    return are_equal

def main(new_args, want_tsne=False, adj_arg = None,feat_arg = None, labels = None, idx_test = None,idx_train = None,idx_val = None, proGNN=False, prognn_args = None, title = ""):
    parser = argparse.ArgumentParser(description='Train HyLa-SGC for node classification tasks')
    parser.add_argument('-checkpoint', action='store_true', default=False)
    parser.add_argument('-task', type=str, default='nc', help='learning task')
    parser.add_argument('-dataset', type=str, required=False, default = 'cora',
                        help='Dataset identifier [cora|disease_nc|pubmed|citeseer|reddit|airport]')
    parser.add_argument('-he_dim', type=int, default=2,
                        help='Hyperbolic Embedding dimension')
    parser.add_argument('-hyla_dim', type=int, default=100,
                        help='HyLa feature dimension')
    parser.add_argument('-order', type=int, default=2,
                        help='order of adjaceny matrix in SGC precomputation')
    parser.add_argument('-manifold', type=str, default='poincare',
                        choices=MANIFOLDS.keys(), help='model of hyperbolic space')
    parser.add_argument('-model', type=str, default='hyla',
                        choices=MODELS.keys(), help='feature model class, hyla|rff')
    parser.add_argument('-lr_e', type=float, default=0.1,
                        help='Learning rate for hyperbolic embedding')
    parser.add_argument('-lr_c', type=float, default=0.1,
                        help='Learning rate for the classifier SGC')
    parser.add_argument('-epochs', type=int, default=100,
                        help='Number of epochs')
    parser.add_argument('-strategy', type=int, default=0,
                        help='Epochs of burn in, some advanced definition')
    parser.add_argument('-eval_each', type=int, default=1,
                        help='Run evaluation every n-th epoch')
    parser.add_argument('-fresh', action='store_true', default=False,
                        help='Override checkpoint')
    parser.add_argument('-debug', action='store_true', default=False,
                        help='Print debuggin output')
    parser.add_argument('-gpu', default=0, type=int,
                        help='Which GPU to run on (-1 for no gpu)')
    parser.add_argument('-seed', default=43, type=int, help='random seed')
    parser.add_argument('-sparse', default=True, action='store_true',
                        help='Use sparse gradients for embedding table')
    parser.add_argument('-quiet', action='store_true', default=True)
    parser.add_argument('-lre_type', choices=['scale', 'constant'], default='constant')
    parser.add_argument('-optim_type', choices=['adam', 'sgd'], default='adam', help='optimizer used for the classification SGC model')
    parser.add_argument('-metric', choices=['acc', 'f1'], default='acc', help='what metrics to report')
    parser.add_argument('-lambda_scale', type=float, default=0.07, help='scale of lambdas when generating HyLa features')
    parser.add_argument('-inductive', action='store_true', default=False, help='inductive training, used for reddit.')
    parser.add_argument('-use_feats', action='store_true', default=False, help='whether embed in the feature level, otherwise node level')
    parser.add_argument('-tuned', action='store_true', default=False, help='whether use tuned hyper-parameters')
    opt = parser.parse_args(new_args)

    if opt.tuned:
        with open(f'/content/drive/MyDrive/hyper_parameters_{opt.he_dim}d.json',) as f:
          hyper_parameters = json.load(f)[opt.dataset]
        opt.he_dim = hyper_parameters['he_dim']
        opt.hyla_dim = hyper_parameters['hyla_dim']
        opt.order = hyper_parameters['order']
        opt.lambda_scale = hyper_parameters['lambda_scale']
        opt.lr_e = hyper_parameters['lr_e']
        opt.lr_c = hyper_parameters['lr_c']
        opt.epochs = hyper_parameters['epochs']
    # Set the evaluation metric based on the dataset
    opt.metric = 'f1' if opt.dataset == 'reddit' else 'acc'

    # Set the starting epoch for training
    opt.epoch_start = 0

    # Set random seeds for reproducibility
    torch.manual_seed(opt.seed)
    np.random.seed(opt.seed)

    # Set split_seed for consistent data splits
    opt.split_seed = opt.seed

    # Display progress information during training unless opt.quiet is True
    opt.progress = not opt.quiet

    # Set up debugging and logging
    log_level = logging.DEBUG if opt.debug else logging.INFO
    log = logging.getLogger('HyLa')
    logging.basicConfig(level=log_level, format='%(message)s', stream=sys.stdout)

    # Set default tensor type to DoubleTensor
    torch.set_default_tensor_type('torch.DoubleTensor')

    # Set device for computation (GPU if available, otherwise CPU)
    opt.device = torch.device(f'cuda:{opt.gpu}' if opt.gpu >= 0 and torch.cuda.is_available() else 'cpu')

    # Specify the path to the dataset
    data_path = f'/content/drive/MyDrive/datasets/{opt.dataset}/'

    # Load data based on the dataset
    if opt.dataset in ['cora', 'disease_nc', 'pubmed', 'citeseer', 'airport']:
        data = load_data(opt, data_path)
    elif opt.dataset in ['reddit']:
        data = load_reddit_data(data_path)
    else:
        raise NotImplementedError
    if (adj_arg!=None):
        if (not proGNN):
          adj_n = aug_normalized_adjacency(adj_arg)
          data['adj_train'] = sparse_mx_to_torch_sparse_tensor(adj_n)
        else:
          data['adj_train'] = sparse_mx_to_torch_sparse_tensor(adj_arg)
        # Precompute features for inductive learning or non-explicit features
    if (feat_arg is not None):
        data['features'] = feat_arg
    if (labels is not None):
        data['labels'] = labels
    if (idx_train is not None):
        data['idx_train'] = idx_train
    if (idx_test is not None):
        data['idx_test'] = idx_test
    if (idx_val is not None):
        data['idx_val'] = idx_val
    # Set up dataset parameters and settings
    if opt.use_feats or opt.inductive:
        # Feature dimension when using hyperbolic Laplacian features at the feature level
        feature_dim = data['features'].size(1)
    else:
        # Feature dimension when using hyperbolic Laplacian features at the node level
        feature_dim = data['adj_train'].size(1)

    # Display information about the data
    if opt.progress:
        log.info(f'Training set size: {len(data["idx_train"])}, Validation set size: {len(data["idx_val"])}, Test set size: {len(data["idx_test"])}')
        log.info(f'Size of original feature matrix: {data["features"].size()}, Number of classes: {data["labels"].max().item()+1}')
        log.info('Precomputing features')
    if opt.inductive:
        features = data['features']
        data['features'], _ = sgc_precompute(data['adj_all'], features, opt.order)
        data['features_train'], nonzero_perc = sgc_precompute(data['adj_train'], features[data['idx_train']], opt.order)
    else:
        if not opt.use_feats:
            if not proGNN:
              features = data['adj_train'].to_dense()
              data['features'], nonzero_perc = sgc_precompute(data['adj_train'], features, opt.order-1)
              data['features_train'] = data['features'][data['idx_train']]
            else:
              features = data['adj_train'].to_dense()
              data['features'], nonzero_perc = sgc_precompute(data['adj_train'], features, opt.order-1)
              data['features_train'] = data['features'][data['idx_train']]
        else:
            features = data['features'].to_dense()
            if not proGNN:
              data['features'], nonzero_perc = sgc_precompute(data['adj_train'], features, opt.order)
              data['features_train'] = data['features'][data['idx_train']]
            else:
              data['features_train'] = features[data['idx_train']]

    # Display information about the percentage of non-zero elements during adjacency matrix precomputations
    if opt.progress:
        log.info(f'Nonzero percentage during adjacency matrix precomputations: {nonzero_perc}%')

    # Build feature model and set up optimizer
    model_f = build_model(opt, feature_dim).to(opt.device)

    # Scale learning rate if lre_type is 'scale'
    if opt.lre_type == 'scale':
        opt.lr_e = opt.lr_e * len(data['idx_train'])

    # Set up optimizer for feature model based on the specified manifold
    if opt.manifold == 'euclidean':
        optimizer_f = torch.optim.SGD(model_f.parameters(), lr=opt.lr_e)
    elif opt.manifold == 'poincare':
        optimizer_f = RiemannianSGD(model_f.optim_params(), lr=opt.lr_e)

    # Build classification model and set up optimizer
    model_c = get_model("SGC", opt.hyla_dim, data['labels'].max().item()+1).to(opt.device)
    if opt.optim_type == 'sgd':
        optimizer_c = torch.optim.SGD(model_c.parameters(), lr=opt.lr_c)
    elif opt.optim_type == 'adam':
        optimizer_c = torch.optim.Adam(model_c.parameters(), lr=opt.lr_c)
    else:
        raise NotImplementedError

    # Set up checkpoints if enabled
    ckps = None
    if opt.checkpoint:
        ckp_fm = generate_ckpt(opt, model_f, f'/content/drive/MyDrive/HyLa/HyLa-master/nc/datasets/{opt.dataset}/fm.pt')
        ckp_cm = generate_ckpt(opt, model_c, f'/content/drive/MyDrive/HyLa/HyLa-master/nc/datasets/{opt.dataset}/cm.pt')
        ckps = (ckp_fm, ckp_cm)

    # Start training and record accuracies
    t_start_all = timeit.default_timer()
    if not proGNN:
      train_acc, train_acc_best, val_acc, val_acc_best = train(
          model_f, model_c, optimizer_f, optimizer_c,
          data, opt, log, progress=opt.progress, ckps=ckps, is_tsne=want_tsne)
    else:
      prognn = ProGNN(model_f,model_c, prognn_args, opt.device)
      # pth = f"/content/drive/MyDrive/plots material/airport/airport_rnd/"

      # # adj = trainer.best_graph.cpu().detach()
      # # torch.save(adj, pth + f"clean_adj_{opt.ptb_lvl}.pt")

      # Hyla_model = prognn.model_f
      # Z_emb = Hyla_model.lt
      # torch.save(Z_emb, pth + f"poincare_emb_max_rnd_epoch100.pt")

      # Hyla_model.eval()
      # HyLa_features = Hyla_model()
      # if opt.use_feats:
      #     HyLa_features = torch.mm(data['features'].to(opt.device), HyLa_features)
      # torch.save(HyLa_features, pth + f"hyla_features_max_rnd_epoch100.pt")

      # print("here")
      if not opt.use_feats:

        features = data['features']
        prognn.fit(features.to(opt.device), data['adj_train'].to_dense().to(opt.device), data['labels'], data['idx_train'], data['idx_val'],opt,True,title)
        # print("here")

      else:
        features = data['features'].to_dense()
        prognn.fit(features.to(opt.device), data['adj_train'].to_dense().to(opt.device), data['labels'], data['idx_train'], data['idx_val'],opt,title)

        # hyla_model = prognn.model_c
        ## poincare embedding


      return prognn.test(features.to(opt.device), data['labels'], data['idx_test'], data['adj_train'].to_dense().to(opt.device), opt.order, not opt.use_feats,title)

    # Display total elapsed time
    if opt.progress:
        log.info(f'TOTAL ELAPSED: {timeit.default_timer()-t_start_all:.2f}')
    print("Total time taken: ", timeit.default_timer()-t_start_all)

    # Load the best model from the checkpoints if applicable
    if opt.checkpoint and ckps is not None:
        state_fm = ckps[0].load()
        state_cm = ckps[1].load()
        model_f.load_state_dict(state_fm['model'])
        model_c.load_state_dict(state_cm['model'])
        if opt.progress:
            log.info(f'Early stopping, loading from epoch: {state_fm["epoch"]} with val_acc_best: {state_fm["val_acc_best"]}')

    # Test the model on the test set
    if not proGNN:
      test_acc,test_f1 = test_regression(model_f, model_c, data['features'][data['idx_test']], data['labels'][data['idx_test']].to(opt.device),metric='acc', title = title)
      #test_f1 = test_regression(model_f, model_c, data['features'][data['idx_test']], data['labels'][data['idx_test']].to(opt.device),metric='f1')

      # Display training and test accuracies
      print("Training accuracy: ", train_acc * 100.0)
      print("Best training accuracy: ", train_acc_best * 100.0)
      print("Validation accuracy: ", val_acc * 100.0)
      print("Best validation accuracy: ", val_acc_best * 100.0)
      print("Test accuracy: ", test_acc * 100.0)
      """
      epochs = range(1, len(train_acc) + 1)
      plt.scatter(epochs, (train_acc * 100.0), label='Training Loss', marker='o')
      plt.scatter(epochs, (val_acc * 100.0), label='Validation Loss', marker='o')
      plt.xlabel('Epochs')
      plt.ylabel('Loss')
      plt.title('Training and Validation Loss')
      plt.legend()
      plt.show()
      """

      # Log the results
      log.info(
          f'"|| Last train_acc": {train_acc*100.0:.2f}%, '
          f'"|| Best train_acc": {train_acc_best*100.0:.2f}%, '
          f'"|| Last val_acc": {val_acc*100.0:.2f}%, '
          f'"|| Best val_acc": {val_acc_best*100.0:.2f}%, '
          f'"|| Test_acc": {test_acc*100.0:.2f}%.'
      )

      return [test_acc*100.0,test_f1*100.0]

## Airport Random Attack ##

In [None]:
class Args:
    def __init__(self, **kwargs):
        self.alpha = kwargs.get('alpha', 0.064)
        self.beta = kwargs.get('beta', 5)
        self.lambda_ = kwargs.get('lambda_', 320)
        self.debug = kwargs.get('debug', False)
        self.lr = kwargs.get('lr', 0.001)
        self.lr_adj = kwargs.get('lr_adj', 0.001)
        self.gamma = kwargs.get('gamma', 16)
        self.phi = kwargs.get('phi', 0.5)
        self.outer_steps = kwargs.get('outer_steps', 1)
        self.inner_steps = kwargs.get('inner_steps', 100)
        self.weight_decay = kwargs.get('weight_decay', 0.0001)
        self.epochs = kwargs.get('epochs', 15)
        self.symmetric = kwargs.get('symmetric', True)
        self.only_gcn = kwargs.get('only_gcn', False)
        self.pos_weight = kwargs.get('pos_weight', False)
        self.n_classes = kwargs.get('n_classes', None)

args_dict = {
    'alpha': 0.01,
    'beta': 2,
    'lambda_': 0.01,
    'debug': False,
    'lr_adj': 0.01,
    'gamma': 1,
    'phi': 0,
    'outer_steps': 1,
    'inner_steps': 2,
    'weight_decay': 0.0005,
    'epochs': 100,
    'symmetric': True,
    'only_gcn': False,
    'pos_weight': False,
    'n_classes': 4
}
args = Args(**args_dict)

# @title
import numpy as np
from scipy.sparse import coo_matrix

# Load the dense adjacency matrix from the .npy file
arr = ['1.0']#,'0.2','0.4','0.6','0.8','1.0']#,'1.2','1.4','1.6','1.8']
new_args =  ['-he_dim', '50', '-hyla_dim', '1000', '-dataset', 'airport', '-order', '2', '-lambda_scale', '0.01', '-lr_e', '0.1', '-lr_c', '0.1', '-epochs', '100','-quiet']
file_path3 = '/content/drive/MyDrive/attacks//random/airport/labels.npy'
file_path2 = '/content/drive/MyDrive/attacks//random/airport/features.npy'
file_path4 = '/content/drive/MyDrive/attacks//random/airport/idx_train.npy'
file_path5 = '/content/drive/MyDrive/attacks//random/airport/idx_test.npy'
file_path6 = '/content/drive/MyDrive/attacks//random/airport/idx_val.npy'

for i in arr:
  file_path = f'/content/drive/MyDrive/original_datasets/airport_adj_train.npy'
  dense_adj_matrix = np.load(file_path,allow_pickle=True)
  feat_matrix = np.load(file_path2,allow_pickle=True)
  idx_train = np.load(file_path4,allow_pickle=True).tolist()
  idx_test = np.load(file_path5,allow_pickle=True).tolist()
  idx_val = np.load(file_path6,allow_pickle=True).tolist()
  labels = torch.tensor(np.load(file_path3))
  labels = labels.to(torch.long)
  sparse_adj_matrix = coo_matrix(dense_adj_matrix)
  coo_indices = torch.tensor(feat_matrix.nonzero())
  coo_values = torch.from_numpy(feat_matrix[feat_matrix.nonzero()])
  size = tuple(feat_matrix.shape)  # Provide the size as a tuple of integers
  #arr_temp = main(new_args, adj_arg=sparse_adj_matrix,feat_arg=torch.sparse_coo_tensor(coo_indices, coo_values, size=size, dtype=torch.double),labels=torch.tensor(labels), idx_train=idx_train, idx_val=idx_val, idx_test=idx_test)
  arr_temp2 = main(new_args, adj_arg=sparse_adj_matrix,feat_arg=torch.sparse_coo_tensor(coo_indices, coo_values, size=size, dtype=torch.double),labels=torch.tensor(labels), idx_train=idx_train, idx_val=idx_val, idx_test=idx_test,proGNN=True, prognn_args=args)
  # print("without pro", arr_temp, i)
  print("with pro", arr_temp2)

  coo_indices = torch.tensor(feat_matrix.nonzero())
  arr_temp2 = main(new_args, adj_arg=sparse_adj_matrix,feat_arg=torch.sparse_coo_tensor(coo_indices, coo_values, size=size, dtype=torch.double),labels=torch.tensor(labels), idx_train=idx_train, idx_val=idx_val, idx_test=idx_test,proGNN=True, prognn_args=args)


2756
144
288




Epoch: 0001 loss_train: 1.6350 acc_train: 0.1082 loss_val: 33.8945 acc_val: 0.4943 time: 16.7506s
Epoch: 0001 loss_train: 27.2738 acc_train: 0.4400 loss_val: 48.1699 acc_val: 0.4943 time: 15.9232s
Epoch: 0002 loss_train: 41.5547 acc_train: 0.4400 loss_val: 58.2063 acc_val: 0.2405 time: 16.3490s
Epoch: 0002 loss_train: 45.7920 acc_train: 0.3505 loss_val: 39.4659 acc_val: 0.2405 time: 16.3310s
Epoch: 0003 loss_train: 28.8695 acc_train: 0.3505 loss_val: 75.0300 acc_val: 0.1050 time: 16.0497s
Epoch: 0003 loss_train: 69.8260 acc_train: 0.1014 loss_val: 52.1980 acc_val: 0.4943 time: 16.2668s
Epoch: 0004 loss_train: 50.7360 acc_train: 0.4400 loss_val: 54.7884 acc_val: 0.4943 time: 16.1106s
Epoch: 0004 loss_train: 54.8447 acc_train: 0.4400 loss_val: 46.1329 acc_val: 0.4943 time: 16.8006s
Epoch: 0005 loss_train: 46.0287 acc_train: 0.4400 loss_val: 29.2347 acc_val: 0.4943 time: 16.0667s
Epoch: 0005 loss_train: 27.6982 acc_train: 0.4400 loss_val: 14.0899 acc_val: 0.2405 time: 16.6890s
Epoch: 0006