# Train DNNs with SecML-Torch

In this notebook, we will use the basic training functionalities of SecML-Torch to train a regular PyTorch Deep Neural Network (DNN) classifier.

We will train a classifier for the MNIST dataset.
First, we define the model as a `torch.nn.Module`, as usually done in the `torch` library. The model is a simple fully-connected network with three layers. Notice that this is standard PyTorch code - SecML-Torch is designed to work with existing PyTorch models without modifications.

In [1]:
%%capture --no-stderr
try:
    import secmlt
except ImportError:
   %pip install git+https://github.com/pralab/secml-torch

In [2]:
import torch


class MNISTNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(784, 200)
        self.fc2 = torch.nn.Linear(200, 200)
        self.fc3 = torch.nn.Linear(200, 10)

    def forward(self, x):
        x = x.flatten(1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


net = MNISTNet()
device = "cpu"
net = net.to(device)


We import the training and testing dataset of MNIST from `torchvision`, and provide them to the dedicated data loaders. We use a batch size of 64 and set `shuffle=False` to ensure reproducible results across runs. The `ToTensor()` transform automatically converts PIL images to tensors and scales pixel values to the [0, 1] range.

In [3]:
%%capture
import torchvision.datasets
from torch.utils.data import DataLoader

dataset_path = "data/datasets/"
training_dataset = torchvision.datasets.MNIST(
    transform=torchvision.transforms.ToTensor(),
    train=True,
    root=dataset_path,
    download=True,
)
training_data_loader = DataLoader(training_dataset, batch_size=64, shuffle=False)
test_dataset = torchvision.datasets.MNIST(
    transform=torchvision.transforms.ToTensor(),
    train=False,
    root=dataset_path,
    download=True,
)
test_data_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


Finally, we initialize the optimizer to use for training the model. We're using Adam with a learning rate of 1e-3, which is a good default choice for our neural network.

In [4]:
from torch.optim import Adam
optimizer = Adam(lr=1e-3, params=net.parameters())

Now we will start using the SecML-Torch functionalities to train the previously-defined model on the MNIST dataset just loaded.

We will use the class `secmlt.models.pytorch.base_pytorch_trainer.BasePyTorchTrainer` to prepare a training loop. 
This class implements the regular training loop which performs optimization steps (with the optimizer of choice) on a for loop on the batches of samples, for a given amount of epochs (passed as an input parameter). The trainer handles the forward pass, loss computation, backward pass, and parameter updates automatically.

We wrap the model into a `secmlt.models.pytorch.base_pytorch_nn.BasePytorchClassifier` class, which provides the APIs to use models subclassing the `torch.nn.Module` within SecML-Torch. This wrapper doesn't modify your model but adds methods that integrate with SecML-Torch's ecosystem for attacks and defenses.
Then, we can train our model by calling `model.train(dataloader=training_data_loader)`. This single line replaces the typical PyTorch training loop boilerplate.


In [5]:
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer

# Training MNIST model
trainer = BasePyTorchTrainer(optimizer=optimizer, epochs=1)
model = BasePytorchClassifier(model=net, trainer=trainer)
model.train(dataloader=training_data_loader)


MNISTNet(
  (fc1): Linear(in_features=784, out_features=200, bias=True)
  (fc2): Linear(in_features=200, out_features=200, bias=True)
  (fc3): Linear(in_features=200, out_features=10, bias=True)
)

We can check how the model performs on the testing dataset by using the `secmlt.metrics.classification.Accuracy` wrapper. 
This provides the accuracy scoring loop that queries the model with all the batches and counts how many predictions are correct. The metric automatically handles device placement and batch aggregation, returning a single accuracy value for the entire test set.

In [6]:
from secmlt.metrics.classification import Accuracy

# Test MNIST model
accuracy = Accuracy()(model, test_data_loader)
print("test accuracy: ", accuracy)

test accuracy:  tensor(0.9498)


Finally, we can save our model weights with the `torch` saving functionalities.
To get the model, we can access the `model` attribute of the `secmlt.models.pytorch.base_pytorch_nn.BasePytorchClassifier`. 

In [7]:
from pathlib import Path

model_path = Path("data/models/mnist")
if not model_path.exists():
    model_path.mkdir(parents=True, exist_ok=True)
torch.save(model.model.state_dict(), model_path / "mnist_model.pt")

After saving the `model` we can load it for further evaluations by simply using the `torch.load` method and wrap it again with `secmlt.models.pytorch.base_pytorch_nn.BasePytorchClassifier`

In [8]:
trained_net = MNISTNet()
model_weights_path = model_path / "mnist_model.pt"
model_weights = torch.load(model_weights_path, map_location="cpu")
trained_net.eval()
trained_net.load_state_dict(model_weights)
trained_model = BasePytorchClassifier(model=trained_net, trainer=trainer)