In [None]:
import sys

import torch
import numpy as np

from attacks import AttackContainer
from basemodels import ModelContainer, MnistCnnCW
from datasets import (DATASET_LIST, DataContainer, get_image_list,
                      get_quantitative_list)
from defences import DefenceContainer

%load_ext autoreload
%autoreload 2

In [None]:
print(sys.version)
print(*sys.path, sep='\n')

In [None]:
print('Avaliable image datasets:')
print(get_image_list())
print()
print('Avaliable quantitative datasets:')
print(get_quantitative_list())
print()

In [None]:
# 1. choose a dataset
DATA_ROOT = 'data'
BATCH_SIZE = 64
TYPE = 'image'  # image or quantitative
# image: 'MNIST', 'CIFAR10', 'SVHN'
# quantitative: 'BankNote', 'BreastCancerWisconsin', 'WheatSeed', 'HTRU2'
NAME = 'MNIST'

print(f'Starting {NAME} data container...')
IMAGE_DATASET = DATASET_LIST[TYPE][NAME]
dc = DataContainer(IMAGE_DATASET, DATA_ROOT)
dc(BATCH_SIZE)

In [None]:
# 2. choose a model
model = MnistCnnCW()

In [None]:
# 3. train/load the model
# train, save, load
mc = ModelContainer(model, dc)

In [None]:
mc.fit(epochs=20)

In [None]:
x = torch.rand(2, 1, 28, 28)
mc.pred(x, require_output=True)
x = torch.rand(1, 28, 28)
mc.pred_one(x, require_output=True)

In [None]:
x = np.random.rand(5, 1, 28, 28).astype(np.float32)
mc.pred(x, require_output=False)

In [None]:
mc.save('MNIST_CNN1')

In [None]:
mc.load('MNIST_CNN1.pt')

In [None]:
args = {'step_size':1, 'gamma':0.9}
print(*args)