# Specify, train, and evaluate a PyTorch-based model using AbstractTorchFSMolModel

Flexible implementations of few-shot models can be evaluated according to our benchmarking procedure using the AbstractTorchFSMolModel base class

In [None]:
import os
import sys
import torch

sys.path.insert(0, os.path.join(os.getcwd(), "../"))
sys.path.insert(0, os.path.join(os.getcwd(), "../fs_mol"))

from models import AbstractTorchFSMolModel

# 1a. Specifying a new model

As an example of a model implementation making use of this tooling and the new dataset we turn to `gnn_multitask.py`. This contains all of the required methods: `forward`, `get_model_state`, `is_param_task_specific`, `load_model_weights`, `build_from_model_file`. 

**forward**

The `GNNMultitaskModel` consists of an initial per-node linear projection layer, a shared `GNN` with a configuration specified by a `GNNConfig`, a readout layer that converts node representations to full graph embeddings, and a tail MLP that acts as a 'head' with the number of outputs determined by the number of tasks the model is being trained/evaluated on simultaneously (one output unit per task in this binary classification problem).

The `GNNMulitaskModel` uses the `FSMolMultitaskBatch`, an implementation of `FSMolBatch` with an additional `sample_to_task_id` attribute that permits graph predictions to be used if they correspond to the correct task.

The output of this method should be predictions that can be used to calculate a differentiable loss when combined with graph labels from the batch.

**model state**

This returns a dictionary which is used to save the model. For `GNNMultitaskModel` an additional configuration dictionary is stored. 

**task specific parameters**

For eg. multitask models some model parameters may be task specific, and therefore not used in later evaluation steps. This method returns true for all Tail_MLP layers.

**load model weights**

Implements a model loading function, and in the multitask model example, the state of the optimizers at the time of saving.



# 1b. Loading a new model and dataset

As an example, we will load a sample model and dataset

In [None]:
from models.gnn_multitask import (
    GNNMultitaskConfig,
    GNNMultitaskModel,
    GNNConfig,
    create_model,
)

from data.fsmol_dataset import FSMolDataset, DataFold

In [None]:
# grab a full FSMolDataset (see the notebooks/dataset.ipynb)
fsmol_dataset = FSMolDataset.from_directory(
        directory=os.path.join(os.getcwd(), "../dataset/"),
    )

In [None]:
# set up an output directory in which to save a model
out_dir = os.path.join(os.getcwd(), "test")
os.makedirs(out_dir, exist_ok=True)

# set up the model configuration that completely specifies a GNNMultitaskModel
model_config = GNNMultitaskConfig(
        num_tasks=fsmol_dataset.get_num_fold_tasks(DataFold.TRAIN), # task for every training task
        node_feature_dim=32, # fixed in our data preprocessing
        gnn_config=GNNConfig(
            type="PNA",
            hidden_dim=128,
            num_edge_types=3,
            num_heads=4, 
            per_head_dim=64,
            intermediate_dim=1024, # intermediate representation used in BOOM layer
            message_function_depth=1,
            num_layers=10, # number of gnn layers
        ),
        num_outputs=1,
        readout_type="combined",
        readout_use_only_last_timestep=False, # use all intermediate GNN activations in the final readout
        num_tail_layers=2,
    )

In [None]:
# create an instance of a model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_model(model_config, device=device)

In [None]:
print(f"\tNum parameters {sum(p.numel() for p in model.parameters())}")
print(f"\tDevice: {device}")
# print(f"\tModel:\n{model}") # prints out full description of model 

# 2. Training an AbstractTorchFSMolModel

The training of the model is demonstrated in `train_loop` of `abstract_torch_fsmol_model.py`. The training loop accepts:

- the model
- the components associated with optimization, that is, the optimizer and learning rate scheduler if it exists
- the training data as an iterable over minibatches
- a validation function callable, that allows validation to be performed throughout metatraining. For example, by performing adaptation to meta-validation tasks.
- a selection of other parameters including model saving directory, training patience, and logging of training values.

To demonstrate the creation of a suitable training batch iterable, we take the example here of a `MultitaskTaskSampleBatchIterable`, which simply uses `FSMolDataset` to build an iterable over the `FSMolMultitaskBatch` we mention above.

In [None]:
from data.multitask import MultitaskTaskSampleBatchIterable

# note we need this dictionary to connect task numbers with names
train_task_name_to_id = {
        name: i for i, name in enumerate(fsmol_dataset.get_task_names(data_fold=DataFold.TRAIN))
    }

train_data=MultitaskTaskSampleBatchIterable(
            fsmol_dataset,
            data_fold=DataFold.TRAIN,
            task_name_to_id=train_task_name_to_id,
            max_num_graphs=256,
        )

An iterable such as `train_data` is required input to `run_on_data_iterable`, which simply runs the model on the data loader. We suggest `train_loop` as a good example outline for using the `FS-Mol` dataloaders.

**Validation function**

Many validation modalities may be desired, however we give an example of validation-by-finetuning on a set of validation tasks.

The validation method is provided in this instance as a callable to the training loop. 

We discuss in more detail below the validation function as it contains many similarities to our model evaluation.

In [None]:
from fs_mol.multitask_train import validate_by_finetuning_on_tasks

In [None]:
from functools import partial

# create a validation callable that, when called on the model 
# using valid_fn(model), it returns a metric of performance on validation tasks (note that the FSMolDataset has been passed
# already).
valid_fn = partial(
        validate_by_finetuning_on_tasks,
        dataset=fsmol_dataset,
        learning_rate=0.00005,
        task_specific_learning_rate=0.0001,
        batch_size=256,
        metric_to_use="avg_precision",
    )

In [None]:
from models.abstract_torch_fsmol_model import (
    train_loop,
    create_optimizer,
)

# create a specific optimizer with learning rate for training
optimizer, lr_scheduler = create_optimizer(
        model,
        lr=0.00005,
        task_specific_lr=0.0001, # we allow task specific layers to adapt faster than the core GNN here
        warmup_steps=100,
        task_specific_warmup_steps=100,
    )

In [None]:
# run a training loop on the model -- it is saved as best_model.pt, any improvement on the validation metric
# results in the saving of new weights.
best_metric = train_loop(
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_data=train_data,
        valid_fn=valid_fn,
        output_folder=out_dir,
        max_num_epochs=1,
    )

The above example made use of a validation callable, `validate_by_finetuning_on_tasks` where a copy of the model is finetuned on a new validation task. This is performed over all validation tasks available from `fsmol_dataset` to return an aggregate metric representing the performance across all validation tasks. 

# 3. Evaluating a model



## The eval_model() function

`fs_mol.utils.test_utils.eval_model()` is a general purpose model evaluation function that allows all models to be run against the full set of testing tasks. This method requires the following input:

1. A test_model_fn callable: this helper function accepts a `FSMolTaskSample` (that is, a sample of datapoints from a single task, which will be created by a `TaskSampler`, this is a list of `MoleculeDatapoints` -- see notebooks/dataset.ipynb).
    - The callable should accept a task sample and operate on it with the model to return a BinaryEvalMetric object containing all metrics calculated from the model output and labels of the task sample.
    - It may also except other arguments which can be defined without being passed directly to eval_model. See `fs_mol/baseline_test.py` for a simple example.

2. An `FSMolDataset` containing information about all tasks, from which a file reading iterable over tasks is made, upon which the TaskSampler acts to produce the necessary samples. If a single task only is being evaluated, for example if it is passed from the command line, the FSMolDataset will contain only that single task. 

3. A task_reader_fn callable: this simply is passed to `dataset.get_task_reading_iterable()` to perform additional transformations of the data as it is read from disk. The default case simple loads the data in to `FSMolTask` objects with no further changes. However, the user may define alternative readers.

4. train_set_sample_sizes: also known as support set sample sizes. This is supplied as a list, in recognition that a model can be evaluated at multiple sizes of support set. This happens in both validation and final benchmarking evaluation. 

5. output directory: if this passed, a summary of the evaluation is written to a csv. Beware: running this during training will results in repeated overwrites. 

6. num_samples: the number of random splits to draw for the task undergoing evaluation. 

7. The data fold here is used to decide whether to shuffle the order of tasks evaluated, but this is only used in the case of the training loop.

The returned results contain a list of evaluation metrics for each task, in a dictionary indexed by task name.

The `validate_by_finetuning_on_tasks` example of `multitask_train.py` shows a further extension of the general `test_model_fn` -- provided the method returns a `BinaryEvalMetric` for the `FSMolTaskSample` passed to it by `eval_model` it will be possible to evaluate in the manner of `FS-Mol` benchmarking. 

In this case, the evaluation method rebuilds a fresh copy of the model from file and performs finetuning using the `FSMolTaskSample` passed by `eval_model`. The train loop carefully uses the batching machinery described elsewhere to ensure the model is passed correctly batched data, and rearranges the validation metrics in to the required `BinaryEvalMetrics`. 

## Simple Example

A simple example is given in `baseline_test.py`. Here the test function builds a baseline kNN model, and evaluates the model on each `TaskSample`, returning a `BinaryEvalMetric`. The eval_model function collects the results by running over all test tasks, and repeated draws from the same task. 

In [None]:
from fs_mol.utils.metrics import BinaryEvalMetrics
from fs_mol.baseline_test import test
from fs_mol.data.fsmol_task import FSMolTask, FSMolTaskSample

# testing function for a single task with the correct signature
def test_model_fn(
        task_sample: FSMolTaskSample, temp_out_folder: str, seed: int
    ) -> BinaryEvalMetrics:
        return test(
            model_name="kNN",
            task_sample=task_sample,
            use_grid_search=False,
        )

In [None]:
from fs_mol.utils.test_utils import eval_model

eval_model(
    test_model_fn=test_model_fn,
    dataset=fsmol_dataset,
    train_set_sample_sizes=[16],
    out_dir=out_dir,
    fold=DataFold.TEST,
    num_samples=1,
    seed=0,
)

Evaluation of a model requires that it can use eval_model in the same manner, and therefore a test_model_fn must be defined.