# Few-Shot Graph Classification

Most of the graph classification task overlook the scarcity of labeled graph in many situations. To overcome this problem, *Few-Shot Learning* is started being used. It is a type of Machine Learning method where the training dataset contains limited information. The general practice is to feed the machine learning model with as much data as possible, since this leads to better predictions. However, few-shot learning aims to build accurate machine learning models with less training data. Few-Shot Learning, and in particular in this case Few-shot classification, aims to reduce the cost of gain and label a huge amount of data.

*Which is the idea behind Few-Shot Learning*? (on graphs) Given graph data $\mathcal{G} = \{(G_1, \mathbf{y}_1), ..., (G_n, \mathbf{y}_n)\}$, we split it into train, $\{(G^{train}, \mathbf{y}^{train})\}$, and test dataset, $\{(G^{test}, \mathbf{y}^{test})\}$. Notice that $\mathbf{y}^{train}$ and $\mathbf{y}^{test}$ must have no common classes. For training we use episodic training method, this means that at training stage the algorithm sample a so-called *Task*, i.e., a pair (*support* set, *query* set) where the support set is $D_{sup}^{train} = \{(G_i^{train}, \mathbf{y}_{i}^{train})\}_{i=1}^s$, where $s = N \times K$, while the query set is $D_{que}^{train} = \{(G_i^{train}, \mathbf{y}_{i}^{train})\}_{i=1}^q$, where $q$ is the number of query data. Given labeled support data, the goal is to predict the labels of query data. Note that in a single task, support data and query data share the same class space. This is also called **N-way-K-shot** learning, where **N** is the number of sampled classes and **K** is the number of samples for each of the N classes. At test stage when performing classification tasks on unseen classes, we firstly fine tune the meta-learner on the support data of test classes, then report classification performance on the test query set.

In the following, I'm going to present some approaches in few-shot Learning. First, a *Meta-Learning Framework* based on Fast Weight Adaptation, taken from the paper [Adaptive-Step Graph Meta-Learner for Few-Shot Graph Classification](https://arxiv.org/pdf/2003.08246.pdf) (Ning Ma et al.). Second, I'm going to compare it with different GDA (graph data augmentation) techniques used to enrich the dataset for the novel classes (i.e., those with the less amount of data) taken from a second paper named [Graph Data Augmentation for Graph Machine Learning: A Survey](https://arxiv.org/pdf/2202.08871.pdf) (Tong Zhao et al.).

## Modules and Constants

In [None]:
import torch
TORCH = torch.__version__.split('+')[0]
CUDA = 'cu' + torch.version.cuda.replace('.','')

!pip install pytorch-lightning
!pip install pyyaml==5.4.1
!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

In [None]:
from typing import (
    Any, Dict, List, Tuple, 
    Union, Generic, Optional,
    TypeVar
)

from tqdm.notebook import tqdm
from functools import wraps
import plotly.graph_objects as go
import networkx as nx
import numpy as np
import pickle
import os
import shutil
import logging
import random
import time
import requests
import zipfile
import math
import sys

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter

import torch_geometric.data as gdata
import torch_geometric.loader as gloader
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.nn.inits import uniform
from torch_geometric.nn.pool.topk_pool import topk, filter_adj
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_goemetric.utils import (
    add_remaining_self_loops, 
    add_self_loops, 
    remove_self_loops,
    softmax
)

from torch_scatter import scatter_add

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

In [None]:
TRIANGLES_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/TRIANGLES.zip"
COIL_DEL_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/COIL-DEL.zip"
R52_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/R52.zip"
LETTER_HIGH_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/Letter-High.zip"

DATASETS = {
    "TRIANGLES"   : TRIANGLES_ZIP_URL, 
    "COIL-DEL"    : COIL_DEL_ZIP_URL, 
    "R52"         : R52_ZIP_URL, 
    "Letter-High" : LETTER_HIGH_ZIP_URL
}

T = TypeVar('T')

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DOWNLOAD_DATASET = True
SAVE_PICLKE  = True
EDGELIMIT_PRINT = 2000

NUM_FEATURES = {"TRIANGLES": 1, "R52": 1, "Letter-High": 2, "COIL-DEL": 2}


class ASMAMLConfig:
    NHID = 128
    POOLING_RATIO = 0.5
    DROPOUT_RATIO = 0.3

    OUTER_LR     = 0.001
    INNER_LR     = 0.01
    STOP_LR      = 0.0001
    WEIGHT_DECAY = 1E-05

    MAX_STEP      = 15
    MIN_STEP      = 5
    STEP_TEST     = 15
    FLEXIBLE_STEP = True
    STEP_PENALITY = 0.001
    USE_SCORE     = True
    USE_GRAD      = False
    USE_LOSS      = True

    TRAIN_SHOT         = 10   # K-shot for training set
    VAL_SHOT           = 10   # K-shot for validation (or test) set
    TRAIN_QUERY        = 15   # Number of query for the training set
    VAL_QUERY          = 15   # Number of query for the validation (or test) set
    TRAIN_WAY          = 3    # N-way for training set
    TEST_WAY           = 3    # N-way for test set
    VAL_EPISODE        = 200  # Number of episodes for validation
    TRAIN_EPISODE      = 200  # Number of episodes for training
    BATCH_PER_EPISODES = 5    # How many batch per episode
    EPOCHS             = 500  # How many epochs
    PATIENCE           = 35
    GRAD_CLIP          = 5

    # Stop Control configurations
    STOP_CONTROL_INPUT_SIZE = 2
    STOP_CONTROL_HIDDEN_SIZE = 20

## Utility Functions

In [None]:
def scandir(root_path: str) -> List[str]:
    """Recursively scan a directory looking for files"""
    root_path = os.path.abspath(root_path)
    content = []
    for file in os.listdir(root_path):
        new_path = os.path.join(root_path, file)
        if os.path.isfile(new_path):
            content.append(new_path)
            continue
        
        content += scandir(new_path)
    
    return content


def download_zipped_data(url: str, path2extract: str, dataset_name: str) -> List[str]:
    """Download and extract a ZIP file from URL. Return the content filename"""
    logging.debug(f"--- Downloading from {url} ---")
    response = requests.get(url)

    abs_path2extract = os.path.abspath(path2extract)
    zip_path = os.path.join(abs_path2extract, f"{dataset_name}.zip")
    with open(zip_path, mode="wb") as iofile:
        iofile.write(response.content)

    # Extract the file
    logging.debug("--- Extracting files from the archive ---")
    with zipfile.ZipFile(zip_path, mode="r") as zip_ref:
        zip_ref.extractall(abs_path2extract)

    logging.debug(f"--- Removing {zip_path} ---")
    os.remove(zip_path)

    return scandir(os.path.join(path2extract, dataset_name))


def delete_data_folder(path2delete: str) -> None:
    """Delete the folder containing data"""
    logging.debug("--- Removing Content Data ---")
    shutil.rmtree(path2delete)
    logging.debug("--- Removed Finished Succesfully ---")

In [None]:
def elapsed_time(func):
    """Just a simple wrapper for counting elapsed time from start to end"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        func(*args, **kwargs)
        end = time.time()
        logging.debug("Elapsed Time: {:.6f}".format(end - start))
    
    return wrapper

In [None]:
def save_with_pickle(path2save: str, content: Any) -> None:
    """Save content inside a .pickle file denoted by path2save"""
    path2save = path2save + ".pickle" if ".pickle" not in path2save else path2save
    with open(path2save, mode="wb") as iostream:
        pickle.dump(content, iostream)


def load_with_pickle(path2load: str) -> Any:
    """Load a content from a .pickle file"""
    with open(path2load, mode="rb") as iostream:
        return pickle.load(iostream)

In [None]:
def setup_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [None]:
def plot_graph(G : Union[nx.Graph, nx.DiGraph], name: str) -> None:
    """
    Plot a graph
    
    Parameters
    ----------
    graph : Union[nx.Graph, nx.DiGraph]
        Just a nx.Graph object
    name  : str
        The name of the graph
        
    Returns
    -------
    None
    """
    # Getting the 3D Spring layout
    layout = nx.spring_layout(G, dim=3, seed=18)
    
    # Getting nodes coordinate
    x_nodes = [layout[i][0] for i in layout]  # x-coordinates of nodes
    y_nodes = [layout[i][1] for i in layout]  # y-coordinates of nodes
    z_nodes = [layout[i][2] for i in layout]  # z-coordinates of nodes
    
    # Getting a list of edges and create a list with coordinates
    elist = G.edges()
    x_edges, y_edges, z_edges = [], [], []
    for edge in elist:
        x_edges += [layout[edge[0]][0], layout[edge[1]][0], None]
        y_edges += [layout[edge[0]][1], layout[edge[1]][1], None]
        z_edges += [layout[edge[0]][2], layout[edge[1]][2], None]

    colors = np.linspace(0, len(x_nodes))
        
    # Create a trace for the edges
    etrace = go.Scatter3d(x=x_edges,
                          y=y_edges,
                          z=z_edges,
                          mode='lines',
                          line=dict(color='rgb(125,125,125)', width=1),
                          hoverinfo='none'
                         )
    
    # Create a trace for the nodes
    ntrace = go.Scatter3d(x=x_nodes,
                          y=y_nodes,
                          z=z_nodes,
                          mode='markers',
                          marker=dict(
                              symbol='circle',
                              size=6,
                              color=colors,
                              colorscale='Viridis',
                              line=dict(color='rgb(50,50,50)', width=.5)),
                          text=list(layout.keys()),
                          hoverinfo='text'
                         )
    
    # Set the axis
    axis = dict(showbackground=False,
                showline=False,
                zeroline=False,
                showgrid=False,
                showticklabels=False,
                title='')
    
    # Create a layout for the plot
    go_layout = go.Layout(title=f"{name} Network Graph",
                          width=600,
                          height=600,
                          showlegend=False,
                          scene=dict(xaxis=dict(axis),
                                     yaxis=dict(axis),
                                     zaxis=dict(axis)),
                          margin=dict(t=100),
                          hovermode='closest'
                         )
    
    # Plot
    data = [etrace, ntrace]
    fig = go.Figure(data=data, layout=go_layout)
    fig.show()

In [None]:
def rename_edge_indexes(data_list: List[gdata.Data]) -> List[gdata.Data]:
    """
    Takes as input a bunch of :obj:`torch_geometric.data.Data` and renames
    each edge node (x, y) from 1 to total number of nodes. For instance, if we have
    this edge_index = [[1234, 1235, 1236, 1237], [1238, 1239, 1230,1241]] this became
    egde_index = [[0, 1, 2, 3],[4, 5, 6, 7]] and so on. 

    :param data_list: the list of :obj:`torch_geometric.data.Data`
    :return: a new list of data
    """
    # First of all let's compute the total number of nodes overall
    total_number_nodes = 0
    for data in data_list:
        total_number_nodes += data.x.shape[0]
    
    # Generate the new nodes
    nodes = torch.arange(0, total_number_nodes)
    
    # Takes the old nodes from the edge_index attribute
    old_nodes = None
    for data in data_list:
        x, y = data.edge_index
        x = torch.hstack((x, y)).unique(sorted=False)
        
        if old_nodes is None:
            old_nodes = x
            continue
    
        old_nodes = torch.hstack((old_nodes, x))
    
    # Create mapping from old to new nodes
    mapping = dict(zip(old_nodes.tolist(), nodes.tolist()[:old_nodes.shape[0]]))
    
    # Finally, map the new nodes
    for data in data_list:
        x, y = data.edge_index
        new_x = torch.tensor(list(map(lambda x: mapping[x], x.tolist())), dtype=x.dtype, device=x.device)
        new_y = torch.tensor(list(map(lambda y: mapping[y], y.tolist())), dtype=y.dtype, device=y.device)
        new_edge_index = torch.vstack((new_x, new_y))
        data.edge_index = new_edge_index
    
    return data_list


def data_batch_collate(data_list: List[gdata.Data]) -> gdata.Data:
    """
    Takes as input a list of data and create a new :obj:`torch_geometric.data.Data`
    collating all together. This is a replacement for torch_geometric.data.Batch.from_data_list

    :param data_list: a list of torch_geometric.data.Data objects
    :return: a new torch_geometric.data.Data object
    """
    x = None
    edge_index = None
    batch = []
    num_graphs = 0
    y = None

    # Do a shuffle of the data
    random.shuffle(data_list)
    
    for i_data, data in enumerate(data_list):
        x = data.x if x is None else torch.vstack((x, data.x))
        edge_index = data.edge_index if edge_index is None else torch.hstack((edge_index, data.edge_index))
        batch += [i_data] * data.x.shape[0]
        num_graphs += 1
        y = data.y if y is None else torch.hstack((y, data.y))

    # Create a mapping between y and a range(0, num_classes_of_y)
    # First we need to compute how many classes do we have
    num_classes = y.unique().shape[0]
    classes = list(range(0, num_classes))
    mapping = dict(zip(y.unique(sorted=False).tolist(), classes))
    
    # This mapping is necessary when computing the cross-entropy-loss
    new_y = torch.tensor(list(map(lambda x: mapping[x], y.tolist())), dtype=y.dtype, device=y.device)
    
    data_batch = gdata.Data(
        x=x, edge_index=edge_index, batch=torch.tensor(batch),
        y=new_y, num_graphs=num_graphs, old_classes_mapping=mapping
    )

    return data_batch


def task_sampler_uncollate(task_sampler: 'TaskBatchSampler', data_batch: gdata.Batch):
    """
    Takes as input the task sampler and a batch containing both the 
    support and the query set. It returns two different DataBatch
    respectively for support and query_set.

    Assume L = [x1, x2, x3, ..., xN] is the data_batch
    each xi is a graph. Moreover, we have that
    L[0:K] = support sample for the first class
    L[K+1:K+Q] = query sample for the first class
    In general, we have that 

            L[i * (K + Q) : (i + 1) * (K + Q)]

    is the (support, query) pair for the i-th class
    Finally, the first batch is the one that goes from
    L[0 : N * (K + Q)], so

            L[i * N * (K + Q) : (i + 1) * N * (K + Q)]

    is the i-th batch.

    :param task_sampler: The task sampler
    :param data_batch: a batch with support and query set
    :return: support batch, query batch
    """
    n_way = task_sampler.task_sampler.n_way
    k_shot = task_sampler.task_sampler.k_shot
    n_query = task_sampler.task_sampler.n_query
    task_batch_size = task_sampler.task_batch_size

    total_support_query_number = n_way * (k_shot + n_query)
    support_plus_query = k_shot + n_query

    # Initialize batch list for support and query set
    support_data_batch = []
    query_data_batch = []

    # I know how many batch do I have, so
    for batch_number in range(task_batch_size):

        # I also know how many class do I have in a task
        for class_number in range(n_way):

            # First of all let's take the i-th batch
            data_batch_slice = slice(
                batch_number * total_support_query_number,
                (batch_number + 1) * total_support_query_number
            )
            data_batch_per_batch = data_batch[data_batch_slice]

            # Then let's take the (support, query) pair for a class
            support_query_slice = slice(
                class_number * support_plus_query,
                (class_number + 1) * support_plus_query
            )
            support_query_data = data_batch_per_batch[support_query_slice]

            # Divide support from query
            support_data = support_query_data[:k_shot]
            query_data = support_query_data[k_shot:support_plus_query]

            support_data_batch += support_data
            query_data_batch += query_data
    
    # Rename the edges
    support_data = data_batch_collate(rename_edge_indexes(support_data_batch))
    query_data   = data_batch_collate(rename_edge_indexes(query_data_batch))

    # Create new DataBatchs and return
    return support_data, query_data

In [None]:
def get_max_acc(accs, step, scores, min_step, test_step):
    step = np.argmax(scores[min_step - 1 : test_step]) + min_step - 1
    return accs[step]


def get_batch_number(databatch, i_batch, n_way, k_shot):
    """From a N batch takes the i-th batch"""
    dim_databatch = n_way * k_shot
    indices = torch.arange(0, config.BATCH_PER_EPISODES)
    return gdata.Batch.from_data_list(databatch[indices * dim_databatch + i_batch])


def glorot(tensor):
    """Apply the Glorot NN initialization (also called Xavier)"""
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)


def zeros(tensor):
    """Fill a tensor with zeros if it is not Null"""
    if tensor is not None:
        tensor.data.fill_(0)

## Dataset, Sampler and DataLoader

### The Dataset

I decided to use the same datasets considered in the paper for AS-MAML: TRIANGLES, COIL-DEL, R52 and Letter-High. All of them can be downloaded directly from this [page](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets), which is the origin of these datasets. Downloading from the previous page will result in a ZIP file with: 

- `<dataname>_node_attributes.txt` with the attribute vector for each node of each graph
- `<dataname>_graph_labels.txt` with the class for each graph
- `<dataname>_graph_edges.txt` with the edges for each graph expressed as a pair (nodex, nodey)
- `<dataname>_graph_indicator.txt` that maps each nodes to its corresponding graph

Each of the dataset has been splitted into *train*, *test* and *validation*, and transformed into a python dictionaries finally saved as `.pickle` files. In this way we have a ready-to-be-used dataset. Moreover, each ZIP dataset containes three files:

- `<dataname>_node_attributes.pickle` with the node attributes saved as a List or a torch Tensor
- `<dataname>_train_set.pickle` with all the train data as python dictionaries
- `<dataname>_test_set.pickle` with all the test data as python dictionaries
- `<dataname>_val_set.pickle` with all the validation data as python dictionaries

These are the link from which you can download the datasets: [TRIANGLES](https://drive.google.com/drive/folders/1na8l6DV7qtYIoteFGIp9p7VfQNjmSQxx?usp=sharingwith), [COIL-DEL](https://drive.google.com/drive/folders/1Cq2quq4XNLL91WlwXgXVx3kH_h3_RL9_?usp=sharing), [R52](https://drive.google.com/drive/folders/1pjh1GHn733xb-msqmVP2voZ_IWKKiEYg?usp=sharing) and [Letter-High]("https://cloud-storage.eu-central-1.linodeobjects.com/Letter-High.zip").

In [None]:
# An Example of dataset. In this case the TRIANGLES
dataset_name = "TRIANGLES"
download_folder = os.getcwd()

node_attribute, _, train_file, _ = download_zipped_data(
    DATASETS[dataset_name], 
    download_folder, 
    dataset_name
)

data_dir = "/".join(node_attribute.split("/")[:-2])

In [None]:
print("Node Attributes Filename --- ", node_attribute)
print("Train Set Filename --- ", train_file)

In [None]:
print("=== Node Attribute Content === ")

# Convert to torch.Tensor for a pretty printing
node_attribute_content = load_with_pickle(node_attribute)
if isinstance(list, node_attribute_content):
    node_attribute_content = torch.tensor(node_attribute_content)

print(node_attribute_content)

as I said the train (or test or validation) set is python dictionary, with three keys: `label2graphs`, mapping each label to a list of corresponding graphs, `graph2nodes`, mapping graphs to their nodes, `graph2edges`, mapping graphs to their egdes (list of nodes pair).

In [None]:
print("=== Train Set Content === ")

train_set_content = load_with_pickle(train_file)

print("Keys --- ", train_set_content.keys())

label2graphs = train_set_content["label2graphs"]
graph2nodes = train_set_content["graph2nodes"]
graph2edges = train_set_content["graph2edges"]

print("Label2graph example --- ", label2graphs[1])
print("Graph2nodes example --- ", graph2nodes[1])
print("Graph2edges example --- ", graph2edges[1])

The way I handled datasets is different from the one used in the AS-MAML paper. I decided to represent the dataset in python using the class `GraphDataset` that inherit properties and methods from the base class `torch_geometric.data.Dataset`. It is an iterable class and each element (each graph) is of type `torch_geometric.data.Data`. That is, each graph is a `data = Data(x=..., edge_index=..., y=...)`, where `data.x` is a `torch.Tensor` (with dim $\mathtt{n\_attribute} \times 1$) representing the attribute vector of all nodes in the graph, `data.edge_index` is a `torch.Tensor` (with dim $2 \times \mathtt{n\_edges}$) representing the edges of the graph, and finally `data.y` is a `torch.Tensor` (with dim 0) representing the class of that graph.

In [None]:
class GraphDataset(gdata.Dataset):
    def __init__(self, graphs_ds: Dict[str, Tuple[nx.Graph, str]]) -> None:
        super(GraphDataset, self).__init__()
        self.graphs_ds = graphs_ds

    @classmethod
    def get_dataset(cls, attributes: List[Any], data: Dict[str, Any]) -> 'GraphDataset':
        """
        Returns a new instance of GraphDataset filled with graphs inside data. 'attributes'
        is the list with all the attributes (not only those beloging to nodes in 'data').

        :param data: a dictionary with label2graphs, graph2nodes and graph2edges
        :param attributes: a list with node attributes
        :return: a new instance of GraphDataset
        """
        graphs = dict()

        label2graphs = data["label2graphs"]
        graph2nodes  = data["graph2nodes"]
        graph2edges  = data["graph2edges"]

        for label, graph_list in label2graphs.items():
            for graph_id in graph_list:
                graph_nodes = graph2nodes[graph_id]
                graph_edges = graph2edges[graph_id]
                nodes_attributes = [[attributes[node_id - 1]] for node_id in graph_nodes]
                nodes = []
                for node, attribute in zip(graph_nodes, nodes_attributes):
                    nodes.append((node, {f"attr{i}" : a for i, a in enumerate(attribute)}))

                g = nx.Graph()
                g.add_edges_from(graph_edges)
                g.add_nodes_from(nodes)
            
                graphs[graph_id] = (g, label)

        graph_dataset = super(GraphDataset, cls).__new__(cls)
        graph_dataset.__init__(graphs)

        return graph_dataset

    def __repr__(self) -> str:
        return f"GraphDataset(classes={set(self.targets().tolist())},n_graphs={self.len()})"

    def indices(self) -> List[str]:
        """ Return all the graph IDs """
        return list(self.graphs_ds.keys())

    def len(self) -> int:
        return len(self.graphs_ds.keys())

    def targets(self) -> torch.Tensor:
        """ Return all the labels """
        targets = []
        for _, graph in self.graphs_ds.items():
            targets.append(int(graph[1]))

        return torch.tensor(targets)

    def get(self, idx: Union[int, str]) -> gdata.Data:
        """ Return (Graph object, Adjacency matrix and label) of a graph """
        if isinstance(idx, str):
            idx = int(idx)

        graph = self.graphs_ds[idx]
        g, label = graph[0].to_directed(), graph[1]

        # Retrieve nodes attributes
        attrs = list(g.nodes(data=True))
        x = torch.tensor([list(map(int, a.values())) for _, a in attrs], dtype=torch.float)

        # Retrieve edges
        edge_index = torch.tensor([list(e) for e in g.edges], dtype=torch.long) \
                          .t()                                                  \
                          .contiguous()

        # Retrieve ground trouth labels
        y = torch.tensor([int(label)], dtype=torch.int)

        return gdata.Data(x=x, edge_index=edge_index, y=y)


def generate_train_val_test(dataset_name: str,
                            data_dir: Optional[str]=None, 
                            download: bool=True,
                            download_folder: str="../data"
) -> Tuple[GraphDataset, GraphDataset, GraphDataset]:
    """ Return dataset for training, validation and testing """
    logging.debug("--- Generating Train, Test and Validation datasets --- ")
    
    assert download or data_dir is not None, "At least one between: data_dir and download must be given"

    node_attribute = None
    test_file = None
    train_file = None
    val_file = None

    if data_dir is not None:
        node_attribute = os.path.join(data_dir, f"{dataset_name}/{dataset_name}_node_attributes.pickle")
        test_file = os.path.join(data_dir, f"{dataset_name}/{dataset_name}_test_set.pickle")
        train_file = os.path.join(data_dir, f"{dataset_name}/{dataset_name}_train_set.pickle")
        val_file = os.path.join(data_dir, f"{dataset_name}/{dataset_name}_val_set.pickle")

    if download:
        node_attribute, test_file, train_file, val_file = download_zipped_data(
            DATASETS[dataset_name], 
            download_folder, 
            dataset_name
        )

        data_dir = "/".join(node_attribute.split("/")[:-2])

    node_attribute_data = load_with_pickle(node_attribute)
    test_data = load_with_pickle(test_file)
    train_data = load_with_pickle(train_file)
    val_data = load_with_pickle(val_file)

    train_ds = GraphDataset.get_dataset(node_attribute_data, train_data)
    test_ds  = GraphDataset.get_dataset(node_attribute_data,  test_data)
    val_ds   = GraphDataset.get_dataset(node_attribute_data,   val_data)

    return train_ds, test_ds, val_ds, data_dir

### The Samplers

Since we need a specific way to sample from the dataset, in particular the N-way-K-shot (both for support and query set), I encountered the needs of create two samplers: `FewShotSampler` and `TaskBatchSampler`, both inheriting from `torch.utils.data.Sampler`. The former, returns a list of indices indicating which graphs belongs to a single N-way-K-shot sample. The second, just iteratively sampling from `FewShotSampler`, creates mini-batches according to the wanted number of N-way-K-shot sample the user wants in a single batch. The `TaskBatchSampler` is used for the `batch_sampler` argument of the DataLoader.

In [None]:
class FewShotSampler(torch.utils.data.Sampler):
    """
    In few-shot classification, and in particular in Meta-Learning, 
    we use a specific way of sampling batches from the training/val/test 
    set. This way is called N-way-K-shot, where N is the number of classes 
    to sample per batch and K is the number of examples to sample per class 
    in the batch. The sample batch on which we train our model is also called 
    `support` set, while the one on which we test is called `query` set.

    This class is a N-way-K-shot sampler that will be used as a batch_sampler
    for the :obj:`torch_geometric.loader.DataLoader` dataloader. This sampler
    return batches of indices that correspond to support and query set batches.

    Attributes:
        labels: PyTorch tensor of the labels of the data elements
        n_way: Number of classes to sampler per batch
        k_shot: Number of examples to sampler per class in the batch
        n_query: Number of query example to sample per class in the batch
        shuffle: If True, examples and classes are shuffled at each iteration
        indices_per_class: How many indices per classes
        classes: list of all classes
        epoch_size: number of batches per epoch
    """

    def __init__(self, labels: torch.Tensor,
                 n_way: int,
                 k_shot: int,
                 n_query: int,
                 epoch_size: int,
                 shuffle: bool = True) -> None:
        super().__init__(None)
        self.labels = labels
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.shuffle = shuffle
        self.epoch_size = epoch_size

        self.classes = torch.unique(self.labels).tolist()
        self.indices_per_class = dict()
        for cl in self.classes:
            self.indices_per_class[cl] = torch.where(self.labels == cl)[0]

    def shuffle_data(self) -> None:
        """
        Shuffle the examples per class

        Args:
            classes: The list of all classes
        """
        for cl in self.classes:
            perm = torch.randperm(self.indices_per_class[cl].shape[0])
            self.indices_per_class[cl] = self.indices_per_class[cl][perm]

    def __iter__(self) -> List[torch.Tensor]:
        # Shuffle the data
        if self.shuffle:
            self.shuffle_data()

        target_classes = random.sample(self.classes, self.n_way)
        for _ in range(self.epoch_size):
            n_way_k_shot_n_query = []
            for cl in target_classes:
                labels_per_class = self.indices_per_class[cl]
                assert len(labels_per_class) >= self.k_shot + self.n_query
                selected_data = random.sample(
                    labels_per_class.tolist(), self.k_shot + self.n_query)
                n_way_k_shot_n_query.append(selected_data)

            yield torch.tensor(n_way_k_shot_n_query)

    def __len__(self) -> int:
        return self.epoch_size


class TaskBatchSampler(torch.utils.data.Sampler):
    """Sample a batch of tasks"""

    def __init__(self, dataset_targets: torch.Tensor,
                       batch_size: int,
                       n_way: int,
                       k_shot: int,
                       n_query: int,
                       epoch_size: int,
                       shuffle: bool = True) -> None:

        super().__init__(None)
        self.task_sampler = FewShotSampler(
            dataset_targets,
            n_way=n_way,
            k_shot=k_shot,
            n_query=n_query,
            epoch_size=epoch_size,
            shuffle=shuffle
        )

        self.task_batch_size = batch_size

    def __iter__(self):
        mini_batches = []
        for task_idx, task in enumerate(self.task_sampler):
            mini_batches.extend(task.tolist())
            if (task_idx + 1) % self.task_batch_size == 0:
                yield torch.tensor(mini_batches).flatten().tolist()
                mini_batches = []

    def __len__(self):
        return len(self.task_sampler) // self.task_batch_size

    def uncollate(self, data_batch):
        """Invoke the uncollate from utils.utils"""
        return task_sampler_uncollate(self, data_batch)

### The DataLoader

In this case each element of the dataset is a `torch_geometric.data.Data` and not just a `torch.Tensor`. For this reason, I decided to create a simple custom dataloader called `FewShotDataLoader` that inherit from `torch.utils.data.DataLoader`. Moreover, there is another problem: `GraphDataset` is not a known type for the default *collate* of PyTorch, or PyTorch-Geometric. So, I created my own collater called `GraphCollater` to manages this situation. 

In [None]:
class GraphCollater(gloader.dataloader.Collater):
    """A Collater to handle batches of GraphDataset instances"""
    def __init__(self, *args) -> None:
        super(GraphCollater, self).__init__(*args)

    def __call__(self, batch: Generic[T]) -> Generic[T]:
        elem = batch[0]

        # All elements inside batch is the just a
        # repetition of the first element, for this 
        # reason we can keep just the first one
        if isinstance(elem, GraphDataset):
            return self(elem)

        return super(GraphCollater, self).__call__(batch)


class FewShotDataLoader(torch.utils.data.DataLoader):
    """Custom DataLoader for GraphDataset"""

    def __init__(self, dataset: GraphDataset,
                 batch_size: int = 1,
                 shuffle: bool = False,
                 follow_batch: Optional[List[str]] = None,
                 exclude_keys: Optional[List[str]] = None,
                 **kwargs) -> None:

        if 'collate_fn' in kwargs:
            del kwargs["collate_fn"]

        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        # Take the batch sampler
        self.batch_sampler = kwargs["batch_sampler"]

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=GraphCollater(follow_batch, exclude_keys),
            **kwargs,
        )

    def __iter__(self):
        for x in super().__iter__():
            support_batch, query_batch = self.batch_sampler.uncollate(x)
            yield support_batch, query_batch