This notebook shows how to use `AdvSecureNet` to train a model.

In [1]:
from advsecurenet.models.model_factory import ModelFactory
from advsecurenet.datasets import DatasetFactory
from advsecurenet.dataloader import DataLoaderFactory
from advsecurenet.shared.types import DatasetType
from advsecurenet.utils.model_utils import train as train_model, test as test_model, save_model
from advsecurenet.shared.types.configs.train_config import TrainConfig
from advsecurenet.shared.types.device import DeviceType

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# We want to use resnet18 model with no pretrained weights and 10 classes for cifar10 dataset
model = ModelFactory.get_model('resnet18', pretrained=False, num_classes=10)

In [3]:
# get cifar10 dataset
dataset = DatasetFactory.load_dataset(DatasetType.CIFAR10)
train_loader = dataset.load_dataset(train=True)
test_loader = dataset.load_dataset(train=False)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# get dataloader
train_loader = DataLoaderFactory.get_dataloader(train_loader, batch_size=64, shuffle=True)
test_loader = DataLoaderFactory.get_dataloader(test_loader, batch_size=64, shuffle=False)

In [5]:
train_config =  train_config = TrainConfig(
        model=model,
        train_loader=train_loader,
        epochs=1, # 1 epoch for simplicity
        device= DeviceType.MPS # mps for apple mps, cuda for nvidia cuda
    )
train_model(train_config)

Training on mps


Epoch 1/1: 100%|██████████| 782/782 [01:24<00:00,  9.30it/s]

Epoch 1 - Average Loss: 1.380936
Training completed.





In [6]:
test_model(model, test_loader)

Testing on mps


Testing: 100%|██████████| 157/157 [00:27<00:00,  5.64batch/s]


Test set: Average loss: 0.0180, Accuracy: 6064/10000 (60.64%)





(0.018029744523763658, 60.64)

In [7]:
# You can also save the model
save_model(model= model, filename='resnet18_cifar10.pth')

In [None]:
# it's also possible to save checkpoints during training
model = ModelFactory.get_model('resnet18', pretrained=False, num_classes=10)
train_config =  train_config = TrainConfig(
        model=model,
        train_loader=train_loader,
        epochs=2, # 2 epoch for simplicity
        device= DeviceType.MPS, # mps for apple mps, cuda for nvidia cuda
        save_checkpoint=True,
        checkpoint_interval=1
    )
train_model(train_config)

In [None]:
# It's also possible to continue training from a checkpoint
model = ModelFactory.get_model('resnet18', pretrained=False, num_classes=10)
train_config =  train_config = TrainConfig(
        model=model,
        train_loader=train_loader,
        epochs=3, # 2 epoch for simplicity
        device= DeviceType.MPS, # mps for apple mps, cuda for nvidia cuda
        save_checkpoint=True,
        checkpoint_interval=1,
        load_checkpoint= True,
        load_checkpoint_path="./checkpoints/resnet18_CIFAR10_checkpoint_2.pth"
    )
train_model(train_config)