# Few-Shot Graph Classification

In [None]:
import torch
TORCH = torch.__version__.split('+')[0]
CUDA = 'cu' + torch.version.cuda.replace('.','')

!pip install pytorch-lightning
!pip install pyyaml==5.4.1
!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

In [None]:
import os
from typing import Dict, List, Tuple, Union, \
                   Sequence, TypeVar, Generic, \
                   Optional
import networkx as nx
import numpy as np
import plotly
import plotly.express as px
import plotly.graph_objects as go

from torchvision import transforms
import torch
from torch import optim
from torch.nn import functional as F
import torch.nn as nn

import torch_geometric as gtorch
import torch_geometric.nn as gnn
import torch_geometric.data as gdata
import torch_geometric.loader as gloader
import torch_geometric.utils as gutils

import random
from collections import defaultdict

## Load dataset

The dataset I will load is the TRIANGLE used in the AS-MAML paper (here's the [link](https://arxiv.org/pdf/2003.08246.pdf))

In [None]:
ROOT_PATH = os.getcwd()
GRAPH_ATTRIBUTE  = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_graph_attributes.txt")
GRAPH_LABELS     = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_graph_labels.txt")
NODE_NATTRIBUTE  = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_node_attributes.txt")
GRAPH_INDICATOR  = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_graph_indicator.txt")
GRAPH_A          = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_A.txt")

T = TypeVar('T')

LOAD_DATASET = False

if LOAD_DATASET:
    !wget https://cloud-storage.eu-central-1.linodeobjects.com/TRIANGLES.zip
    !unzip TRIANGLES.zip

In [None]:
graph_attribute      = open(GRAPH_ATTRIBUTE).readlines()
graph_labels         = open(GRAPH_LABELS).readlines()
graph_node_attribute = open(NODE_NATTRIBUTE).readlines()
graph_indicator      = open(GRAPH_INDICATOR).readlines()
graph_a              = open(GRAPH_A).readlines()

In [None]:
class GeneratorTxt2Graph:
    def __init__(self, **kargs) -> None:
        self.__graph_attribute  = kargs['graph_attribute']
        self.__graph_labels     = kargs['graph_labels']
        self.__node_attribute   = kargs['node_attribute']
        self.__graph_indicator  = kargs['graph_indicator']
        self.__graph_adjacency  = kargs['graph_adjacency']
        self.__node_labels      = kargs['node_labels']
        self.__edge_labels      = kargs['edge_labels']
        self.__edge_attributes  = kargs['edge_attributes']

    def _collect_nodes(self) -> Tuple[Dict[str, List[int]], Dict[str, Tuple[str, int]]]:
        """ Look at the graph_indicator.txt file and return
        a dictionary containing as keys the ID of the graph
        and as values a list of nodes belonging to that graph """
        print("--- Collecting Nodes ... ")

        nodes, i_nodes = dict(), dict()
        for i, graph_id in enumerate(self.__graph_indicator):
            graph_id = graph_id[:-1]
            if graph_id not in nodes:
                nodes[graph_id] = []
        
            nodes[graph_id].append(i + 1)
            i_nodes[i + 1] = [graph_id, i + 1]
        
        return nodes, i_nodes

    def _collect_edges(self, i_nodes: Dict[str, Tuple[str, int]], 
                             direct: bool=False) -> Dict[str, List[Tuple[int, int]]]:
        """ Look at the graph_A.txt file and return a dictionary
        containing as keys the ID of the graph and as values
        a list of edges of that graph """
        print("--- Collecting Edges ...")

        edges = dict()
        for line in self.__graph_adjacency:
            if line == "\n":
                continue
            
            a, b = line.split(", ")
            a, b = a.strip(), b.strip()

            graph_a, node_a = i_nodes[int(a)]
            graph_b, node_b = i_nodes[int(b)]

            assert graph_a == graph_b, f"Two graphs are not equal: {graph_a} != {graph_b}"

            if graph_a not in edges:
                edges[graph_a] = []
            
            edges[graph_a].append((node_a, node_b))

        return edges
    
    def _collect_node_attributes(self, i_nodes: Dict[str, Tuple[str, int]]) -> None:
        """ Set attributes for each nodes """
        print("--- Collecting Node Attributes ...")
        for i, attr in enumerate(self.__node_attribute):
            node_i = i_nodes[i + 1]
            attrs = attr.split(", ")
            attrs[-1] = attrs[-1][:-1]
            node_i.append({f"attr{i}" : attr for i, attr in enumerate(attrs)})

    def _collect_graph_labels(self, graphs: Dict[str, nx.Graph]) -> None:
        """ Set the attribute label for each graph """
        print("--- Collecting Graph Labels ...")
        for i, label in enumerate(self.__graph_labels):
            graph_i = graphs[str(i + 1)]
            graphs[str(i + 1)] = (graph_i, label[:-1])

    # TODO: _collect_node_labels, _collect_edge_labels, _collect_edge_attributes

    def generate(self) -> Dict[str, nx.Graph]:
        """ Return a dictionary of {i : Graph_i} """
        # Get Nodes and Edges
        nodes, i_nodes = self._collect_nodes()
        edges          = self._collect_edges(i_nodes, False)

        # Set attributes for nodes
        self._collect_node_attributes(i_nodes)
        
        # Create the graphs
        graphs = dict()
        for graph_id in edges:
            g = nx.Graph()
            g_nodes = [(i_nodes[n][1], i_nodes[n][-1]) for n in nodes[graph_id]]
            g_edges = edges[graph_id]

            g.add_nodes_from(g_nodes)
            g.add_edges_from(g_edges)

            graphs[graph_id] = g

        # Set labels for graph
        self._collect_graph_labels(graphs)

        return graphs

In [None]:
graphs_gen = GeneratorTxt2Graph(graph_attribute=graph_attribute,
                                graph_labels=graph_labels,
                                node_attribute=graph_node_attribute,
                                graph_indicator=graph_indicator,
                                graph_adjacency=graph_a,
                                node_labels=None,
                                edge_labels=None,
                                edge_attributes=None)

In [None]:
%%time 
graphs = graphs_gen.generate()

--- Collecting Nodes ... 
--- Collecting Edges ...
--- Collecting Node Attributes ...
--- Collecting Graph Labels ...
CPU times: user 15.7 s, sys: 1.43 s, total: 17.2 s
Wall time: 17.2 s


In [None]:
def plot_graph(G : Union[nx.Graph, nx.DiGraph], name: str) -> None:
    """
    Plot a graph
    
    Parameters
    ----------
    graph : Union[nx.Graph, nx.DiGraph]
        Just a nx.Graph object
    name  : str
        The name of the graph
        
    Returns
    -------
    None
    """
    # Getting the 3D Spring layout
    layout = nx.spring_layout(G, dim=3, seed=18)
    
    # Getting nodes coordinate
    x_nodes = [layout[i][0] for i in layout]  # x-coordinates of nodes
    y_nodes = [layout[i][1] for i in layout]  # y-coordinates of nodes
    z_nodes = [layout[i][2] for i in layout]  # z-coordinates of nodes
    
    # Getting a list of edges and create a list with coordinates
    elist = G.edges()
    x_edges, y_edges, z_edges = [], [], []
    for edge in elist:
        x_edges += [layout[edge[0]][0], layout[edge[1]][0], None]
        y_edges += [layout[edge[0]][1], layout[edge[1]][1], None]
        z_edges += [layout[edge[0]][2], layout[edge[1]][2], None]

    colors = np.linspace(0, len(x_nodes))
        
    # Create a trace for the edges
    etrace = go.Scatter3d(x=x_edges,
                          y=y_edges,
                          z=z_edges,
                          mode='lines',
                          line=dict(color='rgb(125,125,125)', width=1),
                          hoverinfo='none'
                         )
    
    # Create a trace for the nodes
    ntrace = go.Scatter3d(x=x_nodes,
                          y=y_nodes,
                          z=z_nodes,
                          mode='markers',
                          marker=dict(
                              symbol='circle',
                              size=6,
                              color=colors,
                              colorscale='Viridis',
                              line=dict(color='rgb(50,50,50)', width=.5)),
                          text=list(layout.keys()),
                          hoverinfo='text'
                         )
    
    # Set the axis
    axis = dict(showbackground=False,
                showline=False,
                zeroline=False,
                showgrid=False,
                showticklabels=False,
                title='')
    
    # Create a layout for the plot
    go_layout = go.Layout(title=f"{name} Network Graph",
                          width=600,
                          height=600,
                          showlegend=False,
                          scene=dict(xaxis=dict(axis),
                                     yaxis=dict(axis),
                                     zaxis=dict(axis)),
                          margin=dict(t=100),
                          hovermode='closest'
                         )
    
    # Plot
    data = [etrace, ntrace]
    fig = go.Figure(data=data, layout=go_layout)
    fig.show()

In [None]:
plot_graph(graphs["1"][0], "1")

## Dataset, Few-Shot Sampler and DataLoader

In [None]:
class GraphDataset(gdata.Dataset):
    def __init__(self, graphs_ds: Dict[str, Tuple[nx.Graph, str]]) -> None:
        super(GraphDataset, self).__init__()
        self.graphs_ds = graphs_ds

    def __repr__(self) -> str:
        return f"GraphDataset(classes={set(self.targets().tolist())},n_graphs={self.len()})"

    def len(self) -> int:
        return len(self.graphs_ds.keys())

    def targets(self) -> torch.Tensor:
        """ Return all the labels """
        targets = []
        for _, graph in self.graphs_ds.items():
            targets.append(int(graph[1]))

        return torch.tensor(targets)
    
    def get(self, idx: int) -> gdata.Data:
        """ Return (Graph object, Adjacency matrix and label) of a graph """
        graph    = self.graphs_ds[str(idx + 1)]
        g, label = graph[0].to_directed(), graph[1]
        
        # Retrieve nodes attributes
        attrs = list(g.nodes(data=True))
        x     = torch.tensor([list(map(int, a.values())) for _, a in attrs], dtype=torch.float)

        # Retrieve edges
        edge_index = torch.tensor([list(e) for e in g.edges], dtype=torch.long) \
                          .t()                                                  \
                          .contiguous()

        # Retrieve ground trouth labels
        y = torch.tensor([int(label)], dtype=torch.int)

        return gdata.Data(x=x, edge_index=edge_index, y=y)

    @classmethod
    def dataset_from_labels(cls, mask   : torch.Tensor,
                                 classes: torch.Tensor,
                                 graphs : Dict[str, Tuple[nx.Graph, str]]
    ) -> 'GraphDataset':
        """ Return a new Dataset containing only graphs with specific labels """
        filter = classes[(mask[:, None] == classes[None, :]).any(dim=0)].numpy()\
                 .astype(str)\
                 .tolist()

        filtered_graphs = {k : v for k, v in graphs.items() if v[1] in filter}
        graph_dataset   = super(GraphDataset, cls).__new__(cls)

        graph_dataset.__init__(filtered_graphs)

        return graph_dataset



def get_all_labels(graphs: Dict[str, Tuple[nx.Graph, str]]) -> torch.Tensor:
    """ Return a list containings all labels of the dataset """
    return torch.tensor(list(set([int(v[1]) for _, v in graphs.items()])))


def generate_train_val_test(graphs    : Dict[str, Tuple[nx.Graph, str]],
                            perc_test : float,
                            perc_train: float,
                            perc_val  : float
) -> Tuple[GraphDataset, GraphDataset, GraphDataset]:
    """ Return dataset for training, validation and testing """
    classes = get_all_labels(graphs)
    n_class = len(classes)
    perm    = torch.randperm(n_class) + 1

    q_train = n_class * perc_train // 100
    q_test  = n_class * perc_test  // 100
    q_val   = n_class * perc_val   // 100

    
    train_perm = perm[:q_train]
    test_perm  = perm[q_train: q_train + q_test]
    val_perm   = perm[q_train + q_test:]
    
    train_ds = GraphDataset.dataset_from_labels(train_perm, classes, graphs)
    test_ds  = GraphDataset.dataset_from_labels(test_perm,  classes, graphs)
    val_ds   = GraphDataset.dataset_from_labels(val_perm,   classes, graphs)

    return train_ds, test_ds, val_ds

In [None]:
train_ds, test_ds, val_ds = generate_train_val_test(graphs, perc_train=50, perc_test=30, perc_val=20)
train_ds, test_ds, val_ds

(GraphDataset(classes={1, 3, 5, 7, 8},n_graphs=22500),
 GraphDataset(classes={2, 10, 4},n_graphs=13500),
 GraphDataset(classes={9, 6},n_graphs=9000))

In [None]:
class NWayKShotSampler(torch.utils.data.Sampler):
    """
    In few-shot classification, and in particular in Meta-Learning, 
    we use a specific way of sampling batches from the training/val/test 
    set. This way is called N-way-K-shot, where N is the number of classes 
    to sample per batch and K is the number of examples to sample per class 
    in the batch. The sample batch on which we train our model is also called 
    `support` set, while the one on which we test is called `query` set.

    This class is a N-way-K-shot sampler that will be used as a bath_sampler
    for the :obj:`torch_geometric.loader.DataLoader` dataloader. This sampler
    return batches of indices that correspond to support and query set batches.

    Attributes:
        labels: PyTorch tensor of the labels of the data elements
        n_way: Number of classes to sampler per batch
        k_shot: Number of examples to sampler per class in the batch
        include_query: If True, returns batch of size N * K * 2, which
                       can be split into support and query set. Simplifies
                       the implementation of sampling the same classes but
                       distinct examples for support and query set.
        shuffle: If True, examples and classes are shuffled at each iteration
        shuffle_once: If True, examples and classes are shuffled only once
                      and remains constant across each iterations.
        batch_size: The size of the batch N * K. If include_query is True
                    then it will be N * K * 2.
        indices_per_class: How many indices per classes
        batches_per_class: Number of K-shot batches for each classes
        iterations: Number of iterations
        class_list: Contains all the classes the number of times
                    defined in batches_per_classes
    """
    def __init__(self, labels       : torch.Tensor, 
                       n_way        : int, 
                       k_shot       : int, 
                       include_query: bool=False, 
                       shuffle      : bool=True, 
                       shuffle_once : bool=False) -> None:
        super().__init__(None)
        self.labels        = labels
        self.n_way         = n_way
        self.k_shot        = k_shot
        self.shuffle       = shuffle
        self.shuffle_once  = shuffle_once
        self.include_query = include_query

        if include_query: 
            self.k_shot *= 2

        self.batch_size = self.n_way * self.k_shot

        self.indices_per_class = dict()
        self.batches_per_class = dict()

        self.classes = torch.unique(self.labels).tolist()
        for cl in self.classes:
            self.indices_per_class[cl] = torch.where(self.labels == cl)[0]
            self.batches_per_class[cl] = self.indices_per_class[cl].shape[0] // self.k_shot
        
        self.iterations = sum(self.batches_per_class.values()) // self.n_way
        self.class_list = [cl for cl in self.classes for _ in range(self.batches_per_class[cl])]

        if self.shuffle_once or self.shuffle:
            self.shuffle_data()
    
    def shuffle_data(self) -> None:
        """
        Shuffle the examples per class
        
        Args:
            classes: The list of all classes
        """
        for cl in self.classes:
            perm = torch.randperm(self.indices_per_class[cl].shape[0])
            self.indices_per_class[cl] = self.indices_per_class[cl][perm]
        
        # Finally shuffle the class list from which we sample
        random.shuffle(self.class_list)

    def __iter__(self) -> List[torch.Tensor]:
        # Shuffle data
        if self.shuffle:
            self.shuffle_data()

        # Sample
        indexes = defaultdict(int)
        for it in range(self.iterations):
            # Select N classes for the batch
            class_batch = self.class_list[it * self.n_way:(it + 1) * self.n_way]
            index_batch = []

            for cl in class_batch:
                idx = indexes[cl]
                index_batch += [self.indices_per_class[cl][idx:idx + self.k_shot]]
                # if not len(index_batch) > 0:
                #     index_batch = [self.indices_per_class[cl][idx:idx + self.k_shot]] 
                # else:
                #     index_batch = [
                #         torch.hstack((index_batch[-1], self.indices_per_class[cl][idx:idx + self.k_shot]))
                #     ]

                indexes[cl] += self.k_shot

            if self.include_query:
                # Include in the index_batch also the query_set
                print(index_batch[1::2])
                print(index_batch)
                index_batch = index_batch[::2] + index_batch[1::2]
            
            yield index_batch

    def __len__(self) -> int:
        return self.iterations

In [None]:
# Needs to create a new dataset with a new collater
class GraphCollater(gloader.dataloader.Collater):
    def __init__(self, *args) -> None:
        super(GraphCollater, self).__init__(*args)
    
    def __call__(self, batch: Generic[T]) -> Generic[T]:
        elem = batch[0]
        if isinstance(elem, GraphDataset):
            return self([[el[j] for j in range(len(el))] for el in batch])
        
        return super(GraphCollater, self).__call__(batch)


class NKDataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset     : GraphDataset, 
                       batch_size  : int=1,
                       shuffle     : bool=False,
                       follow_batch: Optional[List[str]]=None,
                       exclude_keys: Optional[List[str]]=None,
                       **kwargs) -> None:

        if 'collate_fn' in kwargs:
            del kwargs['collate_fn']

        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        super(NKDataLoader, self).__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=GraphCollater(follow_batch, 
                                     exclude_keys),
            **kwargs,
        )

In [None]:
nwshot = NWayKShotSampler(train_ds.targets(),
                          include_query=True,
                          n_way=5,
                          k_shot=4,
                          shuffle=True)

In [None]:
l = next(iter(nwshot))

[tensor([ 6409, 20574, 16251,  8373,  6495,  8390, 16475,  8877]), tensor([ 3899,  5545, 15614, 18216,  3040, 18373,  5232,  3834])]
[tensor([10639, 19354, 21620,  9090, 11960, 11764,  9634, 19168]), tensor([ 6409, 20574, 16251,  8373,  6495,  8390, 16475,  8877]), tensor([ 8480,  8387,  7348,  6934, 21096, 22488,  8953,  7456]), tensor([ 3899,  5545, 15614, 18216,  3040, 18373,  5232,  3834]), tensor([15533,  4987,  5463,  5009, 15643,  5149, 22377,  4203])]


In [None]:
l

[tensor([10639, 19354, 21620,  9090, 11960, 11764,  9634, 19168]),
 tensor([ 8480,  8387,  7348,  6934, 21096, 22488,  8953,  7456]),
 tensor([15533,  4987,  5463,  5009, 15643,  5149, 22377,  4203]),
 tensor([ 6409, 20574, 16251,  8373,  6495,  8390, 16475,  8877]),
 tensor([ 3899,  5545, 15614, 18216,  3040, 18373,  5232,  3834])]

In [None]:
N_WAY  = 5
K_SHOT = 4

graph_train_loader = NKDataLoader(
    graph_dataset,
    batch_sampler=NWayKShotSampler(graph_dataset.targets(),
                                   include_query=True,
                                   n_way=N_WAY,
                                   k_shot=K_SHOT,
                                   shuffle=True)
)

In [None]:
sample = next(iter(graph_train_loader))

In [None]:
sample

[DataBatch(x=[21, 1], edge_index=[2, 72], y=[1], batch=[21], ptr=[2]),
 DataBatch(x=[15, 1], edge_index=[2, 48], y=[1], batch=[15], ptr=[2]),
 DataBatch(x=[80, 1], edge_index=[2, 244], y=[1], batch=[80], ptr=[2]),
 DataBatch(x=[18, 1], edge_index=[2, 50], y=[1], batch=[18], ptr=[2]),
 DataBatch(x=[12, 1], edge_index=[2, 42], y=[1], batch=[12], ptr=[2]),
 DataBatch(x=[19, 1], edge_index=[2, 62], y=[1], batch=[19], ptr=[2]),
 DataBatch(x=[7, 1], edge_index=[2, 24], y=[1], batch=[7], ptr=[2]),
 DataBatch(x=[19, 1], edge_index=[2, 54], y=[1], batch=[19], ptr=[2]),
 DataBatch(x=[7, 1], edge_index=[2, 14], y=[1], batch=[7], ptr=[2]),
 DataBatch(x=[21, 1], edge_index=[2, 46], y=[1], batch=[21], ptr=[2]),
 DataBatch(x=[22, 1], edge_index=[2, 56], y=[1], batch=[22], ptr=[2]),
 DataBatch(x=[11, 1], edge_index=[2, 32], y=[1], batch=[11], ptr=[2]),
 DataBatch(x=[23, 1], edge_index=[2, 58], y=[1], batch=[23], ptr=[2]),
 DataBatch(x=[6, 1], edge_index=[2, 14], y=[1], batch=[6], ptr=[2]),
 DataBatch(

## Adaptive-Step MAML

In [None]:
class GCN4MAML(nn.Module):
    """ Class for a Graph Convolutional Network used in AS-MAML. """
    def __init__(self) -> None:
        super(GCN4MAML, self).__init__()

class StopControl(nn.Module):
    def __init__(self, input_size: int, hidden_size: int) -> None:
        super(StopControl, self).__init__()
        self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
        self.output_layer = nn.Linear(hidden_size, 1)
        self.output_layer.bias.data.fill_(0.0)
        self.h_0 = nn.Parameter(torch.randn((hidden_size, ), requires_grad=True))
        self.c_0 = nn.Parameter(torch.randn((hidden_size, ), requires_grad=True))

    def forward(self, inputs, hx) -> torch.Tensor:
        if hx is None:
            hx = (self.h_0.unsqueeze(0), self.c_0.unsqueeze(0))
        
        h, c = self.lstm(inputs, hx)
        return torch.sigmoid(self.output_layer(h).unsqueeze(0)), (h, c)

class AdaptiveStepMAML(nn.Module):
    """ The Meta-Learner Class """
    def __init__(self, model, 
                       inner_lr    : float=1e-02, 
                       outer_lr    : float=1e-03, 
                       stop_lr     : float=1e-04, 
                       weight_decay: float=1e-05) -> None:

        self.net          = model
        self.inner_lr     = inner_lr
        self.outer_lr     = outer_lr
        self.stop_lr      = stop_lr
        self.weight_decay = weight_decay

        self.stop_prob = 0.5
        self.stop_gate = self.StopControl(2, 20)

        self.meta_optim = self.configure_optimizers()

        self.loss      = nn.BCEWithLogitsLoss()
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.meta_optim,
                                                          gamma=.5,
                                                          last_epoch=-1,
                                                          verbose=True)

    def configure_optimizers(self):
        return optim.Adam([
                           {'params': self.net.parameters(),       'lr': self.outer_lr},
                           {'params': self.stop_gate.parameters(), 'lr': self.stop_lr}],
                          lr=1e-04, weight_decay=self.weight_decay
               )
        
    def compute_loss(self, logits, label) -> float:
        return self.loss(logits.squeeze(), label.double().squeeze())

    