In [None]:
from pathflowai.utils import load_sql_df
import torch
import pickle
import os 
import sys, os
import umap, numba
from sklearn.preprocessing import LabelEncoder
from torch_cluster import knn_graph
from torch_geometric.data import Data 
import numpy as np
from torch_geometric.utils import train_test_split_edges
import os
os.environ['CUDA_VISIBLE_DEVICES']="0"
import argparse
from dgm.dgm import DGM
from dgm.plotting import *
from dgm.utils import *
from dgm.models import GraphClassifier, DGILearner
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import InMemoryDataset,DataLoader
import os,glob, pandas as pd
from sklearn.metrics import f1_score
import copy
import torch
import torch.nn.functional as F
from collections import Counter
from torch import nn
from torch_geometric.nn import GCNConv, GATConv, DeepGraphInfomax, SAGEConv
from torch_geometric.nn import DenseGraphConv
from torch_geometric.utils import to_dense_batch, to_dense_adj, dense_to_sparse
from torch_geometric.nn import GINEConv
from torch_geometric.utils import dropout_adj
from torch_geometric.nn import APPNP
import torch.nn as nn
import fire

EPS = 1e-15


def dense_mincut_pool(x, adj, s, mask=None):
    r"""MinCUt pooling operator from the `"Mincut Pooling in Graph Neural
    Networks" <https://arxiv.org/abs/1907.00481>`_ paper
    .. math::
        \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
        \mathbf{X}
        \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
        \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})
    based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B
    \times N \times C}`.
    Returns pooled node feature matrix, coarsened symmetrically normalized
    adjacency matrix and two auxiliary objectives: (1) The minCUT loss
    .. math::
        \mathcal{L}_c = - \frac{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{A}
        \mathbf{S})} {\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{D}
        \mathbf{S})}
    where :math:`\mathbf{D}` is the degree matrix, and (2) the orthogonality
    loss
    .. math::
        \mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}}
        {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}}
        \right\|}_F.
    Args:
        x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
            \times N \times F}` with batch-size :math:`B`, (maximum)
            number of nodes :math:`N` for each graph, and feature dimension
            :math:`F`.
        adj (Tensor): Symmetrically normalized adjacency tensor
            :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
        s (Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B
            \times N \times C}` with number of clusters :math:`C`. The softmax
            does not have to be applied beforehand, since it is executed
            within this method.
        mask (BoolTensor, optional): Mask matrix
            :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
            the valid nodes for each graph. (default: :obj:`None`)
    :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
        :class:`Tensor`)
    """

    x = x.unsqueeze(0) if x.dim() == 2 else x
    adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
    s = s.unsqueeze(0) if s.dim() == 2 else s

    (batch_size, num_nodes, _), k = x.size(), s.size(-1)

    s = torch.softmax(s, dim=-1)

    if mask is not None:
        mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
        x, s = x * mask, s * mask

    out = torch.matmul(s.transpose(1, 2), x)
    out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)

    # MinCUT regularization.
    mincut_num = _rank3_trace(out_adj)
    d_flat = torch.einsum('ijk->ij', adj)
    d = _rank3_diag(d_flat)
    mincut_den = _rank3_trace(
        torch.matmul(torch.matmul(s.transpose(1, 2), d), s))
    mincut_loss = -(mincut_num / mincut_den)
    mincut_loss = torch.mean(mincut_loss)

    # Orthogonality regularization.
    ss = torch.matmul(s.transpose(1, 2), s)
    i_s = torch.eye(k).type_as(ss)
    ortho_loss = torch.norm(
        ss / torch.norm(ss, dim=(-1, -2), keepdim=True,p=2) -
        i_s / torch.norm(i_s), dim=(-1, -2),p=2)
    ortho_loss = torch.mean(ortho_loss)

    # Fix and normalize coarsened adjacency matrix.
    ind = torch.arange(k, device=out_adj.device)
    out_adj[:, ind, ind] = 0
    d = torch.einsum('ijk->ij', out_adj)
    d = torch.sqrt(d)[:, None] + EPS
    out_adj = (out_adj / d) / d.transpose(1, 2)

    return out, out_adj, mincut_loss, ortho_loss


def _rank3_trace(x):
    return torch.einsum('ijj->i', x)


def _rank3_diag(x):
    eye = torch.eye(x.size(1)).type_as(x)
    out = eye * x.unsqueeze(2).expand(*x.size(), x.size(1))
    return out

class GCNNet(torch.nn.Module):
    def __init__(self, inp_dim, out_dim, hidden_topology=[32,64,128,128], p=0.5, p2=0.1, drop_each=True):
        super(GCNNet, self).__init__()
        self.out_dim=out_dim
        self.convs = nn.ModuleList([GATConv(inp_dim, hidden_topology[0])]+[GATConv(hidden_topology[i],hidden_topology[i+1]) for i in range(len(hidden_topology[:-1]))])
        self.drop_edge = lambda edge_index: dropout_adj(edge_index,p=p2)[0]
        self.dropout = nn.Dropout(p)
        self.fc = nn.Linear(hidden_topology[-1], out_dim)
        self.drop_each=drop_each

    def forward(self, x, edge_index, edge_attr=None):
        for conv in self.convs:
            if self.drop_each and self.training: edge_index=self.drop_edge(edge_index)
            x = F.relu(conv(x, edge_index, edge_attr))
        if self.training:
            x = self.dropout(x)
        x = self.fc(x)
        return x
    
class GCNFeatures(torch.nn.Module):
    def __init__(self, gcn, bayes=False, p=0.05, p2=0.1):
        super(GCNFeatures, self).__init__()
        self.gcn=gcn
        self.drop_each=bayes
        self.gcn.drop_edge = lambda edge_index: dropout_adj(edge_index,p=p2)[0]
        self.gcn.dropout = nn.Dropout(p)
    
    def forward(self, x, edge_index, edge_attr=None):
        for i,conv in enumerate(self.gcn.convs):
            if self.drop_each: edge_index=self.gcn.drop_edge(edge_index)
            x = conv(x, edge_index, edge_attr)
            if i+1<len(self.gcn.convs):
                x=F.relu(x)
        if self.drop_each:
            x = self.gcn.dropout(x)
        y = F.softmax(self.gcn.fc(F.relu(x)))
        return x,y

def extract_features(cv_split=2,
                graph_data='datasets/graph_dataset_no_pretrain.pkl',
                cv_splits='cv_splits/cv_splits.pkl',
                models_dir="models_no_pretrain/",
                out_dir='predictions_no_pretrain',
                hidden_topology=[32,64,128,128],
                p=0.5,
                p2=0.3,
                n_posterior=50
                ):
    # prep data
    datasets=pickle.load(open(graph_data,'rb'))
    cv_splits=pickle.load(open(cv_splits,'rb'))[cv_split]
    train_dataset=[datasets['graph_dataset'][i] for i in cv_splits['train_idx']]
    val_dataset=[datasets['graph_dataset'][i] for i in np.hstack((cv_splits['val_idx'],cv_splits['test_idx']))]#consider adding val_idx to help optimize

    # load model
    model=GCNNet(datasets['graph_dataset'][0].x.shape[1],datasets['df']['annotation'].nunique(),hidden_topology=hidden_topology,p=p,p2=p2)
    model=model.cuda()
    
    # load previous save
    model.load_state_dict(torch.load(os.path.join(models_dir,f"{cv_split}.model.pth")))

    # dataloaders
    dataloaders={}

    dataloaders['train']=DataLoader(train_dataset,shuffle=True)
    dataloaders['val']=DataLoader(val_dataset,shuffle=False)
    dataloaders['warmup']=DataLoader(train_dataset,shuffle=False)
    train_loader=dataloaders['warmup']

    # uncertainty test
    model.eval()
    feature_extractor=GCNFeatures(model,bayes=True,p=p,p2=p2).cuda()
    graphs=[]
    
    for i,data in enumerate(dataloaders['val']):
        with torch.no_grad():
            graph = to_networkx(data).to_undirected()
            model.train(False)
            x=data.x.cuda()
            xy=data.pos.numpy()
            edge_index=data.edge_index.cuda()
            y=data.y.numpy()
            preds=torch.stack([feature_extractor(x,edge_index)[1] for j in range(n_posterior)]).cpu().numpy()
            graphs.append(dict(y=y,G=graph,xy=xy,y_pred_posterior=preds.mean(0),y_std=preds.std(0)))
            del x,edge_index
    model.eval()
    feature_extractor=GCNFeatures(model,bayes=False).cuda()
    for i,data in enumerate(dataloaders['val']):
        with torch.no_grad():
            graph = to_networkx(data).to_undirected()
            model.train(False)
            x=data.x.cuda()
            xy=data.pos.numpy()
            edge_index=data.edge_index.cuda()
            y=data.y.numpy()
            preds=feature_extractor(x,edge_index)
            z,y_pred=preds[0].detach().cpu().numpy(),preds[1].detach().cpu().numpy()
            graphs[i].update(dict(z=z,y_pred=y_pred))
            del x,edge_index
    torch.save(graphs,os.path.join(out_dir,f"{cv_split}.predictions.pth"))
    
class Commands(object):
    def __init__(self):
        pass
    
    def extract_features(self,cv_split=2,
                graph_data='datasets/graph_dataset_no_pretrain.pkl',
                cv_splits='cv_splits/cv_splits.pkl',
                models_dir="models_no_pretrain/",
                out_dir='predictions_no_pretrain',
                hidden_topology=[32,64,128,128],
                p=0.5,
                p2=0.3,
                n_posterior=50
                ):
        extract_features(cv_split,
                        graph_data,
                        cv_splits,
                        models_dir,
                        out_dir,
                        hidden_topology,
                        p,
                        p2,
                        n_posterior)
        

In [None]:
your_args=dict(cv_split=2,
                graph_data='datasets/graph_dataset.pkl',
                cv_splits='cv_splits/cv_splits.pkl',
                models_dir="models/",
                out_dir='predictions',
                hidden_topology=[32,64,128,128],
                p=0.5,
                p2=0.3,
                n_posterior=50)
Commands().extract_features(**your_args)