Example
=======

As a first example, we use an image dataset and a pre-trained model to classify the images. We use the `ReVel` framework to load the dataset and the model, and to perform the classification. We also the `procedures` module to help us with the classification process.

In [5]:
import torch
from torch.utils.data import random_split
import torch.nn.functional as F
from torch.utils.data import DataLoader
from ReVel.perturbations import get_perturbation
from ReVel.load_data import load_data
from SHIELD import SHIELD
from SHIELD.procedures import procedures

device = "cuda" if torch.cuda.is_available() else "cpu"
# Load the model
# Download the dataset Flowers and change the last layer to fit the number of classes
classifier = procedures.classifier("efficientnet-b2", num_classes=102)
perturbation = get_perturbation(name="square",dim=9,num_classes= 102,
    final_size=(224, 224),kernel=150.0,max_dist=20,ratio=0.5)

train_set = load_data("Flowers", perturbation=perturbation, train=True, dir="./data/")
test_set = load_data("Flowers", perturbation=perturbation, train=False, dir="./data/")
classifier.to(device)

regularization = "SHIELD" # "SHIELD" or "Baseline"

Train, Val = random_split(
    train_set, [int(len(train_set) * 0.9), len(train_set) - int(len(train_set) * 0.9)]
)
TrainLoader = DataLoader(Train, batch_size=32, shuffle=True)
ValLoader = DataLoader(Val, batch_size=32, shuffle=False)

def loss_f(y_pred,y_label):
    return F.cross_entropy(y_pred,torch.argmax(y_label,dim=1))
optimizer = torch.optim.AdamW(classifier.parameters(), lr=0.001,
    weight_decay=0.01, amsgrad=True)
epochs = 5 # Change the number of epochs in case you need more
best_loss = torch.tensor(float("inf"))

Training and validation phase
=============================

In [None]:
for epoch in range(epochs):
    print(f"Epoch :{epoch+1}, {(epoch+1)/epochs*100:.2f}%")
    train_loss, train_acc, train_reg = procedures.train_step(
        ds_loader=TrainLoader,
        model=classifier,
        optimizer=optimizer,
        loss_f=loss_f,
        reg_f=lambda x, y: (SHIELD.shield(model=x, input=y, percentage=5, device=device)),
        device=device,
    )
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}, Train Regularization: {train_reg:.4f}")
    val_loss, val_acc, val_reg = procedures.validation_step(
        ds_loader=ValLoader,
        model=classifier,
        loss_f=loss_f,
        reg_f=lambda x, y: (SHIELD.shield(model=x, input=y, percentage=5, device=device)),
        device=device,
    )
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}, Validation Regularization: {val_reg:.4f}")
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save(classifier.state_dict(), "./model.pth")

Testing phase
=============

In [None]:
classifier.load_state_dict(torch.load("./model.pth"))
test = DataLoader(test_set, batch_size=32, shuffle=False)
test_loss, test_acc, test_reg = procedures.validation_step(
    ds_loader=test,
    model=classifier,
    loss_f=loss_f,
    reg_f=lambda x, y: (SHIELD.shield(model=x, input=y, percentage=5, device=device)),
    device=device,
)
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.4f}, Test Regularization: {test_reg:.4f}")