# Dependencies

In [1]:
# --- Dependencies ---
import pickle
import numpy as np

from sklearn.metrics import f1_score, accuracy_score, roc_auc_score
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [2]:
with open('train.pkl', 'rb') as f:
  train = pickle.load(f)
with open('validation.pkl', 'rb') as f:
  validation = pickle.load(f)
with open('test.pkl', 'rb') as f:
  test = pickle.load(f)

In [3]:
X_train = [sequence for sequence, target in train]
X_val = [sequence for sequence, target in validation]
X_test = [sequence for sequence, target in test]

Y_train = [target for sequence, target in train]
Y_val = [target for sequence, target in validation]
Y_test = [target for sequence, target in test]

In [4]:
INPUT_SIZE = len(X_train[0][0])

# Dataset / DataLoader

In [5]:
class ReviewDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __len__(self):
        return len(self.y)
        
    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx]).float(), torch.tensor(self.y[idx]).float()

train_dataset = ReviewDataset(X_train, Y_train)
val_dataset = ReviewDataset(X_val, Y_val)
test_dataset = ReviewDataset(X_test, Y_test)

In [6]:
# Pad the sequences
def collate_fn(batch):
    inputs = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    
    # Pad sequences so they all have the same length within one batch
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=0)

    return inputs_padded, torch.stack(targets)

In [7]:
batch_size = 64
train_loader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1)

# Model

In [8]:
class GRUModel(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers, dropout):
    super(GRUModel, self).__init__()

    self.hidden_size = hidden_size
    self.num_layers = num_layers
    self.dropout = dropout

    # Define GRU layer
    self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)

    # Define output layer
    self.out = nn.Linear(hidden_size, 2)  # Two output nodes for λ and k
    #self.relu = nn.ReLU()

  def forward(self, x):
    # Forward propagate GRU
    y, hidden = self.gru(x)  # shape = (batch_size, seq_length, hidden_size)

    # Decode the hidden state of the last time step
    out = self.out(y[:, -1, :])
    #out = self.relu(out)

    # Get parameters of the Weibull distribution
    #λ = torch.exp(out[:, 0])
    λ, k = torch.exp(out[:, 0]), torch.exp(out[:, 1])  # exp to ensure positive values
    #λ, k, z1 = torch.exp(out[:, 0]), torch.exp(out[:, 1]), torch.exp(out[:, 2])
    return λ, k

In [9]:
def weibull_survival_function(λ, k, t):
  # small constant to prevent division by zero
  epsilon = 1e-7
  λ = torch.clamp(λ, min=epsilon)
  return torch.exp(-(t / λ) ** k)

# Train

In [10]:
device = torch.device('cuda')

In [11]:
from sklearn.metrics import roc_auc_score, fbeta_score, accuracy_score, precision_score, recall_score

def calculate_metrics(model, loader, device, threshold, metrics, print_results=False):
    model.eval()
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            λ, k = model(inputs)
            review_time = targets[:, 0]
            remembered = targets[:, 1]
            predicted_survival = weibull_survival_function(λ, k, review_time)

            preds = predicted_survival
            all_preds.append(preds.cpu().numpy())
            all_targets.append(remembered.cpu().numpy())

    # Flatten the lists of arrays for computing metrics
    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)
    
    threshold_preds = (all_preds > threshold).astype(int)

    results = {}
    for metric in metrics:
        if metric == "roc_auc":
            result = roc_auc_score(all_targets, all_preds)
        elif metric == "f2":
            result = fbeta_score(all_targets, threshold_preds, beta=2)
        elif metric == "accuracy":
            result = accuracy_score(all_targets, threshold_preds)
        elif metric == "precision":
            result = precision_score(all_targets, threshold_preds)
        elif metric == "recall":
            result = recall_score(all_targets, threshold_preds)
        else:
            raise ValueError(f"Unsupported metric: {metric}")
        
        results[metric] = result
    if print_results:
        print(results)
    return results


In [12]:
def log_scores(epoch_scores, score_dict=None):
    """
    This function takes the epoch scores and a dictionary of previous scores (or None),
    and returns an updated dictionary of scores.
    """

    if score_dict is None:
        # Initialize score_dict with empty lists if it's None
        score_dict = {key: [] for key in epoch_scores.keys()}

    for key in epoch_scores.keys():
        if key in score_dict:
            score_dict[key].append(epoch_scores[key])
        else:
            score_dict[key] = [epoch_scores[key]]

    return score_dict

In [13]:
def train(*, epochs = 5, hidden_size = 128, num_layers = 2, lr = 0.001, dropout = 0.2, weight_decay = 0):
  model = GRUModel(INPUT_SIZE, hidden_size, num_layers, dropout)  # Adjust the parameters as needed
  model = model.to(device)
  
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)  # Adjust the learning rate as needed
  criterion = nn.BCELoss()

  scores = {}
  for epoch in range(epochs):
      model.train()

      train_loss = 0.0

      progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
      for inputs, targets in progress:
          inputs, targets = inputs.to(device), targets.to(device)
          
          λ, k = model(inputs)
          review_time = targets[:, 0]
          remembered = targets[:, 1]
          
          predicted_survival = weibull_survival_function(λ, k, review_time)

          loss = criterion(predicted_survival, remembered)

          optimizer.zero_grad()

          # clip grads
          torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

          loss.backward() 
          optimizer.step()

          train_loss += loss.item()
          progress.set_postfix({'loss': train_loss / (progress.n + 1)})

      epoch_loss = train_loss / len(train_loader)
      print(epoch_loss)

      if epoch % 2 == 0:
        #epoch_scores = calculate_metrics(model, train_loader, device, metrics=["roc_auc", "recall", "precision"], print_results=True)
        epoch_scores = calculate_metrics(model, val_loader, device, threshold=0.6, metrics=["roc_auc"], print_results=True)
        #epoch_scores = calculate_metrics(model, val_loader, device, threshold=0.7, metrics=["roc_auc", "recall", "precision"], print_results=True)
        # epoch_scores = calculate_metrics(model, val_loader, device, threshold=0.8, metrics=["roc_auc", "recall", "precision"], print_results=True)
        # epoch_scores = calculate_metrics(model, test_loader, device, metrics=["roc_auc", "recall", "precision"], print_results=True)
        scores = log_scores(epoch_scores, scores)

  return scores

In [14]:
def print_max_values(input_dict):
    for outer_key, inner_dict in input_dict.items():
        print(outer_key)
        for inner_key, values_list in inner_dict.items():
            max_value_index = np.argmax(values_list)
            max_value = values_list[max_value_index]
            print(f'{inner_key}: {round(max_value, 5)} (at {max_value_index})')

In [16]:
  # Set up logging
from itertools import product

# Define the grid of hyperparameters
grid = {
    'lr': [0.001],
    'hidden_size': [128],
    'num_layers': [1, 2, 3],
    'dropout': [0.5],
    'weight_decay': [0.1, 0.01, 0.001],
}

all_results = {}

for hyperparams in product(*grid.values()):
    lr, hidden_size, num_layers, dropout, weight_decay = hyperparams
    if num_layers == 1 and dropout:
      continue
    desc = f"lr={lr}, hidden={hidden_size}, nl={num_layers}, dropout={dropout}, decay={weight_decay}"

    results = train(epochs=40, hidden_size=hidden_size, num_layers=num_layers, lr=lr, dropout=dropout)
    all_results[desc] = results

    print_max_values(all_results)

    # Update the best hyperparameters and the best score if the current score is higher
    #print(score)
    #if score > best_score:
     #   best_hyperparams = hyperparams
       # best_score = score

# Print the best hyperparameters and the best score
#print(f"Best Hyperparameters: {best_hyperparams}")
#print(f"Best ROC AUC Score: {best_score}")

Epoch 1/40: 100%|██████████| 416/416 [00:03<00:00, 107.77it/s, loss=0.401]


0.3908064153690178
{'roc_auc': 0.5886040648046369}


Epoch 2/40: 100%|██████████| 416/416 [00:03<00:00, 124.99it/s, loss=0.371]


0.3682910423869124


Epoch 3/40: 100%|██████████| 416/416 [00:03<00:00, 125.15it/s, loss=0.374]


0.36780605762480545
{'roc_auc': 0.6289593682624282}


Epoch 4/40: 100%|██████████| 416/416 [00:03<00:00, 126.74it/s, loss=0.365]


0.36481593251944733


Epoch 5/40: 100%|██████████| 416/416 [00:03<00:00, 124.33it/s, loss=0.368]


0.3631306928940691
{'roc_auc': 0.6687022763819259}


Epoch 6/40: 100%|██████████| 416/416 [00:03<00:00, 112.72it/s, loss=0.369]


0.36129565784134543


Epoch 7/40: 100%|██████████| 416/416 [00:03<00:00, 117.53it/s, loss=0.363]


0.3622849896693459
{'roc_auc': 0.6673943108391529}


Epoch 8/40: 100%|██████████| 416/416 [00:03<00:00, 123.73it/s, loss=0.368]


0.36174925284173626


Epoch 9/40: 100%|██████████| 416/416 [00:03<00:00, 116.11it/s, loss=0.368]


0.3613704265668415
{'roc_auc': 0.6809346897903668}


Epoch 10/40: 100%|██████████| 416/416 [00:03<00:00, 123.11it/s, loss=0.365]


0.36017938868071026


Epoch 11/40: 100%|██████████| 416/416 [00:03<00:00, 124.91it/s, loss=0.363]


0.3603791996406821
{'roc_auc': 0.6903872687004824}


Epoch 12/40: 100%|██████████| 416/416 [00:03<00:00, 109.09it/s, loss=0.363]


0.3592173269806573


Epoch 13/40: 100%|██████████| 416/416 [00:03<00:00, 120.80it/s, loss=0.36]


0.35984383763458866
{'roc_auc': 0.6858648392115675}


Epoch 14/40: 100%|██████████| 416/416 [00:03<00:00, 124.83it/s, loss=0.364]


0.3600102938138522


Epoch 15/40: 100%|██████████| 416/416 [00:03<00:00, 110.61it/s, loss=0.365]


0.3591740599546868
{'roc_auc': 0.6879617428233138}


Epoch 16/40: 100%|██████████| 416/416 [00:03<00:00, 121.88it/s, loss=0.368]


0.3585169714374038


Epoch 17/40: 100%|██████████| 416/416 [00:03<00:00, 123.06it/s, loss=0.365]


0.35846707540062756
{'roc_auc': 0.6934500608318139}


Epoch 18/40: 100%|██████████| 416/416 [00:03<00:00, 113.58it/s, loss=0.361]


0.35967556963889646


Epoch 19/40: 100%|██████████| 416/416 [00:03<00:00, 121.63it/s, loss=0.369]


0.35880620354929793
{'roc_auc': 0.6986427818894068}


Epoch 20/40: 100%|██████████| 416/416 [00:03<00:00, 116.46it/s, loss=0.359]


0.35840271313029987


Epoch 21/40: 100%|██████████| 416/416 [00:03<00:00, 109.05it/s, loss=0.363]


0.35729072185663074
{'roc_auc': 0.6860764458564239}


Epoch 22/40: 100%|██████████| 416/416 [00:03<00:00, 122.99it/s, loss=0.366]


0.3578529538443455


Epoch 23/40: 100%|██████████| 416/416 [00:03<00:00, 120.58it/s, loss=0.368]


0.35747941934431976
{'roc_auc': 0.6961662910207024}


Epoch 24/40: 100%|██████████| 416/416 [00:03<00:00, 114.51it/s, loss=0.358]


0.35699611350607413


Epoch 25/40: 100%|██████████| 416/416 [00:03<00:00, 121.02it/s, loss=0.368]


0.35696862598594564
{'roc_auc': 0.6975549851100681}


Epoch 26/40: 100%|██████████| 416/416 [00:03<00:00, 125.50it/s, loss=0.361]


0.35653834825811476


Epoch 27/40: 100%|██████████| 416/416 [00:03<00:00, 111.05it/s, loss=0.364]


0.3561308947391808
{'roc_auc': 0.6880608187668592}


Epoch 28/40: 100%|██████████| 416/416 [00:03<00:00, 123.70it/s, loss=0.359]


0.3551270059094979


Epoch 29/40: 100%|██████████| 416/416 [00:03<00:00, 121.72it/s, loss=0.364]


0.3551671480974899
{'roc_auc': 0.6922974365832417}


Epoch 30/40: 100%|██████████| 416/416 [00:03<00:00, 110.17it/s, loss=0.364]


0.3544555199332535


Epoch 31/40: 100%|██████████| 416/416 [00:03<00:00, 122.79it/s, loss=0.361]


0.3542618492904764
{'roc_auc': 0.689850301549662}


Epoch 32/40: 100%|██████████| 416/416 [00:03<00:00, 121.56it/s, loss=0.362]


0.35463494825391817


Epoch 33/40: 100%|██████████| 416/416 [00:03<00:00, 105.77it/s, loss=0.362]


0.35322609955731493
{'roc_auc': 0.6953532974757244}


Epoch 34/40: 100%|██████████| 416/416 [00:03<00:00, 122.81it/s, loss=0.359]


0.3531861180224671


Epoch 35/40: 100%|██████████| 416/416 [00:03<00:00, 122.12it/s, loss=0.362]


0.3520050071752988
{'roc_auc': 0.6889708496557212}


Epoch 36/40: 100%|██████████| 416/416 [00:03<00:00, 113.42it/s, loss=0.354]


0.3517244094624542


Epoch 37/40: 100%|██████████| 416/416 [00:03<00:00, 120.89it/s, loss=0.357]


0.350469437141258
{'roc_auc': 0.6903057247140253}


Epoch 38/40: 100%|██████████| 416/416 [00:03<00:00, 122.46it/s, loss=0.357]


0.34961223881691694


Epoch 39/40: 100%|██████████| 416/416 [00:03<00:00, 106.28it/s, loss=0.353]


0.34950511279301
{'roc_auc': 0.6909425832482559}


Epoch 40/40: 100%|██████████| 416/416 [00:03<00:00, 107.68it/s, loss=0.35]


0.3485754510531059
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.1
roc_auc: 0.69864 (at 9)


Epoch 1/40: 100%|██████████| 416/416 [00:03<00:00, 121.58it/s, loss=0.392]


0.3918728253875787
{'roc_auc': 0.5692096430656626}


Epoch 2/40: 100%|██████████| 416/416 [00:03<00:00, 114.13it/s, loss=0.374]


0.37015045867659724


Epoch 3/40: 100%|██████████| 416/416 [00:03<00:00, 120.73it/s, loss=0.369]


0.3675766293890774
{'roc_auc': 0.6309600499701549}


Epoch 4/40: 100%|██████████| 416/416 [00:03<00:00, 114.06it/s, loss=0.37]


0.36642893067059606


Epoch 5/40: 100%|██████████| 416/416 [00:03<00:00, 110.30it/s, loss=0.365]


0.3638006957391134
{'roc_auc': 0.6641957479703702}


Epoch 6/40: 100%|██████████| 416/416 [00:03<00:00, 120.62it/s, loss=0.363]


0.36175307577762467


Epoch 7/40: 100%|██████████| 416/416 [00:03<00:00, 117.14it/s, loss=0.37]


0.36174070928245783
{'roc_auc': 0.6749754552600764}


Epoch 8/40: 100%|██████████| 416/416 [00:03<00:00, 121.29it/s, loss=0.372]


0.36170236768129355


Epoch 9/40: 100%|██████████| 416/416 [00:03<00:00, 121.63it/s, loss=0.361]


0.3608325906097889
{'roc_auc': 0.684148338296644}


Epoch 10/40: 100%|██████████| 416/416 [00:03<00:00, 114.16it/s, loss=0.368]


0.36116627865256024


Epoch 11/40: 100%|██████████| 416/416 [00:03<00:00, 111.54it/s, loss=0.368]


0.35990321134718567
{'roc_auc': 0.6906212999416145}


Epoch 12/40: 100%|██████████| 416/416 [00:03<00:00, 120.69it/s, loss=0.369]


0.35991487361920566


Epoch 13/40: 100%|██████████| 416/416 [00:03<00:00, 116.54it/s, loss=0.368]


0.35962947490266883
{'roc_auc': 0.68633616345329}


Epoch 14/40: 100%|██████████| 416/416 [00:03<00:00, 117.60it/s, loss=0.369]


0.3591491976896158


Epoch 15/40: 100%|██████████| 416/416 [00:03<00:00, 118.65it/s, loss=0.365]


0.35837494134186554
{'roc_auc': 0.6875674776487933}


Epoch 16/40: 100%|██████████| 416/416 [00:03<00:00, 107.31it/s, loss=0.359]


0.35892867991844046


Epoch 17/40: 100%|██████████| 416/416 [00:03<00:00, 117.19it/s, loss=0.361]


0.3591611950586622
{'roc_auc': 0.6855480408241814}


Epoch 18/40: 100%|██████████| 416/416 [00:03<00:00, 117.05it/s, loss=0.359]


0.3581081974821595


Epoch 19/40: 100%|██████████| 416/416 [00:03<00:00, 114.04it/s, loss=0.367]


0.358430869734058
{'roc_auc': 0.6835102566026166}


Epoch 20/40: 100%|██████████| 416/416 [00:03<00:00, 120.37it/s, loss=0.368]


0.3576365092840905


Epoch 21/40: 100%|██████████| 416/416 [00:03<00:00, 119.07it/s, loss=0.362]


0.35711904203232664
{'roc_auc': 0.687381557359671}


Epoch 22/40: 100%|██████████| 416/416 [00:03<00:00, 104.30it/s, loss=0.366]


0.3561296099796891


Epoch 23/40: 100%|██████████| 416/416 [00:03<00:00, 118.45it/s, loss=0.365]


0.3555582071511218
{'roc_auc': 0.6891486155461979}


Epoch 24/40: 100%|██████████| 416/416 [00:03<00:00, 119.22it/s, loss=0.362]


0.3569485991118619


Epoch 25/40: 100%|██████████| 416/416 [00:03<00:00, 111.58it/s, loss=0.357]


0.3561312746829711
{'roc_auc': 0.6880110769351203}


Epoch 26/40: 100%|██████████| 416/416 [00:03<00:00, 115.54it/s, loss=0.358]


0.3547549583017826


Epoch 27/40: 100%|██████████| 416/416 [00:03<00:00, 116.36it/s, loss=0.356]


0.35420871622717154
{'roc_auc': 0.6846098772599916}


Epoch 28/40: 100%|██████████| 416/416 [00:03<00:00, 112.87it/s, loss=0.359]


0.353168411896779


Epoch 29/40: 100%|██████████| 416/416 [00:03<00:00, 114.31it/s, loss=0.363]


0.3533674943117568
{'roc_auc': 0.6821121849548084}


Epoch 30/40: 100%|██████████| 416/416 [00:03<00:00, 116.15it/s, loss=0.356]


0.35321670487666357


Epoch 31/40: 100%|██████████| 416/416 [00:03<00:00, 108.83it/s, loss=0.356]


0.35187065923729766
{'roc_auc': 0.6875903099650014}


Epoch 32/40: 100%|██████████| 416/416 [00:03<00:00, 118.69it/s, loss=0.356]


0.3504740116902842


Epoch 33/40: 100%|██████████| 416/416 [00:03<00:00, 120.70it/s, loss=0.354]


0.35079775338705915
{'roc_auc': 0.688223091299909}


Epoch 34/40: 100%|██████████| 416/416 [00:03<00:00, 110.88it/s, loss=0.358]


0.34985944976170474


Epoch 35/40: 100%|██████████| 416/416 [00:03<00:00, 117.16it/s, loss=0.35]


0.34846898916965496
{'roc_auc': 0.6870977842868}


Epoch 36/40: 100%|██████████| 416/416 [00:03<00:00, 117.07it/s, loss=0.353]


0.3473340136428865


Epoch 37/40: 100%|██████████| 416/416 [00:03<00:00, 110.26it/s, loss=0.353]


0.34673692379146814
{'roc_auc': 0.680608513844538}


Epoch 38/40: 100%|██████████| 416/416 [00:03<00:00, 119.09it/s, loss=0.349]


0.3449445837893738


Epoch 39/40: 100%|██████████| 416/416 [00:03<00:00, 118.05it/s, loss=0.343]


0.34293622971297455
{'roc_auc': 0.6781034825805736}


Epoch 40/40: 100%|██████████| 416/416 [00:03<00:00, 106.01it/s, loss=0.348]


0.3426130780090506
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.1
roc_auc: 0.69864 (at 9)
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.01
roc_auc: 0.69062 (at 5)


Epoch 1/40: 100%|██████████| 416/416 [00:03<00:00, 116.70it/s, loss=0.399]


0.39468963938550305
{'roc_auc': 0.5724367463297051}


Epoch 2/40: 100%|██████████| 416/416 [00:03<00:00, 120.52it/s, loss=0.379]


0.3703582728496538


Epoch 3/40: 100%|██████████| 416/416 [00:03<00:00, 110.89it/s, loss=0.369]


0.36523613039977276
{'roc_auc': 0.6419974362570657}


Epoch 4/40: 100%|██████████| 416/416 [00:03<00:00, 121.36it/s, loss=0.368]


0.36510205978097826


Epoch 5/40: 100%|██████████| 416/416 [00:03<00:00, 114.95it/s, loss=0.368]


0.3639529809140815
{'roc_auc': 0.6668887381231183}


Epoch 6/40: 100%|██████████| 416/416 [00:03<00:00, 110.65it/s, loss=0.37]


0.36233696757027734


Epoch 7/40: 100%|██████████| 416/416 [00:03<00:00, 114.09it/s, loss=0.363]


0.3620369224092708
{'roc_auc': 0.682689516378925}


Epoch 8/40: 100%|██████████| 416/416 [00:03<00:00, 118.48it/s, loss=0.363]


0.36207010296101755


Epoch 9/40: 100%|██████████| 416/416 [00:03<00:00, 114.41it/s, loss=0.369]


0.36052398996141094
{'roc_auc': 0.6823535551547216}


Epoch 10/40: 100%|██████████| 416/416 [00:03<00:00, 119.21it/s, loss=0.366]


0.36024586293989647


Epoch 11/40: 100%|██████████| 416/416 [00:03<00:00, 118.21it/s, loss=0.369]


0.36015138291538906
{'roc_auc': 0.6867748701004296}


Epoch 12/40: 100%|██████████| 416/416 [00:03<00:00, 111.83it/s, loss=0.365]


0.3609957517697834


Epoch 13/40: 100%|██████████| 416/416 [00:03<00:00, 111.85it/s, loss=0.371]


0.3591063803849885
{'roc_auc': 0.682625096629624}


Epoch 14/40: 100%|██████████| 416/416 [00:03<00:00, 118.49it/s, loss=0.367]


0.35997471790044355


Epoch 15/40: 100%|██████████| 416/416 [00:03<00:00, 114.88it/s, loss=0.362]


0.35923252587851423
{'roc_auc': 0.6856854424413616}


Epoch 16/40: 100%|██████████| 416/416 [00:03<00:00, 120.00it/s, loss=0.364]


0.35858225618274164


Epoch 17/40: 100%|██████████| 416/416 [00:03<00:00, 118.34it/s, loss=0.369]


0.35833024169103456
{'roc_auc': 0.6824041124263249}


Epoch 18/40: 100%|██████████| 416/416 [00:03<00:00, 114.13it/s, loss=0.364]


0.3577275124665063


Epoch 19/40: 100%|██████████| 416/416 [00:03<00:00, 112.13it/s, loss=0.365]


0.3625009669564091
{'roc_auc': 0.6865334999005164}


Epoch 20/40: 100%|██████████| 416/416 [00:03<00:00, 117.76it/s, loss=0.365]


0.35827025258913636


Epoch 21/40: 100%|██████████| 416/416 [00:03<00:00, 115.94it/s, loss=0.36]


0.3575618965551257
{'roc_auc': 0.6806770107931621}


Epoch 22/40: 100%|██████████| 416/416 [00:03<00:00, 114.72it/s, loss=0.367]


0.35787842236459255


Epoch 23/40: 100%|██████████| 416/416 [00:03<00:00, 114.04it/s, loss=0.365]


0.35653543391694814
{'roc_auc': 0.6876049878825635}


Epoch 24/40: 100%|██████████| 416/416 [00:03<00:00, 113.72it/s, loss=0.366]


0.35666066966950893


Epoch 25/40: 100%|██████████| 416/416 [00:03<00:00, 107.06it/s, loss=0.364]


0.35653754241334706
{'roc_auc': 0.6903244798309104}


Epoch 26/40: 100%|██████████| 416/416 [00:03<00:00, 118.81it/s, loss=0.365]


0.3570920842198225


Epoch 27/40: 100%|██████████| 416/416 [00:03<00:00, 117.73it/s, loss=0.364]


0.3555891584031857
{'roc_auc': 0.6856454858879977}


Epoch 28/40: 100%|██████████| 416/416 [00:03<00:00, 114.17it/s, loss=0.363]


0.355603308775104


Epoch 29/40: 100%|██████████| 416/416 [00:03<00:00, 118.29it/s, loss=0.365]


0.3546346676750825
{'roc_auc': 0.686023849985159}


Epoch 30/40: 100%|██████████| 416/416 [00:03<00:00, 114.28it/s, loss=0.358]


0.35418117705446023


Epoch 31/40: 100%|██████████| 416/416 [00:03<00:00, 108.53it/s, loss=0.361]


0.3536785119929566
{'roc_auc': 0.6808955486768673}


Epoch 32/40: 100%|██████████| 416/416 [00:03<00:00, 119.83it/s, loss=0.357]


0.35277601940414083


Epoch 33/40: 100%|██████████| 416/416 [00:03<00:00, 120.27it/s, loss=0.356]


0.3523934569854576
{'roc_auc': 0.6848716334565191}


Epoch 34/40: 100%|██████████| 416/416 [00:03<00:00, 110.26it/s, loss=0.353]


0.35331382405442685


Epoch 35/40: 100%|██████████| 416/416 [00:03<00:00, 116.23it/s, loss=0.355]


0.35126916258237684
{'roc_auc': 0.6749970644164875}


Epoch 36/40: 100%|██████████| 416/416 [00:03<00:00, 115.40it/s, loss=0.354]


0.35103680616101396


Epoch 37/40: 100%|██████████| 416/416 [00:03<00:00, 111.24it/s, loss=0.358]


0.3501825934061064
{'roc_auc': 0.6772529788018252}


Epoch 38/40: 100%|██████████| 416/416 [00:03<00:00, 116.74it/s, loss=0.356]


0.34976397619511074


Epoch 39/40: 100%|██████████| 416/416 [00:03<00:00, 117.60it/s, loss=0.349]


0.3481724185224336
{'roc_auc': 0.6797490402272793}


Epoch 40/40: 100%|██████████| 416/416 [00:03<00:00, 116.06it/s, loss=0.353]


0.3480253074891292
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.1
roc_auc: 0.69864 (at 9)
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.01
roc_auc: 0.69062 (at 5)
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.001
roc_auc: 0.69032 (at 12)


Epoch 1/40: 100%|██████████| 416/416 [00:03<00:00, 106.55it/s, loss=0.393]


0.38982709067372173
{'roc_auc': 0.6185103218378057}


Epoch 2/40: 100%|██████████| 416/416 [00:03<00:00, 116.75it/s, loss=0.373]


0.3687841926629727


Epoch 3/40: 100%|██████████| 416/416 [00:03<00:00, 104.38it/s, loss=0.375]


0.36826589932808507
{'roc_auc': 0.643090125675592}


Epoch 4/40: 100%|██████████| 416/416 [00:03<00:00, 111.00it/s, loss=0.373]


0.36495059959662074


Epoch 5/40: 100%|██████████| 416/416 [00:03<00:00, 107.53it/s, loss=0.364]


0.3641633437230037
{'roc_auc': 0.6599191736006237}


Epoch 6/40: 100%|██████████| 416/416 [00:03<00:00, 105.81it/s, loss=0.368]


0.3621184014213773


Epoch 7/40: 100%|██████████| 416/416 [00:03<00:00, 106.75it/s, loss=0.365]


0.3628242314220048
{'roc_auc': 0.6568319182733551}


Epoch 8/40: 100%|██████████| 416/416 [00:03<00:00, 111.26it/s, loss=0.37]


0.3623768165349387


Epoch 9/40: 100%|██████████| 416/416 [00:03<00:00, 105.21it/s, loss=0.369]


0.36206314268593603
{'roc_auc': 0.6681094516003823}


Epoch 10/40: 100%|██████████| 416/416 [00:03<00:00, 113.86it/s, loss=0.363]


0.3614378795744135


Epoch 11/40: 100%|██████████| 416/416 [00:03<00:00, 107.99it/s, loss=0.363]


0.36115670376099074
{'roc_auc': 0.6773198448707201}


Epoch 12/40: 100%|██████████| 416/416 [00:03<00:00, 105.28it/s, loss=0.371]


0.36107305659411043


Epoch 13/40: 100%|██████████| 416/416 [00:03<00:00, 108.64it/s, loss=0.361]


0.3611259121232881
{'roc_auc': 0.6796161235293543}


Epoch 14/40: 100%|██████████| 416/416 [00:03<00:00, 109.57it/s, loss=0.361]


0.3597952211633898


Epoch 15/40: 100%|██████████| 416/416 [00:04<00:00, 99.13it/s, loss=0.359]


0.3594162012450397
{'roc_auc': 0.6890784877178449}


Epoch 16/40: 100%|██████████| 416/416 [00:03<00:00, 112.25it/s, loss=0.366]


0.36025371531454414


Epoch 17/40: 100%|██████████| 416/416 [00:03<00:00, 108.52it/s, loss=0.366]


0.3593438217201485
{'roc_auc': 0.6813032686091531}


Epoch 18/40: 100%|██████████| 416/416 [00:03<00:00, 109.34it/s, loss=0.363]


0.35969387439007944


Epoch 19/40: 100%|██████████| 416/416 [00:03<00:00, 106.15it/s, loss=0.36]


0.35875417317192143
{'roc_auc': 0.6753269098417068}


Epoch 20/40: 100%|██████████| 416/416 [00:04<00:00, 103.32it/s, loss=0.359]


0.358602231905724


Epoch 21/40: 100%|██████████| 416/416 [00:03<00:00, 105.46it/s, loss=0.368]


0.3592077687454338
{'roc_auc': 0.6738599335253422}


Epoch 22/40: 100%|██████████| 416/416 [00:03<00:00, 109.91it/s, loss=0.369]


0.3588670352832056


Epoch 23/40: 100%|██████████| 416/416 [00:03<00:00, 107.81it/s, loss=0.367]


0.35859854698467714
{'roc_auc': 0.6652081165622359}


Epoch 24/40: 100%|██████████| 416/416 [00:03<00:00, 111.01it/s, loss=0.364]


0.35724735582390654


Epoch 25/40: 100%|██████████| 416/416 [00:03<00:00, 106.41it/s, loss=0.36]


0.35756412495930606
{'roc_auc': 0.6717177730011122}


Epoch 26/40: 100%|██████████| 416/416 [00:04<00:00, 100.47it/s, loss=0.361]


0.35627225446156585


Epoch 27/40: 100%|██████████| 416/416 [00:03<00:00, 106.97it/s, loss=0.362]


0.35770912640369856
{'roc_auc': 0.6767139730513433}


Epoch 28/40: 100%|██████████| 416/416 [00:03<00:00, 107.96it/s, loss=0.359]


0.3561399055358309


Epoch 29/40: 100%|██████████| 416/416 [00:04<00:00, 103.09it/s, loss=0.363]


0.35509825956362945
{'roc_auc': 0.681141811515968}


Epoch 30/40: 100%|██████████| 416/416 [00:03<00:00, 113.21it/s, loss=0.358]


0.3547489572531329


Epoch 31/40: 100%|██████████| 416/416 [00:03<00:00, 108.16it/s, loss=0.361]


0.3550462440157739
{'roc_auc': 0.6586866362453234}


Epoch 32/40: 100%|██████████| 416/416 [00:03<00:00, 105.87it/s, loss=0.36]


0.35484979222886837


Epoch 33/40: 100%|██████████| 416/416 [00:03<00:00, 109.37it/s, loss=0.362]


0.35364616952406674
{'roc_auc': 0.6757688782483047}


Epoch 34/40: 100%|██████████| 416/416 [00:03<00:00, 104.49it/s, loss=0.36]


0.35345929194814885


Epoch 35/40: 100%|██████████| 416/416 [00:03<00:00, 104.27it/s, loss=0.358]


0.3534962309237856
{'roc_auc': 0.6746863818280857}


Epoch 36/40: 100%|██████████| 416/416 [00:03<00:00, 110.25it/s, loss=0.353]


0.35228341069215763


Epoch 37/40: 100%|██████████| 416/416 [00:03<00:00, 106.00it/s, loss=0.353]


0.35081810280322456
{'roc_auc': 0.6687854512481123}


Epoch 38/40: 100%|██████████| 416/416 [00:03<00:00, 111.32it/s, loss=0.361]


0.3514937419229402


Epoch 39/40: 100%|██████████| 416/416 [00:03<00:00, 108.30it/s, loss=0.357]


0.351033529624916
{'roc_auc': 0.6608027026938872}


Epoch 40/40: 100%|██████████| 416/416 [00:03<00:00, 111.92it/s, loss=0.357]


0.3502346588513599
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.1
roc_auc: 0.69864 (at 9)
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.01
roc_auc: 0.69062 (at 5)
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.001
roc_auc: 0.69032 (at 12)
lr=0.001, hidden=128, nl=3, dropout=0.5, decay=0.1
roc_auc: 0.68908 (at 7)


Epoch 1/40: 100%|██████████| 416/416 [00:03<00:00, 105.24it/s, loss=0.397]


0.39178701801798665
{'roc_auc': 0.5771275641506541}


Epoch 2/40: 100%|██████████| 416/416 [00:03<00:00, 110.59it/s, loss=0.379]


0.370042606901664


Epoch 3/40: 100%|██████████| 416/416 [00:03<00:00, 105.15it/s, loss=0.375]


0.3673086941170578
{'roc_auc': 0.6376462491397109}


Epoch 4/40: 100%|██████████| 416/416 [00:03<00:00, 110.58it/s, loss=0.373]


0.3660018094098912


Epoch 5/40: 100%|██████████| 416/416 [00:03<00:00, 107.92it/s, loss=0.37]


0.3642632997451493
{'roc_auc': 0.652831370297766}


Epoch 6/40: 100%|██████████| 416/416 [00:03<00:00, 108.86it/s, loss=0.371]


0.36265970358195215


Epoch 7/40: 100%|██████████| 416/416 [00:03<00:00, 108.87it/s, loss=0.367]


0.3619167504545588
{'roc_auc': 0.6819009860298842}


Epoch 8/40: 100%|██████████| 416/416 [00:03<00:00, 112.23it/s, loss=0.367]


0.36176895060075015


Epoch 9/40: 100%|██████████| 416/416 [00:03<00:00, 107.43it/s, loss=0.368]


0.36227712977247745
{'roc_auc': 0.6807210445458489}


Epoch 10/40: 100%|██████████| 416/416 [00:03<00:00, 116.07it/s, loss=0.368]


0.36114625783207327


Epoch 11/40: 100%|██████████| 416/416 [00:03<00:00, 111.36it/s, loss=0.367]


0.3597259815925589
{'roc_auc': 0.6788422710978756}


Epoch 12/40: 100%|██████████| 416/416 [00:03<00:00, 106.78it/s, loss=0.362]


0.3608039159041185


Epoch 13/40: 100%|██████████| 416/416 [00:03<00:00, 108.55it/s, loss=0.371]


0.36071993826100457
{'roc_auc': 0.6755837733990468}


Epoch 14/40: 100%|██████████| 416/416 [00:03<00:00, 108.95it/s, loss=0.36]


0.3604384086524638


Epoch 15/40: 100%|██████████| 416/416 [00:03<00:00, 107.10it/s, loss=0.362]


0.3600783826281818
{'roc_auc': 0.6807646705786036}


Epoch 16/40: 100%|██████████| 416/416 [00:03<00:00, 114.01it/s, loss=0.362]


0.35904719290108633


Epoch 17/40: 100%|██████████| 416/416 [00:03<00:00, 111.91it/s, loss=0.368]


0.3597744628070639
{'roc_auc': 0.6643388576666025}


Epoch 18/40: 100%|██████████| 416/416 [00:03<00:00, 116.24it/s, loss=0.36]


0.3576845498397373


Epoch 19/40: 100%|██████████| 416/416 [00:03<00:00, 110.97it/s, loss=0.362]


0.3579494279021254
{'roc_auc': 0.6728072006601801}


Epoch 20/40: 100%|██████████| 416/416 [00:03<00:00, 112.95it/s, loss=0.363]


0.3591708813865597


Epoch 21/40: 100%|██████████| 416/416 [00:03<00:00, 111.70it/s, loss=0.365]


0.3584564125451904
{'roc_auc': 0.6695095618478519}


Epoch 22/40: 100%|██████████| 416/416 [00:03<00:00, 112.91it/s, loss=0.362]


0.3584832069822229


Epoch 23/40: 100%|██████████| 416/416 [00:03<00:00, 109.10it/s, loss=0.368]


0.3578549209767236
{'roc_auc': 0.6800686926541916}


Epoch 24/40: 100%|██████████| 416/416 [00:03<00:00, 116.15it/s, loss=0.367]


0.35734937279126966


Epoch 25/40: 100%|██████████| 416/416 [00:03<00:00, 109.14it/s, loss=0.36]


0.3569111946540383
{'roc_auc': 0.6719326414054269}


Epoch 26/40: 100%|██████████| 416/416 [00:03<00:00, 117.33it/s, loss=0.362]


0.35629536548199564


Epoch 27/40: 100%|██████████| 416/416 [00:03<00:00, 113.81it/s, loss=0.364]


0.35517391193514836
{'roc_auc': 0.6647119214046442}


Epoch 28/40: 100%|██████████| 416/416 [00:03<00:00, 118.85it/s, loss=0.366]


0.3562752189687811


Epoch 29/40: 100%|██████████| 416/416 [00:03<00:00, 112.31it/s, loss=0.363]


0.35645741607564
{'roc_auc': 0.6725748002987773}


Epoch 30/40: 100%|██████████| 416/416 [00:03<00:00, 108.72it/s, loss=0.36]


0.3550710128620267


Epoch 31/40: 100%|██████████| 416/416 [00:03<00:00, 109.35it/s, loss=0.357]


0.3538264657657307
{'roc_auc': 0.6637550027235691}


Epoch 32/40: 100%|██████████| 416/416 [00:03<00:00, 107.79it/s, loss=0.357]


0.3535194080322981


Epoch 33/40: 100%|██████████| 416/416 [00:03<00:00, 111.05it/s, loss=0.362]


0.3532288156163234
{'roc_auc': 0.6764726028514301}


Epoch 34/40: 100%|██████████| 416/416 [00:03<00:00, 115.47it/s, loss=0.355]


0.3522057962030746


Epoch 35/40: 100%|██████████| 416/416 [00:03<00:00, 107.68it/s, loss=0.362]


0.353086305459818
{'roc_auc': 0.6693395426360889}


Epoch 36/40: 100%|██████████| 416/416 [00:03<00:00, 117.39it/s, loss=0.36]


0.3511432670773222


Epoch 37/40: 100%|██████████| 416/416 [00:03<00:00, 107.96it/s, loss=0.357]


0.3520211905527573
{'roc_auc': 0.6705223381596501}


Epoch 38/40: 100%|██████████| 416/416 [00:03<00:00, 117.17it/s, loss=0.355]


0.3502546463591548


Epoch 39/40: 100%|██████████| 416/416 [00:03<00:00, 110.20it/s, loss=0.355]


0.34963189939466804
{'roc_auc': 0.6673771866019969}


Epoch 40/40: 100%|██████████| 416/416 [00:03<00:00, 115.90it/s, loss=0.356]


0.34741672235899246
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.1
roc_auc: 0.69864 (at 9)
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.01
roc_auc: 0.69062 (at 5)
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.001
roc_auc: 0.69032 (at 12)
lr=0.001, hidden=128, nl=3, dropout=0.5, decay=0.1
roc_auc: 0.68908 (at 7)
lr=0.001, hidden=128, nl=3, dropout=0.5, decay=0.01
roc_auc: 0.6819 (at 3)


Epoch 1/40: 100%|██████████| 416/416 [00:03<00:00, 110.11it/s, loss=0.404]


0.3931394934725876
{'roc_auc': 0.6008372936529423}


Epoch 2/40: 100%|██████████| 416/416 [00:03<00:00, 112.40it/s, loss=0.38]


0.37041098757002217


Epoch 3/40: 100%|██████████| 416/416 [00:03<00:00, 110.89it/s, loss=0.374]


0.3665484352252231
{'roc_auc': 0.6567328423298096}


Epoch 4/40: 100%|██████████| 416/416 [00:03<00:00, 109.29it/s, loss=0.366]


0.3660539271166691


Epoch 5/40: 100%|██████████| 416/416 [00:03<00:00, 106.54it/s, loss=0.365]


0.363893526355521
{'roc_auc': 0.6590401294266153}


Epoch 6/40: 100%|██████████| 416/416 [00:03<00:00, 114.72it/s, loss=0.372]


0.36436239518941593


Epoch 7/40: 100%|██████████| 416/416 [00:03<00:00, 108.19it/s, loss=0.372]


0.362414280561587
{'roc_auc': 0.67195995864089}


Epoch 8/40: 100%|██████████| 416/416 [00:03<00:00, 115.91it/s, loss=0.366]


0.3630519116369004


Epoch 9/40: 100%|██████████| 416/416 [00:03<00:00, 110.91it/s, loss=0.368]


0.3613425996512748
{'roc_auc': 0.6778645587002541}


Epoch 10/40: 100%|██████████| 416/416 [00:03<00:00, 113.12it/s, loss=0.366]


0.3621252879070548


Epoch 11/40: 100%|██████████| 416/416 [00:03<00:00, 111.73it/s, loss=0.37]


0.36152388635449684
{'roc_auc': 0.6874084668752019}


Epoch 12/40: 100%|██████████| 416/416 [00:03<00:00, 110.19it/s, loss=0.362]


0.3603097188214843


Epoch 13/40: 100%|██████████| 416/416 [00:03<00:00, 110.39it/s, loss=0.362]


0.36095595015929294
{'roc_auc': 0.6857710636271418}


Epoch 14/40: 100%|██████████| 416/416 [00:03<00:00, 113.33it/s, loss=0.361]


0.3597225519613578


Epoch 15/40: 100%|██████████| 416/416 [00:03<00:00, 114.99it/s, loss=0.37]


0.3606788679026067
{'roc_auc': 0.6811434423956971}


Epoch 16/40: 100%|██████████| 416/416 [00:03<00:00, 115.79it/s, loss=0.363]


0.359930979875991


Epoch 17/40: 100%|██████████| 416/416 [00:03<00:00, 109.38it/s, loss=0.365]


0.3600090860317533
{'roc_auc': 0.6941521545552102}


Epoch 18/40: 100%|██████████| 416/416 [00:03<00:00, 114.96it/s, loss=0.359]


0.35879998919195855


Epoch 19/40: 100%|██████████| 416/416 [00:03<00:00, 110.41it/s, loss=0.358]


0.35798987989815384
{'roc_auc': 0.6946214401972712}


Epoch 20/40: 100%|██████████| 416/416 [00:03<00:00, 113.96it/s, loss=0.359]


0.3591798092955007


Epoch 21/40: 100%|██████████| 416/416 [00:03<00:00, 111.76it/s, loss=0.367]


0.35872897799485004
{'roc_auc': 0.6925649008588213}


Epoch 22/40: 100%|██████████| 416/416 [00:03<00:00, 112.13it/s, loss=0.365]


0.357792466878891


Epoch 23/40: 100%|██████████| 416/416 [00:03<00:00, 108.01it/s, loss=0.357]


0.3568300736638216
{'roc_auc': 0.6864095530411014}


Epoch 24/40: 100%|██████████| 416/416 [00:03<00:00, 108.00it/s, loss=0.365]


0.3562572594338025


Epoch 25/40: 100%|██████████| 416/416 [00:03<00:00, 106.15it/s, loss=0.362]


0.3571192070913429
{'roc_auc': 0.6800389290991347}


Epoch 26/40: 100%|██████████| 416/416 [00:03<00:00, 110.82it/s, loss=0.366]


0.3558407780141212


Epoch 27/40: 100%|██████████| 416/416 [00:03<00:00, 111.64it/s, loss=0.364]


0.35584940073581844
{'roc_auc': 0.6901907476931206}


Epoch 28/40: 100%|██████████| 416/416 [00:03<00:00, 109.96it/s, loss=0.359]


0.3560719278712685


Epoch 29/40: 100%|██████████| 416/416 [00:03<00:00, 108.98it/s, loss=0.361]


0.35561849126735556
{'roc_auc': 0.6830491253592013}


Epoch 30/40: 100%|██████████| 416/416 [00:03<00:00, 109.37it/s, loss=0.364]


0.3555788997059258


Epoch 31/40: 100%|██████████| 416/416 [00:04<00:00, 102.31it/s, loss=0.362]


0.35405828803777695
{'roc_auc': 0.6857710636271417}


Epoch 32/40: 100%|██████████| 416/416 [00:03<00:00, 104.48it/s, loss=0.361]


0.3536813332197758


Epoch 33/40: 100%|██████████| 416/416 [00:03<00:00, 111.74it/s, loss=0.357]


0.35314215873726285
{'roc_auc': 0.6851268661341301}


Epoch 34/40: 100%|██████████| 416/416 [00:03<00:00, 110.48it/s, loss=0.354]


0.3527523840849216


Epoch 35/40: 100%|██████████| 416/416 [00:03<00:00, 108.86it/s, loss=0.358]


0.3517169152171566
{'roc_auc': 0.6807740481370461}


Epoch 36/40: 100%|██████████| 416/416 [00:03<00:00, 108.58it/s, loss=0.354]


0.35173339772826206


Epoch 37/40: 100%|██████████| 416/416 [00:03<00:00, 113.96it/s, loss=0.355]


0.3520322762644635
{'roc_auc': 0.6808776089998467}


Epoch 38/40: 100%|██████████| 416/416 [00:03<00:00, 104.09it/s, loss=0.357]


0.35038532119674176


Epoch 39/40: 100%|██████████| 416/416 [00:04<00:00, 101.54it/s, loss=0.354]


0.3494231758209375
{'roc_auc': 0.6744446039082402}


Epoch 40/40: 100%|██████████| 416/416 [00:03<00:00, 105.94it/s, loss=0.348]

0.3474076883509182
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.1
roc_auc: 0.69864 (at 9)
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.01
roc_auc: 0.69062 (at 5)
lr=0.001, hidden=128, nl=2, dropout=0.5, decay=0.001
roc_auc: 0.69032 (at 12)
lr=0.001, hidden=128, nl=3, dropout=0.5, decay=0.1
roc_auc: 0.68908 (at 7)
lr=0.001, hidden=128, nl=3, dropout=0.5, decay=0.01
roc_auc: 0.6819 (at 3)
lr=0.001, hidden=128, nl=3, dropout=0.5, decay=0.001
roc_auc: 0.69462 (at 9)



