In [1]:
import torch

if 'google.colab' in str(get_ipython()):
  print('Running on Colab')
  running_on_colab = True
else:
  print('Not running on Colab')
  running_on_colab = False

if running_on_colab:
    print(torch.__version__)
    !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
    !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
    !pip install -q git+https://github.com/snap-stanford/deepsnap.git

    from google.colab import drive
    drive.mount('/content/drive')
    filepath = '/content/drive/MyDrive/GCNN/graph_data/graphsage_prototype/'
    data_folder = filepath+"data/"
    models_folder = filepath+"models/"
    experiments_folder = filepath+"experiments/"

else:
    data_folder = "../../data/processed/graph_data_nohubs/"
    models_folder = "../../data/models/"
    experiments_folder = "../../data/experiments/"

Not running on Colab


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import datetime
import pickle
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_sparse import matmul
import deepsnap.hetero_gnn
import torch_geometric.transforms as T
from torch_geometric.data import HeteroData

from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score

In [3]:
import random
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

# Utility

In [4]:
def load_node_csv(path, index_col,type_col, **kwargs):
    """Returns node dataframe and a dict of mappings for each node type. 
    Each mapping maps from original df index to "heterodata index" { node_type : { dataframe_index : heterodata_index}}"""
    df = pd.read_csv(path, **kwargs,index_col=index_col)
    node_types = df[type_col].unique()
    mappings_dict = dict()
    for node_type in node_types:
        mapping = {index: i for i, index in enumerate(df[df[type_col] == node_type].index.unique())}
        mappings_dict[node_type] = mapping

    return df,mappings_dict

def load_edge_csv(path, src_index_col, dst_index_col, mappings, edge_type_col,src_type_col,dst_type_col, **kwargs):
    """Returns edge dataframe and a dict of edge indexes. Nodes are indexed according to the "heterodata index", using the node mappings from load_node_csv. 
    Edge indexes are tensors of shape [2, num_edges]. Dict is indexed by triplets of shape (src_type, edge_type, dst_type)."""
    df = pd.read_csv(path, **kwargs)
    df["edge_triple"] = list(zip(df[src_type_col],df[edge_type_col], df[dst_type_col]))
    edge_triplets = df["edge_triple"].unique()

    edge_index_dict = dict()
    for edge_triplet in edge_triplets:

        sub_df = df[df.edge_triple == edge_triplet]
        src_type,edge_type,dst_type = edge_triplet

        src_mapping = mappings[src_type]
        dst_mapping = mappings[dst_type]

        src = [src_mapping[index] for index in sub_df[src_index_col]]
        dst = [dst_mapping[index] for index in sub_df[dst_index_col]]
        edge_index = torch.tensor([src, dst])
        edge_index_dict[edge_triplet] = edge_index

    return df, edge_index_dict

def create_heterodata(node_map, edge_index):
    """Initializes HeteroData object from torch_geometric and creates corresponding nodes and edges, without any features"""
    data = HeteroData()
    for node_type,vals in node_map.items():
        # Initialize all node types without features
        data[node_type].num_nodes = len(vals)
    
    for edge_triplet, index in edge_index.items():
        src_type, edge_type, dst_type = edge_triplet
        data[src_type, edge_type, dst_type].edge_index = index
    
    return data

def get_reverse_types(edge_types):
    newlist = []
    for edge in edge_types:
        rev = tuple(reversed(edge))
        if rev != edge:
            if edge not in newlist:
                newlist.append(rev)
        else:
            newlist.append(rev)

    reversed_newlist = [tuple(reversed(edge)) for edge in newlist]

    return newlist, reversed_newlist

def initialize_features(data_object,feature,dim):
    for nodetype, store in data_object.node_items():
        if feature == "random":
            data_object[nodetype].x = torch.rand(store["num_nodes"],dim)
        if feature == "ones":
            data_object[nodetype].x = torch.ones(store["num_nodes"],dim)
    return data_object

# Model

In [5]:
quiet = True

def talk(msg, quiet=quiet):
    if not quiet:
        print(msg)

def generate_convs(hetero_graph, conv, hidden_size, first_layer=False):
    convs = {}

    msg_types = hetero_graph.edge_types
    for key in msg_types:
        if first_layer:
            dst_feature_dim = hetero_graph.num_node_features[key[2]]
            src_feature_dim = hetero_graph.num_node_features[key[0]]
            convs[key] = conv(src_feature_dim, dst_feature_dim, hidden_size)
        else:
            convs[key] = conv(hidden_size, hidden_size, hidden_size)

    return convs

def hetero_apply_function(x: dict,func) -> dict:
    """X es el diccionario de node embeddings o features, {node_type: tensor}.
    Aplica func a cada entrada del diccionario, devuelve un dict de la misma forma."""
    x_transformed = {}
    for key,val in x.items():
        transformed_val = func(val)
        x_transformed[key] = transformed_val
    
    return x_transformed

class HeteroGNNConv(pyg_nn.MessagePassing):
    def __init__(self, in_channels_src, in_channels_dst, out_channels):
        super().__init__(aggr="mean")

        self.in_channels_src = in_channels_src
        self.in_channels_dst = in_channels_dst
        self.out_channels = out_channels

        self.lin_dst = nn.Linear(in_channels_dst, out_channels)
        self.lin_src = nn.Linear(in_channels_src, out_channels)
        self.lin_update = nn.Linear(2*out_channels, out_channels)

    def forward(self,node_feature_src, node_feature_dst,edge_index):
        talk("HeteroGNN forward")
        talk(f"Node feature src shape: {node_feature_src.shape}, Node feature dst shape: {node_feature_dst.shape}, edge index shape: {edge_index.sparse_sizes()}")
        out = self.propagate(edge_index, node_feature_src=node_feature_src,node_feature_dst=node_feature_dst)
        return out

    def message_and_aggregate(self, edge_index, node_feature_src):
        talk("HeteroGNN msg and agg")
        talk(f"node_feature src shape: {node_feature_src.shape}")
        out = matmul(edge_index, node_feature_src, reduce=self.aggr)
        return out

    def update(self, aggr_out, node_feature_dst):
        talk("HeteroGNN update")
        talk(f"Aggr_out shape: {aggr_out.shape}")
        talk(f"Dst feature shape: {node_feature_dst.shape}")
        dst_msg = self.lin_dst(node_feature_dst)
        src_msg = self.lin_src(aggr_out)

        talk(f"Concat: dst_msg shape: {dst_msg.shape}, src_msg shape: {src_msg.shape}")
        full_msg = torch.concat((dst_msg, src_msg), dim=1)

        talk(f"Full msg shape: {full_msg.shape}")
        out = self.lin_update(full_msg)

        talk(f"After update shape: {out.shape}")
        return out


class HeteroGNNWrapperConv(deepsnap.hetero_gnn.HeteroConv):
    def __init__(self, convs, aggr="mean"):
        super().__init__(convs, None)
        self.aggr = aggr

        # Map the index and message type
        self.mapping = {}

    def reset_parameters(self):
        super().reset_parameters()

    def forward(self, node_features, edge_indices):
        talk("\n ------ Wrapper forward ------ ")
        message_type_emb = {}

        for message_key, adj in edge_indices.items():
            talk(f"\n{message_key}\n")
            src_type, edge_type, dst_type = message_key
            node_feature_src = node_features[src_type]
            node_feature_dst = node_features[dst_type]

            message_type_emb[message_key] = self.convs[message_key](node_feature_src,node_feature_dst,adj)

        # {dst: [] for src, type, dst in message_type.emb.keys()}
        # {tipo de nodo: [lista de embeddings obtenidos]}
        node_emb = {dst: [] for _, _, dst in message_type_emb.keys()}
        mapping = {}

        for (src, edge_type, dst), item in message_type_emb.items():
            #esto es para saber que indice es cada terna/msg type
            mapping[len(node_emb[dst])] = (src, edge_type, dst)

            #Agrego el embedding de la terna (src,type,dst) al la lista de embeddings de dst
            node_emb[dst].append(item)

        self.mapping = mapping

        #Ahora hago aggregation sobre las listas de embeddings, para cada tipo de nodo DST
        talk("\n------ Wrapper agg ------")
        for node_type, embs in node_emb.items():
            talk(f"\nAggregate {node_type} embeddings")

            # Si hay un solo embedding en la lista, me quedo con ese solito
            if len(embs) == 1:
                talk(f"Num embeddings = 1, no AGG needed")
                node_emb[node_type] = embs[0]
            
            #Si hay más de uno hago aggregation
            else:
                node_emb[node_type] = self.aggregate(embs)
        return node_emb

    def aggregate(self, xs):
        # Tomo la lista de embeddings para cada tipo de nodo y los "agrego". En este caso solo los promedio
        # Stackeo los embeddings
        talk(f"Num embeddings: {len(xs)}")
        talk(f"Shape embeddings: {[e.shape for e in xs]}")
        stacked = torch.stack(xs, dim=0)
        talk(f"Stacked shape: {stacked.shape}")
        out = torch.mean(stacked,dim=0)
        talk(f"Final aggregated shape: {out.shape}")
        return out



class HeteroGNN(torch.nn.Module):
    def __init__(self, hetero_graph, pred_mode, hidden_size=32, aggr="mean"):
        super().__init__()

        self.aggr = aggr
        self.pred_mode = pred_mode
        self.hidden_size = hidden_size
        self.bns1 = torch.nn.ModuleDict()
        self.relus1 = torch.nn.ModuleDict()
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        
        # if head=="distmult":
        #   self.distmult_head = distmult_head(hetero_graph,self.hidden_size)

        convs1 = generate_convs(hetero_graph, HeteroGNNConv, self.hidden_size, first_layer=True)
        convs2 = generate_convs(hetero_graph, HeteroGNNConv, self.hidden_size, first_layer=False)
        self.convs1 = HeteroGNNWrapperConv(convs1, aggr=self.aggr)
        self.convs2 = HeteroGNNWrapperConv(convs2, aggr=self.aggr)
        for node_type in hetero_graph.node_types:
            self.bns1[node_type] = torch.nn.BatchNorm1d(self.hidden_size)
            self.relus1[node_type] = torch.nn.LeakyReLU()
    
    def encode(self,graph):
        talk(" ------ ENCODER ------ ")
        x = {k:v["x"] for (k,v) in graph.node_items()}
        adj = {k:v["adj_t"] for (k,v) in graph.edge_items()}
        
        talk("Conv 1")
        x = self.convs1(x, edge_indices=adj)

        talk("\n BNS 1")
        x = deepsnap.hetero_gnn.forward_op(x, self.bns1)

        talk("\n Relu 1")
        x = deepsnap.hetero_gnn.forward_op(x, self.relus1)

        talk("\n Conv 2")
        x = self.convs2(x, edge_indices=adj)

        talk("\n----------")
        talk(f"Node embeddings done. Dimentions: {[(k,item.shape) for k,item in x.items()]}")
        talk("---------")

        return x
    
    def decode_train(self,x,graph):
        supervision_types = [item[0] for item in graph.edge_items() if "edge_label_index" in item[1].keys()]
        edge_label_index = {k:v["edge_label_index"] for (k,v) in graph.edge_items() if k in supervision_types}

        talk("\n ------ DECODER ------ ")
        pred = {}
        if self.pred_mode == "all":
            for message_type, edge_index in edge_label_index.items():
                talk(f"\n Decoding edge type: {message_type}")
                src_type = message_type[0]
                trg_type = message_type[2]

                x_source = x[src_type]
                x_target = x[trg_type]

                nodes_src = x_source[edge_index[0]]
                nodes_trg = x_target[edge_index[1]]

                talk(f"\n Multiplying shapes: {nodes_src.shape}, {nodes_trg.shape}")
                pred[message_type] = torch.sum(nodes_src * nodes_trg, dim=-1)

        elif self.pred_mode == "gda_only":
            keys = [edge for edge in supervision_types if "gda" in edge]
            for message_type in keys:
                talk(f"\n Decoding edge type: {message_type}")
                edge_index = edge_label_index[message_type]
                src_type = message_type[0]
                trg_type = message_type[2]

                x_source = x[src_type]
                x_target = x[trg_type]

                nodes_src = x_source[edge_index[0]]
                nodes_trg = x_target[edge_index[1]]
                talk(f"\n Multiplying shapes: {nodes_src.shape}, {nodes_trg.shape}")
                pred[message_type] = torch.sum(nodes_src * nodes_trg, dim=-1)

        return pred
    
    def decode_pred(self,x1,x2):
        talk(f"\n Multiplying shapes: {x1.shape}, {x2.shape}")
        pred_adj = torch.matmul(x1, x2.t())
        pred_adj = torch.sigmoid(pred_adj)

        return pred_adj

    def forward(self, graph):
        x = self.encode(graph)
        pred = self.decode_train(x,graph)
        
        return pred
          
    def loss(self, prediction_dict, ground_truth_dict):
        loss = 0
        num_types = len(prediction_dict.keys())
        # sets = torch.tensor(len(pred.keys()))
        for edge_type,pred in prediction_dict.items():
            y = ground_truth_dict[edge_type]
            loss += self.loss_fn(pred, y.type(pred.dtype))
        return loss/num_types

# Training and testing

In [6]:
@torch.no_grad()
def hits_at_k(y_true,x_prob,k,key) -> dict:
    """Dados los tensores x_prob y edge_label, calcula cuantas predicciones hizo correctamente en los primeros k puntajes.
    x_prob es la predicción del modelo luego de aplicar sigmoid (sin redondear, osea, el puntaje crudo)"""

    #ordeno los puntajes de mayor a menor
    x_prob, indices = torch.sort(x_prob, descending=True)

    #me quedo solo con los k mayor punteados
    x_prob = x_prob[:k]
    indices = indices[:k]

    if any(x_prob < 0.5):
      threshold_index = (x_prob < 0.5).nonzero()[0].item()
      print(f"Top {k} scores for {key} below classification threshold 0.5, threshold index: {threshold_index}")

    #busco que label tenían esas k preds
    labels = y_true[indices]

    #cuento cuantas veces predije uno positivo en el top k
    hits = labels.sum().item()

    return hits

def train(model, optimizer, graph, printb):
    model.train()
    optimizer.zero_grad()
    preds = model(graph) # acá no paso por sigmoid porque mi loss_fn es BCE with logits, que aplica sigmoid internamente!
    edge_label = {k:v["edge_label"] for (k,v) in graph.edge_items() if "edge_label" in v.keys()}
    loss = model.loss(preds, edge_label)
    loss.backward()
    optimizer.step()
    if printb:
        print(loss.item())
    return loss.item()

def get_metrics(y_true, x_pred):
   acc = round(accuracy_score(y_true,x_pred),2)
   ap = round(average_precision_score(y_true, x_pred),2)
   roc_auc = round(roc_auc_score(y_true,x_pred),2)

   return acc,ap ,roc_auc
  

@torch.no_grad()
def test(model,data,metric):
  model.eval()
  preds = model(data)
  edge_label = {k:v["edge_label"] for (k,v) in data.edge_items() if "edge_label" in v.keys()}
  all_preds = []
  all_true = []
  for key,pred in preds.items():
      probabilities = torch.sigmoid(pred)
      pred_label = torch.round(probabilities)
      ground_truth = edge_label[key]
      all_preds.append(pred_label)
      all_true.append(ground_truth)
  total_predictions = torch.cat(all_preds, dim=0).cpu().numpy()
  total_true = torch.cat(all_true, dim=0).cpu().numpy()
  score = metric(total_true,total_predictions)
  return score
  

@torch.no_grad()
def full_test(model,data,k,global_score=True):
  model.eval()
  preds = model(data)
  edge_label = {k:v["edge_label"] for (k,v) in data.edge_items() if "edge_label" in v.keys()}
  metrics = {}

  if global_score:
    all_scores = []
    all_preds = []
    all_true = []
    for key,pred in preds.items():
        probabilities = torch.sigmoid(pred)
        pred_label = torch.round(probabilities)
        ground_truth = edge_label[key]
        all_scores.append(probabilities)
        all_preds.append(pred_label)
        all_true.append(ground_truth)

    total_predictions = torch.cat(all_preds, dim=0).cpu().numpy()
    total_true = torch.cat(all_true, dim=0).cpu().numpy()
    total_scores = torch.cat(all_scores,dim=0).cpu().numpy()

    acc, ap, roc_auc =  get_metrics(total_true, total_predictions)
    hits_k = hits_at_k(total_true,total_scores,k,"all")
    metrics["all"] = [acc,ap,roc_auc,hits_k]

  else:
    for key,pred in preds.items():
        probabilities = torch.sigmoid(pred)
        pred_label = torch.round(probabilities)
        ground_truth = edge_label[key]
        acc, ap, roc_auc = get_metrics(ground_truth.cpu().numpy(), pred_label.cpu().numpy())
        hits_k = hits_at_k(ground_truth,probabilities,k,key)
        metrics[key] = [acc,ap, roc_auc,hits_k]
  
  return metrics

# Load parameters

In [17]:
model_name = "graphsage_prototype_best_31_03_23__07_49"

with open(models_folder+"params_"+model_name+".pickle", 'rb') as handle:
    params = pickle.load(handle)

params["pred_mode"] = "all"
params

{'weight_decay': 0.001,
 'pred_mode': 'all',
 'num_features': 10,
 'max_epochs': 400,
 'lr': 0.01,
 'hidden_size': 64,
 'feature_type': 'random',
 'aggr': 'sum'}

# Data preparation

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data from csv and create heterodata object
node_data, node_map = load_node_csv(data_folder+"nohub_graph_nodes.csv","node_index","node_type")
edge_data, edge_index = load_edge_csv(data_folder+"nohub_graph_edge_data.csv","x_index","y_index",node_map,"edge_type","x_type","y_type")
data = create_heterodata(node_map,edge_index)

#Split the dataset
edge_types, rev_edge_types = get_reverse_types(data.edge_types)
data = initialize_features(data,params["feature_type"],params["num_features"])
split_transform = T.RandomLinkSplit(num_val=0.3, num_test=0.3, is_undirected=True, add_negative_train_samples=True, disjoint_train_ratio=0.2,edge_types=edge_types,rev_edge_types=rev_edge_types)
transform_dataset = T.Compose([split_transform, T.ToSparseTensor(remove_edge_index=False),T.ToDevice(device)])

train_data, val_data, test_data = transform_dataset(data)

In [8]:
def plot_training_stats(title, losses, train_metric,val_metric,metric_str):

  fig, ax = plt.subplots()
  ax2 = ax.twinx()

  ax.set_xlabel("Training Epochs")
  ax2.set_ylabel("Performance Metric")
  ax.set_ylabel("Loss")

  plt.title(title)
  p1, = ax.plot(losses, "b-", label="training loss")
  p2, = ax2.plot(val_metric, "r-", label=f"val {metric_str}")
  p3, = ax2.plot(train_metric, "o-", label=f"train {metric_str}")
  plt.legend(handles=[p1, p2, p3])
  plt.show()

# Train model

In [18]:
model = HeteroGNN(data,params["pred_mode"],params["hidden_size"],params["aggr"]).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay'])

In [19]:
losses = []
train_scores = []
val_scores = []
metric = accuracy_score
metric_name = "Global accuracy"

for epoch in range(params["max_epochs"]):
    if epoch%10 == 0:
        printb = True
    else:
        printb = False

    loss = train(model,optimizer,train_data,printb)
    train_score = test(model,train_data,metric)
    val_score = test(model,val_data,metric)
    losses.append(loss)
    train_scores.append(train_score)
    val_scores.append(val_score)

plot_training_stats("GraphSAGE prototype training stats",losses,train_scores,val_scores,metric_name)

0.7044072151184082
0.4958636164665222
0.3989097476005554
0.3634708821773529


KeyboardInterrupt: 

In [None]:
def save_model(model,param_dict,folder_path,model_name):
    date = datetime.datetime.now()
    fdate = date.strftime("%d_%m_%y__%I_%M")
    fname = f"{model_name}_{fdate}"

    torch.save(model.state_dict(), f"{folder_path}{fname}.pth")

    with open(f"{folder_path}params_{fname}.pickle", 'wb') as handle:
        pickle.dump(param_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
save_model(model,params,models_folder,"graphsage_all_types_experiment")