# Fed-BioMed Researcher - Saving and Loading breakpoints

This example uses `MNIST` dataset deployed on the nodes. Please see `README` in the notebooks directory for the instructions to load `MNSIT` dataset. 

## Create an experiment to train a model on the data found

Declare a torch training plan MyTrainingPlan class to send for training on the node

In [None]:
import torch
import torch.nn as nn
from fedbiomed.common.training_plans import TorchTrainingPlan
from fedbiomed.common.data import DataManager
from torchvision import datasets, transforms

# Here we define the training plan. 
class MyTrainingPlan(TorchTrainingPlan):
    
    # Defines and return model 
    def init_model(self, model_args):
        return self.Net(model_args = model_args)
    
    # Defines and return optimizer
    def init_optimizer(self, optimizer_args):
        return torch.optim.Adam(self.model().parameters(), lr = optimizer_args["lr"])
    
    # Declares and return dependencies
    def init_dependencies(self):
        deps = ["from torchvision import datasets, transforms"]
        return deps
    
    class Net(nn.Module):
        def __init__(self, model_args):
            super().__init__()
            self.conv1 = nn.Conv2d(1, 32, 3, 1)
            self.conv2 = nn.Conv2d(32, 64, 3, 1)
            self.dropout1 = nn.Dropout(0.25)
            self.dropout2 = nn.Dropout(0.5)
            self.fc1 = nn.Linear(9216, 128)
            self.fc2 = nn.Linear(128, 10)

        def forward(self, x):
            x = self.conv1(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            x = self.dropout1(x)
            x = torch.flatten(x, 1)
            x = self.fc1(x)
            x = F.relu(x)
            x = self.dropout2(x)
            x = self.fc2(x)


            output = F.log_softmax(x, dim=1)
            return output

    def training_data(self):
        # Custom torch Dataloader for MNIST data
        transform = transforms.Compose([transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))])
        dataset1 = datasets.MNIST(self.dataset_path, train=True, download=False, transform=transform)
        train_kwargs = { 'shuffle': True}
        return DataManager(dataset=dataset1, **train_kwargs)
    
    def training_step(self, data, target):
        output = self.model().forward(data)
        loss   = torch.nn.functional.nll_loss(output, target)
        return loss


In [None]:
model_args = {}

training_args = {
    'loader_args': { 'batch_size': 48, }, 
    'optimizer_args': {
        "lr" : 1e-3
    },
    'epochs': 1, 
    'dry_run': False,  
    'batch_maxnum': 100 # Fast pass for development : only use ( batch_maxnum * batch_size ) samples
}

Let's call ${FEDBIOMED_DIR} the base directory where you cloned Fed-BioMed.
Breakpoints will be saved under `Experiment_xxxx` folder at `${FEDBIOMED_DIR}/var/experiments/Experiment_xxxx/breakpoints_yyyy` (by default).

In [None]:
from fedbiomed.researcher.federated_workflows import Experiment
from fedbiomed.researcher.aggregators.fedavg import FedAverage

tags =  ['#MNIST', '#dataset']
rounds = 2

exp = Experiment(tags=tags,
                 model_args=model_args,
                 training_plan_class=MyTrainingPlan,
                 training_args=training_args,
                 round_limit=rounds,
                 aggregator=FedAverage(),
                 node_selection_strategy=None,
                 save_breakpoints=True)

You can interrupt the `exp.run()` after one round, and then reload the breakpoint and continue the training.

In [None]:
exp.run()

Save trained model to file

In [None]:
exp.training_plan().export_model('./trained_model')

## Delete experiment

Here we simulate the removing of the ongoing experiment
fret not! we have saved breakpoint, so we can retrieve parameters
of the experiment using `load_breakpoint` method

In [None]:
del exp

## Resume an experiment

While experiment is running, you can shut it down (after the first round) and resume the experiment from the next cell. Or wait for the experiment completion.


**To load the latest breakpoint of the latest experiment**

Run :
`Experiment.load_breakpoint()`. It reloads latest breakpoint, and will bypass `search` method

and then use `.run` method as you would do with an existing experiment.

**To load a specific breakpoint** specify breakpoint folder.

- absolute path: use `Experiment.load_breakpoint(f"{config.root}/var/experiments/Experiment_xxxx/breakpoint_yyyy")`. Replace `xxxx` and `yyyy` by the real values.
- relative path from a notebook: a notebook is running from the `fbm-researcher/notebooks` directory
so use `Experiment.load_breakpoint("../../fbm-researcher/var/experiments/Experiment_xxxx/breakpoint_yyyy)`. Replace `xxxx` and `yyyy` by the real values.
- relative path from a script: if launching the script from the
  default `Researcher` directory (eg: `python fbm-researcher/notebooks/general-breakpoint-save-resume.py`) then use a path relative to the current directory eg: `Experiment.load_breakpoint(f"../../fbm-researcher/var/experiments/Experiment_xxxx/breakpoint_yyyy")`

In [None]:
from fedbiomed.researcher.federated_workflows import Experiment

loaded_exp = Experiment.load_breakpoint()

In [None]:
# one can also use relative path
from fedbiomed.researcher.config import config
Experiment.load_breakpoint(f"../../fbm-researcher/var/experiments/Experiment_0000/breakpoint_0001")

In [None]:
print(f'Experimentation folder: {loaded_exp.experimentation_folder()}')
print(f'Loaded experiment path: {loaded_exp.experimentation_path()}')

Continue training for the experiment loaded from breakpoint. If you ran all the rounds and load the last breakpoint, there won't be any more round to run.

In [None]:
loaded_exp.run(rounds=3, increase=True)

Save trained model to file

In [None]:
loaded_exp.training_plan().export_model('./trained_model')

In [None]:
exp=loaded_exp
print("______________ loaded training replies_________________")
print("\nList the training rounds : ", exp.training_replies().keys())

print("\nList the nodes for the last training round and their timings : ")
round_data = exp.training_replies()[rounds - 1]
for r in round_data.values():
    print("\t- {id} :\
    \n\t\trtime_training={rtraining:.2f} seconds\
    \n\t\tptime_training={ptraining:.2f} seconds\
    \n\t\trtime_total={rtotal:.2f} seconds".format(id = r['node_id'],
        rtraining = r['timing']['rtime_training'],
        ptraining = r['timing']['ptime_training'],
        rtotal = r['timing']['rtime_total']))
print('\n')

Federated parameters for each round are available via `exp.aggregated_params()` (index 0 to (`rounds` - 1) ).
For example you can view the federated parameters for the last round of the experiment :

In [None]:
print("\nList the training rounds : ", loaded_exp.aggregated_params().keys())

print("\nAccess the federated params for training rounds : ")
for round in loaded_exp.aggregated_params().keys():
  print("round {r}".format(r=round))
  print("\t- parameter data: ", loaded_exp.aggregated_params()[round]['params'].keys())
