This notebook shows how to use `AdvSecureNet` to train a model.

In [None]:
from advsecurenet.models.model_factory import ModelFactory
from advsecurenet.datasets import DatasetFactory
from advsecurenet.dataloader import DataLoaderFactory
from advsecurenet.shared.types import DatasetType
from advsecurenet.utils.trainer import Trainer
from advsecurenet.shared.types.configs.train_config import TrainConfig
from advsecurenet.utils.model_utils import save_model
from advsecurenet.utils.tester import Tester

In [None]:
# We want to use resnet18 model with no pretrained weights and 10 classes for cifar10 dataset
model = ModelFactory.create_model('resnet18', pretrained=False, num_classes=10)

In [None]:
# get cifar10 dataset
dataset = DatasetFactory.create_dataset(DatasetType.CIFAR10)
train_data = dataset.load_dataset(train=True)
test_data = dataset.load_dataset(train=False)

In [None]:
# get dataloader
train_loader = DataLoaderFactory.create_dataloader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoaderFactory.create_dataloader(test_data, batch_size=64, shuffle=False)

In [None]:
train_config =  train_config = TrainConfig(
        model=model,
        train_loader=train_loader,
        epochs=1, # 1 epoch for simplicity
        device= "cuda:2"
    )


In [None]:
trainer = Trainer(train_config)
trainer.train()

In [None]:
tester = Tester(model, test_loader, device="cuda:2")
tester.test()

In [None]:
# You can also save the model
save_model(model= model, filename='resnet18_cifar10.pth')

In [None]:
# it's also possible to save checkpoints during training
model = ModelFactory.create_model('resnet18', pretrained=False, num_classes=10)
train_config =  train_config = TrainConfig(
        model=model,
        train_loader=train_loader,
        epochs=2, # 2 epoch for simplicity
        device= "cuda:7",
        save_checkpoint=True,
        checkpoint_interval=1
    )

In [None]:
trainer = Trainer(train_config)
trainer.train()

In [None]:
# It's also possible to continue training from a checkpoint
model = ModelFactory.create_model('resnet18', pretrained=False, num_classes=10)
train_config =  train_config = TrainConfig(
        model=model,
        train_loader=train_loader,
        epochs=3, # 2 epoch for simplicity
        device= "cuda:7", # mps for apple mps, cuda for nvidia cuda
        save_checkpoint=True,
        checkpoint_interval=1,
        load_checkpoint= True,
        load_checkpoint_path="/home/user/catal/code/advsecurenet/examples/training/checkpoints/training/resnet18_CIFAR10_checkpoint_epoch_2.pth"
    )


In [None]:
trainer = Trainer(train_config)
trainer.train()