# Training the MACGNN model

## Manual appproach

In [1]:
import torch
from deep_neuronmorpho.utils import Config
from deep_neuronmorpho.models import MACGNN
from deep_neuronmorpho.engine import ContrastiveTrainer, setup_dataloaders, setup_seed
from deep_neuronmorpho.train_contrastive import train_model

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
# load model configuration
config_file = "../config/cfg-macgnn_demo.yml"
conf = Config.from_yaml(config_file=config_file)
# set seed
setup_seed(conf.training.random_seed)
# set up dataloaders for contrastive training
dataloaders = setup_dataloaders(conf, datasets=["contra_train", "eval_train", "eval_test"])
macgnn = MACGNN(conf.model)
# create model and trainer
trainer = ContrastiveTrainer(
    model=macgnn,
    config=conf,
    dataloaders=dataloaders,
    device=device,
)

In [3]:
conf.training.dual_aug_loss

False

In [4]:
# start training
trainer.fit()

Training macgnn_demo-pink_owl on 'cpu' for 5 epochs with random_seed 42.Loss uses dual augmentation: False.
Processing batch:[7/7]100%|██████████ [00:16<00:00]
Epoch 1/5: Train Loss: 5.0748
Epoch 1/5: Benchmark Test accuracy: 0.5977
Saving checkpoint: ../datasets/demo_gnn_data/expt_results/2023_11_16_12h_27m-macgnn_demo-pink_owl/ckpts/macgnn_demo-pink_owl-epoch_0001.pt
Processing batch:[7/7]100%|██████████ [00:14<00:00]
Epoch 2/5: Train Loss: 4.5266
Epoch 2/5: Benchmark Test accuracy: 0.5865
Saving checkpoint: ../datasets/demo_gnn_data/expt_results/2023_11_16_12h_27m-macgnn_demo-pink_owl/ckpts/macgnn_demo-pink_owl-epoch_0002.pt
Processing batch:[7/7]100%|██████████ [00:14<00:00]
Epoch 3/5: Train Loss: 4.2506
Epoch 3/5: Benchmark Test accuracy: 0.6015
Saving checkpoint: ../datasets/demo_gnn_data/expt_results/2023_11_16_12h_27m-macgnn_demo-pink_owl/ckpts/macgnn_demo-pink_owl-epoch_0003.pt
Processing batch:[7/7]100%|██████████ [00:14<00:00]
Epoch 4/5: Train Loss: 4.1062


### train from checkpoint
```python
expt = "macgnn_demo-DATETIME" # replace DATETIME with the timestamp from the experiment
epoch = 3 # replace 3 with the epoch number to start from
```

In [None]:
# # train from checkpoint
# expt = "macgnn_demo-2023_08_27_15h_51m"
# epoch = 2 # replace 3 with the epoch number to start from
# ckpt_name = f"{expt}_checkpoint-epoch_{epoch:03d}.pt"
# ckpt_dir = Path(f"{conf.dirs.expt_results}/{expt}/ckpts")
# ckpt_file = ckpt_dir / ckpt_name

# trainer.fit(ckpt_file=ckpt_file)

## Using train.py script

We can do the same as above, but using the `train_model` function from the `train.py` module.  
This is a convenient wrapper over the above code, which allows us to train a model with a single function call.

In [None]:
# config_file = "../config/cfg-macgnn_demo.yml"
# # train from scratch
# train_model(config_file)

### train from checkpoint
```python
expt = "macgnn_demo-DATETIME" # replace DATETIME with the timestamp from the experiment
epoch = 3 # replace 3 with the epoch number to start from
```

In [None]:
# train from checkpoint
# expt = "macgnn_demo-DATETIME"  # replace DATETIME with the timestamp from the experiment
# epoch = 3  # replace 3 with the epoch number to start from
# ckpt_name = f"{expt}_checkpoint-epoch_{epoch:03d}.pt"

# conf = Config.from_yaml(config_file)  # only needed to get the ckpts_dir
# ckpt_dir = Path(f"{conf.dirs.expt_results}/{expt}/ckpts")
# ckpt_file = ckpt_dir / ckpt_name

# train_model(config_file, checkpoint=ckpt_file)