# Few-Shot Graph Classification

Most of the graph classification task overlook the scarcity of labeled graph in many situations. To overcome this problem, *Few-Shot Learning* is started being used. It is a type of Machine Learning method where the training dataset contains limited information. The general practice is to feed the machine learning model with as much data as possible, since this leads to better predictions. However, few-shot learning aims to build accurate machine learning models with less training data. Few-Shot Learning, and in particular in this case Few-shot classification, aims to reduce the cost of gain and label a huge amount of data.

*Which is the idea behind Few-Shot Learning*? (on graphs) Given graph data $\mathcal{G} = \{(G_1, \mathbf{y}_1), ..., (G_n, \mathbf{y}_n)\}$, we split it into train, $\{(G^{train}, \mathbf{y}^{train})\}$, and test dataset, $\{(G^{test}, \mathbf{y}^{test})\}$. Notice that $\mathbf{y}^{train}$ and $\mathbf{y}^{test}$ must have no common classes. For training we use episodic training method, this means that at training stage the algorithm sample a so-called *Task*, i.e., a pair (*support* set, *query* set) where the support set is $D_{sup}^{train} = \{(G_i^{train}, \mathbf{y}_{i}^{train})\}_{i=1}^s$, where $s = N \times K$, while the query set is $D_{que}^{train} = \{(G_i^{train}, \mathbf{y}_{i}^{train})\}_{i=1}^q$, where $q$ is the number of query data. Given labeled support data, the goal is to predict the labels of query data. Note that in a single task, support data and query data share the same class space. This is also called **N-way-K-shot** learning, where **N** is the number of sampled classes and **K** is the number of samples for each of the N classes. At test stage when performing classification tasks on unseen classes, we firstly fine tune the meta-learner on the support data of test classes, then report classification performance on the test query set.

In the following, I'm going to present some approaches in few-shot Learning. First, a *Meta-Learning Framework* based on Fast Weight Adaptation, taken from the paper [Adaptive-Step Graph Meta-Learner for Few-Shot Graph Classification](https://arxiv.org/pdf/2003.08246.pdf) (Ning Ma et al.). Second, I'm going to compare it with different GDA (graph data augmentation) techniques used to enrich the dataset for the novel classes (i.e., those with the less amount of data) taken from a second paper named [Graph Data Augmentation for Graph Machine Learning: A Survey](https://arxiv.org/pdf/2202.08871.pdf) (Tong Zhao et al.).

## Modules and Constants

In [None]:
import torch
TORCH = torch.__version__.split('+')[0]
CUDA = 'cu' + torch.version.cuda.replace('.','')

!pip install pytorch-lightning
!pip install pyyaml==5.4.1
!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

In [None]:
from typing import (
    Any, Dict, List, Tuple, 
    Union, Generic, Optional,
    TypeVar
)

from tqdm.notebook import tqdm
from functools import wraps
import plotly.graph_objects as go
import networkx as nx
import numpy as np
import pickle
import os
import shutil
import logging
import random
import time
import requests
import zipfile
import math
import sys

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn import Parameter

import torch_geometric.data as gdata
import torch_geometric.loader as gloader
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_mean_pool, global_max_pool
from torch_geometric.nn.inits import uniform
from torch_geometric.nn.pool.topk_pool import topk, filter_adj
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_goemetric.utils import (
    add_remaining_self_loops, 
    add_self_loops, 
    remove_self_loops,
    softmax
)

from torch_scatter import scatter_add

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

In [None]:
TRIANGLES_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/TRIANGLES.zip"
COIL_DEL_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/COIL-DEL.zip"
R52_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/R52.zip"
LETTER_HIGH_ZIP_URL = "https://cloud-storage.eu-central-1.linodeobjects.com/Letter-High.zip"

DATASETS = {
    "TRIANGLES"   : TRIANGLES_ZIP_URL, 
    "COIL-DEL"    : COIL_DEL_ZIP_URL, 
    "R52"         : R52_ZIP_URL, 
    "Letter-High" : LETTER_HIGH_ZIP_URL
}

T = TypeVar('T')

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DOWNLOAD_DATASET = True
SAVE_PICLKE  = True
EDGELIMIT_PRINT = 2000

NUM_FEATURES = {"TRIANGLES": 1, "R52": 1, "Letter-High": 2, "COIL-DEL": 2}


class ASMAMLConfig:
    NHID = 128
    POOLING_RATIO = 0.5
    DROPOUT_RATIO = 0.3

    OUTER_LR     = 0.001
    INNER_LR     = 0.01
    STOP_LR      = 0.0001
    WEIGHT_DECAY = 1E-05

    MAX_STEP      = 15
    MIN_STEP      = 5
    STEP_TEST     = 15
    FLEXIBLE_STEP = True
    STEP_PENALITY = 0.001
    USE_SCORE     = True
    USE_GRAD      = False
    USE_LOSS      = True

    TRAIN_SHOT         = 10   # K-shot for training set
    VAL_SHOT           = 10   # K-shot for validation (or test) set
    TRAIN_QUERY        = 15   # Number of query for the training set
    VAL_QUERY          = 15   # Number of query for the validation (or test) set
    TRAIN_WAY          = 3    # N-way for training set
    TEST_WAY           = 3    # N-way for test set
    VAL_EPISODE        = 200  # Number of episodes for validation
    TRAIN_EPISODE      = 200  # Number of episodes for training
    BATCH_PER_EPISODES = 5    # How many batch per episode
    EPOCHS             = 500  # How many epochs
    PATIENCE           = 35
    GRAD_CLIP          = 5

    # Stop Control configurations
    STOP_CONTROL_INPUT_SIZE = 2
    STOP_CONTROL_HIDDEN_SIZE = 20

## Utility Functions

In [None]:
def scandir(root_path: str) -> List[str]:
    """Recursively scan a directory looking for files"""
    root_path = os.path.abspath(root_path)
    content = []
    for file in os.listdir(root_path):
        new_path = os.path.join(root_path, file)
        if os.path.isfile(new_path):
            content.append(new_path)
            continue
        
        content += scandir(new_path)
    
    return content


def download_zipped_data(url: str, path2extract: str, dataset_name: str) -> List[str]:
    """Download and extract a ZIP file from URL. Return the content filename"""
    logging.debug(f"--- Downloading from {url} ---")
    response = requests.get(url)

    abs_path2extract = os.path.abspath(path2extract)
    zip_path = os.path.join(abs_path2extract, f"{dataset_name}.zip")
    with open(zip_path, mode="wb") as iofile:
        iofile.write(response.content)

    # Extract the file
    logging.debug("--- Extracting files from the archive ---")
    with zipfile.ZipFile(zip_path, mode="r") as zip_ref:
        zip_ref.extractall(abs_path2extract)

    logging.debug(f"--- Removing {zip_path} ---")
    os.remove(zip_path)

    return scandir(os.path.join(path2extract, dataset_name))


def delete_data_folder(path2delete: str) -> None:
    """Delete the folder containing data"""
    logging.debug("--- Removing Content Data ---")
    shutil.rmtree(path2delete)
    logging.debug("--- Removed Finished Succesfully ---")

In [None]:
def elapsed_time(func):
    """Just a simple wrapper for counting elapsed time from start to end"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        func(*args, **kwargs)
        end = time.time()
        logging.debug("Elapsed Time: {:.6f}".format(end - start))
    
    return wrapper

In [None]:
def save_with_pickle(path2save: str, content: Any) -> None:
    """Save content inside a .pickle file denoted by path2save"""
    path2save = path2save + ".pickle" if ".pickle" not in path2save else path2save
    with open(path2save, mode="wb") as iostream:
        pickle.dump(content, iostream)


def load_with_pickle(path2load: str) -> Any:
    """Load a content from a .pickle file"""
    with open(path2load, mode="rb") as iostream:
        return pickle.load(iostream)

In [None]:
def setup_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True