# 1 - Setup

In [1]:
# imports from Python standard library

# imports requiring installation
#   connection to Google Cloud Storage
from google.cloud import storage            # pip install google-cloud-storage
from google.oauth2 import service_account   # pip install google-auth

#  data science packages
import numpy as np                          # pip install numpy
import pandas as pd                         # pip install pandas

In [2]:
# imports from tweet_turing.py
import tweet_turing as tur      # note - different import approach from prior notebooks

# imports from tweet_turing_paths.py
from tweet_turing_paths import local_data_paths, local_snapshot_paths, gcp_data_paths, \
    gcp_snapshot_paths, gcp_project_name, gcp_bucket_name, gcp_key_file

In [3]:
# pandas options
pd.set_option('display.max_colwidth', None)

## Local or Cloud?

Decide here whether to run notebook with local data or GCP bucket data
 - if the working directory of this notebook has a "../data/" folder with data loaded (e.g. working on local computer or have data files loaded to a cloud VM) then use the "local files" option and comment out the "gcp bucket files" option
 - if this notebook is being run from a GCP VM (preferrably in the `us-central1` location) then use the "gcp bucket files" option and comment out the "local files" option

In [4]:
# option: local files
local_or_cloud: str = "local"   # comment/uncomment this line or next

# option: gcp bucket files
#local_or_cloud: str = "cloud"   # comment/uncomment this line or previous

# don't comment/uncomment for remainder of cell
if (local_or_cloud == "local"):
    data_paths = local_data_paths
    snapshot_paths = local_snapshot_paths
elif (local_or_cloud == "cloud"):
    data_paths = gcp_data_paths
    snapshot_paths = gcp_snapshot_paths
else:
    raise ValueError("Variable 'local_or_cloud' can only take on one of two values, 'local' or 'cloud'.")
    # subsequent cells will not do this final "else" check

In [5]:
# this cell only needs to run its code if local_or_cloud=="cloud"
#   (though it is harmless if run when local_or_cloud=="local")
gcp_storage_client: storage.Client = None
gcp_bucket: storage.Bucket = None

if (local_or_cloud == "cloud"):
    gcp_storage_client = tur.get_gcp_storage_client(project_name=gcp_project_name, key_file=gcp_key_file)
    gcp_bucket = tur.get_gcp_bucket(storage_client=gcp_storage_client, bucket_name=gcp_bucket_name)

In [6]:
# note this cell requires package `pyarrow` to be installed in environment
parq_filename: str = "data_after_03_eda.parquet.gz"
parq_path: str = f"{snapshot_paths['parq_snapshot']}{parq_filename}"

if (local_or_cloud == "local"):
    df_full = pd.read_parquet(parq_path, engine='pyarrow')
elif (local_or_cloud == "cloud"):
    pass    # TODO: implement loading of cloud file

# BERT With Tabular Data 

In [7]:
from dataclasses import dataclass, field
import json
import logging
import os
from typing import Optional

import numpy as np
import pandas as pd
from transformers import (
    AutoTokenizer,
    AutoConfig,
    Trainer,
    EvalPrediction,
    set_seed
)
from transformers.training_args import TrainingArguments

from multimodal_transformers.data import load_data_from_folder
from multimodal_transformers.model import TabularConfig
from multimodal_transformers.model import AutoModelWithTabular

logging.basicConfig(level=logging.INFO)
os.environ['COMET_MODE'] = 'DISABLED'

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
#Create a dataframe only containing tweets and the class
#Random sample of 8k
bert_df = df_full.sample(n=8000, random_state=1)

In [9]:
bert_df.head()

Unnamed: 0,external_author_id,author,content,region,language,following,followers,updates,post_type,is_retweet,...,tco1_step1,data_source,has_url,emoji_text,emoji_count,publish_date,class,following_ratio,class_numeric,RUS_lett_count
1459252,3252787160,FUNDDIET,http://t.co/ZUs02Wqv2c Rooftop 5 movables so exercise phuket emT @geeoharee @apol___ @Llydisblur @LaylaRalphs @RachelAucker @Billygibson94,United States,en,5,347,12844,,0.0,...,https://twitter.com/safety/unsafe_link_warning?unsafe_link=http%3A%2F%2Fwww.LoseFat.pw%2FRooftop--movables-so-exercise-phuket-emT.html,Troll,1,[],0,2015-07-11 19:52:00+00:00,Troll,0.014368,1,0
2711438,115430404,dpbrugler,"@215khalil 6-0, 225. We'll see what the Combine says.",Northeast Ohio,en,842,135515,1,replied_to,0.0,...,,verified_random,0,[],0,2017-12-03 01:24:51+00:00,Verified,0.006213,0,0
2052811,902000000000000000,PAMELA_SHARKY13,Toxic masculinity and privilege. https://t.co/EE9xVit4d4,United States,en,1305,587,64,RETWEET,1.0,...,https://twitter.com/ABC/status/901907098325352448/photo/1,Troll,1,[],0,2017-08-28 01:28:00+00:00,Troll,2.219388,1,0
3556050,208775627,B_McGee32,RT @MrMcGee33: Yessirrr headed to Hard Rock with my bro @B_McGee21 #Vibin #Family #Brothers,FLL/NYC/LA,en,2522,8203,1,retweeted,1.0,...,,verified_random,0,[],0,2013-07-06 02:44:33+00:00,Verified,0.307411,0,0
1188826,2494112058,DAILYSANJOSE,#SanJose Oakland Singer Kehlani Survives Suicide Attempt,United States,en,4282,16829,15573,,0.0,...,,Troll,0,[],0,2016-03-29 23:04:00+00:00,Troll,0.254427,1,0


In [10]:
train_df, val_df, test_df = np.split(bert_df.sample(frac=1), [int(.8*len(bert_df)), int(.9 * len(bert_df))])
print('Num examples train-val-test')
print(len(train_df), len(val_df), len(test_df))
train_df.to_csv('train.csv')
val_df.to_csv('val.csv')
test_df.to_csv('test.csv')

Num examples train-val-test
6400 800 800


# Experiment Parameters

In [11]:
@dataclass
class ModelArguments:
  """
  Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
  """

  model_name_or_path: str = field(
      metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
  )
  config_name: Optional[str] = field(
      default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
  )
  tokenizer_name: Optional[str] = field(
      default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
  )
  cache_dir: Optional[str] = field(
      default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
  )


@dataclass
class MultimodalDataTrainingArguments:
  """
  Arguments pertaining to how we combine tabular features
  Using `HfArgumentParser` we can turn this class
  into argparse arguments to be able to specify them on
  the command line.
  """

  data_path: str = field(metadata={
                            'help': 'the path to the csv file containing the dataset'
                        })
  column_info_path: str = field(
      default=None,
      metadata={
          'help': 'the path to the json file detailing which columns are text, categorical, numerical, and the label'
  })

  column_info: dict = field(
      default=None,
      metadata={
          'help': 'a dict referencing the text, categorical, numerical, and label columns'
                  'its keys are text_cols, num_cols, cat_cols, and label_col'
  })

  categorical_encode_type: str = field(default='ohe',
                                        metadata={
                                            'help': 'sklearn encoder to use for categorical data',
                                            'choices': ['ohe', 'binary', 'label', 'none']
                                        })
  numerical_transformer_method: str = field(default='yeo_johnson',
                                            metadata={
                                                'help': 'sklearn numerical transformer to preprocess numerical data',
                                                'choices': ['yeo_johnson', 'box_cox', 'quantile_normal', 'none']
                                            })
  task: str = field(default="classification",
                    metadata={
                        "help": "The downstream training task",
                        "choices": ["classification", "regression"]
                    })

  mlp_division: int = field(default=4,
                            metadata={
                                'help': 'the ratio of the number of '
                                        'hidden dims in a current layer to the next MLP layer'
                            })
  combine_feat_method: str = field(default='individual_mlps_on_cat_and_numerical_feats_then_concat',
                                    metadata={
                                        'help': 'method to combine categorical and numerical features, '
                                                'see README for all the method'
                                    })
  mlp_dropout: float = field(default=0.1,
                              metadata={
                                'help': 'dropout ratio used for MLP layers'
                              })
  numerical_bn: bool = field(default=True,
                              metadata={
                                  'help': 'whether to use batchnorm on numerical features'
                              })
  use_simple_classifier: str = field(default=True,
                                      metadata={
                                          'help': 'whether to use single layer or MLP as final classifier'
                                      })
  mlp_act: str = field(default='relu',
                        metadata={
                            'help': 'the activation function to use for finetuning layers',
                            'choices': ['relu', 'prelu', 'sigmoid', 'tanh', 'linear']
                        })
  gating_beta: float = field(default=0.2,
                              metadata={
                                  'help': "the beta hyperparameters used for gating tabular data "
                                          "see https://www.aclweb.org/anthology/2020.acl-main.214.pdf"
                              })

  def __post_init__(self):
      assert self.column_info != self.column_info_path
      if self.column_info is None and self.column_info_path:
          with open(self.column_info_path, 'r') as f:
              self.column_info = json.load(f)

In [12]:
#Define our data types 
text_cols = ['content']

#For easy model I picked out specific columns. Change accordingly!!!
categorical_cols = ['post_type']
numerical_cols = ['has_url','emoji_count', 'following_ratio', 'following','followers']


#### #We can use below code for large mass of features. Just keep in mind to remove the account category, data_source, class and class numeric
# categorical_cols = bert_df.select_dtypes(include='category')
# numerical_cols = bert_df.select_dtypes(include=['int64','float64','uint64'])

column_info_dict = {
    'text_cols': text_cols,
    'num_cols': numerical_cols,
    'cat_cols': categorical_cols,
    'label_col': 'class_numeric',
    'label_list': ['Troll', 'Verified']
}


model_args = ModelArguments(
    model_name_or_path='distilbert-base-uncased-finetuned-sst-2-english'
)

data_args = MultimodalDataTrainingArguments(
    data_path='.',
    combine_feat_method='weighted_feature_sum_on_transformer_cat_and_numerical_feats',
    column_info=column_info_dict,
    task='classification'
)

#Default 
# data_args = MultimodalDataTrainingArguments(
#     data_path='.',
#     combine_feat_method='gating_on_cat_and_num_feats_then_sum',
#     column_info=column_info_dict,
#     task='classification'
# )

training_args = TrainingArguments(
    output_dir="./logs/model_name",
    logging_dir="./logs/runs",
    overwrite_output_dir=True,
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=8,
    warmup_steps=500,   
    num_train_epochs=1,
    weight_decay=0.01, 
    evaluate_during_training=True,
    logging_steps=25,
    eval_steps=30

)

#Default Args 
# training_args = TrainingArguments(
#     output_dir="./logs/model_name",
#     logging_dir="./logs/runs",
#     overwrite_output_dir=True,
#     do_train=True,
#     do_eval=True,
#     per_device_train_batch_size=32,
#     num_train_epochs=1,
#     evaluate_during_training=True,
#     logging_steps=25,
#     eval_steps=250
# )

set_seed(training_args.seed)

In [13]:
tokenizer_path_or_name = model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path
print('Specified tokenizer: ', tokenizer_path_or_name)
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_path_or_name,
    cache_dir=model_args.cache_dir,
)

Specified tokenizer:  distilbert-base-uncased-finetuned-sst-2-english


In [14]:
train_dataset, val_dataset, test_dataset = load_data_from_folder(
    data_args.data_path,
    data_args.column_info['text_cols'],
    tokenizer,
    label_col=data_args.column_info['label_col'],
    label_list=data_args.column_info['label_list'],
    categorical_cols=data_args.column_info['cat_cols'],
    numerical_cols=data_args.column_info['num_cols'],
    sep_text_token_str=tokenizer.sep_token,
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[num_cols] = df[num_cols].fillna(df[num_cols].median())
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df[num_cols] = df[num_cols].fillna(df[num_cols].median())
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_g

In [15]:
num_labels = len(np.unique(train_dataset.labels))
num_labels

2

In [16]:
config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
tabular_config = TabularConfig(num_labels=num_labels,
                               cat_feat_dim=train_dataset.cat_feats.shape[1],
                               numerical_feat_dim=train_dataset.numerical_feats.shape[1],
                               **vars(data_args))
config.tabular_config = tabular_config

In [17]:
model = AutoModelWithTabular.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir
    )

Some weights of DistilBertWithTabular were not initialized from the model checkpoint at distilbert-base-uncased-finetuned-sst-2-english and are newly initialized: ['tabular_combiner.weight_cat', 'tabular_combiner.weight_num', 'tabular_combiner.num_bn.weight', 'tabular_combiner.num_bn.bias', 'tabular_combiner.num_bn.running_mean', 'tabular_combiner.num_bn.running_var', 'tabular_combiner.cat_layer.weight', 'tabular_combiner.cat_layer.bias', 'tabular_combiner.num_layer.weight', 'tabular_combiner.num_layer.bias', 'tabular_combiner.layer_norm.weight', 'tabular_combiner.layer_norm.bias', 'tabular_classifier.weight', 'tabular_classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
import numpy as np
from scipy.special import softmax
from sklearn.metrics import (
    auc,
    precision_recall_curve,
    roc_auc_score,
    f1_score,
    confusion_matrix,
    matthews_corrcoef,
)

In [19]:
def calc_classification_metrics(p: EvalPrediction):
  pred_labels = np.argmax(p.predictions, axis=1)
  pred_scores = softmax(p.predictions, axis=1)[:, 1]
  labels = p.label_ids
  if len(np.unique(labels)) == 2:  # binary classification
      roc_auc_pred_score = roc_auc_score(labels, pred_scores)
      precisions, recalls, thresholds = precision_recall_curve(labels,
                                                                pred_scores)
      fscore = (2 * precisions * recalls) / (precisions + recalls)
      fscore[np.isnan(fscore)] = 0
      ix = np.argmax(fscore)
      threshold = thresholds[ix].item()
      pr_auc = auc(recalls, precisions)
      tn, fp, fn, tp = confusion_matrix(labels, pred_labels, labels=[0, 1]).ravel()
      result = {'roc_auc': roc_auc_pred_score,
                'threshold': threshold,
                'pr_auc': pr_auc,
                'recall': recalls[ix].item(),
                'precision': precisions[ix].item(), 'f1': fscore[ix].item(),
                'tn': tn.item(), 'fp': fp.item(), 'fn': fn.item(), 'tp': tp.item()
                }
  else:
      acc = (pred_labels == labels).mean()
      f1 = f1_score(y_true=labels, y_pred=pred_labels)
      result = {
          "acc": acc,
          "f1": f1,
          "acc_and_f1": (acc + f1) / 2,
          "mcc": matthews_corrcoef(labels, pred_labels)
      }

  return result

In [20]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=calc_classification_metrics,
)

In [21]:
%%time
trainer.train()

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

{'loss': 0.6647732543945313, 'learning_rate': 2.5e-06, 'epoch': 0.03125, 'step': 25}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.94it/s]


{'eval_loss': 0.6497141176462173, 'eval_roc_auc': 0.7498064359254922, 'eval_threshold': 0.3665297329425812, 'eval_pr_auc': 0.80073216574209, 'eval_recall': 0.9392624728850325, 'eval_precision': 0.6744548286604362, 'eval_f1': 0.7851314596554851, 'eval_tn': 291, 'eval_fp': 48, 'eval_fn': 265, 'eval_tp': 196, 'epoch': 0.0375, 'step': 30}




{'loss': 0.6794577026367188, 'learning_rate': 5e-06, 'epoch': 0.0625, 'step': 50}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.97it/s]


{'eval_loss': 0.6128870996832848, 'eval_roc_auc': 0.7860749044977252, 'eval_threshold': 0.50742506980896, 'eval_pr_auc': 0.8319046507260692, 'eval_recall': 0.8221258134490239, 'eval_precision': 0.7460629921259843, 'eval_f1': 0.782249742002064, 'eval_tn': 193, 'eval_fp': 146, 'eval_fn': 77, 'eval_tp': 384, 'epoch': 0.075, 'step': 60}




{'loss': 0.6057778930664063, 'learning_rate': 7.5e-06, 'epoch': 0.09375, 'step': 75}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:34<00:00,  2.93it/s]


{'eval_loss': 0.576728375852108, 'eval_roc_auc': 0.8370350462954076, 'eval_threshold': 0.5470166206359863, 'eval_pr_auc': 0.878688707919508, 'eval_recall': 0.8785249457700651, 'eval_precision': 0.7555970149253731, 'eval_f1': 0.8124373119358074, 'eval_tn': 158, 'eval_fp': 181, 'eval_fn': 33, 'eval_tp': 428, 'epoch': 0.1125, 'step': 90}




{'loss': 0.5927204895019531, 'learning_rate': 1e-05, 'epoch': 0.125, 'step': 100}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.95it/s]


{'eval_loss': 0.5636800557374955, 'eval_roc_auc': 0.837860493092482, 'eval_threshold': 0.5075240135192871, 'eval_pr_auc': 0.8861573177432802, 'eval_recall': 0.8459869848156182, 'eval_precision': 0.7707509881422925, 'eval_f1': 0.8066184074457082, 'eval_tn': 217, 'eval_fp': 122, 'eval_fn': 69, 'eval_tp': 392, 'epoch': 0.15, 'step': 120}




{'loss': 0.6069303894042969, 'learning_rate': 1.25e-05, 'epoch': 0.15625, 'step': 125}




{'loss': 0.5098452758789063, 'learning_rate': 1.5e-05, 'epoch': 0.1875, 'step': 150}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:34<00:00,  2.92it/s]


{'eval_loss': 0.5252826407551765, 'eval_roc_auc': 0.8442976983471867, 'eval_threshold': 0.39766940474510193, 'eval_pr_auc': 0.8931826456684213, 'eval_recall': 0.9457700650759219, 'eval_precision': 0.6844583987441131, 'eval_f1': 0.7941712204007285, 'eval_tn': 190, 'eval_fp': 149, 'eval_fn': 70, 'eval_tp': 391, 'epoch': 0.1875, 'step': 150}




{'loss': 0.5280322265625, 'learning_rate': 1.75e-05, 'epoch': 0.21875, 'step': 175}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.97it/s]


{'eval_loss': 0.47772188782691954, 'eval_roc_auc': 0.9069740656134221, 'eval_threshold': 0.6926343441009521, 'eval_pr_auc': 0.9333277356570475, 'eval_recall': 0.824295010845987, 'eval_precision': 0.8816705336426914, 'eval_f1': 0.852017937219731, 'eval_tn': 178, 'eval_fp': 161, 'eval_fn': 23, 'eval_tp': 438, 'epoch': 0.225, 'step': 180}




{'loss': 0.471778564453125, 'learning_rate': 2e-05, 'epoch': 0.25, 'step': 200}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:35<00:00,  2.81it/s]


{'eval_loss': 0.4343381017446518, 'eval_roc_auc': 0.9101414777417312, 'eval_threshold': 0.7685887217521667, 'eval_pr_auc': 0.937415681591803, 'eval_recall': 0.7830802603036876, 'eval_precision': 0.9093198992443325, 'eval_f1': 0.8414918414918414, 'eval_tn': 200, 'eval_fp': 139, 'eval_fn': 45, 'eval_tp': 416, 'epoch': 0.2625, 'step': 210}




{'loss': 0.4869427490234375, 'learning_rate': 2.25e-05, 'epoch': 0.28125, 'step': 225}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.96it/s]


{'eval_loss': 0.41604275032877924, 'eval_roc_auc': 0.9286212478963904, 'eval_threshold': 0.5311585068702698, 'eval_pr_auc': 0.9521255763263328, 'eval_recall': 0.8546637744034707, 'eval_precision': 0.9016018306636155, 'eval_f1': 0.8775055679287306, 'eval_tn': 282, 'eval_fp': 57, 'eval_fn': 61, 'eval_tp': 400, 'epoch': 0.3, 'step': 240}




{'loss': 0.3952154541015625, 'learning_rate': 2.5e-05, 'epoch': 0.3125, 'step': 250}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:37<00:00,  2.67it/s]


{'eval_loss': 0.3758574853092432, 'eval_roc_auc': 0.9225999654464132, 'eval_threshold': 0.16759763658046722, 'eval_pr_auc': 0.9441077660882561, 'eval_recall': 0.9609544468546638, 'eval_precision': 0.7771929824561403, 'eval_f1': 0.8593598448108632, 'eval_tn': 283, 'eval_fp': 56, 'eval_fn': 87, 'eval_tp': 374, 'epoch': 0.3375, 'step': 270}




{'loss': 0.3702630615234375, 'learning_rate': 2.7500000000000004e-05, 'epoch': 0.34375, 'step': 275}




{'loss': 0.2929193115234375, 'learning_rate': 3e-05, 'epoch': 0.375, 'step': 300}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:35<00:00,  2.82it/s]


{'eval_loss': 0.3088532718271017, 'eval_roc_auc': 0.9523032525163329, 'eval_threshold': 0.587332010269165, 'eval_pr_auc': 0.9648151382433632, 'eval_recall': 0.8806941431670282, 'eval_precision': 0.9206349206349206, 'eval_f1': 0.9002217294900222, 'eval_tn': 279, 'eval_fp': 60, 'eval_fn': 42, 'eval_tp': 419, 'epoch': 0.375, 'step': 300}




{'loss': 0.3318585205078125, 'learning_rate': 3.2500000000000004e-05, 'epoch': 0.40625, 'step': 325}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:35<00:00,  2.81it/s]


{'eval_loss': 0.3064725860580802, 'eval_roc_auc': 0.9529943242534186, 'eval_threshold': 0.3137039542198181, 'eval_pr_auc': 0.9662839101974545, 'eval_recall': 0.9262472885032538, 'eval_precision': 0.8661257606490872, 'eval_f1': 0.8951781970649896, 'eval_tn': 293, 'eval_fp': 46, 'eval_fn': 55, 'eval_tp': 406, 'epoch': 0.4125, 'step': 330}




{'loss': 0.2799078369140625, 'learning_rate': 3.5e-05, 'epoch': 0.4375, 'step': 350}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:35<00:00,  2.82it/s]


{'eval_loss': 0.32553794866427777, 'eval_roc_auc': 0.9610696254775114, 'eval_threshold': 0.8730078935623169, 'eval_pr_auc': 0.9729411116888735, 'eval_recall': 0.8872017353579176, 'eval_precision': 0.9337899543378996, 'eval_f1': 0.9098998887652948, 'eval_tn': 232, 'eval_fp': 107, 'eval_fn': 19, 'eval_tp': 442, 'epoch': 0.45, 'step': 360}




{'loss': 0.33830078125, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.46875, 'step': 375}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:35<00:00,  2.80it/s]


{'eval_loss': 0.2349852030724287, 'eval_roc_auc': 0.9713269217233281, 'eval_threshold': 0.5543552041053772, 'eval_pr_auc': 0.9804327249408961, 'eval_recall': 0.9327548806941431, 'eval_precision': 0.9110169491525424, 'eval_f1': 0.9217577706323689, 'eval_tn': 287, 'eval_fp': 52, 'eval_fn': 29, 'eval_tp': 432, 'epoch': 0.4875, 'step': 390}




{'loss': 0.1958428955078125, 'learning_rate': 4e-05, 'epoch': 0.5, 'step': 400}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:35<00:00,  2.80it/s]


{'eval_loss': 0.38239360524341465, 'eval_roc_auc': 0.9551635216503817, 'eval_threshold': 0.8944015502929688, 'eval_pr_auc': 0.9696489973157499, 'eval_recall': 0.8676789587852495, 'eval_precision': 0.9195402298850575, 'eval_f1': 0.8928571428571429, 'eval_tn': 248, 'eval_fp': 91, 'eval_fn': 37, 'eval_tp': 424, 'epoch': 0.525, 'step': 420}




{'loss': 0.2552130126953125, 'learning_rate': 4.25e-05, 'epoch': 0.53125, 'step': 425}




{'loss': 0.275621337890625, 'learning_rate': 4.5e-05, 'epoch': 0.5625, 'step': 450}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:35<00:00,  2.80it/s]


{'eval_loss': 0.23003127660602332, 'eval_roc_auc': 0.9772842160495012, 'eval_threshold': 0.5292146801948547, 'eval_pr_auc': 0.9832305263752742, 'eval_recall': 0.9674620390455532, 'eval_precision': 0.9065040650406504, 'eval_f1': 0.9359916054564533, 'eval_tn': 283, 'eval_fp': 56, 'eval_fn': 13, 'eval_tp': 448, 'epoch': 0.5625, 'step': 450}




{'loss': 0.248895263671875, 'learning_rate': 4.75e-05, 'epoch': 0.59375, 'step': 475}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:34<00:00,  2.92it/s]


{'eval_loss': 0.25397968431934714, 'eval_roc_auc': 0.9623045962669328, 'eval_threshold': 0.40620478987693787, 'eval_pr_auc': 0.973547833204573, 'eval_recall': 0.9457700650759219, 'eval_precision': 0.8667992047713717, 'eval_f1': 0.904564315352697, 'eval_tn': 281, 'eval_fp': 58, 'eval_fn': 36, 'eval_tp': 425, 'epoch': 0.6, 'step': 480}




{'loss': 0.2073883056640625, 'learning_rate': 5e-05, 'epoch': 0.625, 'step': 500}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.96it/s]


{'eval_loss': 0.22931221047416328, 'eval_roc_auc': 0.971006981104307, 'eval_threshold': 0.5818248987197876, 'eval_pr_auc': 0.9801373481123496, 'eval_recall': 0.9262472885032538, 'eval_precision': 0.9262472885032538, 'eval_f1': 0.9262472885032538, 'eval_tn': 289, 'eval_fp': 50, 'eval_fn': 25, 'eval_tp': 436, 'epoch': 0.6375, 'step': 510}




{'loss': 0.2472711181640625, 'learning_rate': 4.5833333333333334e-05, 'epoch': 0.65625, 'step': 525}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.95it/s]


{'eval_loss': 0.22931823952123523, 'eval_roc_auc': 0.9740336193602467, 'eval_threshold': 0.6190052628517151, 'eval_pr_auc': 0.9804670373878233, 'eval_recall': 0.9566160520607375, 'eval_precision': 0.901840490797546, 'eval_f1': 0.9284210526315789, 'eval_tn': 274, 'eval_fp': 65, 'eval_fn': 13, 'eval_tp': 448, 'epoch': 0.675, 'step': 540}




{'loss': 0.1874310302734375, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.6875, 'step': 550}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.97it/s]


{'eval_loss': 0.16811460881493986, 'eval_roc_auc': 0.9835870462442171, 'eval_threshold': 0.3736043870449066, 'eval_pr_auc': 0.9874276328890481, 'eval_recall': 0.9761388286334056, 'eval_precision': 0.9297520661157025, 'eval_f1': 0.9523809523809524, 'eval_tn': 313, 'eval_fp': 26, 'eval_fn': 22, 'eval_tp': 439, 'epoch': 0.7125, 'step': 570}




{'loss': 0.2383721923828125, 'learning_rate': 3.7500000000000003e-05, 'epoch': 0.71875, 'step': 575}




{'loss': 0.2057196044921875, 'learning_rate': 3.3333333333333335e-05, 'epoch': 0.75, 'step': 600}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:37<00:00,  2.69it/s]


{'eval_loss': 0.18004311225377023, 'eval_roc_auc': 0.9828703792576098, 'eval_threshold': 0.4256266951560974, 'eval_pr_auc': 0.9863984580905822, 'eval_recall': 0.9696312364425163, 'eval_precision': 0.93125, 'eval_f1': 0.9500531349628055, 'eval_tn': 314, 'eval_fp': 25, 'eval_fn': 28, 'eval_tp': 433, 'epoch': 0.75, 'step': 600}




{'loss': 0.2328021240234375, 'learning_rate': 2.916666666666667e-05, 'epoch': 0.78125, 'step': 625}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.96it/s]


{'eval_loss': 0.17590435835532844, 'eval_roc_auc': 0.9818017775900792, 'eval_threshold': 0.4041953682899475, 'eval_pr_auc': 0.9864665531727541, 'eval_recall': 0.96529284164859, 'eval_precision': 0.9213250517598344, 'eval_f1': 0.9427966101694916, 'eval_tn': 309, 'eval_fp': 30, 'eval_fn': 29, 'eval_tp': 432, 'epoch': 0.7875, 'step': 630}




{'loss': 0.1886663818359375, 'learning_rate': 2.5e-05, 'epoch': 0.8125, 'step': 650}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:34<00:00,  2.87it/s]


{'eval_loss': 0.1576316593028605, 'eval_roc_auc': 0.9921870500835045, 'eval_threshold': 0.716361939907074, 'eval_pr_auc': 0.9942300717480501, 'eval_recall': 0.9848156182212582, 'eval_precision': 0.9700854700854701, 'eval_f1': 0.977395048439182, 'eval_tn': 289, 'eval_fp': 50, 'eval_fn': 3, 'eval_tp': 458, 'epoch': 0.825, 'step': 660}




{'loss': 0.19143798828125, 'learning_rate': 2.0833333333333336e-05, 'epoch': 0.84375, 'step': 675}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:34<00:00,  2.88it/s]


{'eval_loss': 0.1566883626859635, 'eval_roc_auc': 0.9868184464963302, 'eval_threshold': 0.6048450469970703, 'eval_pr_auc': 0.9903954221336102, 'eval_recall': 0.9761388286334056, 'eval_precision': 0.9433962264150944, 'eval_f1': 0.9594882729211087, 'eval_tn': 303, 'eval_fp': 36, 'eval_fn': 10, 'eval_tp': 451, 'epoch': 0.8625, 'step': 690}




{'loss': 0.209776611328125, 'learning_rate': 1.6666666666666667e-05, 'epoch': 0.875, 'step': 700}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  2.95it/s]


{'eval_loss': 0.15613988646306098, 'eval_roc_auc': 0.9859290115754515, 'eval_threshold': 0.4784497022628784, 'eval_pr_auc': 0.990077187985346, 'eval_recall': 0.9674620390455532, 'eval_precision': 0.9389473684210526, 'eval_f1': 0.952991452991453, 'eval_tn': 311, 'eval_fp': 28, 'eval_fn': 22, 'eval_tp': 439, 'epoch': 0.9, 'step': 720}




{'loss': 0.2079052734375, 'learning_rate': 1.25e-05, 'epoch': 0.90625, 'step': 725}




{'loss': 0.248255615234375, 'learning_rate': 8.333333333333334e-06, 'epoch': 0.9375, 'step': 750}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:33<00:00,  3.02it/s]


{'eval_loss': 0.13570114463567734, 'eval_roc_auc': 0.9893331797618361, 'eval_threshold': 0.568824827671051, 'eval_pr_auc': 0.9921200820510124, 'eval_recall': 0.96529284164859, 'eval_precision': 0.9611231101511879, 'eval_f1': 0.9632034632034632, 'eval_tn': 313, 'eval_fp': 26, 'eval_fn': 13, 'eval_tp': 448, 'epoch': 0.9375, 'step': 750}




{'loss': 0.1956005859375, 'learning_rate': 4.166666666666667e-06, 'epoch': 0.96875, 'step': 775}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
Evaluation: 100%|██████████| 100/100 [00:34<00:00,  2.86it/s]


{'eval_loss': 0.1369975937018171, 'eval_roc_auc': 0.9899282693132154, 'eval_threshold': 0.529875636100769, 'eval_pr_auc': 0.9926170457407513, 'eval_recall': 0.9804772234273319, 'eval_precision': 0.9475890985324947, 'eval_f1': 0.9637526652452025, 'eval_tn': 308, 'eval_fp': 31, 'eval_fn': 8, 'eval_tp': 453, 'epoch': 0.975, 'step': 780}


Iteration: 100%|██████████| 800/800 [42:07<00:00,  3.16s/it]
Epoch: 100%|██████████| 1/1 [42:07<00:00, 2527.80s/it]

{'loss': 0.186212158203125, 'learning_rate': 0.0, 'epoch': 1.0, 'step': 800}
CPU times: total: 4h 11min 1s
Wall time: 42min 7s





TrainOutput(global_step=800, training_loss=0.3399104690551758)

In [22]:
# Load the TensorBoard notebook extension
%load_ext tensorboard

In [23]:
%tensorboard --logdir ./logs/runs --port=6006