In [1]:
import sys
sys.path.append('../structural-probes/')
import torch
import numpy as np
from pytorch_lightning import Trainer
from pytorch_lightning.loggers.comet import CometLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from utils.setup_runs import parse_args, get_default_args, get_comet_key
from finetune_bert_module import SST_Test

import os
import random

INFO:transformers.file_utils:PyTorch version 1.5.0 available.


In [2]:
desired_hparams = {
    'lr': 1e-4,
    'batch_size': 32,
}
desired_params = {
    'sst_train_path': os.path.join("data", "SST-2", "sentence_splits", "train_cat.tsv"),
    'sst_val_path': os.path.join("data", "SST-2", "sentence_splits", "dev_cat.tsv"),
    'run_name': 'proper_classification'
}
hparams, args = get_default_args(desired_hparams, desired_params)

In [3]:
# Set device to cuda if available unless explicitly disabled
if args.device == torch.device('cuda'):
    num_gpus = 1
else:
    num_gpus = 0

# Output debug logs in debug mode
if args.debug:
    args.log_level = "debug"

In [4]:
# Set all seeds manually for consistent results
torch.manual_seed(hparams.seed)
np.random.seed(hparams.seed)
random.seed(hparams.seed)

In [5]:
###############
# CometLogger configuration
###############

comet_key = get_comet_key()
comet_logger = CometLogger(
    api_key = comet_key,
    workspace = "mykobob",
    project_name = "structural-probes-extension",
    experiment_name = args.run_name
)
# Log args to comet
comet_logger.log_hyperparams(hparams)
comet_logger.experiment.set_name(args.run_name)
comet_logger.experiment.add_tag("sst-tests")

INFO:lightning:CometLogger will be initialized in online mode
COMET INFO: Experiment is live on comet.ml https://www.comet.ml/mykobob/structural-probes-extension/4e189296c87f48bda80ec112b0a67999



In [6]:
###############
# Other callback configuration
###############

checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join("lightning_logs", args.run_name, "checkpoints"),
    save_top_k=args.num_saved_models,
    verbose=True,
    monitor="val_loss",
    mode="min",
)

In [7]:
if args.early_stopping:
    early_stopping = EarlyStopping(
        monitor='val_loss',
        min_delta=0.00,
        patience=args.early_stopping,
        verbose=False,
        mode='min'
        #mode='max'
    )
else:
    early_stopping = None

In [8]:
%cd ..

/home1/06129/mbli/structural-probes


In [9]:
###############
# Model creation
###############

testing = False
if testing:
    model = SST_Test.load_from_checkpoint(
        'finetuning/lightning_logs/default/_ckpt_epoch_2.ckpt', None, None, params).to(args.device)
else:
    model = SST_Test(hparams, args).to(args.device)

comet_logger.experiment.add_tag("SST")

INFO:transformers.tokenization_utils:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home1/06129/mbli/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home1/06129/mbli/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391
INFO:transformers.configuration_utils:Model config BertConfig {
  "_num_labels": 2,
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": nul

In [10]:
###############
# Create Trainer with specified attributes
###############

trainer = Trainer(
    fast_dev_run=args.debug,
    max_nb_epochs=hparams.epochs,
    gpus=num_gpus,
    train_percent_check=hparams.train_pct,
    val_percent_check=hparams.val_pct,
    checkpoint_callback=checkpoint_callback,
    early_stop_callback=early_stopping,
    logger=comet_logger,
)

INFO:lightning:GPU available: True, used: True
INFO:lightning:CUDA_VISIBLE_DEVICES: [0]


In [11]:
if testing:
    trainer.test(model)
else:
    trainer.fit(model)

INFO:lightning:Set SLURM handle signals.
INFO:lightning:
    | Name                                                  | Type                          | Params
----------------------------------------------------------------------------------------------------
0   | loss                                                  | CrossEntropyLoss              | 0     
1   | bert                                                  | BertForSequenceClassification | 108 M 
2   | bert.bert                                             | BertModel                     | 108 M 
3   | bert.bert.embeddings                                  | BertEmbeddings                | 22 M  
4   | bert.bert.embeddings.word_embeddings                  | Embedding                     | 22 M  
5   | bert.bert.embeddings.position_embeddings              | Embedding                     | 393 K 
6   | bert.bert.embeddings.token_type_embeddings            | Embedding                     | 1 K   
7   | bert.bert.embeddings.LayerNo

Validation dataset has 1101 examples
INFO:finetune_bert_module:Validation dataset has 1101 examples


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Training dataset has 8544 examples, and batch_size of 32
INFO:finetune_bert_module:Training dataset has 8544 examples, and batch_size of 32
Validation dataset has 1101 examples
INFO:finetune_bert_module:Validation dataset has 1101 examples


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00000: val_loss reached 0.40108 (best 0.40108), saving model to lightning_logs/proper_classification/_ckpt_epoch_0.ckpt as top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00001: val_loss  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00002: val_loss  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00003: val_loss  was not in top 1


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

INFO:lightning:
Epoch 00004: val_loss  was not in top 1
INFO:lightning:Detected KeyboardInterrupt, attempting graceful shutdown...
COMET INFO: ---------------------------
COMET INFO: Comet.ml Experiment Summary
COMET INFO: ---------------------------
COMET INFO:   Data:
COMET INFO:     display_summary_level : 1
COMET INFO:     url                   : https://www.comet.ml/mykobob/structural-probes-extension/4e189296c87f48bda80ec112b0a67999
COMET INFO:   Metrics [count] (min, max):
COMET INFO:     epoch [157]      : (0, 5)
COMET INFO:     mean_pred [5]    : (-0.049582626670598984, 0.06202519312500954)
COMET INFO:     std_pred [5]     : (0.975523829460144, 2.5750234127044678)
COMET INFO:     train_loss [152] : (0.0037087835371494293, 0.769791841506958)
COMET INFO:     val_loss [5]     : (0.40108224749565125, 0.72518390417099)
COMET INFO:   Others [count]:
COMET INFO:     Name [2] : proper_classification
COMET INFO:   Parameters:
COMET INFO:     batch_size  : 32
COMET INFO:     epochs     




COMET INFO: Uploading stats to Comet before program termination (may take several seconds)
