# Bening Training

In [None]:

from advsecurenet.models.model_factory import ModelFactory
from advsecurenet.datasets.dataset_factory import DatasetFactory
from advsecurenet.dataloader.data_loader_factory import DataLoaderFactory
from advsecurenet.shared.types.configs.preprocess_config import \
    PreprocessConfig, PreprocessStep
from advsecurenet.shared.types.configs.device_config import DeviceConfig
from advsecurenet.shared.types.configs import TrainConfig
from advsecurenet.trainer.trainer import Trainer
from advsecurenet.shared.types.configs.model_config import CreateModelConfig


In [None]:
# Define the model
model = ModelFactory.create_model(model_name='resnet18', num_classes=10, pretrained=True)

In [None]:
# Lets define the preprocessing configuration we want to use
preprocess_config = PreprocessConfig(
    steps=[
        PreprocessStep(name='Resize', params={'size': 32}),
        PreprocessStep(name='CenterCrop', params={'size': 32}),
        PreprocessStep(name='ToTensor'),
        PreprocessStep(name='ToDtype', params={'dtype': 'torch.float32', 'scale': True}),
        PreprocessStep(name='Normalize', params={'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225]}),
    ]
)

# Define the dataset
dataset = DatasetFactory.create_dataset(dataset_type="cifar10", preprocess_config=preprocess_config, return_loaded=False)
train_data = dataset.load_dataset(train=True)
test_data = dataset.load_dataset(train=False)

In [None]:
# Define the dataloder
dataloader = DataLoaderFactory.create_dataloader(dataset=train_data, batch_size=32 )

In [None]:
# Define the training config
config = TrainConfig(
    model=model,
    train_loader=dataloader,
    epochs=2,
    processor="mps" # Set this to your desired processor i.e. "cpu", "gpu", "mps"
)

In [None]:
# Create the trainer
trainer = Trainer(config)

# Train the model
trainer.train()

## Using an External Model

It's also possible to use a custom external model in the `advsecurenet` library. The following example shows how to load an external model and train it.

In [None]:
model_config = CreateModelConfig(
    model_name="Net",
    model_arch_path="./external_model.py",
    num_classes=10,
    pretrained=False,
    is_external=True
)

external_model = ModelFactory.create_model(model_config)

In [None]:
# Check if the model is loaded
print(external_model)

In [None]:
# update the training config
config = TrainConfig(
    model=external_model,
    train_loader=dataloader,
    epochs=1,
    processor="mps" # Set this to your desired processor i.e. "cpu", "gpu", "mps"
)

# Create the trainer
trainer = Trainer(config)

# Train the model
trainer.train()