# Saving and loading multi-module models

Ablator is a flexible framework, and you can overwrite its functions to adapt it to your use case. In this tutorial, we will show how ablator can be customized so that we can save and load multi-module models. Saving multi-module models is helpful when you have a model that consists of multiple modules, and you want to save the entire model to a file and load it back later on. Sample use cases include encoder and decoder blocks in a transformer model, ensemble models, etc.

For this tutorial, we will create an ensemble of 3 simple 1-hidden layer neural networks, train them on the breast cancer dataset for 30 epochs, save the ensemble as a 3-module model, and load it back and train for another 30 epochs. 

```python
from ablator import ModelConfig, OptimizerConfig, TrainConfig, ParallelConfig
from ablator import ModelWrapper, ParallelTrainer
from ablator.main.configs import SearchSpace

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

import shutil
import os
```

## Preparing the data

```python
class BreastCancerDataset(Dataset):
    def __init__(self, data, targets):
        self.data = data
        self.scaler = MinMaxScaler()
        self.data = self.scaler.fit_transform(self.data)
        self.targets = targets
        

    def __getitem__(self, index):
        x = self.data[index]
        y = self.targets[index]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.data)

# Load dataset from scikit-learn
breast_cancer = load_breast_cancer()
data = breast_cancer.data
targets = breast_cancer.target

# Split the data into train and test sets
train_data, test_data, train_targets, test_targets = train_test_split(data, targets, test_size=0.2, random_state=42)

# Create train and test datasets
train_dataset = BreastCancerDataset(train_data, train_targets)
test_dataset = BreastCancerDataset(test_data, test_targets)
```

## Build the ensemble model

### Simple 1-hidden layer neural network module

```python
class NNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(30, 50)
        self.fc2 = nn.Linear(50, 2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return x
```

### Assemble the ensemble

Mention how the neural networks are ensembled !!!!!!!!!!

```python
class MyEnsemble(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__()
        self.nnet1 = NNet()
        self.nnet2 = NNet()
        self.nnet3 = NNet()
    
    def forward(self, x, labels=None):
        x1 = self.nnet1(x)        
        x2 = self.nnet2(x)        
        x3 = self.nnet3(x)

        ensemble = x1+x2+x3
        ensemble = F.softmax(ensemble, dim=1)

        loss = F.cross_entropy(ensemble, labels)
        preds = torch.argmax(ensemble, dim=1)
        return {"preds": preds, "labels": labels}, loss
```

## Define configuration
Now it's time we set up the ablation experiment by defining needed configurations.

### Model configuration
Since we're not ablating the model architecture, no custom model configuration is needed.

```python
model_config = ModelConfig()
```

### Optimizer configuration
We will use adam optimizer, with the learning rate initialized to 0.01

```python
optimizer_config = OptimizerConfig(
    name="adam",
    arguments={"lr": 0.001}
)
```

We will also define a search space for different learning rate values:

```python
search_space = {
    "train_config.optimizer_config.arguments.lr": SearchSpace(
        value_range = [0.001, 0.01],
        value_type = 'float'
    )
}
```

### Training configuration

```python
train_config = TrainConfig(
    dataset="breast-cancer",
    batch_size=32,
    epochs=30,
    optimizer_config=optimizer_config,
    scheduler_config=None,
    rand_weights_init = True
)
```

### Running configuration (parallel config)
Combine model configuration, train configuration, and search space into one

```python
run_config = ParallelConfig(
    train_config=train_config,
    model_config=model_config,
    metrics_n_batches = 800,
    experiment_dir = "/tmp/experiments/",
    device="cuda",
    amp=True,
    random_seed = 42,
    total_trials = 5,
    concurrent_trials = 3,
    search_space = search_space,
    optim_metrics = {"val_loss": "min"},
    gpu_mb_per_experiment = 1024,
    cpus_per_experiment = 1
)
```

## Model wrapper
Other than wrapping the dataloaders and evaluation functions to the wrapper, this is where we modify the model saving and loading functions.

### Multi-module model saving
We will overwrite `save_dict()` function to save the entire model as a dictionary of modules.
```python
    def save_dict(self):
        saved_dict = super().save_dict()
        model_state_dict = {
            "nnet1": self.model.nnet1.state_dict(),
            "nnet2": self.model.nnet2.state_dict(),
            "nnet3": self.model.nnet3.state_dict(),
        }
        saved_dict["model"] = model_state_dict
        
        return saved_dict
```
Originally, ablator framework will save the model as a whole, i.e., `saved_dict["model"] = self.model.state_dict()`.

In our example, as you can see, modules `nnet1`, `nnet2`, and `nnet3` from `MyEnsemble` model can be accessed via `self.model.nnet1`, `self.model.nnet2`, and `self.model.nnet3` respectively, and we will save these modules' state dictionaries into `saved_dict["model"]`.

This way, the model saved will be a dictionary of modules:
```
saved_dict = {
    "model": {
        "nnet1": {"fc1.weights": weights, "fc1.bias": bias},
        "nnet2": {"fc2.weights": weights, "fc2.bias": bias},
        "nnet3": {"fc3.weights": weights, "fc3.bias": bias},
    },
    ...
}
```
After running the experiment, you can use `torch.load(<path_to_checkpoint>)` to verify this, where `<path_to_checkpoint>` is the path to one of the models that are saved in the experiment directory.

### Multi-module model loading
Now that we have saved a multi-module model, we also need to change how the ablator loads the model. This will be done by overwriting `create_model()` function.
```python
    def create_model(
        self,
        save_dict: dict[str, ty.Any] | None = None,
        strict_load: bool = True,
    ) -> None:
        if save_dict is not None:
            nd_save_dict = {}
            for nnet in save_dict["model"]:
                for key in save_dict["model"][nnet]:
                    new_key = nnet + "." + key
                    nd_save_dict[new_key] = save_dict["model"][nnet][key]
            save_dict["model"] = nd_save_dict
        super().create_model(save_dict=save_dict, strict_load=True)
```
Originally, ablator framework will load the model as a whole, i.e., `model.load_state_dict(save_dict["model"])`.

So in our example, as you can see, the keys will be updated to `nnet1.fc1.weights`, `nnet1.fc1.bias`, `nnet2.fc2.weights`, `nnet2.fc2.bias`, `nnet3.fc3.weights`, and `nnet3.fc3.bias`. So that when we use super call, the model will be loaded correctly.

Below is the complete script for the model wrapper:

```python
class MyEnsembleWrapper(ModelWrapper):
    def make_dataloader_train(self, run_config: ParallelConfig):
        return DataLoader(train_dataset, batch_size=run_config.train_config.batch_size, shuffle=True)

    def make_dataloader_val(self, run_config: ParallelConfig):
        return DataLoader(test_dataset, batch_size=run_config.train_config.batch_size, shuffle=False)
    
    def save_dict(self):
        saved_dict = super().save_dict()
        model_state_dict = {
            "nnet1": self.model.nnet1.state_dict(),
            "nnet2": self.model.nnet2.state_dict(),
            "nnet3": self.model.nnet3.state_dict(),
            }
        saved_dict["model"] = model_state_dict

        return saved_dict

    def create_model(self, save_dict=None, strict_load=True):
        if save_dict is not None:
            nd_save_dict = {}
            for nnet in save_dict["model"]:
                for key in save_dict["model"][nnet]:
                    new_key = nnet + "." + key
                    nd_save_dict[new_key] = save_dict["model"][nnet][key]
            save_dict["model"] = nd_save_dict
        super().create_model(save_dict=save_dict, strict_load=True)
```

### Custom evaluation (Optional)
We will use accuracy and f1 as evaluation metrics

```python
from sklearn.metrics import f1_score, accuracy_score

def my_accuracy(preds, labels):
    return accuracy_score(preds.flatten(), labels.flatten())

def my_f1_score(preds, labels):
    return f1_score(preds.flatten(), labels.flatten(), average='weighted')
```

## Launch the ablation experiment
Everything is ready, now we can launch the ablation experiment.

```python
if not os.path.exists(run_config.experiment_dir):
    shutil.os.mkdir(run_config.experiment_dir)

shutil.rmtree(run_config.experiment_dir)

wrapper = MyEnsembleWrapper(
    model_class=MyEnsemble,
)

ablator = ParallelTrainer(
    wrapper=wrapper,
    run_config=run_config,
)
metrics = ablator.launch(working_directory = os.getcwd(), ray_head_address="auto")
```

After the experiment halts and checkpoints are saved, rerun ablation experiment. But this time, specify `init_chkpt` parameter in the running config to load the model from the checkpoint we saved earlier.

```python
if not os.path.exists(run_config.experiment_dir):
    shutil.os.mkdir(run_config.experiment_dir)

shutil.rmtree(run_config.experiment_dir)

run_config = ParallelConfig(
    train_config=train_config,
    model_config=model_config,
    metrics_n_batches = 800,
    experiment_dir = "/tmp/experiments/",
    device="cuda",
    amp=True,
    random_seed = 42,
    total_trials = 5,
    concurrent_trials = 3,
    search_space = search_space,
    optim_metrics = {"val_loss": "min"},
    gpu_mb_per_experiment = 1024,
    cpus_per_experiment = 1,
    init_chkpt="/tmp/experiments1/experiment_7ae3_9991/2ca5_9991/best_checkpoints/MyEnsemble_0000000210.pt"
)

wrapper = MyEnsembleWrapper(
    model_class=MyEnsemble,
)

ablator = ParallelTrainer(
    wrapper=wrapper,
    run_config=run_config,
)
metrics = ablator.launch(working_directory = os.getcwd(), ray_head_address="auto")
```

The experiment will load the model from the checkpoint and continue training. And that's it, this is an example that shows how customizable ablator is, so that you can customize it to fit your needs. 