# Specify, train, and evaluate a PyTorch-based model using AbstractTorchFSMolModel

Flexible implementations of few-shot models can be evaluated according to our benchmarking procedure using the AbstractTorchFSMolModel base class

In [1]:
import os
import sys
import torch

sys.path.insert(0, os.path.join(os.getcwd(), "../"))
sys.path.insert(0, os.path.join(os.getcwd(), "../fs_mol"))


# 1a. Specifying a new model

As an example of a model implementation making use of this tooling and the new dataset we turn to `gnn_multitask.py` as an example implementation of `AbstractTorchFSMolModel`.

```python
class AbstractTorchFSMolModel(Generic[BatchFeaturesType], torch.nn.Module):
    @abstractmethod
    def forward(self, batch: BatchFeaturesType) -> Any:
        raise NotImplementedError()

    @abstractmethod
    def get_model_state(self) -> Dict[str, Any]:
        raise NotImplementedError()

    @abstractmethod
    def is_param_task_specific(self, param_name: str) -> bool:
        raise NotImplementedError()

    @abstractmethod
    def load_model_weights(
        self,
        path: str,
        load_task_specific_weights: bool,
        quiet: bool = False,
        device: Optional[torch.device] = None,
    ) -> None:
        """Load model weights from a saved checkpoint."""
        raise NotImplementedError()

    @abstractclassmethod
    def build_from_model_file(
        cls,
        model_file: str,
        config_overrides: Dict[str, Any] = {},
        quiet: bool = False,
        device: Optional[torch.device] = None,
    ) -> AbstractTorchFSMolModel[BatchFeaturesType]:
        """Build the model architecture based on a saved checkpoint."""
        raise NotImplementedError()
```

**forward**

The `GNNMultitaskModel` consists of an initial per-node linear projection layer, a shared `GNN` with a configuration specified by a `GNNConfig`, a readout layer that converts node representations to full graph embeddings, and a tail MLP that acts as a 'head' with the number of outputs determined by the number of tasks the model is being trained/evaluated on simultaneously (one output unit per task in this binary classification problem).


As discussed in the dataset notebook, the `GNNMulitaskModel` uses the `FSMolMultitaskBatch`, an implementation of `FSMolBatch` with an additional `sample_to_task_id` attribute that propagates task id through the model to allow selection on the outputs.

The output of this method should be predictions that can be used to calculate a differentiable loss when combined with graph labels from the batch.

```python
@dataclass(frozen=True)
class FSMolBatch:
    num_graphs: int
    num_nodes: int
    num_edges: int
    node_features: np.ndarray  # [V, atom_features] float
    adjacency_lists: List[
        np.ndarray
    ]  # list, len num_edge_types, elements [num edges, 2] int tensors
    edge_features: List[
        np.ndarray
    ]  # list, len num_edge_types, elements [num edges, ED] float tensors
    node_to_graph: np.ndarray  # [V] long
@dataclass(frozen=True)
class FSMolMultitaskBatch(FSMolBatch):
    sample_to_task_id: np.ndarray
```

**model state**

This returns a dictionary which is used to save the model. For `GNNMultitaskModel` an additional configuration dictionary is stored. 

**task specific parameters**

For e.g. multitask models some model parameters may be task specific, and therefore not used in later evaluation steps. For `GNNMultitaskModel` this method returns true for all Tail_MLP layers.

**load model weights**

Implements a model loading function, and in the multitask model example, the state of the optimizers at the time of saving.



# 1b. Loading a new model and dataset

As an example, we will load a sample model and dataset

In [2]:
from models.gnn_multitask import (
    GNNMultitaskConfig,
    GNNMultitaskModel,
    GNNConfig,
    create_model,
)

from data.fsmol_dataset import FSMolDataset, DataFold

We use a `FSMolDataset` (see the notebooks/dataset.ipynb)

In [3]:
fsmol_dataset = FSMolDataset.from_directory(
        directory=os.path.join(os.getcwd(), "../dataset/"),
    )

The `GNNMultitaskModel` uses a `GNNMultitaskConfig` to specify the model. Since it is based on the `GNN` module, this contains a `GNNConfig` which is accepted by the initialisation of a `GNN` module, a general purpose GNN module. 

```python

@dataclass
class GNNMultitaskConfig:
    num_tasks: int
    gnn_config: GNNConfig
    node_feature_dim: int = 32
    num_outputs: int = 1
    readout_type: str = "sum"
    readout_use_only_last_timestep: bool = False
    readout_dim: Optional[int] = None
    readout_num_heads: int = 12
    readout_head_dim: int = 64
    num_tail_layers: int = 1

@dataclass
class GNNConfig:
    type: str = "MultiHeadAttention"
    num_edge_types: int = 3
    hidden_dim: int = 128
    num_heads: int = 4
    per_head_dim: int = 32
    intermediate_dim: int = 512
    message_function_depth: int = 1
    num_layers: int = 8
    dropout_rate: float = 0.0
    use_rezero_scaling: bool = True
    make_edges_bidirectional: bool = True
```

In [4]:
# set up an output directory in which to save a model
out_dir = os.path.join(os.getcwd(), "test")
os.makedirs(out_dir, exist_ok=True)

# set up the model configuration that completely specifies a GNNMultitaskModel
model_config = GNNMultitaskConfig(
        num_tasks=fsmol_dataset.get_num_fold_tasks(DataFold.TRAIN), # task for every training task
        node_feature_dim=32, # fixed in our data preprocessing
        gnn_config=GNNConfig(
            type="PNA",
            hidden_dim=128,
            num_edge_types=3,
            num_heads=4, 
            per_head_dim=64,
            intermediate_dim=1024, # intermediate representation used in BOOM layer
            message_function_depth=1,
            num_layers=10, # number of gnn layers
        ),
        num_outputs=1,
        readout_type="combined",
        readout_use_only_last_timestep=False, # use all intermediate GNN activations in the final readout
        num_tail_layers=2,
    )

We build a model using the configuration specified.

In [5]:
# create an instance of a model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_model(model_config, device=device)

In [6]:
print(f"\tNum parameters {sum(p.numel() for p in model.parameters())}")
print(f"\tDevice: {device}")
# print(f"\tModel:\n{model}") # prints out full description of model 

	Num parameters 18690924
	Device: cuda


# 2. Training an AbstractTorchFSMolModel

A an example, the training of the model is demonstrated in `train_loop` of `abstract_torch_fsmol_model.py`. The training loop accepts:

- the model
- the components associated with optimization, that is, the optimizer and learning rate scheduler if it exists
- the training data as an iterable over minibatches
- a validation function callable, that allows validation to be performed throughout metatraining. For example, by performing adaptation to meta-validation tasks.
- a selection of other parameters including model saving directory, training patience, and logging of training values.


```python
def train_loop(
    model: AbstractTorchFSMolModel[BatchFeaturesType],
    optimizer: torch.optim.Optimizer,
    lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
    train_data: Iterable[Tuple[BatchFeaturesType, np.ndarray]],
    valid_fn: Callable[[AbstractTorchFSMolModel[BatchFeaturesType]], float],
    output_folder: str,
    metric_to_use: MetricType = "avg_precision",
    max_num_epochs: int = 100,
    patience: int = 5,
    aml_run=None,
    quiet: bool = False,
):
    if quiet:
        log_level = logging.DEBUG
    else:
        log_level = logging.INFO
    initial_valid_metric = float("-inf")
    best_valid_metric = initial_valid_metric
    logger.log(log_level, f"  Initial validation metric: {best_valid_metric:.5f}")

    save_model(os.path.join(output_folder, "best_model.pt"), model, optimizer, -1)

    epochs_since_best = 0
    for epoch in range(0, max_num_epochs):
        logger.log(log_level, f"== Epoch {epoch}")
        logger.log(log_level, f"  = Training")
        train_loss, train_metrics = run_on_data_iterable(
            model,
            data_iterable=train_data,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            quiet=quiet,
            metric_name_prefix="train_",
            aml_run=aml_run,
        )
        mean_train_metric = np.mean(
            [getattr(task_metrics, metric_to_use) for task_metrics in train_metrics.values()]
        )
        logger.log(log_level, f"  Mean train loss: {train_loss:.5f}")
        logger.log(log_level, f"  Mean train {metric_to_use}: {mean_train_metric:.5f}")
        logger.log(log_level, f"  = Validation")
        valid_metric = valid_fn(model)
        logger.log(log_level, f"  Validation metric: {valid_metric:.5f}")

        if valid_metric > best_valid_metric:
            logger.log(
                log_level,
                f"   New best validation result {valid_metric:.5f} (increased from {best_valid_metric:.5f}).",
            )
            best_valid_metric = valid_metric
            epochs_since_best = 0

            save_model(os.path.join(output_folder, "best_model.pt"), model, optimizer, epoch)
        else:
            epochs_since_best += 1
            logger.log(log_level, f"   Now had {epochs_since_best} epochs since best result.")
            if epochs_since_best >= patience:
                break

    return best_valid_metric
```

As an example ofthe creation of a suitable training batch iterable, we take the example here of a `MultitaskTaskSampleBatchIterable`, which simply uses `FSMolDataset` to build an iterable over the `FSMolMultitaskBatch` we mention above.

In [7]:
from data.multitask import MultitaskTaskSampleBatchIterable

# note we need this dictionary to connect task numbers with names
train_task_name_to_id = {
        name: i for i, name in enumerate(fsmol_dataset.get_task_names(data_fold=DataFold.TRAIN))
    }

train_data=MultitaskTaskSampleBatchIterable(
            fsmol_dataset,
            data_fold=DataFold.TRAIN,
            task_name_to_id=train_task_name_to_id,
            max_num_graphs=256,
        )

An iterable such as `train_data` is required input to `run_on_data_iterable`, which simply runs the model on the data loader. We suggest `train_loop` as a good example outline for using the `FS-Mol` dataloaders.

**Validation function**

Many validation modalities may be desired, however we give an example of validation-by-finetuning on a set of validation tasks.

The validation method is provided in this instance as a callable to the training loop. 

We discuss in more detail below the validation function as it contains many similarities to our model evaluation function `eval_model`

In [8]:
from fs_mol.multitask_train import validate_by_finetuning_on_tasks

In [9]:
from functools import partial

# create a validation callable that, when called on the model 
# using valid_fn(model), it returns a metric of performance on validation tasks (note that the FSMolDataset has been passed
# already).
valid_fn = partial(
        validate_by_finetuning_on_tasks,
        dataset=fsmol_dataset,
        learning_rate=0.00005,
        task_specific_learning_rate=0.0001,
        batch_size=256,
        metric_to_use="avg_precision",
    )

In [10]:
from models.abstract_torch_fsmol_model import (
    train_loop,
    create_optimizer,
)

# create a specific optimizer with learning rate for training
optimizer, lr_scheduler = create_optimizer(
        model,
        lr=0.00005,
        task_specific_lr=0.0001, # we allow task specific layers to adapt faster than the core GNN here
        warmup_steps=100,
        task_specific_warmup_steps=100,
    )

In [None]:
# run a training loop on the model -- it is saved as best_model.pt, any improvement on the validation metric
# results in the saving of new weights.
best_metric = train_loop(
        model=model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        train_data=train_data,
        valid_fn=valid_fn,
        output_folder=out_dir,
        max_num_epochs=1,
    )

The above example made use of a validation callable, `validate_by_finetuning_on_tasks` where a copy of the model is finetuned on a new validation task. This is performed over all validation tasks available from `fsmol_dataset` to return an aggregate metric representing the performance across all validation tasks. 

# 3. Evaluating a model -- Benchmarking Procedure



## The eval_model() function

`fs_mol.utils.test_utils.eval_model()` is a general purpose model evaluation function that allows all models to be run against the full set of testing tasks. This method's behaviour represents the standard benchmarking procedure for any few-shot model on the test tasks.

This method requires the following input:

1. A test_model_fn callable: this helper function accepts a `FSMolTaskSample`  and returns test metrics for that sample.
    - The callable should accept a task sample and operate on it with the model to return a `BinaryEvalMetric` object containing all metrics calculated from the model output and labels of the task sample.
    
```python
@dataclass(frozen=True)
class BinaryEvalMetrics:
    size: int
    acc: float
    balanced_acc: float
    f1: float
    prec: float
    recall: float
    roc_auc: float
    avg_precision: float
    kappa: float
```
See `fs_mol/baseline_test.py` for a simple example.

2. An `FSMolDataset` containing information about all tasks, from which a file reading iterable over tasks is made.

3. A task_reader_fn callable: this simply is passed to `dataset.get_task_reading_iterable()` to perform additional transformations of the data as it is read from disk, see datasets.ipynb. 

4. train_set_sample_sizes: also known as support set sample sizes. This is supplied as a list, in recognition that a model can be evaluated at multiple sizes of support set. This happens in both validation and final benchmarking evaluation. 

5. output directory: if this passed, a summary of the evaluation is written to a csv. Beware: running this during training will results in repeated overwrites. 

6. num_samples: the number of random splits to draw for the task undergoing evaluation. 

7. The data fold here is used to decide whether to shuffle the order of tasks evaluated, but this is only used in the case of the training loop.

The returned results contain a list of evaluation metrics for each task, in a dictionary indexed by task name.

```python

def eval_model(
    test_model_fn: Callable[[FSMolTaskSample, str, int], BinaryEvalMetrics],
    dataset: FSMolDataset,
    train_set_sample_sizes: List[int],
    out_dir: Optional[str] = None,
    num_samples: int = 5,
    valid_size_or_ratio: Union[int, float] = 0.0,
    test_size_or_ratio: Optional[Union[int, float, Tuple[int, int]]] = None,
    fold: DataFold = DataFold.TEST,
    task_reader_fn: Optional[Callable[[List[RichPath], int], Iterable[FSMolTask]]] = None,
    seed: int = 0,
) -> Dict[str, List[FSMolTaskSampleEvalResults]]:
    """Evaluate a model on the FSMolDataset passed.

    Args:
        test_model_fn: A callable directly evaluating the model of interest on a single task
            sample in the form of an FSMolTaskSample. The test_model_fn should act on the task
            sample with the model, using a temporary output folder and seed. All other required
            variables should be defined in the same context as the callable. The function should
            return a BinaryEvalMetrics object from the task.
        dataset: An FSMolDataset with paths to the data to be evaluated supplied.
        train_set_samples_sizes: List[int], a list of the support set sizes at which to evaluate,
            this is the train_samples size in a TaskSample.
        out_dir: final output directory for evaluation results.
        num_samples: number of repeated draws from the task's data on which to evaluate the model.
        valid_size_or_ratio: size of validation set in a TaskSample.
        test_size_or_ratio: size of the test set in a TaskSample.
        fold: the fold of FSMolDataset on which to perform evaluation, typically will be the test fold.
        task_reader_fn: Callable allowing additional transformations on the data prior to its batching
            and passing through a model.
        seed: an base external seed value. Repeated runs vary from this seed.
    """
    task_reading_kwargs = {"task_reader_fn": task_reader_fn} if task_reader_fn is not None else {}
    task_to_results: Dict[str, List[FSMolTaskSampleEvalResults]] = {}

    for task in dataset.get_task_reading_iterable(fold, **task_reading_kwargs):
        test_results: List[FSMolTaskSampleEvalResults] = []
        for train_size in train_set_sample_sizes:
            task_sampler = StratifiedTaskSampler(
                train_size_or_ratio=train_size,
                valid_size_or_ratio=valid_size_or_ratio,
                test_size_or_ratio=test_size_or_ratio,
                allow_smaller_test=True,
            )

            for run_idx in range(num_samples):
                logger.info(f"=== Evaluating on {task.name}, #train {train_size}, run {run_idx}")
                with prefix_log_msgs(
                    f" Test - Task {task.name} - Size {train_size:3d} - Run {run_idx}"
                ), tempfile.TemporaryDirectory() as temp_out_folder:
                    local_seed = seed + run_idx

                    try:
                        task_sample = task_sampler.sample(task, seed=local_seed)
                    except (
                        DatasetTooSmallException,
                        DatasetClassTooSmallException,
                        FoldTooSmallException,
                        ValueError,
                    ) as e:
                        logger.warning(
                            f"Failed to draw sample with {train_size} train points for {task.name}. Skipping."
                        )
                        logger.debug("Sampling error: " + str(e))
                        continue

                    test_metrics = test_model_fn(task_sample, temp_out_folder, local_seed)

                    test_results.append(
                        FSMolTaskSampleEvalResults(
                            task_name=task.name,
                            seed=local_seed,
                            num_train=train_size,
                            num_test=len(task_sample.test_samples),
                            fraction_pos_train=task_sample.train_pos_label_ratio,
                            fraction_pos_test=task_sample.test_pos_label_ratio,
                            **dataclasses.asdict(test_metrics),
                        )
                    )

        task_to_results[task.name] = test_results

        if out_dir is not None:
            write_csv_summary(os.path.join(out_dir, f"{task.name}_eval_results.csv"), test_results)

    return task_to_results
```

The `validate_by_finetuning_on_tasks` example of `multitask_train.py` shows a further extension of the general `test_model_fn` -- provided the method returns a `BinaryEvalMetric` for the `FSMolTaskSample` passed to it by `eval_model` it will be possible to evaluate in the manner of `FS-Mol` benchmarking. 

In this case, the evaluation method rebuilds a fresh copy of the model from file and performs finetuning using the `FSMolTaskSample` passed by `eval_model`. The train loop carefully uses the batching machinery described elsewhere to ensure the model is passed correctly batched data, and rearranges the validation metrics in to the required `BinaryEvalMetrics`. 

## Simple Example

A simple example is given in `baseline_test.py`. Here the test function builds a baseline kNN model, and evaluates the model on each `TaskSample`, returning a `BinaryEvalMetric`. The eval_model function collects the results by running over all test tasks, and repeated draws from the same task. 

In [39]:
from fs_mol.utils.metrics import BinaryEvalMetrics
from fs_mol.baseline_test import test
from fs_mol.data.fsmol_task import FSMolTask, FSMolTaskSample

# testing function for a single task with the correct signature
def test_model_fn(
        task_sample: FSMolTaskSample, temp_out_folder: str, seed: int
    ) -> BinaryEvalMetrics:
        return test(
            model_name="kNN",
            task_sample=task_sample,
            use_grid_search=False,
        )

In the interests of time for evaluation, we define an `FSMolDataset` on a single example task: 

# TODO: do this and get this bit to work (using just the baselines test fn for now)

In [40]:
def test_model_fn(
        task_sample: FSMolTaskSample, temp_out_folder: str, seed: int
    ) -> BinaryEvalMetrics:
        return eval_model_by_finetuning_on_task(
            model_weights_file,
            model_cls=GNNMultitaskModel,
            task_sample=task_sample,
            temp_out_folder=temp_out_folder,
            batcher=get_multitask_inference_batcher(max_num_graphs=args.batch_size),
            learning_rate=args.learning_rate,
            task_specific_learning_rate=args.task_specific_lr,
            metric_to_use="avg_precision",
            seed=seed,
            quiet=True,
            device=device,
        )

In [41]:
from fs_mol.utils.test_utils import eval_model

eval_model(
    test_model_fn=test_model_fn,
    dataset=data,
    train_set_sample_sizes=[16],
    out_dir=out_dir,
    fold=DataFold.TEST,
    num_samples=1,
    seed=0,
)

Process 127391 threw <class 'FileNotFoundError'> exception when trying to read files [/home/megstanley/Projects/FS-Mol/../dataset/test/CHEMBL2219045.jsonl.gz].
Process 127391 threw <class 'FileNotFoundError'> exception when trying to read files [/home/megstanley/Projects/FS-Mol/../dataset/test/CHEMBL4133035.jsonl.gz].


{}

Evaluation of a model requires that it can use eval_model in the same manner, and therefore a test_model_fn must be defined.