In [None]:
# Install required packages.
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
!pip install -q torch-geometric
!pip install -q neptune-client
!pip install psutil



In [None]:
%matplotlib inline
import torch
import random
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from tqdm.notebook import tqdm

from torch_geometric.utils import degree
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch.utils.data import random_split

SEED = None
if SEED is not None:
    torch.manual_seed(SEED)
    np.random.seed(SEED)
    random.seed(SEED)

In [None]:
dataset = TUDataset(root='data/TUDataset', name='PROTEINS')
n = len(dataset)
n_train, n_test = int(n*0.8), int(n*0.1)
n_val = n - (n_train+n_test)

train_set, val_set, test_set = random_split(dataset, [n_train,n_val,n_test])

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Contains isolated nodes: {data.contains_isolated_nodes()}')
print(f'Contains self-loops: {data.contains_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')


Dataset: PROTEINS(1113):
Number of graphs: 1113
Number of features: 3
Number of classes: 2

Data(edge_index=[2, 162], x=[42, 3], y=[1])
Number of nodes: 42
Number of edges: 162
Average node degree: 3.86
Contains isolated nodes: False
Contains self-loops: False
Is undirected: True


In [None]:
from torch_geometric.utils import softmax

class LiftPooling(torch.nn.Module):
    """
    Args:
        in_channels (int): Size of each input sample.
        dropout (float, optional): The probability with
            which to drop edge scores during training. (default: 0)
        add_to_edge_score (float, optional): This is added to each
            computed edge score. Adding this greatly helps with unpool
            stability. (default: 0.5)
        edge_score_method (str, optionnal): Method to use to compute edge 
            scores. (default: 'softmax')
    """

    def __init__(self, in_channels, dropout=0, add_to_edge_score=0.5, 
                 edge_score_method='softmax'):
        super(LiftPooling, self).__init__()
        self.in_channels = in_channels
        self.aggregate = GCNConv(in_channels,in_channels)
        self.compute_edge_score = self.compute_edge_score_fct(edge_score_method)
        self.add_to_edge_score = add_to_edge_score
        self.dropout = dropout

        self.lin = torch.nn.Linear(2 * in_channels, 1)

    def compute_edge_score_fct(self,fct):
        assoc = {'softmax': self.compute_edge_score_softmax,
                 'tanh': self.compute_edge_score_tanh,
                 'sigmoid': self.compute_edge_score_sigmoid}
        return assoc[fct]

    @staticmethod
    def compute_edge_score_softmax(raw_edge_score, edge_index, num_nodes):
        return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes)

    @staticmethod
    def compute_edge_score_tanh(raw_edge_score, edge_index, num_nodes):
        return torch.tanh(raw_edge_score)

    @staticmethod
    def compute_edge_score_sigmoid(raw_edge_score, edge_index, num_nodes):
        return torch.sigmoid(raw_edge_score)

    def forward(self, x, edge_index, batch):
        """Forward computation which computes the raw edge score, normalizes
        it, and lift the edges.

        Args:
            x (Tensor): The node features.
            edge_index (LongTensor): The edge indices.
            batch (LongTensor): Batch vector

        Return types:
            x (Tensor): The pooled node features.
            edge_index (LongTensor): The coarsened edge indices.
            batch (LongTensor): The coarsened batch vector.
        """
        e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1)
        e = self.lin(e).view(-1)
        e = F.dropout(e, p=self.dropout, training=self.training)
        e = self.compute_edge_score(e, edge_index, x.size(0))
        e = e + self.add_to_edge_score

        x, edge_index, batch = self.__lift_edges__(x, edge_index, batch, e)

        return x, edge_index, batch


    def __lift_edges__(self, x, edge_index, batch, edge_score,percent=0.4):
        n_edges = edge_index.size(1)
        n_nodes = x.size(0)
        # Keeps track of the edges which stays the same after the procedure
        edges_remaining = np.ones(n_edges,dtype=bool)

        # Sort edges by their scores
        edge_argsort = torch.argsort(edge_score, descending=True)

        new_edges = []
        old_edges = []
        edge_index_cpu = edge_index.cpu().detach().numpy()
        edge_score_cpu = edge_score.cpu().detach().numpy()
        
        # Iterate through the top k edges
        for edge_idx in edge_argsort[:int(percent*n_edges)].tolist():
            # Case were current edge has been used by previous lifts
            if not edges_remaining[edge_idx]:
                continue
            source = edge_index_cpu[0, edge_idx]
            target = edge_index_cpu[1, edge_idx]

            # Computes the second edge in the Lift
            '''
            -------------------------
            -S-L-O-W---V-E-R-S-I-O-N-
            -------------------------
            dest_candidates = (edge_index_cpu[0]==target) * \
                              (edge_index_cpu[1]!=source) * \
                               edges_remaining * \
                               edge_score_cpu
            '''
            '''
            -----------------------------
            -S-L-O-W---V-E-R-S-I-O-N---1-
            -----------------------------
            dest_candidates = torch.bitwise_and(
                torch.bitwise_and(edge_index_cpu[0]==target,
                                  edge_index_cpu[1]!=source),
                edges_remaining
            ) * edge_score_cpu
            '''
            '''
            -----------------------------
            -S-L-O-W---V-E-R-S-I-O-N---2-
                -----------------------------
                dest_candidates = ((edge_index_cpu[0]==target) & \
                                (edge_index_cpu[1]!=source) & \
                                edges_remaining) * \
                                edge_score_cpu
                
            nxt = torch.argmax(dest_candidates)
            '''
            dest_candidates = (edge_index_cpu[0]==target) * \
                              (edge_index_cpu[1]!=source) * \
                               edges_remaining * \
                               edge_score_cpu
            
            nxt = np.argmax(dest_candidates)
            dest = edge_index_cpu[1,nxt]

            new_edges.append([target,source])
            old_edges.append([target,dest])
            old_edges.append([target,dest])
            edges_remaining[edge_idx] = 0
            edges_remaining[nxt] = 0
        new_edges = torch.tensor(new_edges,dtype=torch.int64).to(device)
        old_edges = torch.tensor(old_edges,dtype=torch.int64).to(device)
        x = self.aggregate(x, old_edges.T)
        remaining_edges = edge_index.T[edges_remaining==True]
        new_edges = torch.cat((new_edges.to(device),remaining_edges)).T
        return x, new_edges, batch

testing = False
if testing:
    liftpool = LiftPooling(1)
    from torch_geometric.data import Data
    edge_index = torch.tensor([[1,2,3,4,2,3,4,1],
                            [2,3,4,1,1,2,3,4]], dtype=torch.long)
    x = torch.zeros(4)
    data = Data(x=x, edge_index=edge_index)
    print(liftpool.__lift_edges__(data.x,data.edge_index,None,torch.tensor([1,3,4,1,1,3,4,1], dtype=torch.long)))

In [None]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import EdgePooling
from torch_geometric.nn import TopKPooling
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels,dropout=0.1):
        super(GCN, self).__init__()
        self.dropout = dropout

        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.pool1 = LiftPooling(hidden_channels,dropout=dropout)

        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.pool2 = LiftPooling(hidden_channels,dropout=dropout)

        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.pool3 = LiftPooling(hidden_channels,dropout=dropout)

        self.conv4 = GCNConv(hidden_channels, hidden_channels)
        self.pool4 = LiftPooling(hidden_channels,dropout=dropout)

        self.lin1 = Linear(hidden_channels, hidden_channels//2)
        self.lin2 = Linear(hidden_channels//2, hidden_channels//4)
        self.lin3 = Linear(hidden_channels//4, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x,edge_index,batch = self.pool1(x, edge_index, batch)

        x = self.conv2(x, edge_index)
        x = x.relu()
        x,edge_index,batch = self.pool2(x, edge_index, batch)

        x = self.conv3(x, edge_index)
        x = x.relu()
        x,edge_index,batch = self.pool3(x, edge_index, batch)

        x = self.conv4(x, edge_index)
        x = x.relu()
        x,edge_index,batch = self.pool4(x, edge_index, batch)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        
        return x

class GCN_EdgePool(torch.nn.Module):
    def __init__(self, hidden_channels,dropout=0.1):
        super(GCN_EdgePool, self).__init__()
        self.dropout = dropout

        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.pool1 = LiftPooling(hidden_channels,dropout=dropout)

        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.pool2 = EdgePooling(hidden_channels,dropout=dropout)

        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.pool3 = LiftPooling(hidden_channels,dropout=dropout)

        self.conv4 = GCNConv(hidden_channels, hidden_channels)
        self.pool4 = EdgePooling(hidden_channels,dropout=dropout)

        self.lin1 = Linear(hidden_channels, hidden_channels//2)
        self.lin2 = Linear(hidden_channels//2, hidden_channels//4)
        self.lin3 = Linear(hidden_channels//4, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x,edge_index,batch= self.pool1(x, edge_index, batch)

        x = self.conv2(x, edge_index)
        x = x.relu()
        x,edge_index,batch,_ = self.pool2(x, edge_index, batch)

        x = self.conv3(x, edge_index)
        x = x.relu()
        x,edge_index,batch = self.pool3(x, edge_index, batch)

        x = self.conv4(x, edge_index)
        x = x.relu()
        x,edge_index,batch,_ = self.pool4(x, edge_index, batch)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.relu(self.lin1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        
        return x

In [None]:
params = {'batch_size': 64,
          'learning_rate': 10e-5,
          'epochs': 50,
          'dropout': 0.5,
          'hidden_channels': 512,
          }

train_loader = DataLoader(train_set, batch_size=params['batch_size'], shuffle=True)
test_loader = DataLoader(test_set, batch_size=params['batch_size'], shuffle=True)
val_loader = DataLoader(val_set, batch_size=params['batch_size'], shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GCN(hidden_channels=params['hidden_channels'],
            dropout=params['dropout']).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=params['learning_rate'])
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()
    for data in train_loader:  # Iterate in batches over the training dataset.
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def test(loader):
    model.eval()
    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        out = model(data.x, data.edge_index, data.batch)
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


"""
--------------------------------------------------------------------------------
-----------------------------E-X-P-E-R-I-M-E-N-T-S------------------------------
--------------------------------------------------------------------------------
"""

import neptune
PROJECT_NAME=
NEPTUNE_API_TOKEN=
neptune.init(project_qualified_name=PROJECT_NAME,
             api_token=NEPTUNE_API_TOKEN)
neptune.create_experiment(params=params,
                          name='EdgePool')

class Logger():
    def __call__(self,logs):
        for metric_name, metric_value in logs.items():
            neptune.log_metric(metric_name, metric_value)

logger = Logger()
with tqdm(range(params['epochs'])) as t:
    for epoch in t:
        train()
        train_acc = test(train_loader)
        test_acc = test(test_loader)

        logs = {
            'train_acc':train_acc,
            'test_acc':test_acc
        }
        logger(logs)
        t.set_postfix(train_acc=train_acc,  test_acc=test_acc)

neptune.stop()

https://ui.neptune.ai/lmagne/lift-pool/e/LFT-104


HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))


