In [None]:
""" 
Create file experiments/config_002.yml and add:

trainer:
  logger: True
"""

In [None]:
# Pytorch Lightning already has Tensorboard integrated where we can track training
# progress and metrics in real time and later. They are saved in the folder lightning_logs/
# and Tensorboard can be opened with CTRL + SHIFT + P from VSCode.

## CSV Logger

There are other options like storing all metrics in csv files

In [None]:
from src import *
from pathlib import Path
import pytorch_lightning as pl
import yaml
import sys

config = {
    'datamodule': {
        'path': Path('dataset'),
        'batch_size': 25
    },
    'trainer': {
        'max_epochs': 10,
        'enable_checkpointing': False,
        'overfit_batches': 0
    },
    'logger': None,
}


def train(config):
    dm = MNISTDataModule(**config['datamodule'])
    module = MNISTModule(config)
    # configure logger
    if config['logger'] is not None:
        config['trainer']['logger'] = getattr(pl.loggers, config['logger'])(
            **config['logger_params'])
    trainer = pl.Trainer(**config['trainer'])
    trainer.fit(module, dm)
    trainer.save_checkpoint('final.ckpt')


if __name__ == '__main__':
    if len(sys.argv) > 1:
        config_file = sys.argv[1]
        if config_file:
            with open(config_file, 'r') as stream:
                loaded_config = yaml.safe_load(stream)
            deep_update(config, loaded_config)
    print(config)
    train(config)

In [None]:
"""
Create experiments/config_003.yml and add:

logger: CSVLogger
logger_params:
  save_dir: logs
  name: "003"
"""

In [None]:
import pandas as pd 

logs = pd.read_csv('logs/003/version_0/metrics.csv')

logs

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(6, 3))
ax = plt.subplot(1,2,1)
logs['val_loss'].dropna().plot(ax=ax)
logs['loss'].dropna().plot(ax=ax)
ax.legend(['val_loss', 'loss'])
ax.grid(True)
ax = plt.subplot(1,2,2)
logs['val_acc'].dropna().plot(ax=ax)
logs['acc'].dropna().plot(ax=ax)
ax.legend(['val_acc', 'acc'])
ax.grid(True)
plt.tight_layout()
plt.show()

## Weights and Biases

It is also possible to track trainings from different computers and people to a centralized cloud system like W&B

In [None]:
"""
Create experiments/config_004.yml and add:

logger: WandbLogger
logger_params:
  project: dlops-mnist
  name: "004"

$ pip install wandb
then you will have to insert your api key
"""

In [None]:
"""
Create experiments/config_005.yml and add:

logger: WandbLogger
logger_params:
  project: dlops-mnist
  name: "005"
callbacks:
  - name: WandBCallback
    lib: src.utils
    params:
      labels:
        - "no 3"
        - "3"
"""

In [None]:
# main.py:

from src import *
from pathlib import Path
import pytorch_lightning as pl
import yaml
import sys
import importlib


config = {
    'datamodule': {
        'path': Path('dataset'),
        'batch_size': 25
    },
    'trainer': {
        'max_epochs': 10,
        'enable_checkpointing': False,
        'overfit_batches': 0
    },
    'logger': None,
    'callbacks': None
}


def train(config):
    dm = MNISTDataModule(**config['datamodule'])
    module = MNISTModule(config)
    # configure logger
    if config['logger'] is not None:
        if config['logger'] == 'WandbLogger':
            config['trainer']['logger'] = getattr(pl.loggers, config['logger'])(
                **config['logger_params'], config=config)
        else:
            config['trainer']['logger'] = getattr(
                pl.loggers, config['logger'])(**config['logger_params'])
    # configure callbacks
    if config['callbacks'] is not None:
        callbacks = []
        for callback in config['callbacks']:
            if callback['name'] == 'WandBCallback':
                dm.setup()
                callback['params']['dl'] = dm.val_dataloader()
            cb = getattr(importlib.import_module(callback['lib']), callback['name'])(
                **callback['params'])
            callbacks.append(cb)
            config['trainer']['callbacks'] = callbacks
    # train
    trainer = pl.Trainer(**config['trainer'])
    trainer.fit(module, dm)
    trainer.save_checkpoint('final.ckpt')


if __name__ == '__main__':
    if len(sys.argv) > 1:
        config_file = sys.argv[1]
        if config_file:
            with open(config_file, 'r') as stream:
                loaded_config = yaml.safe_load(stream)
            deep_update(config, loaded_config)
    print(config)
    train(config)

## Saving models

In [None]:
"""
Create experiments/config_006.yml and add:

logger: WandbLogger
logger_params:
  project: dlops-mnist
  name: "006"
callbacks:
  - name: WandBCallback
    lib: src.utils
    params:
      labels:
        - "no 3"
        - "3"
  - name: ModelCheckpoint
    lib: pytorch_lightning.callbacks
    params:
      dirpath: checkpoints
      filename: "006"
      save_top_k: 1
      monitor: val_loss
      mode: min

trainer:
  enable_checkpointing: True
"""

In [None]:
# main.py:

from src import *
from pathlib import Path
import pytorch_lightning as pl
import yaml
import sys
import importlib


config = {
    'datamodule': {
        'path': Path('dataset'),
        'batch_size': 25
    },
    'trainer': {
        'max_epochs': 10,
        'enable_checkpointing': False,
        'overfit_batches': 0
    },
    'logger': None,
    'callbacks': None
}


def train(config):
    dm = MNISTDataModule(**config['datamodule'])
    module = MNISTModule(config)
    # configure logger
    if config['logger'] is not None:
        if config['logger'] == 'WandbLogger':
            config['trainer']['logger'] = getattr(pl.loggers, config['logger'])(
                **config['logger_params'], config=config)
        else:
            config['trainer']['logger'] = getattr(
                pl.loggers, config['logger'])(**config['logger_params'])
    # configure callbacks
    if config['callbacks'] is not None:
        callbacks = []
        for callback in config['callbacks']:
            if callback['name'] == 'WandBCallback':
                dm.setup()
                callback['params']['dl'] = dm.val_dataloader()
            elif callback['name'] == 'ModelCheckpoint':
                callback['params']['filename'] = f'{callback["params"]["filename"]}-{{val_loss:.5f}}-{{epoch}}'
            cb = getattr(importlib.import_module(callback['lib']), callback['name'])(
                **callback['params'])
            callbacks.append(cb)
            config['trainer']['callbacks'] = callbacks
    # train
    trainer = pl.Trainer(**config['trainer'])
    trainer.fit(module, dm)
    trainer.save_checkpoint('checkpoints/final.ckpt')


if __name__ == '__main__':
    if len(sys.argv) > 1:
        config_file = sys.argv[1]
        if config_file:
            with open(config_file, 'r') as stream:
                loaded_config = yaml.safe_load(stream)
            deep_update(config, loaded_config)
    print(config)
    train(config)