In [9]:
import sys, os

sys.path.insert(0,"/dartfs-hpc/rc/home/w/f003k8w/.local/lib/python3.7/site-packages/")

from pathflowai.utils import load_sql_df
import torch
import pickle
import os 
# import umap, numba
from sklearn.preprocessing import LabelEncoder
from torch_cluster import knn_graph
from torch_geometric.data import Data 
import numpy as np
from torch_geometric.utils import train_test_split_edges
import os
os.environ['CUDA_VISIBLE_DEVICES']="0"
import argparse
from torch_geometric.utils.convert import to_networkx
from torch_geometric.data import InMemoryDataset,DataLoader
import os,glob, pandas as pd
from sklearn.metrics import f1_score
import copy
import torch
import torch.nn.functional as F
from collections import Counter
from torch import nn
from torch_geometric.nn import GCNConv, GATConv, DeepGraphInfomax, SAGEConv
from torch_geometric.nn import DenseGraphConv
from torch_geometric.utils import to_dense_batch, to_dense_adj, dense_to_sparse
from torch_geometric.nn import GINEConv
from torch_geometric.utils import dropout_adj
from torch_geometric.nn import APPNP
import torch.nn as nn
import fire

EPS = 1e-15

class GCNNet(torch.nn.Module):
    def __init__(self, inp_dim, out_dim, hidden_topology=[32,64,128,128], p=0.5, p2=0.1, drop_each=True):
        super(GCNNet, self).__init__()
        self.out_dim=out_dim
        self.convs = nn.ModuleList([GATConv(inp_dim, hidden_topology[0])]+[GATConv(hidden_topology[i],hidden_topology[i+1]) for i in range(len(hidden_topology[:-1]))])
        self.drop_edge = lambda edge_index: dropout_adj(edge_index,p=p2)[0]
        self.dropout = nn.Dropout(p)
        self.fc = nn.Linear(hidden_topology[-1], out_dim)
        self.drop_each=drop_each

    def forward(self, x, edge_index, edge_attr=None):
        for conv in self.convs:
            if self.drop_each and self.training: edge_index=self.drop_edge(edge_index)
            x = F.relu(conv(x, edge_index, edge_attr))
        if self.training:
            x = self.dropout(x)
        x = self.fc(x)
        return x
    
class GCNFeatures(torch.nn.Module):
    def __init__(self, gcn, bayes=False, p=0.05, p2=0.1):
        super(GCNFeatures, self).__init__()
        self.gcn=gcn
        self.drop_each=bayes
        self.gcn.drop_edge = lambda edge_index: dropout_adj(edge_index,p=p2)[0]
        self.gcn.dropout = nn.Dropout(p)
    
    def forward(self, x, edge_index, edge_attr=None):
        for i,conv in enumerate(self.gcn.convs):
            if self.drop_each: edge_index=self.gcn.drop_edge(edge_index)
            x = conv(x, edge_index, edge_attr)
            if i+1<len(self.gcn.convs):
                x=F.relu(x)
        if self.drop_each:
            x = self.gcn.dropout(x)
        y = self.gcn.fc(F.relu(x))#F.softmax()
        return x,y

def extract_features(cv_split=2,
                graph_data='datasets/graph_dataset_no_pretrain.pkl',
                cv_splits='cv_splits/cv_splits.pkl',
                models_dir="models_no_pretrain/",
                out_dir='predictions_no_pretrain',
                hidden_topology=[32,64,128,128],
                p=0.5,
                p2=0.3,
                n_posterior=50
                ):
    # prep data
    datasets=pickle.load(open(graph_data,'rb'))
    cv_splits=pickle.load(open(cv_splits,'rb'))[cv_split]
    train_dataset=[datasets['graph_dataset'][i] for i in cv_splits['train_idx']]
    val_dataset=[datasets['graph_dataset'][i] for i in np.hstack((cv_splits['train_idx'],cv_splits['val_idx']))]#,cv_splits['test_idx']consider adding val_idx to help optimize

    # load model
    model=GCNNet(datasets['graph_dataset'][0].x.shape[1],datasets['df']['annotation'].nunique(),hidden_topology=hidden_topology,p=p,p2=p2)
    model=model.cuda()
    
    # load previous save
    model.load_state_dict(torch.load(os.path.join(models_dir,f"{cv_split}.model.pth")))

    # dataloaders
    dataloaders={}

    dataloaders['train']=DataLoader(train_dataset,shuffle=True)
    dataloaders['val']=DataLoader(val_dataset,shuffle=False)
    dataloaders['warmup']=DataLoader(train_dataset,shuffle=False)
    train_loader=dataloaders['warmup']

    # uncertainty test
    model.eval()
    feature_extractor=GCNFeatures(model,bayes=True,p=p,p2=p2).cuda()
    graphs=[]
    
    for i,data in enumerate(dataloaders['val']):
        with torch.no_grad():
            graph = to_networkx(data).to_undirected()
            model.train(False)
            x=data.x.cuda()
            xy=data.pos.numpy()
            edge_index=data.edge_index.cuda()
            y=data.y.numpy()
            preds=torch.stack([feature_extractor(x,edge_index)[1] for j in range(n_posterior)]).cpu().numpy()
            graphs.append(dict(y=y,G=graph,xy=xy,y_pred_posterior=preds.mean(0),y_std=preds.std(0)))
            del x,edge_index
    model.eval()
    feature_extractor=GCNFeatures(model,bayes=False).cuda()
    for i,data in enumerate(dataloaders['val']):
        with torch.no_grad():
            graph = to_networkx(data).to_undirected()
            model.train(False)
            x=data.x.cuda()
            xy=data.pos.numpy()
            edge_index=data.edge_index.cuda()
            y=data.y.numpy()
            preds=feature_extractor(x,edge_index)
            z,y_pred=preds[0].detach().cpu().numpy(),preds[1].detach().cpu().numpy()
            graphs[i].update(dict(z=z,y_pred=y_pred))
            del x,edge_index
    torch.save(graphs,os.path.join(out_dir,f"{cv_split}.predictions.pth"))
    
class Commands(object):
    def __init__(self):
        pass
    
    def extract_features(self,cv_split=2,
                graph_data='datasets/graph_dataset_no_pretrain.pkl',
                cv_splits='cv_splits/cv_splits.pkl',
                models_dir="models_no_pretrain/",
                out_dir='predictions_no_pretrain',
                hidden_topology=[32,64,128,128],
                p=0.5,
                p2=0.3,
                n_posterior=50
                ):
        extract_features(cv_split,
                        graph_data,
                        cv_splits,
                        models_dir,
                        out_dir,
                        hidden_topology,
                        p,
                        p2,
                        n_posterior)

In [10]:
your_args=dict(cv_split=4,
                graph_data='bcc/graph_datasets/pretrain_graph_data_256.pkl',
                cv_splits='bcc/graph_datasets/cv_splits_256_pretrain.pkl',
                models_dir="bcc/gnn_models/",
                out_dir='bcc/predictions',
                hidden_topology=[32]*3,
                p=0.3,
                p2=0.2,
                n_posterior=100)
Commands().extract_features(**your_args)

In [20]:
! free -g 

              total        used        free      shared  buff/cache   available
Mem:            376          61         276           1          38         311
Swap:             3           2           1
