# Few-Shot Graph Classification

Most of the graph classification task overlook the scarcity of labeled graph in many situations. To overcome this problem, *Few-Shot Learning* is started being used. It is a type of Machine Learning method where the training dataset contains limited information. The general practice is to feed the machine learning model with as much data as possible, since this leads to better predictions. However, few-shot learning aims to build accurate machine learning models with less training data. Few-Shot Learning, and in particular in this case Few-shot classification, aims to reduce the cost of gain and label a huge amount of data.

*Which is the idea behind Few-Shot Learning*? (on graphs) Given graph data $\mathcal{G} = \{(G_1, \mathbf{y}_1), ..., (G_n, \mathbf{y}_n)\}$, we split it into train, $\{(G^{train}, \mathbf{y}^{train})\}$, and test dataset, $\{(G^{test}, \mathbf{y}^{test})\}$. Notice that $\mathbf{y}^{train}$ and $\mathbf{y}^{test}$ must have no common classes. For training we use episodic training method, this means that at training stage the algorithm sample a so-called *Task*, i.e., a pair (*support* set, *query* set) where the support set is $D_{sup}^{train} = \{(G_i^{train}, \mathbf{y}_{i}^{train})\}_{i=1}^s$, where $s = N \times K$, while the query set is $D_{que}^{train} = \{(G_i^{train}, \mathbf{y}_{i}^{train})\}_{i=1}^q$, where $q$ is the number of query data. Given labeled support data, the goal is to predict the labels of query data. Note that in a single task, support data and query data share the same class space. This is also called **N-way-K-shot** learning, where **N** is the number of sampled classes and **K** is the number of samples for each of the N classes. At test stage when performing classification tasks on unseen classes, we firstly fine tune the meta-learner on the support data of test classes, then report classification performance on the test query set.

In the following, I'm going to present some approaches in few-shot Learning. First, a *Meta-Learning Framework* based on Fast Weight Adaptation, taken from the paper [Adaptive-Step Graph Meta-Learner for Few-Shot Graph Classification](https://arxiv.org/pdf/2003.08246.pdf) (Ning Ma et al.). Second, I'm going to compare it with different GDA (graph data augmentation) techniques used to enrich the dataset for the novel classes (i.e., those with the less amount of data) taken from a second paper named [Graph Data Augmentation for Graph Machine Learning: A Survey](https://arxiv.org/pdf/2202.08871.pdf) (Tong Zhao et al.).

---

## Modules and Constants

In [None]:
import torch
TORCH = torch.__version__.split('+')[0]
CUDA = 'cu' + torch.version.cuda.replace('.','')

!pip install pytorch-lightning
!pip install pyyaml==5.4.1
!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

!mkdir models

In [69]:
from typing import (
    Any, Dict, List, Tuple, 
    Union, Generic, Optional,
    TypeVar
)

from tqdm.notebook import tqdm
from functools import wraps
import plotly.graph_objects as go
import networkx as nx
import numpy as np
import pickle
import os
import shutil
import logging
import random
import time
import requests
import zipfile
import math
import sys

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter

import torch_geometric.data as gdata
import torch_geometric.loader as gloader
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.nn.inits import uniform
from torch_geometric.nn.pool.topk_pool import topk, filter_adj
from torch_geometric.nn.pool.sag_pool import SAGPooling
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.utils import (
    add_remaining_self_loops, 
    add_self_loops, 
    remove_self_loops,
    softmax
)

from torch_scatter import scatter_add

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

In [98]:
TRIANGLES_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/TRIANGLES.zip"
COIL_DEL_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/COIL-DEL.zip"
R52_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/R52.zip"
LETTER_HIGH_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/Letter-High.zip"

DATASETS = {
    "TRIANGLES"   : TRIANGLES_ZIP_URL, 
    "COIL-DEL"    : COIL_DEL_ZIP_URL, 
    "R52"         : R52_ZIP_URL, 
    "Letter-High" : LETTER_HIGH_ZIP_URL
}

DEFAULT_DATASET = "TRIANGLES"

T = TypeVar('T')

DEVICE = "cpu"
DOWNLOAD_DATASET = False
SAVE_PICKLE  = True
EDGELIMIT_PRINT = 2000
SAVE_PRETRAINED = True
DATA_PATH = os.path.abspath(os.getcwd()) if not DOWNLOAD_DATASET else None
MODELS_SAVE_PATH = "models"

NUM_FEATURES = {"TRIANGLES": 1, "R52": 1, "Letter-High": 2, "COIL-DEL": 2}


class ASMAMLConfig:
    NHID = 128
    POOLING_RATIO = 0.5
    DROPOUT_RATIO = 0.3

    OUTER_LR     = 0.001
    INNER_LR     = 0.01
    STOP_LR      = 0.0001
    WEIGHT_DECAY = 1E-05

    MAX_STEP      = 15
    MIN_STEP      = 5
    STEP_TEST     = 15
    FLEXIBLE_STEP = True
    STEP_PENALITY = 0.001
    USE_SCORE     = True
    USE_GRAD      = False
    USE_LOSS      = True

    TRAIN_SHOT         = 10   # K-shot for training set
    VAL_SHOT           = 10   # K-shot for validation (or test) set
    TRAIN_QUERY        = 15   # Number of query for the training set
    VAL_QUERY          = 15   # Number of query for the validation (or test) set
    TRAIN_WAY          = 3    # N-way for training set
    TEST_WAY           = 3    # N-way for test set
    VAL_EPISODE        = 200  # Number of episodes for validation
    TRAIN_EPISODE      = 200  # Number of episodes for training
    BATCH_PER_EPISODES = 5    # How many batch per episode
    EPOCHS             = 500  # How many epochs
    PATIENCE           = 35
    GRAD_CLIP          = 5

    # Stop Control configurations
    STOP_CONTROL_INPUT_SIZE = 2
    STOP_CONTROL_HIDDEN_SIZE = 20

---

## Utility Functions

In [71]:
def scandir(root_path: str) -> List[str]:
    """Recursively scan a directory looking for files"""
    root_path = os.path.abspath(root_path)
    content = []
    for file in os.listdir(root_path):
        new_path = os.path.join(root_path, file)
        if os.path.isfile(new_path):
            content.append(new_path)
            continue
        
        content += scandir(new_path)
    
    return content


def download_zipped_data(url: str, path2extract: str, dataset_name: str) -> List[str]:
    """Download and extract a ZIP file from URL. Return the content filename"""
    print(f"--- Downloading from {url} ---")
    response = requests.get(url)

    abs_path2extract = os.path.abspath(path2extract)
    zip_path = os.path.join(abs_path2extract, f"{dataset_name}.zip")
    with open(zip_path, mode="wb") as iofile:
        iofile.write(response.content)

    # Extract the file
    print("--- Extracting files from the archive ---")
    with zipfile.ZipFile(zip_path, mode="r") as zip_ref:
        zip_ref.extractall(abs_path2extract)

    print(f"--- Removing {zip_path} ---")
    os.remove(zip_path)

    return sorted(scandir(os.path.join(path2extract, dataset_name)))


def delete_data_folder(path2delete: str) -> None:
    """Delete the folder containing data"""
    print("--- Removing Content Data ---")
    shutil.rmtree(path2delete)
    print("--- Removed Finished Succesfully ---")

In [72]:
def elapsed_time(func):
    """Just a simple wrapper for counting elapsed time from start to end"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        func(*args, **kwargs)
        end = time.time()
        print("Elapsed Time: {:.6f}".format(end - start))
    
    return wrapper

In [73]:
def save_with_pickle(path2save: str, content: Any) -> None:
    """Save content inside a .pickle file denoted by path2save"""
    path2save = path2save + ".pickle" if ".pickle" not in path2save else path2save
    with open(path2save, mode="wb") as iostream:
        pickle.dump(content, iostream)


def load_with_pickle(path2load: str) -> Any:
    """Load a content from a .pickle file"""
    with open(path2load, mode="rb") as iostream:
        return pickle.load(iostream)

In [74]:
def setup_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True

In [75]:
def plot_graph(G : Union[nx.Graph, nx.DiGraph], name: str) -> None:
    """
    Plot a graph
    
    Parameters
    ----------
    graph : Union[nx.Graph, nx.DiGraph]
        Just a nx.Graph object
    name  : str
        The name of the graph
        
    Returns
    -------
    None
    """
    # Getting the 3D Spring layout
    layout = nx.spring_layout(G, dim=3, seed=18)
    
    # Getting nodes coordinate
    x_nodes = [layout[i][0] for i in layout]  # x-coordinates of nodes
    y_nodes = [layout[i][1] for i in layout]  # y-coordinates of nodes
    z_nodes = [layout[i][2] for i in layout]  # z-coordinates of nodes
    
    # Getting a list of edges and create a list with coordinates
    elist = G.edges()
    x_edges, y_edges, z_edges = [], [], []
    for edge in elist:
        x_edges += [layout[edge[0]][0], layout[edge[1]][0], None]
        y_edges += [layout[edge[0]][1], layout[edge[1]][1], None]
        z_edges += [layout[edge[0]][2], layout[edge[1]][2], None]

    colors = np.linspace(0, len(x_nodes))
        
    # Create a trace for the edges
    etrace = go.Scatter3d(x=x_edges,
                          y=y_edges,
                          z=z_edges,
                          mode='lines',
                          line=dict(color='rgb(125,125,125)', width=1),
                          hoverinfo='none'
                         )
    
    # Create a trace for the nodes
    ntrace = go.Scatter3d(x=x_nodes,
                          y=y_nodes,
                          z=z_nodes,
                          mode='markers',
                          marker=dict(
                              symbol='circle',
                              size=6,
                              color=colors,
                              colorscale='Viridis',
                              line=dict(color='rgb(50,50,50)', width=.5)),
                          text=list(layout.keys()),
                          hoverinfo='text'
                         )
    
    # Set the axis
    axis = dict(showbackground=False,
                showline=False,
                zeroline=False,
                showgrid=False,
                showticklabels=False,
                title='')
    
    # Create a layout for the plot
    go_layout = go.Layout(title=f"{name} Network Graph",
                          width=600,
                          height=600,
                          showlegend=False,
                          scene=dict(xaxis=dict(axis),
                                     yaxis=dict(axis),
                                     zaxis=dict(axis)),
                          margin=dict(t=100),
                          hovermode='closest'
                         )
    
    # Plot
    data = [etrace, ntrace]
    fig = go.Figure(data=data, layout=go_layout)
    fig.show()

In [76]:
def rename_edge_indexes(data_list: List[gdata.Data]) -> List[gdata.Data]:
    """
    Takes as input a bunch of :obj:`torch_geometric.data.Data` and renames
    each edge node (x, y) from 1 to total number of nodes. For instance, if we have
    this edge_index = [[1234, 1235, 1236, 1237], [1238, 1239, 1230,1241]] this became
    egde_index = [[0, 1, 2, 3],[4, 5, 6, 7]] and so on. 

    :param data_list: the list of :obj:`torch_geometric.data.Data`
    :return: a new list of data
    """
    # First of all let's compute the total number of nodes overall
    total_number_nodes = 0
    for data in data_list:
        total_number_nodes += data.x.shape[0]
    
    # Generate the new nodes
    nodes = torch.arange(0, total_number_nodes)
    
    # Takes the old nodes from the edge_index attribute
    old_nodes = None
    for data in data_list:
        x, y = data.edge_index
        x = torch.hstack((x, y)).unique(sorted=False)
        
        if old_nodes is None:
            old_nodes = x
            continue
    
        old_nodes = torch.hstack((old_nodes, x))
    
    # Create mapping from old to new nodes
    mapping = dict(zip(old_nodes.tolist(), nodes.tolist()[:old_nodes.shape[0]]))
    
    # Finally, map the new nodes
    for data in data_list:
        x, y = data.edge_index
        new_x = torch.tensor(list(map(lambda x: mapping[x], x.tolist())), dtype=x.dtype, device=x.device)
        new_y = torch.tensor(list(map(lambda y: mapping[y], y.tolist())), dtype=y.dtype, device=y.device)
        new_edge_index = torch.vstack((new_x, new_y))
        data.edge_index = new_edge_index
    
    return data_list


def data_batch_collate(data_list: List[gdata.Data]) -> gdata.Data:
    """
    Takes as input a list of data and create a new :obj:`torch_geometric.data.Data`
    collating all together. This is a replacement for torch_geometric.data.Batch.from_data_list

    :param data_list: a list of torch_geometric.data.Data objects
    :return: a new torch_geometric.data.Data object
    """
    x = None
    edge_index = None
    batch = []
    num_graphs = 0
    y = None
    
    for i_data, data in enumerate(data_list):
        x = data.x if x is None else torch.vstack((x, data.x))
        edge_index = data.edge_index if edge_index is None else torch.hstack((edge_index, data.edge_index))
        batch += [i_data] * data.x.shape[0]
        num_graphs += 1
        y = data.y if y is None else torch.hstack((y, data.y))

    # Create a mapping between y and a range(0, num_classes_of_y)
    # First we need to compute how many classes do we have
    num_classes = y.unique().shape[0]
    classes = list(range(0, num_classes))
    mapping = dict(zip(y.unique(sorted=False).tolist(), classes))
    
    # This mapping is necessary when computing the cross-entropy-loss
    new_y = torch.tensor(list(map(lambda x: mapping[x], y.tolist())), dtype=y.dtype, device=y.device)
    
    data_batch = gdata.Data(
        x=x, edge_index=edge_index, batch=torch.tensor(batch),
        y=new_y, num_graphs=num_graphs, old_classes_mapping=mapping
    )

    return data_batch


def task_sampler_uncollate(task_sampler: 'TaskBatchSampler', data_batch: gdata.Batch):
    """
    Takes as input the task sampler and a batch containing both the 
    support and the query set. It returns two different DataBatch
    respectively for support and query_set.

    Assume L = [x1, x2, x3, ..., xN] is the data_batch
    each xi is a graph. Moreover, we have that
    L[0:K] = support sample for the first class
    L[K+1:K+Q] = query sample for the first class
    In general, we have that 

            L[i * (K + Q) : (i + 1) * (K + Q)]

    is the (support, query) pair for the i-th class
    Finally, the first batch is the one that goes from
    L[0 : N * (K + Q)], so

            L[i * N * (K + Q) : (i + 1) * N * (K + Q)]

    is the i-th batch.

    :param task_sampler: The task sampler
    :param data_batch: a batch with support and query set
    :return: support batch, query batch
    """
    n_way = task_sampler.task_sampler.n_way
    k_shot = task_sampler.task_sampler.k_shot
    n_query = task_sampler.task_sampler.n_query
    task_batch_size = task_sampler.task_batch_size

    total_support_query_number = n_way * (k_shot + n_query)
    support_plus_query = k_shot + n_query

    # Initialize batch list for support and query set
    support_data_batch = []
    query_data_batch = []

    # I know how many batch do I have, so
    for batch_number in range(task_batch_size):

        # I also know how many class do I have in a task
        for class_number in range(n_way):

            # First of all let's take the i-th batch
            data_batch_slice = slice(
                batch_number * total_support_query_number,
                (batch_number + 1) * total_support_query_number
            )
            data_batch_per_batch = data_batch[data_batch_slice]

            # Then let's take the (support, query) pair for a class
            support_query_slice = slice(
                class_number * support_plus_query,
                (class_number + 1) * support_plus_query
            )
            support_query_data = data_batch_per_batch[support_query_slice]

            # Divide support from query
            support_data = support_query_data[:k_shot]
            query_data = support_query_data[k_shot:support_plus_query]

            support_data_batch += support_data
            query_data_batch += query_data
    
    # Rename the edges
    support_data = data_batch_collate(rename_edge_indexes(support_data_batch))
    query_data   = data_batch_collate(rename_edge_indexes(query_data_batch))

    # Create new DataBatchs and return
    return support_data, query_data

In [77]:
def get_max_acc(accs, step, scores, min_step, test_step):
    step = np.argmax(scores[min_step - 1 : test_step]) + min_step - 1
    return accs[step]


def get_batch_number(databatch, i_batch, n_way, k_shot):
    """From a N batch takes the i-th batch"""
    dim_databatch = n_way * k_shot
    indices = torch.arange(0, ASMAMLConfig.BATCH_PER_EPISODES)
    return gdata.Batch.from_data_list(databatch[indices * dim_databatch + i_batch])


def glorot(tensor):
    """Apply the Glorot NN initialization (also called Xavier)"""
    if tensor is not None:
        stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
        tensor.data.uniform_(-stdv, stdv)


def zeros(tensor):
    """Fill a tensor with zeros if it is not Null"""
    if tensor is not None:
        tensor.data.fill_(0)

---

## Dataset, Sampler and DataLoader

### The Dataset

I decided to use the same datasets considered in the paper for AS-MAML: TRIANGLES, COIL-DEL, R52 and Letter-High. All of them can be downloaded directly from this [page](https://ls11-www.cs.tu-dortmund.de/staff/morris/graphkerneldatasets), which is the origin of these datasets. Downloading from the previous page will result in a ZIP file with: 

- `<dataname>_node_attributes.txt` with the attribute vector for each node of each graph
- `<dataname>_graph_labels.txt` with the class for each graph
- `<dataname>_graph_edges.txt` with the edges for each graph expressed as a pair (nodex, nodey)
- `<dataname>_graph_indicator.txt` that maps each nodes to its corresponding graph

Each of the dataset has been splitted into *train*, *test* and *validation*, and transformed into a python dictionaries finally saved as `.pickle` files. In this way we have a ready-to-be-used dataset. Moreover, each ZIP dataset containes three files:

- `<dataname>_node_attributes.pickle` with the node attributes saved as a List or a torch Tensor
- `<dataname>_train_set.pickle` with all the train data as python dictionaries
- `<dataname>_test_set.pickle` with all the test data as python dictionaries
- `<dataname>_val_set.pickle` with all the validation data as python dictionaries

These are the link from which you can download the datasets: [TRIANGLES](https://drive.google.com/drive/folders/1na8l6DV7qtYIoteFGIp9p7VfQNjmSQxx?usp=sharingwith), [COIL-DEL](https://drive.google.com/drive/folders/1Cq2quq4XNLL91WlwXgXVx3kH_h3_RL9_?usp=sharing), [R52](https://drive.google.com/drive/folders/1pjh1GHn733xb-msqmVP2voZ_IWKKiEYg?usp=sharing) and [Letter-High]("https://cloud-storage.eu-central-1.linodeobjects.com/Letter-High.zip").

In [78]:
# An Example of dataset. In this case the TRIANGLES
dataset_name = "TRIANGLES"
download_folder = os.getcwd()

node_attribute, _, train_file, _ = download_zipped_data(
    DATASETS[dataset_name], 
    download_folder, 
    dataset_name
)

data_dir = "/".join(node_attribute.split("/")[:-2])

--- Downloading from https://cloud-storage.eu-central-1.linodeobjects.com/TRIANGLES.zip ---
--- Extracting files from the archive ---
--- Removing /content/TRIANGLES.zip ---


In [79]:
print("Node Attributes Filename --- ", node_attribute)
print("Train Set Filename --- ", train_file)

Node Attributes Filename ---  /content/TRIANGLES/TRIANGLES_node_attributes.pickle
Train Set Filename ---  /content/TRIANGLES/TRIANGLES_train_set.pickle


In [80]:
print("=== Node Attribute Content === ")

# Convert to torch.Tensor for a pretty printing
node_attribute_content = load_with_pickle(node_attribute)
if isinstance(node_attribute_content, list):
    node_attribute_content = torch.tensor(list(map(int, node_attribute_content))).long()

print(node_attribute_content)

=== Node Attribute Content === 
tensor([4, 3, 2,  ..., 2, 3, 2])


as I said the train (or test or validation) set is python dictionary, with three keys: `label2graphs`, mapping each label to a list of corresponding graphs, `graph2nodes`, mapping graphs to their nodes, `graph2edges`, mapping graphs to their egdes (list of nodes pair).

In [81]:
print("=== Train Set Content === ")

train_set_content = load_with_pickle(train_file)

print("Keys --- ", train_set_content.keys())

label2graphs = train_set_content["label2graphs"]
graph2nodes = train_set_content["graph2nodes"]
graph2edges = train_set_content["graph2edges"]

print("Label2graph example --- ", label2graphs[1])
print("Graph2nodes example --- ", graph2nodes[1])
print("Graph2edges example --- ", graph2edges[1])

=== Train Set Content === 
Keys ---  dict_keys(['label2graphs', 'graph2nodes', 'graph2edges'])
Label2graph example ---  [0, 1, 2, 3, 4, 7, 8, 9, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22, 25, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 44, 45, 46, 47, 48, 49, 51, 53, 54, 56, 57, 59, 61, 62, 64, 65, 66, 67, 68, 69, 70, 71, 72, 74, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 95, 96, 98, 99, 100, 101, 102, 104, 105, 106, 108, 109, 110, 111, 113, 114, 115, 118, 119, 120, 122, 123, 124, 125, 126, 127, 128, 130, 131, 132, 133, 134, 136, 138, 140, 141, 142, 143, 144, 145, 146, 147, 149, 150, 151, 152, 153, 154, 155, 157, 159, 160, 161, 163, 164, 167, 168, 169, 170, 172, 173, 174, 175, 176, 177, 178, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 200]
Graph2nodes example ---  [21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32]
Graph2edges example ---  [[21, 22], [21, 23], [21, 24], [21, 25], [26, 27], [26, 25], [2

The way I handled datasets is different from the one used in the AS-MAML paper. I decided to represent the dataset in python using the class `GraphDataset` that inherit properties and methods from the base class `torch_geometric.data.Dataset`. It is an iterable class and each element (each graph) is of type `torch_geometric.data.Data`. That is, each graph is a `data = Data(x=..., edge_index=..., y=...)`, where `data.x` is a `torch.Tensor` (with dim $\mathtt{n\_attribute} \times 1$) representing the attribute vector of all nodes in the graph, `data.edge_index` is a `torch.Tensor` (with dim $2 \times \mathtt{n\_edges}$) representing the edges of the graph, and finally `data.y` is a `torch.Tensor` (with dim 0) representing the class of that graph.

In [82]:
class GraphDataset(gdata.Dataset):
    def __init__(self, graphs_ds: Dict[str, Tuple[nx.Graph, str]]) -> None:
        super(GraphDataset, self).__init__()
        self.graphs_ds = graphs_ds

    @classmethod
    def get_dataset(cls, attributes: List[Any], data: Dict[str, Any]) -> 'GraphDataset':
        """
        Returns a new instance of GraphDataset filled with graphs inside data. 'attributes'
        is the list with all the attributes (not only those beloging to nodes in 'data').

        :param data: a dictionary with label2graphs, graph2nodes and graph2edges
        :param attributes: a list with node attributes
        :return: a new instance of GraphDataset
        """
        graphs = dict()

        label2graphs = data["label2graphs"]
        graph2nodes  = data["graph2nodes"]
        graph2edges  = data["graph2edges"]

        for label, graph_list in label2graphs.items():
            for graph_id in graph_list:
                graph_nodes = graph2nodes[graph_id]
                graph_edges = graph2edges[graph_id]
                nodes_attributes = [[attributes[node_id]] for node_id in graph_nodes]
                nodes = []
                for node, attribute in zip(graph_nodes, nodes_attributes):
                    nodes.append((node, {f"attr{i}" : a for i, a in enumerate(attribute)}))

                g = nx.Graph()

                g.add_edges_from(graph_edges)
                g.add_nodes_from(nodes)
            
                graphs[graph_id] = (g, label)

        graphs = dict(sorted(graphs.items(), key=lambda x: x[0]))
        graph_dataset = super(GraphDataset, cls).__new__(cls)
        graph_dataset.__init__(graphs)

        return graph_dataset

    def __repr__(self) -> str:
        return f"GraphDataset(classes={set(self.targets().tolist())},n_graphs={self.len()})"

    def indices(self) -> List[str]:
        """ Return all the graph IDs """
        return list(self.graphs_ds.keys())

    def len(self) -> int:
        return len(self.graphs_ds.keys())

    def targets(self) -> torch.Tensor:
        """ Return all the labels """
        targets = []
        for _, graph in self.graphs_ds.items():
            targets.append(int(graph[1]))

        return torch.tensor(targets)

    def get(self, idx: Union[int, str]) -> gdata.Data:
        """ Return (Graph object, Adjacency matrix and label) of a graph """
        if isinstance(idx, str):
            idx = int(idx)

        graph = self.graphs_ds[idx]
        g, label = graph[0].to_directed(), graph[1]

        # Retrieve nodes attributes
        attrs = list(g.nodes(data=True))
        x = torch.tensor([list(map(int, a.values())) for _, a in attrs], dtype=torch.float)

        # Retrieve edges
        edge_index = torch.tensor([list(e) for e in g.edges], dtype=torch.long) \
                          .t()                                                  \
                          .contiguous()                                         \
                          .long()

        # Retrieve ground trouth labels
        y = torch.tensor([int(label)], dtype=torch.int)

        return gdata.Data(x=x, edge_index=edge_index, y=y)


def get_all_labels(graphs: Dict[str, Tuple[nx.Graph, str]]) -> torch.Tensor:
    """ Return a list containings all labels of the dataset """
    return torch.tensor(list(set([int(v[1]) for _, v in graphs.items()])))


def generate_train_val_test(dataset_name: str,
                            data_dir: Optional[str]=None, 
                            download: bool=True,
                            download_folder: str="../data"
) -> Tuple[GraphDataset, GraphDataset, GraphDataset]:
    """ Return dataset for training, validation and testing """
    print("--- Generating Train, Test and Validation datasets --- ")
    
    assert download or data_dir is not None, "At least one between: data_dir and download must be given"

    node_attribute = None
    test_file = None
    train_file = None
    val_file = None

    if data_dir is not None:
        node_attribute = os.path.join(data_dir, f"{dataset_name}/{dataset_name}_node_attributes.pickle")
        test_file = os.path.join(data_dir, f"{dataset_name}/{dataset_name}_test_set.pickle")
        train_file = os.path.join(data_dir, f"{dataset_name}/{dataset_name}_train_set.pickle")
        val_file = os.path.join(data_dir, f"{dataset_name}/{dataset_name}_val_set.pickle")

    if download:
        node_attribute, test_file, train_file, val_file = download_zipped_data(
            DATASETS[dataset_name], 
            download_folder, 
            dataset_name
        )

        data_dir = "\\".join(node_attribute.replace("\\", "/").split("/")[:-2])

    node_attribute_data = load_with_pickle(node_attribute)
    test_data = load_with_pickle(test_file)
    train_data = load_with_pickle(train_file)
    val_data = load_with_pickle(val_file)

    train_ds = GraphDataset.get_dataset(node_attribute_data, train_data)
    test_ds  = GraphDataset.get_dataset(node_attribute_data,  test_data)
    val_ds   = GraphDataset.get_dataset(node_attribute_data,   val_data)

    return train_ds, test_ds, val_ds, data_dir



def get_dataset(download: bool=False, 
                dataset_name: str="TRIANGLES", 
                data_dir: str="../data") -> Tuple[GraphDataset, GraphDataset, GraphDataset, str]:
    """Generate the train, test and validation dataset"""
    data_dir = data_dir if not download else None
    train_ds, test_ds, val_ds, data_dir = generate_train_val_test(
        data_dir=data_dir,
        download=download,
        dataset_name=dataset_name
    )
    return train_ds, test_ds, val_ds, data_dir

In [83]:
# Example of dataset
train_ds, test_ds, val_ds, _ = get_dataset(data_dir=os.getcwd())

--- Generating Train, Test and Validation datasets --- 


In [84]:
print("--- Training Set --- ", train_ds)
print("--- Test Set --- ", test_ds)
print("--- Validation Set --- ", val_ds)

--- Training Set ---  GraphDataset(classes={1, 3, 4, 6, 7, 8, 9},n_graphs=1127)
--- Test Set ---  GraphDataset(classes={2, 10, 5},n_graphs=603)
--- Validation Set ---  GraphDataset(classes={1, 3, 4, 6, 7, 8, 9},n_graphs=280)


### The Samplers

Since we need a specific way to sample from the dataset, in particular the N-way-K-shot (both for support and query set), I encountered the needs of create two samplers: `FewShotSampler` and `TaskBatchSampler`, both inheriting from `torch.utils.data.Sampler`. The former, returns a list of indices indicating which graphs belongs to a single N-way-K-shot sample. The second, just iteratively sampling from `FewShotSampler`, creates mini-batches according to the wanted number of N-way-K-shot sample the user wants in a single batch. The `TaskBatchSampler` is used for the `batch_sampler` argument of the DataLoader.

In [85]:
class FewShotSampler(torch.utils.data.Sampler):
    """
    In few-shot classification, and in particular in Meta-Learning, 
    we use a specific way of sampling batches from the training/val/test 
    set. This way is called N-way-K-shot, where N is the number of classes 
    to sample per batch and K is the number of examples to sample per class 
    in the batch. The sample batch on which we train our model is also called 
    `support` set, while the one on which we test is called `query` set.

    This class is a N-way-K-shot sampler that will be used as a batch_sampler
    for the :obj:`torch_geometric.loader.DataLoader` dataloader. This sampler
    return batches of indices that correspond to support and query set batches.

    Attributes:
        labels: PyTorch tensor of the labels of the data elements
        n_way: Number of classes to sampler per batch
        k_shot: Number of examples to sampler per class in the batch
        n_query: Number of query example to sample per class in the batch
        shuffle: If True, examples and classes are shuffled at each iteration
        indices_per_class: How many indices per classes
        classes: list of all classes
        epoch_size: number of batches per epoch
    """

    def __init__(self, labels: torch.Tensor,
                 n_way: int,
                 k_shot: int,
                 n_query: int,
                 epoch_size: int,
                 shuffle: bool = True) -> None:
        super().__init__(None)
        self.labels = labels
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.shuffle = shuffle
        self.epoch_size = epoch_size

        self.classes = torch.unique(self.labels).tolist()
        self.indices_per_class = dict()
        for cl in self.classes:
            self.indices_per_class[cl] = torch.where(self.labels == cl)[0]

    def shuffle_data(self) -> None:
        """
        Shuffle the examples per class

        Args:
            classes: The list of all classes
        """
        for cl in self.classes:
            perm = torch.randperm(self.indices_per_class[cl].shape[0])
            self.indices_per_class[cl] = self.indices_per_class[cl][perm]

    def __iter__(self) -> List[torch.Tensor]:
        # Shuffle the data
        if self.shuffle:
            self.shuffle_data()

        target_classes = random.sample(self.classes, self.n_way)
        for _ in range(self.epoch_size):
            n_way_k_shot_n_query = []
            for cl in target_classes:
                labels_per_class = self.indices_per_class[cl]
                assert len(labels_per_class) >= self.k_shot + self.n_query
                selected_data = random.sample(
                    labels_per_class.tolist(), self.k_shot + self.n_query)
                n_way_k_shot_n_query.append(selected_data)

            yield torch.tensor(n_way_k_shot_n_query)

    def __len__(self) -> int:
        return self.epoch_size
    
    def __repr__(self) -> str:
        """Return a descriptive string"""
        return "{name}(classes={cls}, \n\t\t\tsupport_set_size={sts}, \n\t\t\tquery_set_size={qts}, \n\t\t\tsize={size})".format(
            name=self.__class__.__name__, cls=self.classes, sts=f"{self.n_way} x {self.k_shot}",
            qts=f"{self.n_way} x {self.n_query}", size=self.__len__()
        )


class TaskBatchSampler(torch.utils.data.Sampler):
    """Sample a batch of tasks"""

    def __init__(self, dataset_targets: torch.Tensor,
                 batch_size: int,
                 n_way: int,
                 k_shot: int,
                 n_query: int,
                 epoch_size: int,
                 shuffle: bool = True) -> None:

        super().__init__(None)
        self.task_sampler = FewShotSampler(
            dataset_targets,
            n_way=n_way,
            k_shot=k_shot,
            n_query=n_query,
            epoch_size=epoch_size,
            shuffle=shuffle
        )

        self.task_batch_size = batch_size

    def __iter__(self):
        mini_batches = []
        for task_idx, task in enumerate(self.task_sampler):
            mini_batches.extend(task.tolist())
            if (task_idx + 1) % self.task_batch_size == 0:
                yield torch.tensor(mini_batches).flatten().tolist()
                mini_batches = []

    def __len__(self):
        return len(self.task_sampler) // self.task_batch_size

    def uncollate(self, data_batch):
        """Invoke the uncollate from utils.utils"""
        return task_sampler_uncollate(self, data_batch)

    def __repr__(self) -> str:
        """Return a descriptive string"""
        return "{name}(task_batch_size={tbs},\n\t\ttask_sampler={ts},\n\t\tsize={size})".format(
            name=self.__class__.__name__, tbs=self.task_batch_size,
            ts=self.task_sampler.__repr__(), size=self.__len__()
        )


### The DataLoader

In this case each element of the dataset is a `torch_geometric.data.Data` and not just a `torch.Tensor`. For this reason, I decided to create a simple custom dataloader called `FewShotDataLoader` that inherit from `torch.utils.data.DataLoader`. Moreover, there is another problem: `GraphDataset` is not a known type for the default *collate* of PyTorch, or PyTorch-Geometric. So, I created my own collater called `GraphCollater` to manages this situation. 

In [86]:
class GraphCollater(gloader.dataloader.Collater):
    """A Collater to handle batches of GraphDataset instances"""
    def __init__(self, *args) -> None:
        super(GraphCollater, self).__init__(*args)

    def __call__(self, batch: Generic[T]) -> Generic[T]:
        elem = batch[0]

        # All elements inside batch is the just a
        # repetition of the first element, for this 
        # reason we can keep just the first one
        if isinstance(elem, GraphDataset):
            return self(elem)

        return super(GraphCollater, self).__call__(batch)


class FewShotDataLoader(torch.utils.data.DataLoader):
    """Custom DataLoader for GraphDataset"""

    def __init__(self, dataset: GraphDataset,
                 batch_size: int = 1,
                 shuffle: bool = False,
                 follow_batch: Optional[List[str]] = None,
                 exclude_keys: Optional[List[str]] = None,
                 **kwargs) -> None:

        if 'collate_fn' in kwargs:
            del kwargs["collate_fn"]

        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        # Take the batch sampler
        self.batch_sampler = kwargs["batch_sampler"]

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=GraphCollater(follow_batch, exclude_keys),
            **kwargs,
        )

    def __iter__(self):
        for x in super().__iter__():
            support_batch, query_batch = self.batch_sampler.uncollate(x)
            yield support_batch, query_batch
    
    def __repr__(self) -> str:
        """Return a descriptive string"""
        return "{name}(dataset={ds},\n\tbatch_sampler={bs},\n\tsize={size})".format(
            name=self.__class__.__name__, ds=self.dataset, 
            bs=self.batch_sampler.__repr__(), size=self.__len__()
        )


def get_dataloader(
    ds: GraphDataset, n_way: int, k_shot: int, n_query: int, 
    epoch_size: int, shuffle: bool, batch_size: int
) -> FewShotDataLoader:
    """Return a dataloader instance"""
    return FewShotDataLoader(
        dataset=ds,
        batch_sampler=TaskBatchSampler(
            dataset_targets=ds.targets(),
            n_way=n_way,
            k_shot=k_shot,
            n_query=n_query,
            epoch_size=epoch_size,
            shuffle=shuffle,
            batch_size=batch_size
        )
    )

In [87]:
# Example Using the previous generated sets
train_dataloader = get_dataloader(
    ds=train_ds, n_way=ASMAMLConfig.TRAIN_WAY,
    k_shot=ASMAMLConfig.TRAIN_SHOT, n_query=ASMAMLConfig.TRAIN_QUERY,
    epoch_size=ASMAMLConfig.TRAIN_EPISODE, shuffle=True, batch_size=1
)

In [88]:
print("--- Train DataLoader --- ")
print(train_dataloader)

--- Train DataLoader --- 
FewShotDataLoader(dataset=GraphDataset(classes={1, 3, 4, 6, 7, 8, 9},n_graphs=1127),
	batch_sampler=TaskBatchSampler(task_batch_size=1,
		task_sampler=FewShotSampler(classes=[1, 3, 4, 6, 7, 8, 9], 
			support_set_size=3 x 10, 
			query_set_size=3 x 15, 
			size=200),
		size=200),
	size=200)


In [89]:
print("--- First Sample ---")
sample = next(iter(train_dataloader))
support_data, query_data = sample

print("--- Support Data ---")
print(support_data)
print()

print("--- Query Data ---")
print(query_data)

--- First Sample ---
--- Support Data ---
Data(
  x=[446, 1],
  edge_index=[2, 1542],
  y=[30],
  batch=[446],
  num_graphs=30,
  old_classes_mapping={
    6=0,
    4=1,
    9=2
  }
)

--- Query Data ---
Data(
  x=[731, 1],
  edge_index=[2, 2490],
  y=[45],
  batch=[731],
  num_graphs=45,
  old_classes_mapping={
    6=0,
    4=1,
    9=2
  }
)


---

## Models

### AS-MAML

This framework consists of a graph *meta-learner*, which uses GNNs base modules for fast adaptation on graph data and a step controller for robustness and generalization of the meta-learner. They was inspired by the [**Model Agnostic Meta-Learner** (MAML)](https://arxiv.org/pdf/1703.03400.pdf), due to its fast adaptation mechanism. However, directly applying MAML is suboptimal due to the following reasons: painstaking hyperparameter search to reach high generalization; unlike images graphs have arbitrary node size and sub-structure, which brings uncertainty for adaptation. 

<center>
    <img src="https://i.imgur.com/SwTvlOE.png" width=600>
</center>

In [90]:
# A modification of torch_geometric.nn.conv.gcn_conv.GCNConv
# Link: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/gcn_conv.html#GCNConv
class GCNConv(MessagePassing):
    """
    GCN Convolutional Layer. 

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        improved (bool, optional): If set to :obj:`True`, the layer computes
            :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`.
            (default: :obj:`False`)
        cached (bool, optional): If set to :obj:`True`, the layer will cache
            the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
            \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the
            cached version for further executions.
            This parameter should only be set to :obj:`True` in transductive
            learning scenarios. (default: :obj:`False`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(self, in_channels: int,
                       out_channels: int,
                       improved: bool=False,
                       cached: bool=False,
                       bias: bool=True,
                       **kwargs):
        super().__init__(aggr="add", **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))
        self.weight.fast = None

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            self.bias.fast = None
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None
        self.cached_num_edges = None

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None):
        """Compute the Norm"""
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device)
        
        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes
        )

        row, col = edge_index

        # src = edge_weight
        # index = row
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
    
    def forward(self, x, edge_index, edge_weight=None):
        """The forward method"""
        x = x @ (self.weight if self.weight.fast is None else self.weight.fast)

        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}. Please '
                    'disable the caching behavior of this layer by removing '
                    'the `cached=True` argument in its constructor.'.format(
                        self.cached_num_edges, edge_index.size(1)))

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = self.norm(
                edge_index, x.size(self.node_dim), edge_weight,
                self.improved, x.dtype
            )
            self.cached_result = edge_index, norm
        
        edge_index, norm = self.cached_result
        return self.propagate(edge_index, x=x, norm=norm)
    
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j
    
    def update(self, aggr_out):
        if self.bias is not None:
            if self.bias.fast is not None:
                aggr_out += self.bias.fast
            else:
                aggr_out += self.bias
        
        return aggr_out
    
    def __repr__(self):
        return '{}({}, {})'.format(
            self.__class__.__name__, 
            self.in_channels, 
            self.out_channels
        )


# A modification of torch_geometric.nn.conv.sage_conv.SAGEConv
# Link: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/sage_conv.html#SAGEConv
class SAGEConv(MessagePassing):
    """
    The GraphSAGE operator, modified for the fast weight adaptation

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        normalize (bool, optional): If set to :obj:`True`, output features
            will be :math:`\ell_2`-normalized. (default: :obj:`False`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(self, in_channels: int, out_channels: int,
                       normalize: bool=False, bias: bool=True,
                       **kwargs) -> None:
        super().__init__(aggr='mean', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize

        self.weight = Parameter(torch.Tensor(self.in_channels, self.out_channels))
        self.weight.fast = None

        if bias:
            self.bias = Parameter(torch.Tensor(self.out_channels))
            self.bias.fast = None
        else:
            self.register_parameter('bias', None)
        
        self.reset_parameters()

    def reset_parameters(self):
        uniform(self.in_channels, self.weight)
        uniform(self.in_channels, self.bias)
    
    def forward(self, x, edge_index, edge_weight=None, size=None):
        if size is None and torch.is_tensor(x):
            edge_index, edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, 1, x.size(0)
            )

        if self.weight.fast is not None:
            weight = self.weight.fast
        else:
            weight = self.weight

        if torch.is_tensor(x):
            x = x @ weight
        else:
            x0 = None if x[0] is None else x[0] @ weight
            x1 = None if x[1] is None else x[1] @ weight
            x = (x0, x1)
    
        return self.propagate(edge_index, size=size, x=x, edge_weight=edge_weight)

    def message(self, x_j, edge_weight):
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def update(self, aggr_out):
        if self.bias is not None:
            if self.bias.fast is not None:
                aggr_out = aggr_out + self.bias.fast
            else:
                aggr_out = aggr_out + self.bias
        if self.normalize:
            aggr_out = F.normalize(aggr_out, p=2, dim=-1)
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


# A modification of torch_geometric.nn.conv.graph_conv.GraphConv
# Link: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/conv/graph_conv.html#GraphConv
class GraphConv(MessagePassing):
    """
    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        aggr (string, optional): The aggregation scheme to use
            (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`).
            (default: :obj:`"add"`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """

    def __init__(self, in_channels, out_channels, aggr='add', bias=True,
                 **kwargs):
        super(GraphConv, self).__init__(aggr=aggr, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.weight = Parameter(torch.Tensor(in_channels, out_channels))
        self.lin = LinearModel(in_channels, out_channels)
        self.weight.fast = None
        self.lin.weight.fast = None
        self.lin.bias.fast = None

        self.reset_parameters()

    def reset_parameters(self):
        uniform(self.in_channels, self.weight)
        self.lin.reset_parameters()

    def forward(self, x, edge_index, edge_weight=None, size=None):
        """"""
        if self.weight.fast is not None:
            h = x @ self.weight.fast
        else:
            h = x @ self.weight

        return self.propagate(edge_index, size=size, x=x, h=h,
                              edge_weight=edge_weight)

    def message(self, h_j, edge_weight):
        return h_j if edge_weight is None else edge_weight.view(-1, 1) * h_j

    def update(self, aggr_out, x):
        return aggr_out + self.lin(x)

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

In [91]:
class LinearModel(nn.Linear):
    """A Simple Linear model implementation for fast weights"""
    def __init__(self, in_features: int, out_features: int) -> None:
        super().__init__(in_features, out_features, bias=True)
        self.weight.fast = None
        self.bias.fast = None
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.weight.fast is not None and self.bias.fast is not None:
            return F.linear(x, self.weight.fast, self.bias.fast)
        return super().forward(x)

In [92]:
# A Modification of torch_geometric.nn.pool.topk_pool.TopKPool
# Link: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/pool/topk_pool.html#TopKPooling
class TopKPooling(nn.Module):
    """
    Args:
        in_channels (int): Size of each input sample.
        ratio (float): Graph pooling ratio, which is used to compute
            :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`.
            This value is ignored if min_score is not None.
            (default: :obj:`0.5`)
        min_score (float, optional): Minimal node score :math:`\tilde{\alpha}`
            which is used to compute indices of pooled nodes
            :math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`.
            When this value is not :obj:`None`, the :obj:`ratio` argument is
            ignored. (default: :obj:`None`)
        multiplier (float, optional): Coefficient by which features gets
            multiplied after pooling. This can be useful for large graphs and
            when :obj:`min_score` is used. (default: :obj:`1`)
        nonlinearity (torch.nn.functional, optional): The nonlinearity to use.
            (default: :obj:`torch.tanh`)
    """
    def __init__(self, in_channels, 
                       ratio: float=0.5,
                       min_score: Optional[float]=None,
                       multiplier: int=1,
                       nonlinearity=torch.tanh) -> None:
        super().__init__()
        
        self.in_channels = in_channels
        self.ratio = ratio
        self.min_score = min_score
        self.multiplier = multiplier
        self.nonlinearity = nonlinearity

        self.weight = nn.Parameter(torch.Tensor(1, in_channels))
        self.weight.fast = None
        self.reset_parameters()

    def reset_parameters(self):
        """Reset the parameters"""
        size = self.in_channels
        uniform(size, self.weight)

    def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None):
        """The forward method"""
        if batch is None:
            batch = edge_index.new_zeros(x.size(0))

        attn = x if attn is None else attn
        attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn

        if self.weight.fast is not None:
            score = (attn * self.weight.fast).sum(dim=-1)
        else:
            score = (attn * self.weight).sum(dim=-1)
        
        if self.min_score is None:
            if self.weight.fast is not None:
                score = self.nonlinearity(score / self.weight.fast.norm(p=2, dim=-1))
            else:
                score = self.nonlinearity(score / self.weight.norm(p=2, dim=-1))
        else:
            score = softmax(score, batch)
        
        perm = topk(score, self.ratio, batch, self.min_score)
        x = x[perm] * score[perm].view(-1, 1)
        x = self.multiplier * x if self.multiplier != 1 else x

        batch = batch[perm]
        edge_index, edge_attr = filter_adj(
            edge_index, edge_attr, perm,
            num_nodes=score.size(0)
        )

        return x, edge_index, edge_attr, batch, perm, score[perm]

    def __repr__(self):
        return '{}({}, {}={}, multiplier={})'.format(
            self.__class__.__name__, self.in_channels,
            'ratio' if self.min_score is None else 'min_score',
            self.ratio if self.min_score is None else self.min_score,
            self.multiplier)
        

# A very simple modification of torch_geometric.nn.pool.sag_pool.SAGPooling
# Link: https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/nn/pool/sag_pool.html#SAGPooling
class SAGPool4MAML(SAGPooling):
    """SAGPooling for MAML. Change only the __repr__ method"""
    def __init__(self, in_channels: int, ratio: float=0.5,
                       GNN: nn.Module=GraphConv, min_score: Optional[float]=None,
                       multiplier: int=1, nonlinearity=torch.tanh, **kwargs) -> None:
        super().__init__(
            in_channels=in_channels, ratio=ratio,
            GNN=GNN, min_score=min_score, multiplier=multiplier,
            nonlinearity=nonlinearity, **kwargs
        )

    def __repr__(self) -> str:
        return '{}({}, {}, {}={}, multiplier={})'.format(
            self.__class__.__name__, self.gnn.__class__.__name__,
            self.in_channels,
            'ratio' if self.min_score is None else 'min_score',
            self.ratio if self.min_score is None else self.min_score,
            self.multiplier)

In [93]:
class StopControl(nn.Module):
    """For computing the stop probability"""
    def __init__(self, input_size: int, hidden_size: int) -> None:
        super(StopControl, self).__init__()
        self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
        self.output_layer = nn.Linear(hidden_size, 1)
        self.output_layer.bias.data.fill_(0.0)
        self.h_0 = nn.Parameter(torch.randn((hidden_size, ), requires_grad=True))
        self.c_0 = nn.Parameter(torch.randn((hidden_size, ), requires_grad=True))

    def forward(self, inputs, hx) -> torch.Tensor:
        if hx is None:
            hx = (self.h_0.unsqueeze(0), self.c_0.unsqueeze(0))
        
        h, c = self.lstm(inputs, hx)
        return torch.sigmoid(self.output_layer(h).unsqueeze(0)), (h, c)

In [94]:
class NodeInformationScore(MessagePassing):
    """Node information score"""
    def __init__(self, improved=False, cached=False, **kwargs):
        super().__init__(aggr='add', **kwargs)

        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.cached_num_edges = None
    
    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, dtype=None):
        edge_index, _ = remove_self_loops(edge_index)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype, device=edge_index.device)
        
        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        edge_index, edge_weight = add_self_loops(edge_index, edge_weight, 0, num_nodes)

        row, col = edge_index
        expand_deg = torch.zeros((edge_weight.size(0), ), dtype=dtype, device=edge_index.device)
        expand_deg[-num_nodes:] = torch.ones((num_nodes, ), dtype=dtype, device=edge_index.device)

        return edge_index, expand_deg - deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        if self.cached and self.cached_result is not None:
            if edge_index.size(1) != self.cached_num_edges:
                raise RuntimeError(
                    'Cached {} number of edges, but found {}'.format(self.cached_num_edges, edge_index.size(1)))

        if not self.cached or self.cached_result is None:
            self.cached_num_edges = edge_index.size(1)
            edge_index, norm = self.norm(edge_index, x.size(0), edge_weight, x.dtype)
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result

        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out


class GCN4MAML(nn.Module):
    """GCN for AS-MAML"""
    def __init__(self, num_features: int=1, num_classes: int=30) -> None:
        super().__init__()

        self.num_features = num_features
        self.num_classes  = num_classes
        
        # Define convolutional layers
        self.conv1 = GCNConv(self.num_features, ASMAMLConfig.NHID)
        self.conv2 = GCNConv(ASMAMLConfig.NHID, ASMAMLConfig.NHID)
        self.conv3 = GCNConv(ASMAMLConfig.NHID, ASMAMLConfig.NHID)

        self.calc_information_score = NodeInformationScore()

        # Define Pooling layers
        self.pool1 = TopKPooling(ASMAMLConfig.NHID, ASMAMLConfig.POOLING_RATIO)
        self.pool2 = TopKPooling(ASMAMLConfig.NHID, ASMAMLConfig.POOLING_RATIO)
        self.pool3 = TopKPooling(ASMAMLConfig.NHID, ASMAMLConfig.POOLING_RATIO)

        # Define Linear Layers
        self.linear1 = LinearModel(ASMAMLConfig.NHID * 2, ASMAMLConfig.NHID)
        self.linear2 = LinearModel(ASMAMLConfig.NHID, ASMAMLConfig.NHID // 2)
        self.linear3 = LinearModel(ASMAMLConfig.NHID // 2, self.num_classes)

        # Define activation function
        self.relu = F.leaky_relu

    def forward(self, x, edge_index, batch):
        edge_attr = None

        x = self.relu(self.conv1(x, edge_index, edge_attr), negative_slope=0.1)
        x, edge_index, edge_attr, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x = self.relu(self.conv2(x, edge_index, edge_attr), negative_slope=0.1)
        x, edge_index, edge_attr, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x = self.relu(self.conv3(x, edge_index, edge_attr), negative_slope=0.1)
        x, edge_index, edge_attr, batch, _, _ = self.pool3(x, edge_index, None, batch)

        x_information_score = self.calc_information_score(x, edge_index)
        score = torch.sum(torch.abs(x_information_score), dim=1)
        x3 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x = self.relu(x1, negative_slope=0.1) + \
            self.relu(x2, negative_slope=0.1) + \
            self.relu(x3, negative_slope=0.1)
        
        x = self.relu(self.linear1(x), negative_slope=0.1)
        x = self.relu(self.linear2(x), negative_slope=0.1)
        x = self.linear3(x)

        return x, score.mean(), None


class SAGE4MAML(nn.Module):
    """SAGE Model 4 MAML"""
    def __init__(self, num_features: int=1, num_classes: int=30) -> None:
        super().__init__()

        self.num_features = num_features
        self.num_classes  = num_classes
        
        # Define convolutional layers
        self.conv1 = SAGEConv(self.num_features, ASMAMLConfig.NHID)
        self.conv2 = SAGEConv(ASMAMLConfig.NHID, ASMAMLConfig.NHID)
        self.conv3 = SAGEConv(ASMAMLConfig.NHID, ASMAMLConfig.NHID)

        self.calc_information_score = NodeInformationScore()

        # Define Pooling layers
        self.pool1 = SAGPool4MAML(ASMAMLConfig.NHID, ASMAMLConfig.POOLING_RATIO)
        self.pool2 = SAGPool4MAML(ASMAMLConfig.NHID, ASMAMLConfig.POOLING_RATIO)
        self.pool3 = SAGPool4MAML(ASMAMLConfig.NHID, ASMAMLConfig.POOLING_RATIO)

        # Define Linear Layers
        self.linear1 = LinearModel(ASMAMLConfig.NHID * 2, ASMAMLConfig.NHID)
        self.linear2 = LinearModel(ASMAMLConfig.NHID, ASMAMLConfig.NHID // 2)
        self.linear3 = LinearModel(ASMAMLConfig.NHID // 2, self.num_classes)

        # Define activation function
        self.relu = F.leaky_relu

    def forward(self, x, edge_index, batch):
        edge_attr = None

        x = self.relu(self.conv1(x, edge_index, edge_attr),negative_slope=0.1)
        x, edge_index, edge_attr, batch, _, _ = self.pool1(x, edge_index, None, batch)
        x1 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x =self.relu(self.conv2(x, edge_index, edge_attr),negative_slope=0.1)
        x, edge_index, edge_attr, batch, _, _ = self.pool2(x, edge_index, None, batch)
        x2 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x = self.relu(self.conv3(x, edge_index, edge_attr), negative_slope=0.1)
        x, edge_index, edge_attr, batch, _, _ = self.pool3(x, edge_index, None, batch)
        x3 = torch.cat([global_max_pool(x, batch), global_mean_pool(x, batch)], dim=1)

        x_information_score = self.calc_information_score(x, edge_index)
        score = torch.sum(torch.abs(x_information_score), dim=1)

        x = self.relu(x1,negative_slope=0.1) + \
            self.relu(x2,negative_slope=0.1) + \
            self.relu(x3,negative_slope=0.1)

        graph_emb = x

        x = self.relu(self.linear1(x),negative_slope=0.1)
        x = self.relu(self.linear2(x),negative_slope=0.1)
        x = self.linear3(x)

        return x, score.mean(), graph_emb

In [95]:
# AdaptiveStepMAML Class Here
class AdaptiveStepMAML(nn.Module):
    """ The Meta-Learner Class """
    def __init__(self, model: Union[GCN4MAML, SAGE4MAML], inner_lr: float, 
                 outer_lr: float, stop_lr: float, weight_decay: float) -> None:
        super().__init__()
        self.net          = model
        self.inner_lr     = inner_lr
        self.outer_lr     = outer_lr
        self.stop_lr      = stop_lr
        self.weight_decay = weight_decay

        self.task_index = 1
        self.stop_prob = 0.5
        self.stop_gate = StopControl(ASMAMLConfig.STOP_CONTROL_INPUT_SIZE, ASMAMLConfig.STOP_CONTROL_HIDDEN_SIZE)

        self.meta_optim = self.configure_optimizers()

        self.loss      = nn.CrossEntropyLoss()
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.meta_optim, mode="min", factor=0.5, 
            patience=ASMAMLConfig.PATIENCE, verbose=True, min_lr=1e-05
        )

        self.graph_embs = []
        self.graph_labels = []
        self.index = 1

    def configure_optimizers(self):
        """Configure Optimizers"""
        return optim.Adam([
                           {'params': self.net.parameters(),       'lr': self.outer_lr},
                           {'params': self.stop_gate.parameters(), 'lr': self.stop_lr}],
                          lr=1e-04, weight_decay=self.weight_decay
               )
        
    def compute_loss(self, logits: torch.Tensor, label: torch.Tensor) -> float:
        """Compute the loss"""
        return self.loss(logits, label.long())

    @staticmethod
    def smooth(weight, p=10, eps=1e-10):
        weight_abs = weight.abs()
        less = (weight_abs < math.exp(-p)).type(torch.float)
        noless = 1.0 - less
        log_weight = less * -1 + noless * torch.log(weight_abs + eps) / p
        sign = less * math.exp(p) * weight + noless * weight.sign()
        assert  torch.sum(torch.isnan(log_weight))==0,'stop_gate input has nan'
        return log_weight, sign

    def stop(self, step: int, loss: float, node_score: torch.Tensor):
        stop_hx = None
        if ASMAMLConfig.FLEXIBLE_STEP and step < ASMAMLConfig.MAX_STEP:
            inputs = []

            if ASMAMLConfig.USE_LOSS:
                inputs += [loss.detach()]
            if ASMAMLConfig.USE_SCORE:
                score = node_score.detach()
                inputs += [score]

            inputs = torch.stack(inputs, dim=0).unsqueeze(0)
            inputs = self.smooth(inputs)[0]
            stop_gate, stop_hx = self.stop_gate(inputs, stop_hx)

            return stop_gate    

        return loss.new_zeros(1, dtype=torch.float)

    def adapt_meta_learning_rate(self, loss):
        self.scheduler.step(loss)
    
    def get_meta_learning_rate(self):
        epoch_learning_rate = []
        for param_group in self.meta_optim.param_groups:
            epoch_learning_rate.append(param_group['lr'])
        return epoch_learning_rate[0]

    def forward(self, support_data: gdata.batch.Batch, query_data: gdata.batch.Batch):
        # It is just the number of labels to predict in the query set
        query_size = query_data.y.shape[0]

        losses_q = []  # Losses on query data
        corrects, stop_gates, train_losses, train_accs, scores = [], [], [], [], []
        
        fast_parameters = list(self.net.parameters())

        for weight in self.net.parameters():
            weight.fast = None
        
        step = 0
        self.stop_prob = 0.1 if self.stop_prob < 0.1 else self.stop_prob

        # Get adaptation step
        ada_step = min(ASMAMLConfig.MAX_STEP, ASMAMLConfig.MIN_STEP + int(1.0 / self.stop_prob))

        for k in range(0, ada_step):
            # Run the i-th task and compute the loss
            logits, score, _ = self.net(support_data.x, support_data.edge_index, support_data.batch)
            loss = self.compute_loss(logits, support_data.y)

            stop_probability = 0
            if ASMAMLConfig.FLEXIBLE_STEP:
                stop_probability = self.stop(k, loss, score)
                self.stop_prob = stop_probability
            
            stop_gates.append(stop_probability)
            scores.append(score.item())

            with torch.no_grad():
                pred = F.softmax(logits, dim=1).argmax(dim=1)
                correct = torch.eq(pred, support_data.y).sum().item()
                train_accs.append(correct / support_data.y.shape[0])

            step = k
            train_losses.append(loss.item())

            # Compute the gradient with respect to the loss
            grad = torch.autograd.grad(loss, fast_parameters, create_graph=True)
            fast_parameters = []
            for index, weight in enumerate(self.net.parameters()):
                if weight.fast is not None:
                    weight.fast = weight.fast - self.inner_lr * grad[index]
                else:
                    weight.fast = weight - self.inner_lr * grad[index]
                
                fast_parameters.append(weight.fast)
            
            logits_q, _, _ = self.net(query_data.x, query_data.edge_index, query_data.batch)
            loss_q = self.compute_loss(logits_q, query_data.y)

            losses_q.append(loss_q)

            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, query_data.y).sum().item()
                corrects.append(correct)
        
        final_loss = losses_q[step]
        accs = np.array(corrects) / (query_size)
        final_acc = accs[step]
        total_loss = 0

        if ASMAMLConfig.FLEXIBLE_STEP:
            for step, (stop_gate, step_acc) in enumerate(zip(stop_gates[ASMAMLConfig.MIN_STEP - 1:], accs[ASMAMLConfig.MIN_STEP - 1:])):
                assert stop_gate >= 0.0 and stop_gate <= 1.0, "stop_gate error value: {:.5f}".format(stop_gate)
                log_prob = torch.log(1 - stop_gate)
                tem_loss = - log_prob * ((final_acc - step_acc - (np.exp(step) - 1) * ASMAMLConfig.STEP_PENALITY))
                total_loss += tem_loss

            total_loss = (total_loss + final_acc + final_loss)
        else:
            total_loss = final_loss

        total_loss.backward()

        if self.task_index == ASMAMLConfig.BATCH_PER_EPISODES:
            if ASMAMLConfig.GRAD_CLIP > 0.1:
                torch.nn.utils.clip_grad_norm_(self.parameters(), ASMAMLConfig.GRAD_CLIP)

            self.meta_optim.step()
            self.meta_optim.zero_grad()
            self.task_index = 1
        else:
            self.task_index += 1
        
        if ASMAMLConfig.FLEXIBLE_STEP:
            stop_gates = [stop_gate.item() for stop_gate in stop_gates]

        return accs * 100, step, final_loss.item(), total_loss.item(), stop_gates, scores, train_losses, train_accs

    def finetuning(self, support_data, query_data):
        # It is just the number of labels to predict in the query set
        query_size = query_data.y.shape[0]

        corrects = []
        step = 0
        stop_gates, scores, query_loss = [], [], []

        fast_parameters = list(self.net.parameters())

        for weight in self.net.parameters():
            weight.fast = None
        
        ada_step = min(ASMAMLConfig.STEP_TEST, ASMAMLConfig.MIN_STEP + int(2 / self.stop_prob))

        for k in range(ada_step):
            logits, score, _ = self.net(support_data.x, support_data.edge_index, support_data.batch)
            loss = self.compute_loss(logits, support_data.y)

            stop_probability = 0

            if ASMAMLConfig.FLEXIBLE_STEP:
                with torch.no_grad():
                    stop_probability = self.stop(k, loss, score)
            
            stop_gates.append(stop_probability)
            step = k
            scores.append(score.item())

            grad = torch.autograd.grad(loss, fast_parameters, create_graph=True)
            fast_parameters = []

            for index, weight in enumerate(self.net.parameters()):
                if weight.fast is None:
                    weight.fast = weight - ASMAMLConfig.INNER_LR * grad[index]
                else:
                    weight.fast = weight.fast - ASMAMLConfig.INNER_LR * grad[index]

                fast_parameters.append(weight.fast)

            logits_q, _, graph_emb = self.net(query_data.x, query_data.edge_index, query_data.batch)
            self.graph_labels.append(query_data.y)

            self.graph_embs.append(graph_emb)

            if self.index % 1 == 0:
                self.index = 1
                self.graph_embs = []
                self.graph_labels = []
            else:
                self.index += 1
            
            with torch.no_grad():
                pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
                correct = torch.eq(pred_q, query_data.y).sum().item()
                corrects.append(correct)
                
                loss_query = self.compute_loss(logits_q, query_data.y)
                query_loss.append(loss_query.item())

        accs = 100 * np.array(corrects) / query_size

        if ASMAMLConfig.FLEXIBLE_STEP:
            stop_gates = [stop_gate.item() for stop_gate in stop_gates]

        return accs, step, stop_gates, scores, query_loss

---

## Training and Testing

In [104]:
# First of all let's define two python class for training (optimization) and testing
class Optimizer:
    """
    Run Training with train set and validation set

    Attributes:
        train_ds (GraphDataset): the train set
        val_ds (GraphDataset): the validation set
        model_name (str, default="sage"): the name of the model to use ('sage' or 'gcn')
        epochs (int, default=200): number of epochs to run
        dataset_name (str, default="TRIANGLES"): the name of the dataset
    """
    def __init__(self, train_ds: GraphDataset, val_ds: GraphDataset, 
                 model_name: str="sage", epochs: int=200, 
                 dataset_name: str="TRIANGLES", config_class: object=ASMAMLConfig
    ) -> None:
        self.train_ds = train_ds
        self.val_ds = val_ds
        self.model_name = model_name
        self.epochs = epochs
        self.dataset_name = dataset_name
        self.config = config_class

        self.model = self.get_model()
        self.meta_model = self.get_meta()

    def get_model(self) -> Union[GCN4MAML, SAGE4MAML]:
        """Return the model to use with the MetaModel"""
        models = {'sage': SAGE4MAML, 'gcn': GCN4MAML}
        model = models[self.model_name](num_classes=self.config.TRAIN_WAY).to(DEVICE)
        print(f"Creating model of type {model.__class__.__name__}")
        return model

    def get_meta(self) -> AdaptiveStepMAML:
        """Return the meta model"""
        print(f"Creating the AS-MAML model")
        return AdaptiveStepMAML(self.model,
                                inner_lr=self.config.INNER_LR,
                                outer_lr=self.config.OUTER_LR,
                                stop_lr=self.config.STOP_LR,
                                weight_decay=self.config.WEIGHT_DECAY).to(DEVICE)

    def get_dataloaders(self) -> Tuple[FewShotDataLoader, FewShotDataLoader]:
        """Return train and validation dataloaders"""
        print("--- Creating the DataLoader for Training ---")
        train_dataloader = get_dataloader(
            ds=self.train_ds, n_way=self.config.TRAIN_WAY, k_shot=self.config.TRAIN_SHOT,
            n_query=self.config.TRAIN_QUERY, epoch_size=self.config.TRAIN_EPISODE,
            shuffle=True, batch_size=1
        )

        print("--- Creating the DataLoader for Validation ---")
        validation_dataloader = get_dataloader(
            ds=self.val_ds, n_way=self.config.TEST_WAY, k_shot=self.config.VAL_SHOT,
            n_query=self.config.VAL_QUERY, epoch_size=self.config.VAL_EPISODE,
            shuffle=True, batch_size=1
        )

        return train_dataloader, validation_dataloader
    
    def run_one_step_train(
        self, support_data: gdata.Data, query_data: gdata.Data, train_accs: List[float],
        train_total_losses: List[float], train_final_losses: List[float], loop_counter: int
    ) -> None:
        """Run one episode, i.e. one or more tasks, of training"""
        # Set support and query data to the GPU
        if DEVICE != "cpu":
            support_data = support_data.pin_memory()
            query_data = query_data.pin_memory()
            
        support_data = support_data.to(DEVICE)
        query_data = query_data.to(DEVICE)

        accs, step, final_loss, total_loss, _, _, _, _ = self.meta_model(
            support_data, query_data
        )

        train_accs.append(accs[step])
        train_final_losses.append(final_loss)
        train_total_losses.append(total_loss)

        if (loop_counter + 1) % 50 == 0:
            print(f"({loop_counter + 1})" + " Mean Accuracy: {:.6f}, Mean Final Loss: {:.6f}, Mean Total Loss: {:.6f}".format(
                np.mean(train_accs), np.mean(train_final_losses), np.mean(train_total_losses)
                ))
            
    def run_one_step_validation(self, support_data: gdata.Data, 
                                      query_data: gdata.Data, 
                                      val_accs: List[float], 
                                      loop_counter: int) -> None:
        """Run one episode, i.e. one or more tasks, of validation"""
        if DEVICE != "cpu":
            support_data = support_data.pin_memory()
            query_data = query_data.pin_memory()

        support_data = support_data.to(DEVICE)
        query_data = query_data.to(DEVICE)
        
        accs, step, _, scores, query_losses = self.meta_model.finetuning(support_data, query_data)
        acc = get_max_acc(accs, step, scores, self.config.MIN_STEP, self.config.MAX_STEP)

        val_accs.append(accs[step])
        if (loop_counter + 1) % 50 == 0:
            printable_string = f"Test Number {loop_counter + 1}\n" + \
                                "\tQuery Losses[{l}]: {query_losses}\n\tAccuracies {step}: {accs}\n\tMax Accuracy: {max_acc}\n".format(
                                    l=len(query_losses), query_losses=query_losses, step=step,
                                    accs=np.array([accs[i] for i in range(0, step + 1)]), max_acc=acc
                                )

            print(printable_string)
    
    @elapsed_time
    def optimize(self):
        """Run the optimization (fitting)"""
        train_dl, val_dl = self.get_dataloaders()
        max_val_acc = 0
        print("=" * 40 + " Starting Optimization " + "=" * 40)

        for epoch in range(self.epochs):
            setup_seed(epoch)
            print("=" * 103)
            print("Epoch Number {:04d}".format(epoch))

            self.meta_model.train()
            train_accs, train_final_losses, train_total_losses, val_accs = [], [], [], []

            print("Training Phase")

            for i, data in enumerate(tqdm(train_dl)):
                support_data, query_data = data
                self.run_one_step_train(
                    support_data=support_data, query_data=query_data,
                    train_accs=train_accs, train_total_losses=train_total_losses,
                    train_final_losses=train_final_losses, loop_counter=i
                )
            
            print("Ended Training Phase")
            print("Validation Phase")

            self.meta_model.eval()
            for i, data in enumerate(tqdm(val_dl)):
                support_data, query_data = data
                self.run_one_step_validation(
                    support_data=support_data, query_data=query_data,
                    val_accs=val_accs, loop_counter=i
                )
            
            print("Ended Validation Phase")

            val_acc_avg = np.mean(val_accs)
            train_acc_avg = np.mean(train_accs)
            train_loss_avg = np.mean(train_final_losses)
            val_acc_ci95 = 1.96 * np.std(np.array(val_accs)) / np.sqrt(self.config.VAL_EPISODE)

            if val_acc_avg > max_val_acc:
                max_val_acc = val_acc_avg
                printable_string = "Epoch(***Best***) {:04d}\n".format(epoch)

                torch.save({
                        'epoch': epoch, 
                        'embedding': self.meta_model.state_dict()
                    }, os.path.join(MODELS_SAVE_PATH, f'{self.dataset_name}_BestModel.pth')
                )
            else :
                printable_string = "Epoch {:04d}\n".format(epoch)
            
            printable_string += "\tAvg Train Loss: {:.6f}, Avg Train Accuracy: {:.6f}\n".format(train_loss_avg, train_acc_avg) + \
                                "\tAvg Validation Accuracy: {:.2f} ±{:.26f}\n".format(val_acc_avg, val_acc_ci95) + \
                                "\tMeta Learning Rate: {:.6f}\n".format(self.meta_model.get_meta_learning_rate()) + \
                                "\tBest Current Validation Accuracy: {:.2f}".format(max_val_acc)

            print(printable_string)
            self.meta_model.adapt_meta_learning_rate(train_loss_avg)

        print("Optimization Finished")


class Tester:
    """Class for run tests using the best model from training"""
    def __init__(self, test_ds: GraphDataset, best_model_path: str, config_class: object=ASMAMLConfig,
                       dataset_name: str="TRIANGLES", model_name: str="sage") -> None:
        self.test_ds = test_ds
        self.dataset_name = dataset_name
        self.model_name = model_name
        self.best_model_path = best_model_path
        self.config = config_class

        self.model = self.get_model()
        self.meta_model = self.get_meta()

        # Using the pre-trained model, i.e. the best model resulted during training
        saved_models = torch.load(self.best_model_path)
        self.meta_model.load_state_dict(saved_models["embedding"])
        self.model = self.meta_model.net

    
    def get_model(self) -> Union[GCN4MAML, SAGE4MAML]:
        """Return the model to use with the MetaModel"""
        models = {'sage': SAGE4MAML, 'gcn': GCN4MAML}
        model = models[self.model_name](num_classes=self.config.TRAIN_WAY).to(DEVICE)
        print(f"Creating model of type {model.__class__.__name__}")
        return model

    def get_meta(self) -> AdaptiveStepMAML:
        """Return the meta model"""
        print(f"Creating the AS-MAML model")
        return AdaptiveStepMAML(self.model,
                                inner_lr=self.config.INNER_LR,
                                outer_lr=self.config.OUTER_LR,
                                stop_lr=self.config.STOP_LR,
                                weight_decay=self.config.WEIGHT_DECAY,
                                paper=self.paper).to(DEVICE)
    
    def run_one_step_test(self, support_data: gdata.Data, query_data: gdata.Data, 
                                val_accs: List[float], query_losses_list: List[float]) -> None:
        """Run one single step of testing"""
        if DEVICE != "cpu":
            support_data = support_data.pin_memory()
            query_data = query_data.pin_memory()

        support_data = support_data.to(DEVICE)
        query_data = query_data.to(DEVICE)

        accs, step, _, _, query_losses = self.meta_model.finetunning(support_data, query_data)

        val_accs.append(accs[step])
        query_losses_list.extend(query_losses)
    
    def get_dataloader(self) -> FewShotDataLoader:
        """Return test dataloader"""
        print("--- Creating the DataLoader for Testing ---")
        test_dataloader = get_dataloader(
            ds=self.test_ds, n_way=self.config.TEST_WAY, k_shot=self.config.VAL_SHOT,
            n_query=self.config.VAL_QUERY, epoch_size=self.config.VAL_EPISODE,
            shuffle=True, batch_size=1
        )

        return test_dataloader
    
    @elapsed_time
    def test(self):
        """Run testing"""
        setup_seed(1)

        test_dl = self.get_dataloader()

        print("=" * 40 + " Starting Testing " + "=" * 40)

        val_accs = []
        query_losses_list = []
        self.meta_model.eval()

        for _, data in enumerate(tqdm(test_dl)):
            support_data, query_data = data
            self.run_one_step_test(support_data, query_data, val_accs, query_losses_list)
        
        val_acc_avg = np.mean(val_accs)
        val_acc_ci95 = 1.96 * np.std(np.array(val_accs)) / np.sqrt(self.config.VAL_EPISODE)
        query_losses_avg = np.array(query_losses_list).mean()
        query_losses_min = np.array(query_losses_list).min()

        printable_string = (
            "\nTEST FINISHED --- Results\n"        +
            "\tTesting Accuracy: {:.2f} ±{:.2f}\n" + 
            "\tQuery Losses Avg: {:.6f}\n"         +
            "\tMin Query Loss: {:.6f}\n"
            ).format(
                val_acc_avg, val_acc_ci95,
                query_losses_avg, query_losses_min
            )

        print(printable_string)

In [None]:
# Run training and then testing
torch.set_printoptions(edgeitems=EDGELIMIT_PRINT)

dataset_name = DEFAULT_DATASET
train_ds, test_ds, val_ds, _ = get_dataset(
    download=DOWNLOAD_DATASET, 
    data_dir=DATA_PATH,
    dataset_name=dataset_name
)

print("--- Datasets ---")
print("\n- Train: ", train_ds)
print("- Test : ", test_ds)
print("- Validation: ", val_ds)
print()

print("--- Configurations ---")

configurations = ("\nDEVICE: {device}\n"                            +
                    "DATASET NAME: {dataset_name}\n"                + 
                    "TRAIN SUPPORT SIZE: {train_support_size}\n"    +
                    "TRAIN QUERY SIZE: {train_query_size}\n"        +
                    "VALIDATION SUPPORT SIZE: {val_support_size}\n" +
                    "VALIDATION QUERY SIZE: {val_query_size}\n"     +
                    "TEST SUPPORT SIZE: {test_support_size}\n"      +
                    "TEST QUERY SIZE: {test_query_size}\n"          +
                    "TRAIN EPISODE: {train_episode}\n"              +
                    "VALIDATION EPISODE: {val_episode}\n"           +
                    "NUMBER OF EPOCHS: {number_of_epochs}\n"        +
                    "BATCH PER EPISODES: {batch_per_episodes}\n"
    ).format(
        device=DEVICE, dataset_name=dataset_name,
        train_support_size=f"{ASMAMLConfig.TRAIN_WAY} x {ASMAMLConfig.TRAIN_SHOT}",
        train_query_size=f"{ASMAMLConfig.TRAIN_WAY} x {ASMAMLConfig.TRAIN_QUERY}",
        val_support_size=f"{ASMAMLConfig.TEST_WAY} x {ASMAMLConfig.VAL_SHOT}",
        val_query_size=f"{ASMAMLConfig.TEST_WAY} x {ASMAMLConfig.VAL_QUERY}",
        test_support_size=f"{ASMAMLConfig.TEST_WAY} x {ASMAMLConfig.VAL_SHOT}",
        test_query_size=f"{ASMAMLConfig.TEST_WAY} x {ASMAMLConfig.VAL_QUERY}",
        train_episode=ASMAMLConfig.TRAIN_EPISODE, val_episode=ASMAMLConfig.VAL_EPISODE,
        number_of_epochs=ASMAMLConfig.EPOCHS, batch_per_episodes=ASMAMLConfig.BATCH_PER_EPISODES
    )

print(configurations)

optimizer = Optimizer(train_ds, val_ds, epochs=ASMAMLConfig.EPOCHS, dataset_name=dataset_name)
optimizer.optimize()

best_model_path = os.path.join(MODELS_SAVE_PATH, f"{dataset_name}_BestModel.pth")
tester = Tester(test_ds, best_model_path)
tester.test()