In [1]:
import torch 
import torch.nn as nn 
import numpy as np
from collections import OrderedDict

from utils.losses import protop_loss
from utils.sample_parameters import SingleGenerator, HiddenAndKernelGenerator, MonotonicGenerator
from utils.sample_parameters import ParamGenerators
from models.ProtICU import ProtICU
from utils.train_n_test import TrainTest

In [2]:
train, val, test = torch.load('data/in-hospital-mortality/tensors.pkl')

In [3]:
# Calculating class weights
perc_mort = np.concatenate((train[1], val[1], test[1])).mean()
class_weights = torch.Tensor([perc_mort, 1-perc_mort])

In [14]:
generators_dict = {
    # HYPERMARAMETER RANGES
    'BATCH_SIZE': SingleGenerator([128, 256, 512]),
    'EPOCHS': SingleGenerator([20, 40, 50]),
    'OPTIMIZER': SingleGenerator([torch.optim.Adam]),
    'LEARNING_RATE': SingleGenerator([1e-5, 5e-5, 1e-4, 5e-4]),
    'LOSS': SingleGenerator([protop_loss(class_weights, .5 , .5), protop_loss(class_weights, 1, 1),
                             protop_loss(class_weights, .1 , .1), protop_loss(class_weights, 1 , .5),
                             protop_loss(class_weights, .5 , .1), protop_loss(class_weights, .3 , .1)]),
    'EARLY_STOPPING': SingleGenerator([True]),
    'PATIENCE': SingleGenerator(list(range(2,4))),
    'MIN_DELTA': SingleGenerator([5e-5, 1e-4, 5e-4, 1e-3, 5e-3]), 

    # NETWORK PARAMETER RANGES
    'HIDDEN_AND_KERNEL_SIZES': HiddenAndKernelGenerator(range(1,4), [64, 128, 256, 512], [3, 5, 7, 9], 
                               ascending=(True, False)),
    'MAXPOOL': SingleGenerator([2]),
    'OBO_SIZES': MonotonicGenerator(range(1,3), [64, 128, 256, 512], ascending=False),
    'PROTOTYPE_NUM': SingleGenerator([10,20]),
    'DROPOUT': SingleGenerator(np.arange(8)/10)
} # each of these params are sampled INDEPENDENTLY of one another

In [16]:
N = 30
gen = ParamGenerators(generators_dict)
param_samples = gen.sample(N)

In [7]:
stats = []
for i in np.arange(N):
    print(param_samples[i])
    run = TrainTest(ProtICU, (train, val, test), param_samples[i])
    run.train()
    stats.append(run.test())

  0%|          | 0/115 [00:00<?, ?it/s]

{'BATCH_SIZE': 128, 'EPOCHS': 40, 'OPTIMIZER': <class 'torch.optim.adam.Adam'>, 'LEARNING_RATE': 5e-05, 'LOSS': protop_loss_0.5_0.5, 'EARLY_STOPPING': True, 'PATIENCE': 2, 'MIN_DELTA': 0.0001, 'MAXPOOL': 2, 'OBO_SIZES': array([256, 256]), 'PROTOTYPE_NUM': 20, 'DROPOUT': 0.5, 'HIDDEN_SIZES': array([ 64, 128]), 'KERNEL_SIZES': array([7, 3])}


100%|██████████| 115/115 [00:16<00:00,  6.79it/s]
100%|██████████| 16/16 [00:00<00:00, 20.23it/s]
  1%|          | 1/115 [00:00<00:14,  8.00it/s]

Epoch: 0, train_loss: 0.29564544558525085, valid_loss: -0.20622141659259796


100%|██████████| 115/115 [00:20<00:00,  5.58it/s]
100%|██████████| 16/16 [00:01<00:00, 15.82it/s]
  0%|          | 0/115 [00:00<?, ?it/s]

Epoch: 1, train_loss: -0.5244238376617432, valid_loss: -0.6580662131309509


100%|██████████| 115/115 [00:24<00:00,  4.61it/s]
100%|██████████| 16/16 [00:01<00:00, 15.89it/s]
  1%|          | 1/115 [00:00<00:18,  6.21it/s]

Epoch: 2, train_loss: -0.7162585854530334, valid_loss: -0.697081446647644


100%|██████████| 115/115 [00:16<00:00,  6.96it/s]
100%|██████████| 16/16 [00:00<00:00, 18.55it/s]
  1%|          | 1/115 [00:00<00:14,  7.90it/s]

Epoch: 3, train_loss: -0.756279706954956, valid_loss: -0.7457778453826904


100%|██████████| 115/115 [00:16<00:00,  6.87it/s]
100%|██████████| 16/16 [00:00<00:00, 18.26it/s]
  1%|          | 1/115 [00:00<00:17,  6.40it/s]

Epoch: 4, train_loss: -0.7732405066490173, valid_loss: -0.7476105690002441


100%|██████████| 115/115 [00:15<00:00,  7.27it/s]
100%|██████████| 16/16 [00:00<00:00, 20.73it/s]
  1%|          | 1/115 [00:00<00:14,  7.76it/s]

Epoch: 5, train_loss: -0.7823200821876526, valid_loss: -0.7563960552215576


100%|██████████| 115/115 [00:14<00:00,  7.82it/s]
100%|██████████| 16/16 [00:00<00:00, 20.45it/s]
  1%|          | 1/115 [00:00<00:13,  8.16it/s]

Epoch: 6, train_loss: -0.7867077589035034, valid_loss: -0.767099142074585


100%|██████████| 115/115 [00:15<00:00,  7.62it/s]
100%|██████████| 16/16 [00:00<00:00, 19.56it/s]
  1%|          | 1/115 [00:00<00:13,  8.35it/s]

val loss increased, patience count:  1
Epoch: 7, train_loss: -0.7923399806022644, valid_loss: -0.7627156972885132


100%|██████████| 115/115 [00:16<00:00,  6.78it/s]
100%|██████████| 16/16 [00:00<00:00, 17.81it/s]
  1%|          | 1/115 [00:00<00:14,  7.69it/s]

Epoch: 8, train_loss: -0.797744631767273, valid_loss: -0.7722408771514893


100%|██████████| 115/115 [00:17<00:00,  6.67it/s]
100%|██████████| 16/16 [00:00<00:00, 17.87it/s]
  1%|          | 1/115 [00:00<00:19,  5.99it/s]

Epoch: 9, train_loss: -0.7994142174720764, valid_loss: -0.7746787071228027


100%|██████████| 115/115 [00:15<00:00,  7.43it/s]
100%|██████████| 16/16 [00:00<00:00, 20.65it/s]
  1%|          | 1/115 [00:00<00:16,  6.78it/s]

Epoch: 10, train_loss: -0.8060867786407471, valid_loss: -0.7773119211196899


100%|██████████| 115/115 [00:16<00:00,  7.09it/s]
100%|██████████| 16/16 [00:00<00:00, 21.65it/s]
  1%|          | 1/115 [00:00<00:12,  9.04it/s]

Epoch: 11, train_loss: -0.8092324137687683, valid_loss: -0.7888762950897217


100%|██████████| 115/115 [00:14<00:00,  7.99it/s]
100%|██████████| 16/16 [00:00<00:00, 21.79it/s]
  1%|          | 1/115 [00:00<00:13,  8.60it/s]

val loss increased, patience count:  1
Epoch: 12, train_loss: -0.809702455997467, valid_loss: -0.7843737006187439


100%|██████████| 115/115 [00:14<00:00,  8.08it/s]
100%|██████████| 16/16 [00:00<00:00, 20.90it/s]
  1%|          | 1/115 [00:00<00:13,  8.43it/s]

val loss increased, patience count:  2
Epoch: 13, train_loss: -0.8138859868049622, valid_loss: -0.7818833589553833


100%|██████████| 115/115 [00:14<00:00,  8.07it/s]
100%|██████████| 16/16 [00:00<00:00, 18.43it/s]


Early stopped at Epoch:  14


In [8]:
stats

[{'epoch_stopped': 14,
  'auroc': 0.7195216900196326,
  'auprc': 0.3724263590994855,
  'acc': 0.8538315988647115,
  'f1': 0.0}]

In [9]:
for i in np.arange(N):
    for key,val in stats[i].items():
        param_samples[i][key] = val

In [None]:
torch.save(param_samples, 'results/Proto_experiment_nopush_N'+str(N)+'.pkl')

In [57]:
model = ProtICU(test[0].shape, 2, [64, 128], [5,3], 2, [256, 128], [0,0,0,0,0,1,1,1,1,1], .2)
out, min_dis = model(test[0][:14])