# Graph Data Augmentation for Few-Shot Graph Classification

In [None]:
import torch
TORCH = torch.__version__.split('+')[0]
CUDA = 'cu' + torch.version.cuda.replace('.','')

!pip install pytorch-lightning
!pip install pyyaml==5.4.1
!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

In [69]:
import os
from typing import Dict, List, Tuple, Union, \
                   Sequence, TypeVar, Generic, \
                   Optional
import networkx as nx
import numpy as np
import plotly
import plotly.express as px
import plotly.graph_objects as go

from torch.utils.data import Dataset
from torchvision import transforms

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, \
                                        ModelCheckpoint
import torch_geometric as gtorch
import torch_geometric.nn as gnn
import torch_geometric.data as gdata
import torch_geometric.loader as gloader
import torch_geometric.utils as gutils

import random
from collections import defaultdict

## Load dataset

The dataset I will load is the TRIANGLE used in the AS-MAML paper (here's the [link](https://arxiv.org/pdf/2003.08246.pdf))

In [61]:
ROOT_PATH = os.getcwd()
GRAPH_ATTRIBUTE  = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_graph_attributes.txt")
GRAPH_LABELS     = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_graph_labels.txt")
NODE_NATTRIBUTE  = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_node_attributes.txt")
GRAPH_INDICATOR  = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_graph_indicator.txt")
GRAPH_A          = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_A.txt")

T = TypeVar('T')

LOAD_DATASET = False

if LOAD_DATASET:
    !wget https://cloud-storage.eu-central-1.linodeobjects.com/TRIANGLES.zip
    !unzip TRIANGLES.zip

In [5]:
graph_attribute      = open(GRAPH_ATTRIBUTE).readlines()
graph_labels         = open(GRAPH_LABELS).readlines()
graph_node_attribute = open(NODE_NATTRIBUTE).readlines()
graph_indicator      = open(GRAPH_INDICATOR).readlines()
graph_a              = open(GRAPH_A).readlines()

In [6]:
class GeneratorTxt2Graph:
    def __init__(self, **kargs) -> None:
        self.__graph_attribute  = kargs['graph_attribute']
        self.__graph_labels     = kargs['graph_labels']
        self.__node_attribute   = kargs['node_attribute']
        self.__graph_indicator  = kargs['graph_indicator']
        self.__graph_adjacency  = kargs['graph_adjacency']
        self.__node_labels      = kargs['node_labels']
        self.__edge_labels      = kargs['edge_labels']
        self.__edge_attributes  = kargs['edge_attributes']

    def _collect_nodes(self) -> Tuple[Dict[str, List[int]], Dict[str, Tuple[str, int]]]:
        """ Look at the graph_indicator.txt file and return
        a dictionary containing as keys the ID of the graph
        and as values a list of nodes belonging to that graph """
        print("--- Collecting Nodes ... ")

        nodes, i_nodes = dict(), dict()
        for i, graph_id in enumerate(self.__graph_indicator):
            graph_id = graph_id[:-1]
            if graph_id not in nodes:
                nodes[graph_id] = []
        
            nodes[graph_id].append(i + 1)
            i_nodes[i + 1] = [graph_id, i + 1]
        
        return nodes, i_nodes

    def _collect_edges(self, i_nodes: Dict[str, Tuple[str, int]], 
                             direct: bool=False) -> Dict[str, List[Tuple[int, int]]]:
        """ Look at the graph_A.txt file and return a dictionary
        containing as keys the ID of the graph and as values
        a list of edges of that graph """
        print("--- Collecting Edges ...")

        edges = dict()
        for line in self.__graph_adjacency:
            if line == "\n":
                continue
            
            a, b = line.split(", ")
            a, b = a.strip(), b.strip()

            graph_a, node_a = i_nodes[int(a)]
            graph_b, node_b = i_nodes[int(b)]

            assert graph_a == graph_b, f"Two graphs are not equal: {graph_a} != {graph_b}"

            if graph_a not in edges:
                edges[graph_a] = []
            
            edges[graph_a].append((node_a, node_b))

        return edges
    
    def _collect_node_attributes(self, i_nodes: Dict[str, Tuple[str, int]]) -> None:
        """ Set attributes for each nodes """
        print("--- Collecting Node Attributes ...")
        for i, attr in enumerate(self.__node_attribute):
            node_i = i_nodes[i + 1]
            attrs = attr.split(", ")
            attrs[-1] = attrs[-1][:-1]
            node_i.append({f"attr{i}" : attr for i, attr in enumerate(attrs)})

    def _collect_graph_labels(self, graphs: Dict[str, nx.Graph]) -> None:
        """ Set the attribute label for each graph """
        print("--- Collecting Graph Labels ...")
        for i, label in enumerate(self.__graph_labels):
            graph_i = graphs[str(i + 1)]
            graphs[str(i + 1)] = (graph_i, label[:-1])

    # TODO: _collect_node_labels, _collect_edge_labels, _collect_edge_attributes

    def generate(self) -> Dict[str, nx.Graph]:
        """ Return a dictionary of {i : Graph_i} """
        # Get Nodes and Edges
        nodes, i_nodes = self._collect_nodes()
        edges          = self._collect_edges(i_nodes, False)

        # Set attributes for nodes
        self._collect_node_attributes(i_nodes)
        
        # Create the graphs
        graphs = dict()
        for graph_id in edges:
            g = nx.Graph()
            g_nodes = [(i_nodes[n][1], i_nodes[n][-1]) for n in nodes[graph_id]]
            g_edges = edges[graph_id]

            g.add_nodes_from(g_nodes)
            g.add_edges_from(g_edges)

            graphs[graph_id] = g

        # Set labels for graph
        self._collect_graph_labels(graphs)

        return graphs

In [7]:
graphs_gen = GeneratorTxt2Graph(graph_attribute=graph_attribute,
                                graph_labels=graph_labels,
                                node_attribute=graph_node_attribute,
                                graph_indicator=graph_indicator,
                                graph_adjacency=graph_a,
                                node_labels=None,
                                edge_labels=None,
                                edge_attributes=None)

In [8]:
%%time 
graphs = graphs_gen.generate()

--- Collecting Nodes ... 
--- Collecting Edges ...
--- Collecting Node Attributes ...
--- Collecting Graph Labels ...
CPU times: user 16.2 s, sys: 1.67 s, total: 17.9 s
Wall time: 18 s


In [9]:
def plot_graph(G : Union[nx.Graph, nx.DiGraph], name: str) -> None:
    """
    Plot a graph
    
    Parameters
    ----------
    graph : Union[nx.Graph, nx.DiGraph]
        Just a nx.Graph object
    name  : str
        The name of the graph
        
    Returns
    -------
    None
    """
    # Getting the 3D Spring layout
    layout = nx.spring_layout(G, dim=3, seed=18)
    
    # Getting nodes coordinate
    x_nodes = [layout[i][0] for i in layout]  # x-coordinates of nodes
    y_nodes = [layout[i][1] for i in layout]  # y-coordinates of nodes
    z_nodes = [layout[i][2] for i in layout]  # z-coordinates of nodes
    
    # Getting a list of edges and create a list with coordinates
    elist = G.edges()
    x_edges, y_edges, z_edges = [], [], []
    for edge in elist:
        x_edges += [layout[edge[0]][0], layout[edge[1]][0], None]
        y_edges += [layout[edge[0]][1], layout[edge[1]][1], None]
        z_edges += [layout[edge[0]][2], layout[edge[1]][2], None]

    colors = np.linspace(0, len(x_nodes))
        
    # Create a trace for the edges
    etrace = go.Scatter3d(x=x_edges,
                          y=y_edges,
                          z=z_edges,
                          mode='lines',
                          line=dict(color='rgb(125,125,125)', width=1),
                          hoverinfo='none'
                         )
    
    # Create a trace for the nodes
    ntrace = go.Scatter3d(x=x_nodes,
                          y=y_nodes,
                          z=z_nodes,
                          mode='markers',
                          marker=dict(
                              symbol='circle',
                              size=6,
                              color=colors,
                              colorscale='Viridis',
                              line=dict(color='rgb(50,50,50)', width=.5)),
                          text=list(layout.keys()),
                          hoverinfo='text'
                         )
    
    # Set the axis
    axis = dict(showbackground=False,
                showline=False,
                zeroline=False,
                showgrid=False,
                showticklabels=False,
                title='')
    
    # Create a layout for the plot
    go_layout = go.Layout(title=f"{name} Network Graph",
                          width=600,
                          height=600,
                          showlegend=False,
                          scene=dict(xaxis=dict(axis),
                                     yaxis=dict(axis),
                                     zaxis=dict(axis)),
                          margin=dict(t=100),
                          hovermode='closest'
                         )
    
    # Plot
    data = [etrace, ntrace]
    fig = go.Figure(data=data, layout=go_layout)
    fig.show()

In [None]:
plot_graph(graphs["1"][0], "1")

In [10]:
class GraphDataset(gdata.Dataset):
    def __init__(self, graphs_ds: Dict[str, Tuple[nx.Graph, str]]) -> None:
        super(GraphDataset, self).__init__()
        self.graphs_ds = graphs_ds

    def len(self) -> int:
        return len(self.graphs_ds.keys())

    def targets(self) -> torch.Tensor:
        """ Return all the labels """
        targets = []
        for _, graph in self.graphs_ds.items():
            targets.append(int(graph[1]))

        return torch.tensor(targets)
    
    def get(self, idx: int) -> gdata.Data:
        """ Return (Graph object, Adjacency matrix and label) of a graph """
        graph    = self.graphs_ds[str(idx + 1)]
        g, label = graph[0].to_directed(), graph[1]
        
        # Retrieve nodes attributes
        attrs = list(g.nodes(data=True))
        x     = torch.tensor([list(map(int, a.values())) for _, a in attrs], dtype=torch.float)

        # Retrieve edges
        edge_index = torch.tensor([list(e) for e in g.edges], dtype=torch.long) \
                          .t()                                                  \
                          .contiguous()

        # Retrieve ground trouth labels
        y = torch.tensor([int(label)], dtype=torch.int)

        return gdata.Data(x=x, edge_index=edge_index, y=y)

In [11]:
graph_dataset = GraphDataset(graphs)

In [42]:
class NWayKShotSampler(torch.utils.data.Sampler):
    """
    In few-shot classification, and in particular in Meta-Learning, 
    we use a specific way of sampling batches from the training/val/test 
    set. This way is called N-way-K-shot, where N is the number of classes 
    to sample per batch and K is the number of examples to sample per class 
    in the batch. The sample batch on which we train our model is also called 
    `support` set, while the one on which we test is called `query` set.

    This class is a N-way-K-shot sampler that will be used as a bath_sampler
    for the :obj:`torch_geometric.loader.DataLoader` dataloader. This sampler
    return batches of indices that correspond to support and query set batches.

    Attributes:
        labels: PyTorch tensor of the labels of the data elements
        n_way: Number of classes to sampler per batch
        k_shot: Number of examples to sampler per class in the batch
        include_query: If True, returns batch of size N * K * 2, which
                       can be split into support and query set. Simplifies
                       the implementation of sampling the same classes but
                       distinct examples for support and query set.
        shuffle: If True, examples and classes are shuffled at each iteration
        shuffle_once: If True, examples and classes are shuffled only once
                      and remains constant across each iterations.
        batch_size: The size of the batch N * K. If include_query is True
                    then it will be N * K * 2.
        indices_per_class: How many indices per classes
        batches_per_class: Number of K-shot batches for each classes
        iterations: Number of iterations
        class_list: Contains all the classes the number of times
                    defined in batches_per_classes
    """
    def __init__(self, labels       : torch.Tensor, 
                       n_way        : int, 
                       k_shot       : int, 
                       include_query: bool=False, 
                       shuffle      : bool=True, 
                       shuffle_once : bool=False) -> None:
        super().__init__(None)
        self.labels        = labels
        self.n_way         = n_way
        self.k_shot        = k_shot
        self.shuffle       = shuffle
        self.shuffle_once  = shuffle_once
        self.include_query = include_query

        if include_query: 
            self.k_shot *= 2

        self.batch_size = self.n_way * self.k_shot

        self.indices_per_class = dict()
        self.batches_per_class = dict()

        self.classes = torch.unique(self.labels).tolist()
        for cl in self.classes:
            self.indices_per_class[cl] = torch.where(self.labels == cl)[0]
            self.batches_per_class[cl] = self.indices_per_class[cl].shape[0] // self.k_shot
        
        self.iterations = sum(self.batches_per_class.values()) // self.n_way
        self.class_list = [cl for cl in self.classes for _ in range(self.batches_per_class[cl])]

        if self.shuffle_once or self.shuffle:
            self.shuffle_data()
    
    def shuffle_data(self) -> None:
        """
        Shuffle the examples per class
        
        Args:
            classes: The list of all classes
        """
        for cl in self.classes:
            perm = torch.randperm(self.indices_per_class[cl].shape[0])
            self.indices_per_class[cl] = self.indices_per_class[cl][perm]
        
        # Finally shuffle the class list from which we sample
        random.shuffle(self.class_list)

    def __iter__(self) -> List[torch.Tensor]:
        # Shuffle data
        if self.shuffle:
            self.shuffle_data()

        # Sample
        indexes = defaultdict(int)
        for it in range(self.iterations):
            # Select N classes for the batch
            class_batch = self.class_list[it * self.n_way:(it + 1) * self.n_way]
            index_batch = []

            for cl in class_batch:
                idx = indexes[cl]
                index_batch += [self.indices_per_class[cl][idx:idx + self.k_shot]]
                indexes[cl] += self.k_shot
            
            if self.include_query:
                # Include in the index_batch also the query_set
                index_batch = index_batch[::2] + index_batch[1::2]
            
            yield index_batch

    def __len__(self) -> int:
        return self.iterations

In [87]:
# Needs to create a new dataset with a new collater
class GraphCollater(gloader.dataloader.Collater):
    def __init__(self, *args) -> None:
        super(GraphCollater, self).__init__(*args)
    
    def __call__(self, batch: Generic[T]) -> Generic[T]:
        elem = batch[0]
        if isinstance(elem, GraphDataset):
            return self([[el[j] for j in range(len(el))] for el in batch])
        
        return super(GraphCollater, self).__call__(batch)


class NKDataLoader(torch.utils.data.DataLoader):
    def __init__(self, dataset     : GraphDataset, 
                       batch_size  : int=1,
                       shuffle     : bool=False,
                       follow_batch: Optional[List[str]]=None,
                       exclude_keys: Optional[List[str]]=None,
                       **kwargs) -> None:

        if 'collate_fn' in kwargs:
            del kwargs['collate_fn']

        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        super(NKDataLoader, self).__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=GraphCollater(follow_batch, 
                                     exclude_keys),
            **kwargs,
        )

In [88]:
N_WAY  = 5
K_SHOT = 4

graph_train_loader = NKDataLoader(
    graph_dataset,
    batch_sampler=NWayKShotSampler(graph_dataset.targets(),
                                   include_query=True,
                                   n_way=N_WAY,
                                   k_shot=K_SHOT,
                                   shuffle=True)
)

## Adaptive-Step MAML