# Fine-tune BERT Model for classification Task via Transformers library and including Logging to W&B

#### This Notebook contains all necessary Code to fine-tune a BERT model and some extra Code. Adapt to your needs.

## Imports, Logins

In [None]:
import pandas as pd
import numpy as np
import os
import wandb 
import torch
import import_ipynb
import yaml
from sadice import SelfAdjDiceLoss
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig   
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
from transformers import set_seed, enable_full_determinism
from datasets import Dataset, DatasetDict, disable_caching
from datasets import disable_caching

In [None]:
disable_caching()

In [None]:
import helpers

## Import internal config 

In [None]:
conf = yaml.safe_load(open('config.yaml'))

## Reproducability

In [None]:
REPRO_SEED = conf['seeds']['repro_seed']
helpers.enable_reproducability(REPRO_SEED)

## Setup W&B

In [None]:
os.environ["WANDB_API_KEY"] = "my key"
wandb.login()

In [None]:
# WANDB PARAMS
WANDB = True
WANDB_PROJECT = "project name "
WANDB_ENTITY = "project account name"

## Setup Torch Device


In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [None]:
print("GPU is available: ", torch.cuda.is_available())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## If needed: Load Config of Model to Rerun (Here W&B)

In [None]:
import wandb
api = wandb.Api()

In [None]:
run = api.run("path_to_run")
config_run = run.config

In [None]:
configs = {} # several runs can be loaded and compared 
configs['run_name'] = config_spring

## Prepare Data

In [None]:
TEMP_MODEL = 'run_name' # only nevessary when using training configurations stored in configs dictionary
MODEL = configs[TEMP_MODEL]['model'] # alt: provide (path to) model checkpint
SAMPLING_SEED = configs[TEMP_MODEL]['sampling_seed']

In [None]:
train = pd.read_pickle('path_to_data/train.pkl')
val = pd.read_pickle('path_to_data/val.pkl')

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL)

In [None]:
def tokenize(batch):
    return tokenizer(batch["text"], padding=True, truncation=True)

def prepare_data(train, val, remove_footer, remove_emojis, downsampling, sampling_seed):
    if downsampling:
        train = helpers.downsample(train, sampling_seed)
    
    train = helpers.select_text(train, remove_footer, remove_emojis)
    val = helpers.select_text(val, remove_footer, remove_emojis)
    
    train_ds = Dataset.from_pandas(train[["text", "label"]])
    val_ds = Dataset.from_pandas(val[["text", "label"]])
    
    ds = DatasetDict({"train": train_ds, "validation": val_ds})
    ds_encoded = ds.map(tokenize)
    
    ds_encoded['train'] = ds_encoded['train'].remove_columns(["text", "__index_level_0__"]) #, "token_type_ids"])
    ds_encoded['validation'] = ds_encoded['validation'].remove_columns(["text", "__index_level_0__"])#, "token_type_ids"])
    
    return ds_encoded

In [None]:
ds_encoded = prepare_data(train, val, remove_footer=configs[TEMP_MODEL]['remove_footer'], remove_emojis=configs[TEMP_MODEL]['remove_emojis'], downsampling=configs[TEMP_MODEL]['downsampling'], sampling_seed=SAMPLING_SEED)

## Training

In [None]:
from datetime import date
date = date.today()

### Training params

#### Change according to your needs

In [None]:
NUM_EPOCHS =  configs[TEMP_MODEL]['epochs']
EVAL_BATCH_SIZE = configs[TEMP_MODEL]['per_device_eval_batch_size']
TRAIN_BATCH_SIZE = configs[TEMP_MODEL]['batch_size']
LEARNING_RATE = configs[TEMP_MODEL]['learning_rate'] 
WEIGHT_DECAY = configs[TEMP_MODEL]['weight_decay']
HIDDEN_DROPOUT_PROB = configs[TEMP_MODEL]['hidden_dropout_prob'] 
ATTENTION_PROBS_DROPOUT_PROB = configs[TEMP_MODEL]['attention_probs_dropout_prob']


OUTPUT_DIR = 'output_path'
OVERWRITE_OUTPUT_DIR = True 
LOG_LEVEL = configs[TEMP_MODEL]['log_level']
EVALUATION_STRATEGY = configs[TEMP_MODEL]['evaluation_strategy'] 
SAVE_STRATEGY =configs[TEMP_MODEL]['save_strategy']
LOGGING_STRATEGY=configs[TEMP_MODEL]['logging_strategy']
LOAD_BEST_MODEL_AT_END = True,
METRIC_FOR_BEST_MODEL=configs[TEMP_MODEL]['metric_for_best_model']
REMOVE_UNUSED_COLUMNS=configs[TEMP_MODEL]['remove_unused_columns']
DISABLE_TQDM=configs[TEMP_MODEL]['disable_tqdm']

### Training

In [None]:
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR, 
    overwrite_output_dir=OVERWRITE_OUTPUT_DIR, 
    log_level=LOG_LEVEL, 
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE, 
    evaluation_strategy=EVALUATION_STRATEGY,
    save_strategy=SAVE_STRATEGY,
    weight_decay=WEIGHT_DECAY,
    learning_rate=LEARNING_RATE,
    logging_strategy=LOGGING_STRATEGY,
    disable_tqdm=DISABLE_TQDM,
    load_best_model_at_end=LOAD_BEST_MODEL_AT_END,
    metric_for_best_model=METRIC_FOR_BEST_MODEL,
    remove_unused_columns=REMOVE_UNUSED_COLUMNS,
    report_to="wandb",
    #save_total_limit=1
    )

In [None]:
# Add extra Information to be uploded to W&B
if WANDB:
    config = dict (  
      remove_emojis = configs[TEMP_MODEL]['remove_emojis'],
      remove_footer = configs[TEMP_MODEL]['remove_footer'],
      sampling_seed = SAMPLING_SEED,
      repro_seed = REPRO_SEED,
      downsampling = configs[TEMP_MODEL]['downsampling'],
  )
    wandb.init(project=WANDB_PROJECT, entity=WANDB_ENTITY, config=config)

else:
    wandb.init(mode="disabled") 

In [None]:
# For Dice Loss:

#class CustomTrainer(Trainer):
#    def compute_loss(self,model, inputs, return_outputs=False):
#        
#        criterion = SelfAdjDiceLoss()
#        labels = inputs.get("labels")
#        # forward pass
#        outputs = model(**inputs)
#        logits = outputs.get("logits")
#        loss = criterion(logits, labels)
#        return (loss, outputs) if return_outputs else loss

In [None]:
def model_init():
    model_config = AutoConfig.from_pretrained(MODEL)
    model_config.hidden_dropout_prob = HIDDEN_DROPOUT_PROB 
    model_config.attention_probs_dropout_prob = ATTENTION_PROBS_DROPOUT_PROB
    model_config.num_labels = 2
    model = (AutoModelForSequenceClassification
         .from_pretrained(MODEL, config=model_config)
         .to(device))
    return model

In [None]:
trainer = Trainer(model_init=model_init,  args=training_args,
                  compute_metrics=helpers.compute_metrics,
                  train_dataset=ds_encoded["train"],
                  eval_dataset=ds_encoded["validation"],
                  tokenizer=tokenizer)

In [None]:
trainer.train();

In [None]:
trainer.save_model() 