<a href="https://colab.research.google.com/github/m-newhauser/rep-or-dem-tweets/blob/main/finetune_full_architecture_tftrainer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Resources

* [Fine-tuning DistilBERT with TF (freezing last hidden layer from DistilBERT models)](https://towardsdatascience.com/hugging-face-transformers-fine-tuning-distilbert-for-binary-classification-tasks-490f1d192379)
* [Fine-tuning DistilBERT with only TF](https://medium.com/geekculture/hugging-face-distilbert-tensorflow-for-custom-text-classification-1ad4a49e26a7)

In [30]:
!pip install transformers==4.6.0
!pip install tweet-preprocessor



In [31]:
import random
import pandas as pd
import numpy as np
import csv
import tensorflow as tf
import preprocessor as p

from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from transformers import (
    TFDistilBertForSequenceClassification,
    TFTrainer,
    TFTrainingArguments,
)

from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

random.seed(123)

## Pre-process data

In [32]:
# Load data from fivethirtyeight
tweets = pd.read_csv("senators_training.csv")
tweets_validation = pd.read_csv("senators_validation.csv")

# Remove numbers, emojis and &'s
p.set_options(p.OPT.NUMBER, p.OPT.EMOJI)
tweets["text_clean"] = tweets["text"].apply(p.clean)
tweets["text_clean"] = tweets["text_clean"].str.replace("&amp;", "and ")
tweets_validation["text_clean"] = tweets_validation["text"].apply(p.clean)
tweets_validation["text_clean"] = tweets_validation["text_clean"].str.replace("&amp;", "and ")

# Truncate tweets to max 512
tweets["text_clean"] = tweets["text_clean"].str[:512]
tweets_validation["text_clean"] = tweets_validation["text_clean"].str[:512]


In [33]:
# Create a column with numeric labels
label_mapping = {
    "D": 0,
    "R": 1
}

tweets['label'] = np.where(tweets['party'] == "D", 0, 1)
tweets_validation['label'] = np.where(tweets_validation['party'] == "D", 0, 1)

In [34]:
# Convert to list
texts = list(tweets.text)
labels = list(tweets.label)
val_texts = list(tweets_validation.text)
val_labels = list(tweets_validation.label)

# Split training dataset into test and train
(train_texts, test_texts, train_labels, test_labels) = train_test_split(
    texts, labels, test_size=0.3
)

### Tokenize data for DistilBERT

In [35]:
# Load DistilBERT tokenizer and tokenize (encode) the texts
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)


### Create encodings

In [36]:
# Wrap encodings in a Tensor Flow dataset
train_dataset = tf.data.Dataset.from_tensor_slices((
    dict(train_encodings),
    train_labels
))
val_dataset = tf.data.Dataset.from_tensor_slices((
    dict(val_encodings),
    val_labels
))
test_dataset = tf.data.Dataset.from_tensor_slices((
    dict(test_encodings),
    test_labels
))

## Fine-tune entire DistilBERT architecture (layers)

In [40]:
# Create a dict of metrics to calculate during training
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }


# Provide args for fine-tuning DistilBERT on our data
training_args = TFTrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=4,              # total # of training epochs
    per_device_train_batch_size=32,  # batch size per device during training
    per_device_eval_batch_size=32,   # batch size for evaluation
    learning_rate=2e-05,             # start with a low learning rate when fine-tuning
    warmup_steps=250,                # number of warmup steps for learning rate scheduler ([500, 1000] are normal but start low)
    weight_decay=0.01,               # strength of weight decay
    evaluation_strategy="epoch",
    logging_dir='./logs',            # directory for storing logs
    logging_steps=1
)

# Instantiate the pre-trained model
with training_args.strategy.scope():
    model = TFDistilBertForSequenceClassification.from_pretrained(
        "distilbert-base-uncased", 
        num_labels=2
    )

# Create the trainer
trainer = TFTrainer(
    model=model,  # the instantiated ðŸ¤— Transformers model to be trained
    args=training_args,  # training arguments, defined above
    train_dataset=train_dataset,  # training dataset
    eval_dataset=val_dataset,  # evaluation dataset,
    compute_metrics=compute_metrics # custom function with metrics to compute
)

Some layers from the model checkpoint at distilbert-base-uncased were not used when initializing TFDistilBertForSequenceClassification: ['vocab_layer_norm', 'activation_13', 'vocab_transform', 'vocab_projector']
- This IS expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFDistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFDistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['dropout_179', 'pre_classifier', 'classifier']
You should probably TRAIN this model on a down-stream task to be able to use 

In [41]:
# Train the model
trainer.train()

















In [42]:
# Evaluate the model
trainer.evaluate()









{'eval_accuracy': 0.53125,
 'eval_f1': 0.6445497630331752,
 'eval_loss': 0.6903654098510742,
 'eval_precision': 0.5190839694656488,
 'eval_recall': 0.85}