# Model training

This notebook shows how to perfrom a full training of the PlaNet model using pytorch `lightning`. Before running it, please make sure to have `planet` installed. See the README.md file.

This consists in the following steps:
1. instantiate the model and the datamodule, as well some useful callbacks;
2. train the model, logging the training status on Weights and Biases,
3. save the model and the related data (config file and scaler) to perform inference

You can training the model as well by running the `make train` command in your terminal.

In [1]:
from pathlib import Path

from planet.config import Config
from planet.train import main_train

**Note:** to use Weights and Biases, you need a valid account and to be logged in. To log in, run the following command in your terminal

```shell
wandb login --relogin
```

## Out of the box training

A full training can be run bu wunning these 2 commands. First define the a `Config` object with all the confiruration. You can also create a `config.yml` file like [this one](../config/config.yml) and load it using `planet.utils.load_config`.

In [2]:
config = Config(
    save_path=  'trained_models/test/', #path where to save the model
    dataset_path= 'planet_sample_dataset.h5', # path to you dataset (see notebook 1_dataset_creation.ipynb)
    is_physics_informed = True, # if compute also the physics informes term in the loss funciton
    do_super_resolution= False, # to do super-resolution: Very expensive! if True, num_workers should be ~batch_size//2
    batch_size= 16, # training batch size
    epochs= 10, # training epochs
    log_to_wandb= True, # if true, logs to wand
    wandb_project= 'planet_test', # wandb project name
    save_checkpoints= True, # if true, saves checkpoint at best eval loss, keep last 2
    resume_from_checkpoint= False, #if true, the training is resumed from the last checkpoint
    num_workers= 0, # num workest in the dataloarer. If ==-1, the value is set automatically
    planet_config={ 
        'hidden_dim' : 128, 
        'nr' : 64, # this must match with the nr in your input grids
        'nz' : 64, # this must match with the nz in your input grids
        'n_measures' : 302, # this must match with the total input dimension (see notebook 1_dataset_creation.ipynb)
        }
)

Then run the `main_train` script, that will do all the points 1 -> 3 and store the model in `config.save_path`.

In [3]:
main_train(config=config)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Currently logged in as: [33mmatteob-90-hotmail-it[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin



  | Name        | Type       | Params | Mode 
---------------------------------------------------
0 | model       | PlaNetCore | 1.8 M  | train
1 | loss_module | PlaNetLoss | 0      | train
---------------------------------------------------
1.8 M     Trainable params
0         Non-trainable params
1.8 M     Total params
7.121     Total estimated model params size (MB)
81        Modules in train mode
0         Modules in eval mode


Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/Users/matte/Documents/RESEARCH/PlaNet_Equil_reconstruction/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

/Users/matte/Documents/RESEARCH/PlaNet_Equil_reconstruction/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/matte/Documents/RESEARCH/PlaNet_Equil_reconstruction/venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 9: 100%|██████████| 4/4 [00:00<00:00, 13.01it/s, v_num=6m2v, train_loss=26.40, val_loss=30.40]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 4/4 [00:00<00:00,  9.73it/s, v_num=6m2v, train_loss=26.40, val_loss=30.40]
Loading best model from checkpoint: trained_models/test/ckp/epoch=9-step=40.ckpt


## Custom training
The following cell is the content of the `main_train` function. You can edit it in any place to perform any kind of customization.

In [8]:
from pathlib import Path
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, Callback
from lightning.pytorch.loggers import WandbLogger, Logger

from planet.train import LightningPlaNet, DataModule
from planet.utils import get_accelerator, last_ckp_path, save_model_and_scaler

save_dir = Path(config.save_path)
save_dir.mkdir(exist_ok=True, parents=True)

### instantiate model and datamodule
model = LightningPlaNet(config=config)
datamodule = DataModule(config=config)

### define some callbacks
callbacks = []
if config.save_checkpoints is not None:
    callbacks.append(
        ModelCheckpoint(
            dirpath=save_dir / Path("ckp"), save_top_k=2, monitor="val_loss"
        )
    )

# get the logger
logger = None
if config.log_to_wandb:
    logger = WandbLogger(project=config.wandb_project)

### train the model
trainer = Trainer(
    max_epochs=config.epochs,
    accelerator=get_accelerator(),
    devices="auto",
    callbacks=callbacks,
    logger=logger,
)
trainer.fit(
    model=model,
    datamodule=datamodule,
    ckpt_path=(
        last_ckp_path(save_dir / Path("ckp"))
        if config.resume_from_checkpoint
        else None
    ),
)

### save model + scaler for inference
save_model_and_scaler(trainer, datamodule.dataset.scaler, config)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/matte/Documents/RESEARCH/PlaNet_Equil_reconstruction/venv/lib/python3.10/site-packages/lightning/pytorch/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/Users/matte/Documents/RESEARCH/PlaNet_Equil_reconstruction/venv/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/matte/Documents/RESEARCH/PlaNet_Equil_reconstruction/notebooks/trained_models/ckp exists and is not empty.

  | Name        | Type       | Params | Mode 
---------------------------------------------------
0 | model       | PlaNetCore | 1.8 M  | train
1 | loss_module | PlaNetLoss | 0      | train
---------------------------------------------------
1.8 M     Trainable params
0         

                                                                           

/Users/matte/Documents/RESEARCH/PlaNet_Equil_reconstruction/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/matte/Documents/RESEARCH/PlaNet_Equil_reconstruction/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
/Users/matte/Documents/RESEARCH/PlaNet_Equil_reconstruction/venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (4) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to

Epoch 9: 100%|██████████| 4/4 [00:00<00:00, 16.32it/s, v_num=vryl, train_loss=57.90, val_loss=53.90]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 4/4 [00:00<00:00, 11.63it/s, v_num=vryl, train_loss=57.90, val_loss=53.90]
