### Imports

In [None]:
import syft as sy
import torch
from tools import models, datasets
import numpy as np
import pandas as pd
import opacus
from tools.utils import train, test

### Loading the data

In [None]:
# In the first run download=True has to be passed to this function in order to automatically dowload the data!
train_ds, test_ds, val_ds = datasets.Loader.load_MedNIST(sample_size=0.04, test_size=0.0872, val_size=0.125)
train_data, train_labels = train_ds.as_tensor()
test_data, test_labels = test_ds.as_tensor()
val_data, val_labels = val_ds.as_tensor()

print(len(train_data))
print(len(test_data))
print(len(val_data))

### Parameters and constants

In [None]:
# Constants for tracking purposes
MODEL = 'Deep2DNet'
DATASET = 'MedNIST'
TRACKING = True # Whether or not this run should be tracked in the results.csv
DP = True # Whether or not Differential Privacy should be applied

BATCH_SIZE = 100
EPOCHS = 30

DELTA = 0.0001 # Set to be less then the inverse of the training dataset (from https://opacus.ai/tutorials/building_image_classifier)

# Parameters for training
length = len(train_data)
SAMPLE_SIZE = length - length % BATCH_SIZE # NOTE: Current implementation only trains data in multiples of batch size. So BATCH_SIZE % LENGTH amount of data will not be used for training.
SAMPLE_RATE = BATCH_SIZE / SAMPLE_SIZE

### Training and testing

In [None]:
def run(learning_rate, noise_multiplier, max_grad_norm):
    # Getting model
    model = models.Deep2DNet(torch)

    # Setting device to train on
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        device = torch.device('cuda:0')
        model.cuda(device)
    else:
        device = torch.device('cpu')
        model.cpu()

    # Optimizer and Loss Function
    params = model.parameters()
    optim = torch.optim.Adam(params=params, lr=learning_rate)
    loss_function = torch.nn.CrossEntropyLoss()

    # Setting up Differential Privacy Engine
    if DP:
        privacy_engine = opacus.privacy_engine.PrivacyEngine(
            model.real_module, sample_rate=SAMPLE_RATE,
            noise_multiplier=noise_multiplier, max_grad_norm=max_grad_norm
        )
        privacy_engine.attach(optim)
    else:
        privacy_engine = None

    # Training
    losses, test_accs, test_losses, epsilons, alphas, epoch_times = train(BATCH_SIZE, EPOCHS, DELTA,
                                                                          model, torch,
                                                                          optim, loss_function, 
                                                                          train_data, train_labels, 
                                                                          test_data, test_labels, 
                                                                          [1, 64, 64], device, privacy_engine)

    # Validation
    val_acc, val_loss = test(model, loss_function, torch, val_data, val_labels, device)

    print(f'Validation Accuracy: {val_acc} ---- Validation Loss: {val_loss}')

    # Tracking all interesting variables and results in .csv file
    if TRACKING:
        d = {
            'model': MODEL,
            'dataset': DATASET,
            'batch_size': BATCH_SIZE,
            'epochs': EPOCHS,
            'learning_rate': learning_rate,
            'train_sample_size': SAMPLE_SIZE,
            'test_sample_size': len(test_data),
            'val_sample_size': len(val_data),
            'delta': DELTA,
            'noise_multiplier': noise_multiplier,
            'max_grad_norm': max_grad_norm,
            'dp_used': DP,
            'epsilons': epsilons,
            'alphas': alphas,
            'train_losses': losses,
            'test_accs': test_accs,
            'test_losses': test_losses,
            'val_acc': val_acc,
            'val_loss': val_loss,
            'epoch_times': epoch_times
        }      
        df = pd.read_csv('./Results/1DS.csv')
        df = df.append(d, ignore_index=True)
        df.to_csv('./Results/1DS.csv', index=False)

### Gridsearch

In [None]:
import itertools

def gridsearch(lrs, noises, norms):
    grid = itertools.product(lrs, noises, norms)
    num = 0
    
    if DP:
        for learning_rate, noise_multiplier, max_grad_norm in grid:
            num += 1
            print(f'################################# RUN No. {num} #################################')
            run(learning_rate, noise_multiplier, max_grad_norm)
            
    else:
        for learning_rate in lrs:
            num += 1
            print(f'################################# RUN No. {num} #################################')
            run(learning_rate, 0 , 0)
            
    
lrs = [0.0025]
noises = [0.1, 0.5, 1.0, 1.5, 2.0, 2.5, 3.0, 3.5, 4.0]
norms = [5.0]