# Training a Hugging Face image classifier with PyTorch Lightning

Running the following cells will train the model using the model and Trainer flags that are shown.

In [None]:
from datetime import datetime

import torch

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.profilers import PyTorchProfiler

from datamodule import AutoImageProcessorDataModule
from module import ImageClassificationModule
from utils import create_dirs
from config import Config, DataModuleConfig, ModuleConfig

First, let's configure our settings

In [None]:
# model and dataset
model_name = ModuleConfig.model_name
batch_size = 16
lr = 5e-05
dataset_name = DataModuleConfig.dataset_name

# paths
cache_dir = Config.cache_dir
log_dir = Config.log_dir
ckpt_dir = Config.ckpt_dir
prof_dir = Config.prof_dir
perf_dir = Config.perf_dir
# creates dirs to avoid failure if empty dir has been deleted
create_dirs([cache_dir, log_dir, ckpt_dir, prof_dir, perf_dir])

# set matmul precision
# see https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
torch.set_float32_matmul_precision("medium")

Now, we can define our LightningDataModule, which will be used by Trainer for its DataLoaders

In [None]:
lit_datamodule = AutoImageProcessorDataModule(
    model_name=model_name,
    dataset_name=dataset_name,
    cache_dir=cache_dir,
    batch_size=batch_size,
)

Here's our [custom LightningModule](module.py) equipped with ResNet.

LightningModules are the second most important feature of PyTorch Lightning's Core API after Trainer, as this is the class that Trainer interacts with to train the model. <br>
Be sure to check out the code contained in the [module.py](module.py) to gain an understanding of how LightningModule is used 🙂

In [None]:
lit_model = ImageClassificationModule(learning_rate=lr)

next - we are going to define some common callbacks, and our most basic logger - CSVLogger.

EarlyStopping callback helps us to end training early if a convergence criteria is met before the max-iteration setting is reached.

ModelCheckpoint saves the model periodically, and after training finishes, uses best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

In [None]:
callbacks = [
    EarlyStopping(monitor="val-acc", mode="min"),
    ModelCheckpoint(
        dirpath=ckpt_dir,
        filename="model",
    ),
]

In [None]:
logger = CSVLogger(
    save_dir=log_dir,
    name="csv-logs",
)

Finally – we create our Trainer and pass in our flags (settings), the callbacks and loggers.  Then we call fit!

In [None]:
lit_trainer = pl.Trainer(
    accelerator="auto",
    devices="auto",
    strategy="auto",
    precision="16-mixed",
    max_epochs=10,
    deterministic=True,
    logger=logger,
    callbacks=callbacks,
)

In [None]:
lit_trainer.fit(model=lit_model, datamodule=lit_datamodule)