### TODO - CNN portion of project

1. Update metric to match those in paper (F1 score, precision, recall, AUC
2. Integrate phenotype dictionary to make selecting phenotype for Y-value in experiment easier (ie. not using a hard-coded integer)
3. Run experiments across all 10 phenotypes used in paper with default parameters
4. Repose this code
5. Add readMe
6. Create figure to compare F1 scores across phenotypes for the CNN

### Setup

In [1]:
# Mount into drive

from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import sys
sys.path.append('/content/drive/MyDrive/Project/src')

In [3]:
root = '/content/drive/MyDrive/Project'

In [4]:
import os
os.chdir(root)
%pwd

'/content/drive/.shortcut-targets-by-id/1AB0jDnt7V3OmP_E50BE3q1El8K68T-6-/Project'

In [5]:
# Detect PY file updates and reload
%load_ext autoreload
%autoreload 0.5

In [6]:
%ls

 BDH_reproducibility_challenge.pdf
 [0m[01;34mcse6250-project[0m/
 CSE_6250_ProjectPaperSelection2.pdf
 CSE6250_ProjectPaperSelection.pdf
 CSE6250_ProjectProposal.pdf
 [01;34mdata[0m/
 [01;34mnotebooks[0m/
'Paper Notes.gdoc'
 [01;34mPapers[0m/
 ProjectNotes.gdoc
 ProjectTimeline.gsheet
 [01;34msrc[0m/
'Team Registration & Paper Selection.gsheet'
 [01;34mwandb[0m/


In [7]:
%pwd

'/content/drive/.shortcut-targets-by-id/1AB0jDnt7V3OmP_E50BE3q1El8K68T-6-/Project'

### Installations

In [8]:
!pip install wandb -qqq

In [9]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [10]:
import wandb
wandb.login()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33malphabitten[0m ([33mcs7643-teamscam[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [11]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import logging
import time
import h5py
from platform import python_version
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import sklearn
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import Adam, Adadelta
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import train_test_split
from platform import python_version
from torch.utils import data

# Local imports
import src.CNN.CNN_NLP as cnn_model
from src.CNN.data_load import get_data
from src.CNN.run_model import run_model

In [12]:
if torch.cuda.is_available():       
    device = torch.device("cuda")
    print(f'There are {torch.cuda.device_count()} GPU(s) available.')
    print('Device name:', torch.cuda.get_device_name(0))

else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
Device name: A100-SXM4-40GB


### Conditions Dictionary

In [13]:
conditions = {}
conditions[0] = 'cohort'
conditions[1] = 'Obesity'
conditions[2] = 'Non.Adherence'
conditions[3] = 'Developmental.Delay.Retardation'
conditions[4] = 'Advanced.Heart.Disease'
conditions[5] = 'Advanced.Lung.Disease'
conditions[6] = 'Schizophrenia.and.other.Psychiatric.Disorders'
conditions[7] = 'Alcohol.Abuse'
conditions[8] = 'Other.Substance.Abuse'
conditions[9] = 'Chronic.Pain.Fibromyalgia'
conditions[10] = 'Chronic.Neurological.Dystrophies'
conditions[11] = 'Advanced.Cancer'
conditions[12] = 'Depression'
conditions[13] = 'Dementia'
conditions[14] = 'Unsure'

### Experiment

In [14]:
import torch.optim as optim

phenotype = 1

sweep_config = {
    'method': 'random', #grid, random
    'metric': {
      'name': 'val_accuracy',
      'goal': 'maximize'   
    },
    'parameters': {
        'h5py_file': {
            'value':'src/phenotyping/their-embeddings/data-nobatch.h5'
        },
        'batch_size': {
            'values':[32, 64, 128]
        },
        'filter_sizes': {
            'value':[2, 3, 4, 5]
        },
        'num_filters': {
            'value':[100,100,100,100]
        },
        'num_classes': {
            'value':2
        },
        'dropout': {
            'values': [0.3, 0.5]
        },
        'learning_rate': {
            'values': [1e-1, 1e-2, 1e-3]
        },
        'phenotype_value': {
            'value': phenotype
        },
        'phenotype_description': {
            'value': conditions[phenotype]
        },
        'epochs': {
            'values':[100]
        },
        'opt': {
            'values':['ada']
        },
        'rho':{
            'values':[0.9, 0.95]
        },
        'freeze_embeddings':{
            'values':[True]
        }
    }
}

def run():
  with wandb.init(project="cs6250-project", entity="cs7643-teamscam") as run:
    config = wandb.config
    
    # Parameters
    H5PY_FILE = config["h5py_file"]
    BATCH_SIZE = config["batch_size"]
    FILTER_SIZES = config["filter_sizes"]
    NUM_FILTERS = config["num_filters"]
    NUM_CLASSES = config["num_classes"]
    DROPOUT = config["dropout"]
    LEARNING_RATE = config["learning_rate"]
    RHO = config["rho"]
    PHENOTYPE = config["phenotype_value"]
    EPOCHS = config["epochs"]
    FREEZE_EMBEDDINGS = config["freeze_embeddings"]

    # Get Train and Validation DataLoader
    train_dataloader, val_dataloader, embeddings_tensor = get_data(H5PY_FILE, device, BATCH_SIZE, PHENOTYPE)

    # Instantiate CNN model
    model = cnn_model.CNN_NLP(pretrained_embedding=embeddings_tensor,
                        freeze_embedding=FREEZE_EMBEDDINGS,
                        vocab_size=None,
                        embed_dim=300,
                        filter_sizes=FILTER_SIZES,
                        num_filters=NUM_FILTERS,
                        num_classes=NUM_CLASSES,
                        dropout=0.5)
    
    # Send model to `device` (GPU/CPU)
    model.to(device)
    
    # Instantiate Optimizer
    if (config['opt'] == 'adam'): 
      optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    else:
      optimizer = optim.Adadelta(model.parameters(), lr=LEARNING_RATE, rho=RHO)

    # Specify loss function
    loss_fn = nn.CrossEntropyLoss()

    # Instantiate the model run
    run = run_model(model, optimizer, loss_fn, device)

    # Run the train/validation
    results = run.train(train_dataloader, val_dataloader, EPOCHS)

count = 1  # number of runs to execute
sweep_id = wandb.sweep(sweep_config, project="cs6250-project", entity="cs7643-teamscam")
wandb.agent(sweep_id, function=run, count=count)




Create sweep with ID: 72sokxx6
Sweep URL: https://wandb.ai/cs7643-teamscam/cs6250-project/sweeps/72sokxx6


[34m[1mwandb[0m: Agent Starting Run: 7gncof33 with config:
[34m[1mwandb[0m: 	batch_size: 64
[34m[1mwandb[0m: 	dropout: 0.5
[34m[1mwandb[0m: 	epochs: 100
[34m[1mwandb[0m: 	filter_sizes: [2, 3, 4, 5]
[34m[1mwandb[0m: 	freeze_embeddings: True
[34m[1mwandb[0m: 	h5py_file: src/phenotyping/their-embeddings/data-nobatch.h5
[34m[1mwandb[0m: 	learning_rate: 0.1
[34m[1mwandb[0m: 	num_classes: 2
[34m[1mwandb[0m: 	num_filters: [100, 100, 100, 100]
[34m[1mwandb[0m: 	opt: ada
[34m[1mwandb[0m: 	phenotype_description: Obesity
[34m[1mwandb[0m: 	phenotype_value: 1
[34m[1mwandb[0m: 	rho: 0.9
ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Start training...

 Epoch  |  Train Loss  |  Train Acc  |  Val Loss  |  Val Acc  |  Train F1   |  Val F1  |  Elapsed 
------------------------------------------------------------


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

   1    |   0.352237   |   89.03   |  0.331342  |   89.24   |   0.003897   |   0.000000   |   21.96  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

   2    |   0.332966   |   89.29   |  0.323233  |   89.24   |   0.008356   |   0.000000   |   19.36  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

   3    |   0.324036   |   89.28   |  0.315963  |   89.24   |   0.012929   |   0.000000   |   19.37  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

   4    |   0.313642   |   89.57   |  0.305551  |   89.36   |   0.058260   |   0.015800   |   19.37  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

   5    |   0.302518   |   90.03   |  0.294879  |   90.06   |   0.126324   |   0.122062   |   19.42  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

   6    |   0.291252   |   90.42   |  0.283732  |   90.99   |   0.199978   |   0.246654   |   19.45  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

   7    |   0.278928   |   90.84   |  0.271919  |   91.47   |   0.248181   |   0.304143   |   19.41  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

   8    |   0.270067   |   91.20   |  0.268450  |   91.01   |   0.296616   |   0.244249   |   19.41  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

   9    |   0.257850   |   91.51   |  0.254078  |   91.70   |   0.331499   |   0.332745   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  10    |   0.246679   |   91.88   |  0.243464  |   92.29   |   0.376784   |   0.395436   |   19.39  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  11    |   0.238537   |   92.19   |  0.240744  |   92.08   |   0.410726   |   0.372696   |   19.37  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  12    |   0.230382   |   92.38   |  0.231856  |   92.68   |   0.437196   |   0.431764   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  13    |   0.221128   |   92.63   |  0.233238  |   92.42   |   0.465402   |   0.407452   |   19.39  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  14    |   0.212607   |   92.76   |  0.225421  |   92.78   |   0.477227   |   0.444038   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  15    |   0.205909   |   93.05   |  0.221961  |   93.00   |   0.508696   |   0.475511   |   19.35  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  16    |   0.199991   |   93.15   |  0.211752  |   93.31   |   0.513929   |   0.509024   |   19.36  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  17    |   0.194232   |   93.25   |  0.212307  |   93.21   |   0.535505   |   0.495099   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  18    |   0.188080   |   93.48   |  0.212382  |   93.19   |   0.549452   |   0.499944   |   19.49  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  19    |   0.182067   |   93.70   |  0.204352  |   93.43   |   0.573709   |   0.529029   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  20    |   0.179415   |   93.68   |  0.208495  |   93.31   |   0.568913   |   0.509700   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  21    |   0.172996   |   93.91   |  0.216525  |   93.14   |   0.594574   |   0.493300   |   19.36  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  22    |   0.168729   |   94.00   |  0.203805  |   93.45   |   0.604083   |   0.533543   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  23    |   0.163844   |   94.26   |  0.203452  |   93.40   |   0.623024   |   0.526191   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  24    |   0.159017   |   94.39   |  0.197334  |   93.75   |   0.636199   |   0.574860   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  25    |   0.153529   |   94.52   |  0.201245  |   93.49   |   0.648455   |   0.537631   |   19.37  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  26    |   0.151090   |   94.59   |  0.202166  |   93.47   |   0.656315   |   0.534868   |   19.42  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  27    |   0.146461   |   94.77   |  0.217245  |   93.31   |   0.673810   |   0.510349   |   19.39  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  28    |   0.143589   |   94.85   |  0.214270  |   93.24   |   0.677215   |   0.508653   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  29    |   0.139845   |   94.93   |  0.201499  |   93.66   |   0.679741   |   0.558026   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  30    |   0.136098   |   95.10   |  0.205048  |   93.59   |   0.697340   |   0.548671   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  31    |   0.132301   |   95.22   |  0.211034  |   93.49   |   0.704175   |   0.536645   |   19.37  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  32    |   0.130330   |   95.28   |  0.195501  |   93.80   |   0.716693   |   0.571519   |   19.39  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  33    |   0.129304   |   95.30   |  0.207492  |   93.54   |   0.713046   |   0.545382   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  34    |   0.125105   |   95.42   |  0.201954  |   93.63   |   0.724789   |   0.554583   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  35    |   0.123527   |   95.58   |  0.203403  |   93.71   |   0.736833   |   0.563292   |   19.35  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  36    |   0.120297   |   95.63   |  0.214060  |   93.59   |   0.736691   |   0.545467   |   19.35  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  37    |   0.117056   |   95.76   |  0.209493  |   93.56   |   0.750326   |   0.546885   |   19.36  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  38    |   0.114081   |   95.87   |  0.195450  |   93.91   |   0.745815   |   0.587220   |   19.39  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  39    |   0.110479   |   95.92   |  0.194659  |   93.91   |   0.755408   |   0.589091   |   19.35  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  40    |   0.108931   |   96.00   |  0.213113  |   93.70   |   0.762585   |   0.561646   |   19.50  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  41    |   0.107617   |   96.08   |  0.210038  |   93.71   |   0.773424   |   0.563213   |   19.37  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  42    |   0.105331   |   96.06   |  0.192707  |   94.00   |   0.765082   |   0.598188   |   19.41  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  43    |   0.104027   |   96.16   |  0.211653  |   93.63   |   0.776955   |   0.557473   |   19.39  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  44    |   0.103630   |   96.20   |  0.213520  |   93.70   |   0.781263   |   0.570722   |   19.42  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  45    |   0.101560   |   96.30   |  0.198583  |   93.98   |   0.788004   |   0.595710   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  46    |   0.098783   |   96.46   |  0.216792  |   93.66   |   0.796912   |   0.567505   |   19.41  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  47    |   0.097915   |   96.45   |  0.209233  |   93.84   |   0.796546   |   0.584290   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  48    |   0.093932   |   96.56   |  0.211681  |   93.73   |   0.803443   |   0.576059   |   19.39  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  49    |   0.094046   |   96.57   |  0.211711  |   93.71   |   0.799712   |   0.578027   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  50    |   0.090151   |   96.73   |  0.198789  |   93.93   |   0.816820   |   0.606715   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  51    |   0.090355   |   96.66   |  0.193959  |   94.05   |   0.811719   |   0.612079   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  52    |   0.088253   |   96.74   |  0.212063  |   93.87   |   0.810801   |   0.595429   |   19.43  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  53    |   0.088434   |   96.82   |  0.225273  |   93.71   |   0.812156   |   0.568127   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  54    |   0.087376   |   96.85   |  0.210043  |   94.01   |   0.821790   |   0.595236   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  55    |   0.086122   |   96.87   |  0.217124  |   93.89   |   0.820759   |   0.590512   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  56    |   0.084251   |   96.98   |  0.223947  |   93.77   |   0.834190   |   0.583505   |   19.36  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  57    |   0.082196   |   97.04   |  0.246626  |   93.61   |   0.832691   |   0.556408   |   19.39  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  58    |   0.079515   |   97.11   |  0.203576  |   94.14   |   0.838691   |   0.623993   |   19.45  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  59    |   0.082270   |   97.03   |  0.213091  |   93.93   |   0.838132   |   0.600677   |   19.43  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  60    |   0.080275   |   97.13   |  0.209664  |   94.03   |   0.840298   |   0.607066   |   19.47  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  61    |   0.079834   |   97.13   |  0.211311  |   94.14   |   0.842794   |   0.620626   |   19.48  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  62    |   0.077080   |   97.17   |  0.238524  |   93.89   |   0.843806   |   0.586419   |   19.55  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  63    |   0.075256   |   97.37   |  0.214711  |   94.17   |   0.855048   |   0.621075   |   19.50  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  64    |   0.074828   |   97.24   |  0.223049  |   94.05   |   0.847258   |   0.608580   |   19.47  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  65    |   0.074986   |   97.33   |  0.211397  |   94.03   |   0.853285   |   0.619249   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  66    |   0.073889   |   97.31   |  0.216884  |   94.10   |   0.846445   |   0.614220   |   19.45  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  67    |   0.073090   |   97.40   |  0.224181  |   94.07   |   0.860657   |   0.607474   |   19.52  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  68    |   0.070552   |   97.40   |  0.235711  |   93.84   |   0.859178   |   0.586464   |   19.68  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  69    |   0.071313   |   97.47   |  0.231372  |   93.98   |   0.858365   |   0.604506   |   19.60  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  70    |   0.068913   |   97.51   |  0.227769  |   93.94   |   0.863473   |   0.601629   |   19.59  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  71    |   0.069000   |   97.60   |  0.228384  |   94.12   |   0.867926   |   0.616341   |   19.46  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  72    |   0.067270   |   97.67   |  0.248032  |   94.01   |   0.874230   |   0.598133   |   19.51  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  73    |   0.066378   |   97.63   |  0.239102  |   93.94   |   0.870175   |   0.599842   |   19.62  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  74    |   0.069228   |   97.48   |  0.260858  |   93.70   |   0.858904   |   0.574712   |   19.66  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  75    |   0.065151   |   97.65   |  0.257589  |   93.86   |   0.870053   |   0.588279   |   19.60  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  76    |   0.065545   |   97.65   |  0.238396  |   94.01   |   0.874766   |   0.610288   |   19.67  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  77    |   0.064766   |   97.65   |  0.233369  |   94.14   |   0.871639   |   0.618307   |   19.55  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  78    |   0.063454   |   97.65   |  0.216591  |   94.17   |   0.870989   |   0.633250   |   19.61  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  79    |   0.064547   |   97.63   |  0.244868  |   94.03   |   0.871278   |   0.606273   |   19.55  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  80    |   0.062159   |   97.74   |  0.251836  |   93.93   |   0.874988   |   0.594401   |   19.74  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  81    |   0.062621   |   97.67   |  0.232753  |   94.26   |   0.875185   |   0.630486   |   19.71  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  82    |   0.061059   |   97.76   |  0.225481  |   94.31   |   0.877695   |   0.639490   |   19.72  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  83    |   0.059318   |   97.82   |  0.274027  |   93.77   |   0.884296   |   0.583248   |   19.71  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  84    |   0.062298   |   97.76   |  0.242770  |   94.00   |   0.876036   |   0.612619   |   19.73  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  85    |   0.059472   |   97.86   |  0.221251  |   94.21   |   0.884347   |   0.631596   |   19.77  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  86    |   0.062110   |   97.77   |  0.285812  |   93.68   |   0.881463   |   0.574050   |   19.81  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  87    |   0.057221   |   97.93   |  0.258362  |   93.94   |   0.889799   |   0.605781   |   19.84  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  88    |   0.058965   |   97.86   |  0.274934  |   93.77   |   0.884871   |   0.590216   |   19.81  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  89    |   0.056748   |   98.04   |  0.269219  |   93.75   |   0.889253   |   0.586410   |   19.81  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  90    |   0.057801   |   97.96   |  0.262891  |   93.93   |   0.890228   |   0.605732   |   19.68  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  91    |   0.055326   |   98.13   |  0.270624  |   93.87   |   0.899652   |   0.597030   |   19.59  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  92    |   0.058075   |   97.89   |  0.258553  |   93.94   |   0.889324   |   0.605891   |   19.61  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  93    |   0.056369   |   98.01   |  0.260710  |   94.00   |   0.892407   |   0.608964   |   19.59  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  94    |   0.054050   |   98.04   |  0.278742  |   93.91   |   0.895763   |   0.595367   |   19.44  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  95    |   0.056302   |   98.00   |  0.264533  |   94.03   |   0.892536   |   0.605742   |   19.36  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  96    |   0.055300   |   98.04   |  0.268945  |   94.03   |   0.894803   |   0.611245   |   19.37  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  97    |   0.054090   |   98.09   |  0.274414  |   93.91   |   0.898732   |   0.599066   |   19.38  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  98    |   0.054485   |   98.10   |  0.229850  |   94.24   |   0.896146   |   0.645134   |   19.41  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  99    |   0.054021   |   97.99   |  0.289781  |   93.73   |   0.891494   |   0.583749   |   19.40  


  0%|          | 0/622 [00:00<?, ?it/s]

  0%|          | 0/89 [00:00<?, ?it/s]

  100   |   0.051139   |   98.13   |  0.252435  |   94.07   |   0.899019   |   0.614161   |   19.40  


Training complete! Best f1-score: 0.65%.


VBox(children=(Label(value='0.001 MB of 0.019 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.038734…

0,1
avg_train_f1,▁▁▃▃▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████████████
avg_train_loss,█▇▇▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
time_elapsed,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂▂▂▂▂▂▂▁▁▁
train_acc,▁▁▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████
train_accuracy,▁▁▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████████
train_f1,▁▁▃▃▄▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████████████
train_loss,█▇▇▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▁▃▃▅▅▇▆▆▇▇▇▇▇▇▇▇▇▇▇██▇▇█████▇████▇▇▇▇▇█
val_f1,▁▁▄▄▅▅▇▆▆▇▇▇▇▇▇▇▇▇▇▇██▇██████▇█████▇████

0,1
avg_train_f1,0.89902
avg_train_loss,0.05114
epoch,99.0
time_elapsed,19.39552
train_acc,98.12852
train_accuracy,98.12852
train_f1,0.89902
train_loss,0.05114
val_accuracy,94.06601
val_f1,0.61416
