In [None]:
import pickle
import os 
import sys, os
import umap, numba
from sklearn.preprocessing import LabelEncoder
from torch_cluster import knn_graph
from torch_geometric.data import Data 
import numpy as np
from torch_geometric.utils import train_test_split_edges
import os
import argparse
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import InMemoryDataset,DataLoader
import os,glob, pandas as pd
from sklearn.metrics import f1_score
import copy
import torch
import torch.nn.functional as F
from collections import Counter
from torch import nn
from torch_geometric.nn import GCNConv, GATConv, DeepGraphInfomax, SAGEConv
from torch_geometric.nn import DenseGraphConv
from torch_geometric.utils import to_dense_batch, to_dense_adj, dense_to_sparse
from torch_geometric.nn import GINEConv
from torch_geometric.utils import dropout_adj
from torch_geometric.nn import APPNP
import torch.nn as nn
import fire
from visdom import Visdom

EPS = 1e-15


def dense_mincut_pool(x, adj, s, mask=None):
    r"""MinCUt pooling operator from the `"Mincut Pooling in Graph Neural
    Networks" <https://arxiv.org/abs/1907.00481>`_ paper
    .. math::
        \mathbf{X}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
        \mathbf{X}
        \mathbf{A}^{\prime} &= {\mathrm{softmax}(\mathbf{S})}^{\top} \cdot
        \mathbf{A} \cdot \mathrm{softmax}(\mathbf{S})
    based on dense learned assignments :math:`\mathbf{S} \in \mathbb{R}^{B
    \times N \times C}`.
    Returns pooled node feature matrix, coarsened symmetrically normalized
    adjacency matrix and two auxiliary objectives: (1) The minCUT loss
    .. math::
        \mathcal{L}_c = - \frac{\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{A}
        \mathbf{S})} {\mathrm{Tr}(\mathbf{S}^{\top} \mathbf{D}
        \mathbf{S})}
    where :math:`\mathbf{D}` is the degree matrix, and (2) the orthogonality
    loss
    .. math::
        \mathcal{L}_o = {\left\| \frac{\mathbf{S}^{\top} \mathbf{S}}
        {{\|\mathbf{S}^{\top} \mathbf{S}\|}_F} -\frac{\mathbf{I}_C}{\sqrt{C}}
        \right\|}_F.
    Args:
        x (Tensor): Node feature tensor :math:`\mathbf{X} \in \mathbb{R}^{B
            \times N \times F}` with batch-size :math:`B`, (maximum)
            number of nodes :math:`N` for each graph, and feature dimension
            :math:`F`.
        adj (Tensor): Symmetrically normalized adjacency tensor
            :math:`\mathbf{A} \in \mathbb{R}^{B \times N \times N}`.
        s (Tensor): Assignment tensor :math:`\mathbf{S} \in \mathbb{R}^{B
            \times N \times C}` with number of clusters :math:`C`. The softmax
            does not have to be applied beforehand, since it is executed
            within this method.
        mask (BoolTensor, optional): Mask matrix
            :math:`\mathbf{M} \in {\{ 0, 1 \}}^{B \times N}` indicating
            the valid nodes for each graph. (default: :obj:`None`)
    :rtype: (:class:`Tensor`, :class:`Tensor`, :class:`Tensor`,
        :class:`Tensor`)
    """

    x = x.unsqueeze(0) if x.dim() == 2 else x
    adj = adj.unsqueeze(0) if adj.dim() == 2 else adj
    s = s.unsqueeze(0) if s.dim() == 2 else s

    (batch_size, num_nodes, _), k = x.size(), s.size(-1)

    s = torch.softmax(s, dim=-1)

    if mask is not None:
        mask = mask.view(batch_size, num_nodes, 1).to(x.dtype)
        x, s = x * mask, s * mask

    out = torch.matmul(s.transpose(1, 2), x)
    out_adj = torch.matmul(torch.matmul(s.transpose(1, 2), adj), s)

    # MinCUT regularization.
    mincut_num = _rank3_trace(out_adj)
    d_flat = torch.einsum('ijk->ij', adj)
    d = _rank3_diag(d_flat)
    mincut_den = _rank3_trace(
        torch.matmul(torch.matmul(s.transpose(1, 2), d), s))
    mincut_loss = -(mincut_num / mincut_den)
    mincut_loss = torch.mean(mincut_loss)

    # Orthogonality regularization.
    ss = torch.matmul(s.transpose(1, 2), s)
    i_s = torch.eye(k).type_as(ss)
    ortho_loss = torch.norm(
        ss / torch.norm(ss, dim=(-1, -2), keepdim=True,p=2) -
        i_s / torch.norm(i_s), dim=(-1, -2),p=2)
    ortho_loss = torch.mean(ortho_loss)

    # Fix and normalize coarsened adjacency matrix.
    ind = torch.arange(k, device=out_adj.device)
    out_adj[:, ind, ind] = 0
    d = torch.einsum('ijk->ij', out_adj)
    d = torch.sqrt(d)[:, None] + EPS
    out_adj = (out_adj / d) / d.transpose(1, 2)

    return out, out_adj, mincut_loss, ortho_loss


def _rank3_trace(x):
    return torch.einsum('ijj->i', x)


def _rank3_diag(x):
    eye = torch.eye(x.size(1)).type_as(x)
    out = eye * x.unsqueeze(2).expand(*x.size(), x.size(1))
    return out

class GCNNet(torch.nn.Module):
    def __init__(self, inp_dim, out_dim, hidden_topology=[32,64,128,128], p=0.5, p2=0.1, drop_each=True):
        super(GCNNet, self).__init__()
        self.out_dim=out_dim
        self.convs = nn.ModuleList([GATConv(inp_dim, hidden_topology[0])]+[GATConv(hidden_topology[i],hidden_topology[i+1]) for i in range(len(hidden_topology[:-1]))])
        self.drop_edge = lambda edge_index: dropout_adj(edge_index,p=p2)[0]
        self.dropout = nn.Dropout(p)
        self.fc = nn.Linear(hidden_topology[-1], out_dim)
        self.drop_each=drop_each

    def forward(self, x, edge_index, edge_attr=None):
        for conv in self.convs:
            if self.drop_each and self.training: edge_index=self.drop_edge(edge_index)
            x = F.relu(conv(x, edge_index, edge_attr))
        if self.training:
            x = self.dropout(x)
        x = self.fc(x)
        return x
    
class GCNFeatures(torch.nn.Module):
    def __init__(self, gcn, bayes=False):
        super(GCNFeatures, self).__init__()
        self.gcn=gcn
        self.drop_each=bayes
    
    def forward(self, x, edge_index, edge_attr=None):
        for conv in self.gcn.convs:
            if self.drop_each: edge_index=self.gcn.drop_edge(edge_index)
            x = F.relu(conv(x, edge_index, edge_attr))
        if self.drop_each:
            x = self.gcn.dropout(x)
        y = F.softmax(self.gcn.fc(x))
        return x,y

def fit_model(cv_split=2,
                graph_data='datasets/graph_dataset_no_pretrain_final.pkl',
                cv_splits='cv_splits/cv_splits_final.pkl',
                use_weights=False,
                n_batches_backward=1,
                f1_metric='weighted',
                n_epochs=1500,
                out_dir='models_no_pretrain_final',
                lr=1e-2,
                eta_min=1e-4,
                T_max=20,
                wd=0,
                hidden_topology=[32,64,128,128],
                p=0.5,
                p2=0.3,
                burnin=400,
                warmup=100,
                gpu_id=0,
                batch_size=1,
                vis_every=0,
                port=5555
                ):
    print(gpu_id); torch.cuda.set_device(gpu_id)
    if not vis_every: vis=None
    else: vis=Visdom(port=port)
    datasets=pickle.load(open(graph_data,'rb'))
    cv_splits=pickle.load(open(cv_splits,'rb'))[cv_split]
    train_dataset=[datasets['graph_dataset'][i] for i in cv_splits['train_idx']]
    val_dataset=[datasets['graph_dataset'][i] for i in cv_splits['val_idx']]
    weights=datasets['weight']


    # load model
    model=GCNNet(datasets['graph_dataset'][0].x.shape[1],datasets['df']['annotation'].nunique(),hidden_topology=hidden_topology,p=p,p2=p2)
    model=model.cuda()

    # load optimizer
    optimizer = torch.optim.Adam(model.parameters(),lr=lr,weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max, eta_min=eta_min, last_epoch=-1)
    criterion=nn.CrossEntropyLoss(weight=torch.tensor(weights).float() if use_weights else None)
    criterion=criterion.cuda()

    # initialize val saving
    save_mod=False
    past_performance=[0]

    # dataloaders
    dataloaders={}

    dataloaders['train']=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
    dataloaders['val']=DataLoader(val_dataset,shuffle=True)
    dataloaders['warmup']=DataLoader(train_dataset,shuffle=False)
    train_loader=dataloaders['warmup']

    n_total_batches=0
    train_val_f1=[]
    for epoch in range(n_epochs):
        Y,Y_Pred=[],[]
        for i,data in enumerate(train_loader):
            n_total_batches+=1
            model.train(True)
            x=data.x.cuda()
            edge_index=data.edge_index.cuda()
            y=data.y.cuda()
            y_out=model(x,edge_index)
            loss = criterion(y_out, y) / n_batches_backward
            loss.backward()
            if n_total_batches%n_batches_backward==0 or (i==len(train_loader.dataset)-1):
                optimizer.step()
                optimizer.zero_grad()
            Y_Pred.append(F.softmax(y_out).argmax(1).detach().cpu().numpy().flatten())
            Y.append(y.detach().cpu().numpy().flatten())
            del x, edge_index, loss, y_out
            if epoch <=warmup:
                break 
        if epoch == warmup:
            train_loader=dataloaders['train']
        if epoch>=burnin:
            save_mod=True
        train_f1=f1_score(np.hstack(Y),np.hstack(Y_Pred),average=f1_metric)
        scheduler.step()
        Y,Y_Pred=[],[]
        for i,data in enumerate(dataloaders['val']):
            model.train(False)
            x=data.x.cuda()
            edge_index=data.edge_index.cuda()
            y=data.y.cuda()
            y_out=model(x,edge_index)
            loss = criterion(y_out, y) 
            y_prob=F.softmax(y_out).detach().cpu().numpy()
            y_pred=y_prob.argmax(1).flatten()
            y_true=y.detach().cpu().numpy().flatten()
            Y_Pred.append(y_pred)
            Y.append(y_true)
            if epoch%vis_every==0 and not i:
                vis.scatter(data.pos.numpy(),opts=dict(markercolor=(y_pred*255).astype(int),webgl=False,markerborderwidth=0,markersize=5),win="pred")
                vis.scatter(data.pos.numpy(),opts=dict(markercolor=y_true*255,webgl=False,markerborderwidth=0,markersize=5),win="true")
            del x, edge_index, loss, y_out
        val_f1=f1_score(np.hstack(Y),np.hstack(Y_Pred),average=f1_metric)
        if save_mod and val_f1>=max(past_performance):
            best_model_dict=copy.deepcopy(model.state_dict())
            past_performance.append(val_f1)
        print(epoch,train_f1,val_f1,flush=True)
        train_val_f1.append((train_f1,val_f1))

    model.load_state_dict(best_model_dict)
    torch.save(model.state_dict(),os.path.join(out_dir,f"{cv_split}.model.pth"))
    torch.save(train_val_f1,os.path.join(out_dir,f"{cv_split}.f1.log.pth"))
    

class Commands(object):
    def __init__(self):
        pass
        
    def fit_model(self,cv_split=2,
                graph_data='datasets/graph_dataset_pretrain.pkl',
                cv_splits='cv_splits/cv_splits.pkl',
                use_weights=False,
                n_batches_backward=1,
                f1_metric='weighted',
                n_epochs=1500,
                out_dir='models_no_pretrain',
                lr=1e-2,
                eta_min=1e-4,
                T_max=20,
                wd=0,
                hidden_topology=[32,64,128,128],
                p=0.5,
                p2=0.3,
                burnin=400,
                warmup=100,
                gpu_id=0,
                batch_size=1,
                vis_every=0,
                port=5555):
        fit_model(cv_split,
                graph_data,
                cv_splits,
                use_weights,
                n_batches_backward,
                f1_metric,
                n_epochs,
                out_dir,
                lr,
                eta_min,
                T_max,
                wd,
                hidden_topology,
                p,
                p2,
                burnin,
                warmup,
                gpu_id,
                batch_size,
                vis_every,
                port)
                

In [None]:
Commands().fit_model(cv_split=4,
                graph_data='your_data_set/graph_datasets/imagenet_graph_data.pkl',
                cv_splits='your_data_set/graph_datasets/cv_splits.pkl',
                use_weights=True,
                n_batches_backward=4,
                f1_metric='macro',
                n_epochs=1500,
                out_dir='your_data_set/gnn_modoels',
                lr=1e-2,
                eta_min=1e-5,
                T_max=20,
                wd=1e-5,
                hidden_topology=[32]*3,
                p=0.2,
                p2=0.2,
                burnin=400,
                warmup=25,
                gpu_id=0,
                batch_size=16,
                vis_every=10)