# Few-Shot Graph Classification

In [1]:
import torch
TORCH = torch.__version__.split('+')[0]
CUDA = 'cu' + torch.version.cuda.replace('.','')

!pip install pytorch-lightning
!pip install pyyaml==5.4.1
!pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
!pip install torch-geometric

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-lightning
  Downloading pytorch_lightning-1.7.1-py3-none-any.whl (701 kB)
[K     |████████████████████████████████| 701 kB 8.7 MB/s 
Collecting PyYAML>=5.4
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 49.3 MB/s 
Collecting torchmetrics>=0.7.0
  Downloading torchmetrics-0.9.3-py3-none-any.whl (419 kB)
[K     |████████████████████████████████| 419 kB 75.7 MB/s 
[?25hCollecting pyDeprecate>=0.3.1
  Downloading pyDeprecate-0.3.2-py3-none-any.whl (10 kB)
Collecting tensorboard>=2.9.1
  Downloading tensorboard-2.9.1-py3-none-any.whl (5.8 MB)
[K     |████████████████████████████████| 5.8 MB 51.1 MB/s 
[?25hCollecting fsspec[http]!=2021.06.0,>=2021.05.0
  Downloading fsspec-2022.7.1-py3-none-any.whl (141 kB)
[K     |██████████████████████

In [2]:
import os
from typing import Dict, List, Tuple, Union, \
                   Sequence, TypeVar, Generic, \
                   Optional
import networkx as nx
import numpy as np
import plotly
import plotly.express as px
import plotly.graph_objects as go

from torchvision import transforms
import torch
from torch import optim
from torch.nn import functional as F
import torch.nn as nn

import torch_geometric as gtorch
import torch_geometric.nn as gnn
import torch_geometric.data as gdata
import torch_geometric.loader as gloader
import torch_geometric.utils as gutils

import random
from collections import defaultdict

  defaults = yaml.load(f)


## Load dataset

The dataset I will load is the TRIANGLE used in the AS-MAML paper (here's the [link](https://arxiv.org/pdf/2003.08246.pdf))

In [3]:
ROOT_PATH = os.getcwd()
GRAPH_ATTRIBUTE  = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_graph_attributes.txt")
GRAPH_LABELS     = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_graph_labels.txt")
NODE_NATTRIBUTE  = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_node_attributes.txt")
GRAPH_INDICATOR  = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_graph_indicator.txt")
GRAPH_A          = os.path.join(ROOT_PATH, "TRIANGLES/TRIANGLES_A.txt")

T = TypeVar('T')

LOAD_DATASET = True

if LOAD_DATASET:
    !wget https://cloud-storage.eu-central-1.linodeobjects.com/TRIANGLES.zip
    !unzip TRIANGLES.zip

--2022-08-10 13:57:32--  https://cloud-storage.eu-central-1.linodeobjects.com/TRIANGLES.zip
Resolving cloud-storage.eu-central-1.linodeobjects.com (cloud-storage.eu-central-1.linodeobjects.com)... 172.105.80.252, 139.162.182.14, 172.105.69.135, ...
Connecting to cloud-storage.eu-central-1.linodeobjects.com (cloud-storage.eu-central-1.linodeobjects.com)|172.105.80.252|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 6339762 (6.0M) [application/x-zip-compressed]
Saving to: ‘TRIANGLES.zip’


2022-08-10 13:57:35 (4.56 MB/s) - ‘TRIANGLES.zip’ saved [6339762/6339762]

Archive:  TRIANGLES.zip
   creating: TRIANGLES/
  inflating: TRIANGLES/README.txt    
  inflating: TRIANGLES/TRIANGLES_A.txt  
  inflating: TRIANGLES/TRIANGLES_graph_attributes.txt  
  inflating: TRIANGLES/TRIANGLES_graph_indicator.txt  
  inflating: TRIANGLES/TRIANGLES_graph_labels.txt  
  inflating: TRIANGLES/TRIANGLES_node_attributes.txt  


In [4]:
graph_attribute      = open(GRAPH_ATTRIBUTE).readlines()
graph_labels         = open(GRAPH_LABELS).readlines()
graph_node_attribute = open(NODE_NATTRIBUTE).readlines()
graph_indicator      = open(GRAPH_INDICATOR).readlines()
graph_a              = open(GRAPH_A).readlines()

In [5]:
class GeneratorTxt2Graph:
    def __init__(self, **kargs) -> None:
        self.__graph_attribute  = kargs['graph_attribute']
        self.__graph_labels     = kargs['graph_labels']
        self.__node_attribute   = kargs['node_attribute']
        self.__graph_indicator  = kargs['graph_indicator']
        self.__graph_adjacency  = kargs['graph_adjacency']
        self.__node_labels      = kargs['node_labels']
        self.__edge_labels      = kargs['edge_labels']
        self.__edge_attributes  = kargs['edge_attributes']

    def _collect_nodes(self) -> Tuple[Dict[str, List[int]], Dict[str, Tuple[str, int]]]:
        """ Look at the graph_indicator.txt file and return
        a dictionary containing as keys the ID of the graph
        and as values a list of nodes belonging to that graph """
        print("--- Collecting Nodes ... ")

        nodes, i_nodes = dict(), dict()
        for i, graph_id in enumerate(self.__graph_indicator):
            graph_id = graph_id[:-1]
            if graph_id not in nodes:
                nodes[graph_id] = []
        
            nodes[graph_id].append(i + 1)
            i_nodes[i + 1] = [graph_id, i + 1]
        
        return nodes, i_nodes

    def _collect_edges(self, i_nodes: Dict[str, Tuple[str, int]], 
                             direct: bool=False) -> Dict[str, List[Tuple[int, int]]]:
        """ Look at the graph_A.txt file and return a dictionary
        containing as keys the ID of the graph and as values
        a list of edges of that graph """
        print("--- Collecting Edges ...")

        edges = dict()
        for line in self.__graph_adjacency:
            if line == "\n":
                continue
            
            a, b = line.split(", ")
            a, b = a.strip(), b.strip()

            graph_a, node_a = i_nodes[int(a)]
            graph_b, node_b = i_nodes[int(b)]

            assert graph_a == graph_b, f"Two graphs are not equal: {graph_a} != {graph_b}"

            if graph_a not in edges:
                edges[graph_a] = []
            
            edges[graph_a].append((node_a, node_b))

        return edges
    
    def _collect_node_attributes(self, i_nodes: Dict[str, Tuple[str, int]]) -> None:
        """ Set attributes for each nodes """
        print("--- Collecting Node Attributes ...")
        for i, attr in enumerate(self.__node_attribute):
            node_i = i_nodes[i + 1]
            attrs = attr.split(", ")
            attrs[-1] = attrs[-1][:-1]
            node_i.append({f"attr{i}" : attr for i, attr in enumerate(attrs)})

    def _collect_graph_labels(self, graphs: Dict[str, nx.Graph]) -> None:
        """ Set the attribute label for each graph """
        print("--- Collecting Graph Labels ...")
        for i, label in enumerate(self.__graph_labels):
            graph_i = graphs[str(i + 1)]
            graphs[str(i + 1)] = (graph_i, label[:-1])

    # TODO: _collect_node_labels, _collect_edge_labels, _collect_edge_attributes

    def generate(self) -> Dict[str, nx.Graph]:
        """ Return a dictionary of {i : Graph_i} """
        # Get Nodes and Edges
        nodes, i_nodes = self._collect_nodes()
        edges          = self._collect_edges(i_nodes, False)

        # Set attributes for nodes
        self._collect_node_attributes(i_nodes)
        
        # Create the graphs
        graphs = dict()
        for graph_id in edges:
            g = nx.Graph()
            g_nodes = [(i_nodes[n][1], i_nodes[n][-1]) for n in nodes[graph_id]]
            g_edges = edges[graph_id]

            g.add_nodes_from(g_nodes)
            g.add_edges_from(g_edges)

            graphs[graph_id] = g

        # Set labels for graph
        self._collect_graph_labels(graphs)

        return graphs

In [6]:
graphs_gen = GeneratorTxt2Graph(graph_attribute=graph_attribute,
                                graph_labels=graph_labels,
                                node_attribute=graph_node_attribute,
                                graph_indicator=graph_indicator,
                                graph_adjacency=graph_a,
                                node_labels=None,
                                edge_labels=None,
                                edge_attributes=None)

In [7]:
%%time 
graphs = graphs_gen.generate()

--- Collecting Nodes ... 
--- Collecting Edges ...
--- Collecting Node Attributes ...
--- Collecting Graph Labels ...
CPU times: user 14.4 s, sys: 1.16 s, total: 15.5 s
Wall time: 15.9 s


In [8]:
def plot_graph(G : Union[nx.Graph, nx.DiGraph], name: str) -> None:
    """
    Plot a graph
    
    Parameters
    ----------
    graph : Union[nx.Graph, nx.DiGraph]
        Just a nx.Graph object
    name  : str
        The name of the graph
        
    Returns
    -------
    None
    """
    # Getting the 3D Spring layout
    layout = nx.spring_layout(G, dim=3, seed=18)
    
    # Getting nodes coordinate
    x_nodes = [layout[i][0] for i in layout]  # x-coordinates of nodes
    y_nodes = [layout[i][1] for i in layout]  # y-coordinates of nodes
    z_nodes = [layout[i][2] for i in layout]  # z-coordinates of nodes
    
    # Getting a list of edges and create a list with coordinates
    elist = G.edges()
    x_edges, y_edges, z_edges = [], [], []
    for edge in elist:
        x_edges += [layout[edge[0]][0], layout[edge[1]][0], None]
        y_edges += [layout[edge[0]][1], layout[edge[1]][1], None]
        z_edges += [layout[edge[0]][2], layout[edge[1]][2], None]

    colors = np.linspace(0, len(x_nodes))
        
    # Create a trace for the edges
    etrace = go.Scatter3d(x=x_edges,
                          y=y_edges,
                          z=z_edges,
                          mode='lines',
                          line=dict(color='rgb(125,125,125)', width=1),
                          hoverinfo='none'
                         )
    
    # Create a trace for the nodes
    ntrace = go.Scatter3d(x=x_nodes,
                          y=y_nodes,
                          z=z_nodes,
                          mode='markers',
                          marker=dict(
                              symbol='circle',
                              size=6,
                              color=colors,
                              colorscale='Viridis',
                              line=dict(color='rgb(50,50,50)', width=.5)),
                          text=list(layout.keys()),
                          hoverinfo='text'
                         )
    
    # Set the axis
    axis = dict(showbackground=False,
                showline=False,
                zeroline=False,
                showgrid=False,
                showticklabels=False,
                title='')
    
    # Create a layout for the plot
    go_layout = go.Layout(title=f"{name} Network Graph",
                          width=600,
                          height=600,
                          showlegend=False,
                          scene=dict(xaxis=dict(axis),
                                     yaxis=dict(axis),
                                     zaxis=dict(axis)),
                          margin=dict(t=100),
                          hovermode='closest'
                         )
    
    # Plot
    data = [etrace, ntrace]
    fig = go.Figure(data=data, layout=go_layout)
    fig.show()

In [None]:
plot_graph(graphs["1"][0], "1")

## Dataset, Few-Shot Sampler and DataLoader

#### Some useful definitions

- **Episodic Training**: at training stage the algorithm sample a *Task*
- **Task**: a pair (support set, query set)
- **Support Set**: $D_{sup}^{train} = \{(G_i^{train}, \mathbf{y}_{i}^{train})\}_{i=1}^s$, where $s = N \times K$
- **Query Set**: $D_{que}^{train} = \{(G_i^{train}, \mathbf{y}_{i}^{train})\}_{i=1}^q$, where $q$ is the number of query data

*Problem Definition*

Given graph data $\mathcal{G} = \{(G_1, \mathbf{y}_1), ..., (G_n, \mathbf{y}_n)\}$, we split it into train, $\{(G^{train}, \mathbf{y}^{train})\}$, and test dataset, $\{(G^{test}, \mathbf{y}^{test})\}$. Notice that $\mathbf{y}^{train}$ and $\mathbf{y}^{test}$ must have no common classes. For training, we use episodic training method. Given labeled support data, the goal is to predict the labels of query data. Note that in a single task, support data and query data share the same class space. At test stage when performing classification tasks on unseen classes, we firstly fine tune the meta-learner on the support data of test classes, then report classification performance on the test query set.

In [9]:
class GraphDataset(gdata.Dataset):
    def __init__(self, graphs_ds: Dict[str, Tuple[nx.Graph, str]]) -> None:
        super(GraphDataset, self).__init__()
        self.graphs_ds = graphs_ds

    def __repr__(self) -> str:
        return f"GraphDataset(classes={set(self.targets().tolist())},n_graphs={self.len()})"

    def indices(self) -> List[str]:
        """ Return all the graph IDs """
        return list(self.graphs_ds.keys())

    def len(self) -> int:
        return len(self.graphs_ds.keys())

    def targets(self) -> torch.Tensor:
        """ Return all the labels """
        targets = []
        for _, graph in self.graphs_ds.items():
            targets.append(int(graph[1]))

        return torch.tensor(targets)
    
    def get(self, idx: Union[int, str]) -> gdata.Data:
        """ Return (Graph object, Adjacency matrix and label) of a graph """
        if isinstance(idx, str):
            idx = int(idx)

        graph    = self.graphs_ds[str(idx)]
        g, label = graph[0].to_directed(), graph[1]
        
        # Retrieve nodes attributes
        attrs = list(g.nodes(data=True))
        x     = torch.tensor([list(map(int, a.values())) for _, a in attrs], dtype=torch.float)

        # Retrieve edges
        edge_index = torch.tensor([list(e) for e in g.edges], dtype=torch.long) \
                          .t()                                                  \
                          .contiguous()

        # Retrieve ground trouth labels
        y = torch.tensor([int(label)], dtype=torch.int)

        return gdata.Data(x=x, edge_index=edge_index, y=y)

    @classmethod
    def dataset_from_labels(cls, mask   : torch.Tensor,
                                 classes: torch.Tensor,
                                 graphs : Dict[str, Tuple[nx.Graph, str]]
    ) -> 'GraphDataset':
        """ Return a new Dataset containing only graphs with specific labels """
        filter = classes[(mask[:, None] == classes[None, :]).any(dim=0)].numpy()\
                 .astype(str)\
                 .tolist()

        filtered_graphs = {k : v for k, v in graphs.items() if v[1] in filter}
        graph_dataset   = super(GraphDataset, cls).__new__(cls)

        graph_dataset.__init__(filtered_graphs)

        return graph_dataset



def get_all_labels(graphs: Dict[str, Tuple[nx.Graph, str]]) -> torch.Tensor:
    """ Return a list containings all labels of the dataset """
    return torch.tensor(list(set([int(v[1]) for _, v in graphs.items()])))


def generate_train_val_test(graphs    : Dict[str, Tuple[nx.Graph, str]],
                            perc_test : float,
                            perc_train: float,
                            perc_val  : float
) -> Tuple[GraphDataset, GraphDataset, GraphDataset]:
    """ Return dataset for training, validation and testing """
    classes = get_all_labels(graphs)
    n_class = len(classes)
    perm    = torch.randperm(n_class) + 1

    q_train = n_class * perc_train // 100
    q_test  = n_class * perc_test  // 100
    q_val   = n_class * perc_val   // 100

    
    train_perm = perm[:q_train]
    test_perm  = perm[q_train: q_train + q_test]
    val_perm   = perm[q_train + q_test:]
    
    train_ds = GraphDataset.dataset_from_labels(train_perm, classes, graphs)
    test_ds  = GraphDataset.dataset_from_labels(test_perm,  classes, graphs)
    val_ds   = GraphDataset.dataset_from_labels(val_perm,   classes, graphs)

    return train_ds, test_ds, val_ds

In [10]:
train_ds, test_ds, val_ds = generate_train_val_test(graphs, perc_train=50, perc_test=30, perc_val=20)
train_ds, test_ds, val_ds

(GraphDataset(classes={1, 4, 7, 8, 10},n_graphs=22500),
 GraphDataset(classes={2, 5, 6},n_graphs=13500),
 GraphDataset(classes={9, 3},n_graphs=9000))

**Sampler Pseudo-Code**

```
function iter_sample_NKshot_with_Query(
	Data:
		- G(train)  --> train set
		- d --> dimension of the train set
		- c --> number of classes of the train set
		- N --> Number of classes to select
		- K --> Number of support sample per class
		- Q --> Number of query sample per class
		- epoch_size --> number of batches per epoch
){
	target_classes = random.sample(from=unique(y(train, i), i=1...d), size=N)

	for (i=1...epoch_size) do
	{
		foreach (cl <- target_classes) do 
		{
			filtered_data = filter(data=G(train),by=Lambda(x, x.y == cl))
			f = |filtered_data|

			IMPORTANT: assert(f >= K + Q)

			selected_data = random.sample(from=filtered_data, size=(K + Q))	
			support_data = selected_data.slice(start=0, end=K)
			query_data = selected_data.slice(start=K, end=(K + Q))

			generate(support_data, query_data)
		}
	}
}
```

In [229]:
class NWayKShotSampler(torch.utils.data.Sampler):
    """
    In few-shot classification, and in particular in Meta-Learning, 
    we use a specific way of sampling batches from the training/val/test 
    set. This way is called N-way-K-shot, where N is the number of classes 
    to sample per batch and K is the number of examples to sample per class 
    in the batch. The sample batch on which we train our model is also called 
    `support` set, while the one on which we test is called `query` set.

    This class is a N-way-K-shot sampler that will be used as a batch_sampler
    for the :obj:`torch_geometric.loader.DataLoader` dataloader. This sampler
    return batches of indices that correspond to support and query set batches.

    Attributes:
        labels: PyTorch tensor of the labels of the data elements
        n_way: Number of classes to sampler per batch
        k_shot: Number of examples to sampler per class in the batch
        n_query: Number of query example to sample per class in the batch
        shuffle: If True, examples and classes are shuffled at each iteration
        indices_per_class: How many indices per classes
        classes: list of all classes
        epoch_size: number of batches per epoch
    """
    def __init__(self, labels       : torch.Tensor, 
                       n_way        : int, 
                       k_shot       : int,
                       n_query      : int,
                       epoch_size   : int,
                       shuffle      : bool=True) -> None:
        super().__init__(None)
        self.labels = labels
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.shuffle = shuffle
        self.epoch_size = epoch_size

        self.classes = torch.unique(self.labels).tolist()
        self.indices_per_class = dict()
        for cl in self.classes:
            self.indices_per_class[cl] = torch.where(self.labels == cl)[0]
    
    def shuffle_data(self) -> None:
        """
        Shuffle the examples per class
        
        Args:
            classes: The list of all classes
        """
        for cl in self.classes:
            perm = torch.randperm(self.indices_per_class[cl].shape[0])
            self.indices_per_class[cl] = self.indices_per_class[cl][perm]

    def __iter__(self) -> List[torch.Tensor]:
        # Shuffle the data
        if self.shuffle:
            self.shuffle_data()

        target_classes = random.sample(self.classes, self.n_way)
        for _ in range(self.epoch_size):
            n_way_k_shot_n_query = []
            for cl in target_classes:
                labels_per_class = self.indices_per_class[cl]
                assert len(labels_per_class) >= self.k_shot + self.n_query
                selected_data = random.sample(labels_per_class.tolist(), self.k_shot + self.n_query)
                n_way_k_shot_n_query.append(selected_data)

            yield torch.tensor(n_way_k_shot_n_query)

    def __len__(self) -> int:
        return self.epoch_size

In [270]:
class TaskBatchSampler(torch.utils.data.Sampler):
    """Sample a batch of tasks"""

    def __init__(self, dataset_targets: torch.Tensor,
                       batch_size     : int,
                       n_way          : int,
                       k_shot         : int,
                       n_query        : int,
                       epoch_size     : int,
                       shuffle        : bool = True) -> None:
    
        super().__init__(None)
        self.task_sampler = NWayKShotSampler(
            dataset_targets,
            n_way=n_way,
            k_shot=k_shot,
            n_query=n_query,
            epoch_size=epoch_size,
            shuffle=shuffle
        )

        self.task_batch_size = batch_size
    
    def __iter__(self):
        mini_batches = []
        for task_idx, task in enumerate(self.task_sampler):
            mini_batches.extend(task.tolist())
            if (task_idx + 1) % self.task_batch_size == 0:
                yield torch.tensor(mini_batches).flatten().tolist()
                mini_batches = []

    def __len__(self):
        return len(self.task_sampler) // self.task_batch_size
    
    def create_batches_from_data_batch(self, data_batch: gdata.batch.DataBatch):
        """
        Assume L = [x1, x2, x3, ..., xN] is the data_batch
        each xi is a graph. Moreover, we have that
        L[0:K] = support sample for the first class
        L[K+1:K+Q] = query sample for the first class
        In general, we have that 
        
              L[i * (K + Q) : (i + 1) * (K + Q)]
    
        is the (support, query) pair for the i-th class
        Finally, the first batch is the one that goes from
        L[0 : N * (K + Q)], so
        
              L[i * N * (K + Q) : (i + 1) * N * (K + Q)]
    
        is the i-th batch. 
        """
        n_way = self.task_sampler.n_way
        k_shot = self.task_sampler.k_shot
        n_query = self.task_sampler.n_query

        total_support_query_number = n_way * (k_shot + n_query)
        support_plus_query = k_shot + n_query

        # Initialize batch list for support and query set
        support_data_batch = []
        query_data_batch = []

        # I know how many batch do I have, so
        for batch_number in range(self.task_batch_size):

            # I also know how many class do I have in a task
            for class_number in range(n_way):

                # First of all let's take the i-th batch
                data_batch_slice = slice(
                    batch_number * total_support_query_number, 
                    (batch_number + 1) * total_support_query_number
                )
                data_batch_per_batch = data_batch[data_batch_slice]

                # Then let's take the (support, query) pair for a class
                support_query_slice = slice(
                    class_number * support_plus_query,
                    (class_number + 1) * support_plus_query
                )
                support_query_data = data_batch_per_batch[support_query_slice]

                # Divide support from query
                support_data = support_query_data[:k_shot]
                query_data = support_query_data[k_shot:support_plus_query]

                support_data_batch += support_data
                query_data_batch += query_data
        
        # Create new DataBatchs and return
        return gdata.Batch.from_data_list(support_data_batch), gdata.Batch.from_data_list(query_data_batch)

In [264]:
class GraphCollater(gloader.dataloader.Collater):
    def __init__(self, *args) -> None:
        super(GraphCollater, self).__init__(*args)
    
    def __call__(self, batch: Generic[T]) -> Generic[T]:
        elem = batch[0]
        if isinstance(elem, GraphDataset):
            return self(elem)
        
        return super(GraphCollater, self).__call__(batch)

class FewShotDataLoader(torch.utils.data.DataLoader):
    """Custom DataLoader for GraphDataset"""
    def __init__(self, dataset   : GraphDataset,
                       batch_size: int = 1,
                       shuffle   : bool=False,
                       follow_batch: Optional[List[str]]=None,
                       exclude_keys: Optional[List[str]]=None,
                       **kwargs) -> None:

        if 'collate_fn' in kwargs:
            del kwargs["collate_fn"]

        self.follow_batch = follow_batch
        self.exclude_keys = exclude_keys

        # Take the batch sampler
        self.batch_sampler = kwargs["batch_sampler"]

        super().__init__(
            dataset,
            batch_size,
            shuffle,
            collate_fn=GraphCollater(follow_batch, exclude_keys),
            **kwargs,
        )
    
    def __iter__(self):
        for x in super().__iter__():
            support_batch, query_batch = self.batch_sampler.create_batches_from_data_batch(x)
            yield support_batch, query_batch

In [147]:
N_WAY   = 5
K_SHOT  = 5
N_QUERY = 5

In [271]:
graph_train_loader = FewShotDataLoader(
    dataset=train_ds,
    batch_sampler=TaskBatchSampler(
        dataset_targets=train_ds.targets(),
        n_way=N_WAY,
        k_shot=K_SHOT,
        n_query=N_QUERY,
        epoch_size=10,
        shuffle=True,
        batch_size=2
    )
)

In [272]:
support, query = next(iter(graph_train_loader))

In [273]:
support

DataBatch(x=[1124, 1], edge_index=[2, 3556], y=[50], batch=[1124], ptr=[51])

In [274]:
query

DataBatch(x=[1256, 1], edge_index=[2, 4146], y=[50], batch=[1256], ptr=[51])

## Adaptive-Step MAML

Some important configurations

In [None]:
POOLING_RATIO = 0.5
DROPOUT_RATIO = 0.3

OUTER_LR     = 0.001
INNER_LR     = 0.01
STOP_LR      = 0.0001
WEIGHT_DECAY = 1E-05

MAX_STEP      = 15
MIN_STEP      = 5
STEP_TEST     = 15
FLEXIBLE_STEP = True
STEP_PENALITY = 0.001
USE_SCORE     = True
USE_GRAD      = False
USE_LOSS      = True

# Episodes: How many tasks to run

TRAIN_SHOT         = 10   # K-shot for training set
VAL_SHOT           = 10   # K-shot for validation (or test) set
TRAIN_QUERY        = 15   # Number of query for the training set
VAL_QUERY          = 15   # Number of query for the validation (or test) set
TRAIN_WAY          = 3    # N-way for training set
TEST_WAY           = 3    # N-way for test set
VAL_EPISODE        = 200  # Number of episodes for validation
TRAIN_EPISODE      = 200  # Number of episodes for training
BATCH_PER_EPISODES = 5    # How many batch per episode
EPOCHS             = 500  # How many epochs
PATIENCE           = 35
GRAD_CLIP          = 5

In [None]:
class GCN4MAML(nn.Module):
    """ Class for a Graph Convolutional Network used in AS-MAML. """
    def __init__(self) -> None:
        super(GCN4MAML, self).__init__()

class StopControl(nn.Module):
    def __init__(self, input_size: int, hidden_size: int) -> None:
        super(StopControl, self).__init__()
        self.lstm = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
        self.output_layer = nn.Linear(hidden_size, 1)
        self.output_layer.bias.data.fill_(0.0)
        self.h_0 = nn.Parameter(torch.randn((hidden_size, ), requires_grad=True))
        self.c_0 = nn.Parameter(torch.randn((hidden_size, ), requires_grad=True))

    def forward(self, inputs, hx) -> torch.Tensor:
        if hx is None:
            hx = (self.h_0.unsqueeze(0), self.c_0.unsqueeze(0))
        
        h, c = self.lstm(inputs, hx)
        return torch.sigmoid(self.output_layer(h).unsqueeze(0)), (h, c)

class AdaptiveStepMAML(nn.Module):
    """ The Meta-Learner Class """
    def __init__(self, model, 
                       inner_lr    : float=INNER_LR, 
                       outer_lr    : float=OUTER_LR, 
                       stop_lr     : float=STOP_LR, 
                       weight_decay: float=WEIGHT_DECAY) -> None:

        self.net          = model
        self.inner_lr     = inner_lr
        self.outer_lr     = outer_lr
        self.stop_lr      = stop_lr
        self.weight_decay = weight_decay

        self.stop_prob = 0.5
        self.stop_gate = self.StopControl(2, 20)

        self.meta_optim = self.configure_optimizers()

        self.loss      = nn.BCEWithLogitsLoss()
        self.scheduler = optim.lr_scheduler.ExponentialLR(self.meta_optim,
                                                          gamma=.5,
                                                          last_epoch=-1,
                                                          verbose=True)

    def configure_optimizers(self):
        return optim.Adam([
                           {'params': self.net.parameters(),       'lr': self.outer_lr},
                           {'params': self.stop_gate.parameters(), 'lr': self.stop_lr}],
                          lr=1e-04, weight_decay=self.weight_decay
               )
        
    def compute_loss(self, logits, label) -> float:
        return self.loss(logits.squeeze(), label.double().squeeze())

    