In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import copy
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from team36.mnist.vgg import VGG
from team36.attacks.fast_gradient_attack_data_set import FastSignGradientAttackDataSet

DIR = '.'
DATA_DIR = f'{DIR}/data'

training_set = torchvision.datasets.MNIST(root=DATA_DIR, train=True, download=True, 
                                          transform=transforms.ToTensor())

prev_model = VGG()
state_dict = torch.load(f"{DIR}/checkpoints/mnist-vgg.pth", map_location=torch.device('cpu'))
prev_model.load_state_dict(state_dict)
prev_criterion = nn.CrossEntropyLoss()

attack_training_set = FastSignGradientAttackDataSet(training_set, prev_model, prev_criterion, 
                                                    epsilon=0.25)

combined_training_set = torch.utils.data.ConcatDataset([training_set, attack_training_set])

training_indices, validation_indices, _, _ = train_test_split(
    range(len(combined_training_set)),
    combined_training_set.targets,
    stratify=combined_training_set.targets,
    test_size=0.1,
)
training_split = torch.utils.data.Subset(combined_training_set, training_indices)
validation_split = torch.utils.data.Subset(combined_training_set, validation_indices)

print(f"{len(training_split)} in training set")
print(f"{len(validation_split)} in validation set")