In [None]:
%run notes_events_network_join.ipynb

# Hyperparameters experiment
We train and eval the model varying the following:

* cohorts: unbalanced vs balanced
* optimizers: Adam vs SGD
* learning rates
* epochs
* number of samples

In [None]:
cohorts = ['unbalanced', 'balanced']
optimizers = [torch.optim.Adam, torch.optim.SGD]
learning_rates = [0.0001, 0.001, 0.01]
epochs = [5, 10, 15]
samples = [100, 1000, 0]

results = np.empty(shape=(len(cohorts), len(optimizers), len(learning_rates), len(epochs), len(samples)), dtype='object')
criterion = nn.BCELoss()

for c, cohort in enumerate(cohorts):
    for o, optim in enumerate(optimizers):
        for l, learning_rate in enumerate(learning_rates):
            for e, n_epochs in enumerate(epochs):
                for s, n_samples in enumerate(samples):
                    model, optimizer = create_model_and_optimizer()
                    optimizer = optim(model.parameters(), lr=learning_rate)
                    print ("Training for:\n")
                    print ("Cohort        \t= ", cohort)
                    print ("Optimizer     \t= ", optimizer)
                    print ("Learning rate \t= ", learning_rate)
                    print ("No. of epochs \t= ", n_epochs)
                    print ("No. of samples\t= ", n_samples)
                    print ('---------------')
                    model_filename = cohort + '-' + str(o) + '-' + str(learning_rate) + '-' + str(n_epochs) + '-' + str(n_samples) + '.pt'

                    train_loader, val_loader = get_unbalanced_dataloaders(n_samples) if cohort=='unbalanced' else get_balanced_dataloaders(n_samples)                    
                    p, r, f, roc_auc = train_and_eval(model, train_loader, val_loader, n_epochs, model_filename)
                    results[c,o,l,e,s] = [p, r, f, roc_auc]
                    print ('---------------\n\n')


## Saving of the results

In [27]:
now = datetime.datetime.now()
r = {}
r['cohorts'] = cohorts
r['optimizers'] = optimizers
r['learning_rates'] = learning_rates
r['epochs'] = epochs
r['samples'] = samples
r['results'] = results
pickle.dump( r, open("hyperparam-exp-" + now.strftime("%Y%m%d-%H%M") +".p", "wb" ) )

## Plotting the results