# Adversarial training (Hugging Face)

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
import torch
from lightning.pytorch import Trainer

from adv_utils import (
    Flowers102DataModule,
    AdversarialHFClassifier
)

## Load data

In [None]:
# load data
flowers = Flowers102DataModule(
    data_dir='../run/data/',
    mean=0.5,
    std=0.5,
    batch_size=32
)

flowers.prepare_data() # download data if not yet done
flowers.setup(stage='test') # create test set

In [None]:
# get batch
test_loader = flowers.test_dataloader()
x_batch, y_batch = next(iter(test_loader))

In [None]:
# show example images
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(5, 6))
for idx, ax in enumerate(axes.ravel()):
    image = flowers.renormalize(x_batch[idx]).permute(1, 2, 0).numpy()
    label = y_batch[idx].item()
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[], xlabel='', ylabel='')
fig.tight_layout()

## Load models

In [None]:
# set device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# load standard model
ckpt_file = '../run/train_std/version_0/checkpoints/last.ckpt'

std_model = AdversarialHFClassifier.load_from_checkpoint(ckpt_file)

std_model = std_model.eval()
std_model = std_model.to(device)

In [None]:
# load adversarially trained model
ckpt_file = '../run/train_adv/version_0/checkpoints/last.ckpt'

adv_model = AdversarialHFClassifier.load_from_checkpoint(ckpt_file)

adv_model = adv_model.eval()
adv_model = adv_model.to(device)

## Test models

In [None]:
# create trainer
trainer = Trainer(logger=False, accelerator='auto')

In [None]:
# test standard model
std_metrics = trainer.test(
    model=std_model,
    dataloaders=test_loader,
    verbose=True
)

print(std_metrics)

In [None]:
# test adversarially trained model
adv_metrics = trainer.test(
    model=adv_model,
    dataloaders=test_loader,
    verbose=True
)

print(adv_metrics)