In [None]:
# For colab

#!pip install dgl-cu100
#!pip install --upgrade tables

In [None]:
import dgl
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import json
import glob
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import os

## Exercise 3, part 2

### Message passing graph networks

Goals of this assignment:

1. learn to construct a message passing network (learn about DGLs update functions, build edge and node networks)
2. learn to train an edge classifier, and a node classifier

First, we download the dataset. The dataset is made from random graphs, where two points on the graph have been labeled as the start and end points of a path. 

<b> The task is to classify the edges of the graph to determine if they are "part of" the shortest path between the two nodes. </b>

In [None]:
!wget https://www.dropbox.com/s/2s7yhlrdpovyxtk/Dataset.zip

In [None]:
!unzip Dataset.zip

In [None]:
%load_ext autoreload
%autoreload 2

### DataSet

It is already build, the following code are here to let you understand how to deal with a directed graph in the future...

In [None]:
path = 'Dataset/training_data/'
filelist = glob.glob(path+'/*.json')
idx = 18
        
filelist = filelist[idx:idx+1]

for fname in filelist:
    with open(fname) as jfile:
        
        json_data = json.load(jfile)
        json_data['directed'] = 'true'
        graph = nx.node_link_graph(json_data)
        g = dgl.from_networkx(graph, node_attrs=['node_features'],edge_attrs=['distance','on_path'])
        
plt.rcParams["figure.figsize"] = (3,3)

nx.draw(g.to_networkx(), with_labels=True)
plt.show()

In [None]:
g.edges(), g.edata

In [None]:
path = 'Dataset/training_data/'
filelist = glob.glob(path+'/*.json')
idx = 18
        
filelist = filelist[idx:idx+1]

for fname in filelist:
    with open(fname) as jfile:
        
        json_data = json.load(jfile)
        json_data['directed'] = 'true'
        graph = nx.node_link_graph(json_data)
        g = dgl.from_networkx(graph, node_attrs=['node_features'],edge_attrs=['distance','on_path'])

        edge_distance = torch.cat([g.edata['distance'],g.edata['distance']],dim=0)
        edge_target = torch.cat([g.edata['on_path'],g.edata['on_path']],dim=0)

        g.add_edges(g.edges()[1],g.edges()[0])

        g.edata['distance'] = edge_distance
        g.edata['on_path'] = edge_target.float()

        g.update_all(dgl.function.copy_edge('on_path','on_path'),dgl.function.max('on_path','on_path'))
        
plt.rcParams["figure.figsize"] = (3,3)

nx.draw(g.to_networkx(), with_labels=True)
plt.show()

In [None]:
g.edges()

In [None]:
g.ndata, g.edata

In [None]:
np.where(g.ndata['node_features'].numpy() == [1,0])

In [None]:
from shortest_path_dataloader import ShortestPathDataset, collate_graphs

In [None]:
training_dataset = ShortestPathDataset('Dataset/training_data/')
validation_dataset = ShortestPathDataset('Dataset/validation_data/')

Everything about the target is store in the graph. You have some nodes, some edges, some nodes features and property called 'on_path' that will be usefull for the training.

In [None]:
g = training_dataset[18]

In [None]:
g

The nodes of the graph have no real "features" - the node features mark the starting point (1,0) and ending point (0,1) of our trajectory. 

In [None]:
g.ndata['node_features']

Each edge has a distance associated with it, and the target for training is also stored on the edge data - saying if the edge is part of the path or not.

In [None]:
g.edata['distance']

In [None]:
g.edata['on_path']

In [None]:
g.ndata['on_path']

In [None]:
#Training example

fig,ax = plt.subplots(figsize=(3,3),dpi=150)

nx_graph = nx.DiGraph()
nx_graph.add_nodes_from(g.nodes().data.numpy())
nx_graph.add_edges_from([(es.item(),ee.item()) for es,ee in zip(g.edges()[0],g.edges()[1])])

edge_dict = {(s.item(),e.item()) : on_path.item() for s,e,on_path in zip(g.edges()[0],g.edges()[1],
                                                                         g.edata['on_path'])}
edge_list = [e for i,e in enumerate(nx_graph.edges) if edge_dict[e]>0]

pos = nx.spring_layout(nx_graph)

nx.draw( nx_graph ,pos=pos ,ax=ax ,node_size=5,arrows=False)
nx.draw_networkx_edges(nx_graph,pos=pos,edgelist=edge_list,width=2,edge_color='r',ax=ax,arrows=False)

plt.show()

In [None]:
from torch.utils.data import Dataset, DataLoader

data_loader = DataLoader(training_dataset, batch_size=300, shuffle=True,
                         collate_fn=collate_graphs)
validation_data_loader = DataLoader(validation_dataset, batch_size=300, shuffle=False,
                         collate_fn=collate_graphs)

In [None]:
for batched_g in data_loader:
    break

In [None]:
batched_g

## The model

### Explanation of the structure

In order to implement the edge and node updates, we use DGLs "update_all" interface.

-------------------

For details look at the DGL documentation, https://docs.dgl.ai/tutorials/blitz/index.html

In model.py you have the basic skeleton. You implement an "edge network" and a "node network".

The edge network will act on all the edges in your graph - it will look at the nodes at the "src" and "dst" (source and destination) of the edge and apply a fully connected network to it.

Then the node network will have access to a "mailbox" with all the information sent by the edges connected to each node. You sum that "mailbox", add to it the existing node hidden represation, and then apply a fully connected network to update the node represation.

After each step of the update, networks take the node and edge represenations and apply a binary classifier to say if its "on path" or not.

The prediction is added to the previous step prediction, and the final result is compared to the target with BCEWithLogitsLoss, for both nodes and edges.

We want to implement the model such that it decorates the nodes and edges of the graph with the prediction (the forward pass of the model doesnt actually return anything).

<img src="gn_structure.jpeg" width="800" height="400">
<img src="gn_iterations.jpeg" width="800" height="400">

In [None]:
from MPNN_model import Classifier

In [None]:
net = Classifier()

In [None]:
for batched_g in data_loader:
    break

In [None]:
net(batched_g)

In [None]:
batched_g.ndata['prediction']

In [None]:
batched_g.edata['prediction']

In [None]:
# There are way more nodes not on path, so it's an unbalance classification task

fig = plt.figure(figsize=(3,3),dpi=150)
h = plt.hist( batched_g.edata['on_path'].data.numpy() ,bins=[-0.5,0.5,1.5])
print(h[0])
print(h[0][0]/h[0][1])
plt.show()

## Training and testing the model

We want the edge and node network to reach accuracy above 85%!

In [None]:
loss_func = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(10.0)) # To solve the unbalanced data
optimizer = optim.Adam(net.parameters(), lr=0.001) 

In [None]:
# We don't want accuracy (Since there are way more 0 than 1). 
# If you use accuracy it's easy that the network will only predict zeros and not learn the task
# We use f1, compute false positive and false negative.. HAVE A LOOK

def compute_f1_and_loss(dataloader,net):
    
    edge_true_pos = 0
    edge_false_pos = 0
    edge_false_neg = 0
    
    node_true_pos = 0
    node_false_pos = 0
    node_false_neg = 0

    loss = 0
    
    if torch.cuda.is_available():
        net.cuda()
    net.eval()
    
    n_batches = 0
    with torch.no_grad():
        for batched_g in dataloader:
            n_batches+=1
            
            if torch.cuda.is_available():
                batched_g = batched_g.to(torch.device('cuda'))
                
            net(batched_g)
            
            #We extract from the graph
            edge_target = batched_g.edata['on_path'] 
            edge_pred = batched_g.edata['prediction']
            
            node_target = batched_g.ndata['on_path']
            node_pred = batched_g.ndata['prediction']
            
            loss+= loss_func(edge_pred,edge_target).item()+loss_func(node_pred,node_target).item()
            
            edge_true_pos+=len(torch.where( (edge_pred>0) & (edge_target==1) )[0])
            edge_false_pos+=len(torch.where( (edge_pred>0) & (edge_target==0) )[0])
            edge_false_neg+=len(torch.where( (edge_pred<0) & (edge_target==1) )[0])
            
            node_true_pos+=len(torch.where( (node_pred>0) & (node_target==1) )[0])
            node_false_pos+=len(torch.where( (node_pred>0) & (node_target==0) )[0])
            node_false_neg+=len(torch.where( (node_pred<0) & (node_target==1) )[0])
            
    f1_edge = edge_true_pos/(edge_true_pos+0.5*(edge_false_pos+edge_false_neg))
    f1_node = node_true_pos/(node_true_pos+0.5*(node_false_pos+node_false_neg))
    loss = loss/n_batches      
    return f1_edge, f1_node, loss

In [None]:
if torch.cuda.is_available():
    net.cuda()

In [None]:
compute_f1_and_loss(validation_data_loader,net)

In [None]:
# Run it on colab!!1

if torch.cuda.is_available() == True:
    n_epochs = 100

    training_loss_vs_epoch = []
    validation_loss_vs_epoch = []

    training_f1_edge_vs_epoch = []
    training_f1_node_vs_epoch = []
    validation_f1_edge_vs_epoch = []
    validation_f1_node_vs_epoch = []

    pbar = tqdm( range(n_epochs) )

    for epoch in pbar: 

        if len(validation_loss_vs_epoch) > 1:
            pbar.set_description(
                  ' val f1 node:'+'{0:.5f}'.format(validation_f1_node_vs_epoch[-1])+
                   ' val f1 edge:'+'{0:.5f}'.format(validation_f1_edge_vs_epoch[-1]) )

        net.train() # put the net into "training mode"
        for batched_g in tqdm(data_loader):
            if torch.cuda.is_available():
                batched_g = batched_g.to(torch.device('cuda'))


            optimizer.zero_grad()
            net(batched_g)
            edge_target = batched_g.edata['on_path']
            edge_pred = batched_g.edata['prediction']

            node_target = batched_g.ndata['on_path']
            node_pred = batched_g.ndata['prediction']

            loss = loss_func(edge_pred,edge_target)+loss_func(node_pred,node_target)
            loss.backward()
            optimizer.step()

        net.eval() #put the net into evaluation mode
        train_f1_edge, train_f1_node, train_loss = compute_f1_and_loss(data_loader,net)
        valid_f1_edge, valid_f1_node, valid_loss =  compute_f1_and_loss(validation_data_loader,net)

        training_loss_vs_epoch.append(train_loss)    
        training_f1_edge_vs_epoch.append( train_f1_edge )
        training_f1_node_vs_epoch.append( train_f1_node )


        validation_f1_edge_vs_epoch.append(valid_f1_edge)
        validation_f1_node_vs_epoch.append(valid_f1_node)
        validation_loss_vs_epoch.append(valid_loss)
        if len(validation_loss_vs_epoch)==1 or validation_loss_vs_epoch[-2] > validation_loss_vs_epoch[-1]:
            torch.save(net.state_dict(), 'trained_model.pt')

In [None]:
if torch.cuda.is_available() == True:
    fig,ax = plt.subplots(1,3,figsize=(8,3))

    ax[0].plot(training_loss_vs_epoch,label='training')
    ax[0].plot(validation_loss_vs_epoch,label='validation')

    ax[1].plot(training_f1_edge_vs_epoch)
    ax[1].plot(validation_f1_edge_vs_epoch)

    ax[2].plot(training_f1_node_vs_epoch)
    ax[2].plot(validation_f1_node_vs_epoch)
    plt.show()

In [None]:
net.load_state_dict(torch.load('trained_model.pt',map_location='cpu'))

In [None]:
for batched_g in validation_data_loader:
    break
    
net.eval()
if torch.cuda.is_available():
    net.cuda()
    batched_g = batched_g.to(torch.device('cuda'))
    
net(batched_g)
predictions = batched_g.edata['prediction'].cpu().data.numpy()
sigmoid_predictions =  torch.sigmoid(batched_g.edata['prediction']).cpu().data.numpy()
targets = batched_g.edata['on_path'].cpu().data.numpy()

In [None]:
import numpy as np
fig,ax = plt.subplots(1,2,figsize=(7,3))

ax[0].hist(sigmoid_predictions[targets==1],histtype='step',bins=np.linspace(0,1,50),density=True)
ax[0].hist(sigmoid_predictions[targets==0],histtype='step',bins=np.linspace(0,1,50),density=True)

ax[1].hist(predictions[targets==1],histtype='step',bins=50,density=True)
ax[1].hist(predictions[targets==0],histtype='step',bins=50,density=True)

plt.tight_layout()
plt.show()

In [None]:
fig,ax = plt.subplots(1,2,figsize=(6,3),dpi=150)

ax[0].set_title('Target')
ax[1].set_title('Model Prediction')
net.eval()
net.cpu()

g = validation_dataset[666]
net(g)

output_pred = torch.sigmoid(g.edata['prediction']).data.numpy()

nx_graph = nx.DiGraph()
nx_graph.add_nodes_from(g.nodes().data.numpy())
nx_graph.add_edges_from([(es.item(),ee.item()) for es,ee in zip(g.edges()[0],g.edges()[1])])

edge_dict = {(s.item(),e.item()) : on_path.item() for s,e,on_path in zip(g.edges()[0],g.edges()[1],
                                                                         g.edata['on_path'])}

edge_prediction_dict = {(s.item(),e.item()) : pred.item() for s,e,pred in zip(g.edges()[0],g.edges()[1],
                                                                         output_pred)}

edge_list = [e for i,e in enumerate(nx_graph.edges) if edge_dict[e]>0]

predicted_edge_list = [e for i,e in enumerate(nx_graph.edges) if edge_prediction_dict[e]>0.5]

pos = nx.spring_layout(nx_graph)


nx.draw( nx_graph ,pos=pos ,ax=ax[0] ,node_size=5,arrows=False)
nx.draw_networkx_edges(nx_graph,pos=pos,edgelist=edge_list,width=2,edge_color='r',ax=ax[0],arrows=False)

nx.draw( nx_graph ,pos=pos ,ax=ax[1] ,node_size=5,arrows=False)
nx.draw_networkx_edges(nx_graph,pos=pos,edgelist=predicted_edge_list,width=2,edge_color='r',ax=ax[1],arrows=False)

plt.show()

In [None]:
from test_part2 import *

In [None]:
test_part2()