# Using PyTorch Lightning for classification

In [None]:
import os

from datasets import load_dataset
import transformers
import pandas as pd
import plotly.express as px

from module import ImageClassificationModule

transformers.utils.logging.set_verbosity_error()  # suppress image processor warning

## Loading a model checkpoint with PyTorch Lightning

It is possible to load checkpoints directly into LightningModules to either continue training, or for inference. Here, we load the model with the expectation that we will use it to predict on images.

In [None]:
checkpoints = os.listdir("checkpoints")
print(f"We will use the {checkpoints[0]} checkpoint for inference")

In [None]:
%%capture
model = ImageClassificationModule.load_from_checkpoint(f"checkpoints/{checkpoints[0]}")

## Using LightningModule's .predict_step to classify on input

We know from our `visualizing_logs_metrics_cost.ipynb` notebook that the models should produce reasonably accurate results, as each model had a final validation accuracy of around 80% (not ideal).

Below, we read in known positive sequences taken from the test dataset, and then pass that sequence to our LightningModules's `predict_step` several times to observe results:

In [None]:
labels = model.model.config.id2label
test_dataset = load_dataset("cifar100", cache_dir="data", split="test")

Let's grab our label mapping to check our prediction ID and actual label tag:

In [None]:
cifar_label_map = pd.read_csv("data/cifar_fine_label_map.csv", index_col=0)

# Classifying images

Next, let's grab a small sample of just 15 images and use the model to predict what those images might be.

In [None]:
images = test_dataset[:15]
results = []

for idx, image in enumerate(images["img"]):
    pred = model.predict_step(image)
    pred_label = cifar_label_map.iloc[pred.argmax(-1).item()].item().strip()
    truth_label = cifar_label_map.iloc[images["fine_label"][idx]].item().strip()
    truthiness = pred_label == truth_label
    results.append(truthiness)
    print(f"Our finetuned model classifies this image as: {pred_label}. The actual label is: {truth_label}. The classification is {truthiness}.")

trues = [i for i in results if i]
print(f"\nThe accuracy for this random sample is {round((len(trues) / len(results)) * 100, 4)}%")

## Conclusion

We can see that our output is below our validation accuracy; however, this is for a random sample - and not a stratified sampling based on the labels.

The overall accuracy of each training checkpoint isn't optimal - as an accuracy of around 80% isn't desirable for production.  Additional steps can be taken to improve the accuracy - such as replacing classifier layers and freezing the encoder layers. However, such tasks are outside of the scope of this work and warrant further experimentation to create a better performing model 🙂