<a href="https://colab.research.google.com/github/nsomabalint/explainable-text-classification/blob/master/BERT_sentiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
!pip install wandb transformers datasets shap

Collecting shap
  Downloading shap-0.40.0-cp37-cp37m-manylinux2010_x86_64.whl (564 kB)
[K     |████████████████████████████████| 564 kB 9.7 MB/s 
Collecting slicer==0.0.7
  Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Installing collected packages: slicer, shap
Successfully installed shap-0.40.0 slicer-0.0.7


In [2]:
%env WANDB_LOG_MODEL=true

env: WANDB_LOG_MODEL=true


In [20]:
import pandas as pd
import numpy as np
import wandb
import torch
import shap
from transformers import AutoTokenizer
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.preprocessing import LabelEncoder
from random import shuffle


def load_sentiment_data():
    dataset_url = "https://drive.google.com/uc?export=download&id=1WTEViNH8i9a3ethzP5mN07evErguwgcq"
    dataset_df = pd.read_csv(dataset_url)
    dataset_df = dataset_df.sample(frac=1.0)
    
    labels = [*["train"] * 70, *["valid"] * 10, *["test"] * 20] * 1000
    shuffle(labels)

    dataset_df["ds_name"] = labels[:len(dataset_df)]
    return dataset_df[['tweet_id', 'airline_sentiment', 'airline_sentiment_confidence', 'airline', 'text', 'ds_name']]


def compute_metrics(p):
    pred, labels = p
    pred = np.argmax(pred, axis=1)

    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average="weighted")
    precision = precision_score(y_true=labels, y_pred=pred, average="weighted")
    f1 = f1_score(y_true=labels, y_pred=pred, average="weighted")

    return {"val_accuracy": accuracy, "val_precision": precision, "val_recall": recall, "val_f1": f1}


def tokenize_function(examples, tokenizer):
    return tokenizer(examples["message"], padding="max_length", truncation=True)

In [4]:
wandb.init(project="mlexp-project", entity='nsoma')

MODEL_NAME = 'bert-base-cased'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/426k [00:00<?, ?B/s]

In [5]:
dataset = load_sentiment_data()

le = LabelEncoder().fit(dataset.airline_sentiment.tolist())

dataset['label'] = le.transform(dataset.airline_sentiment.tolist())
dataset = dataset.rename(columns={'text': 'message'})

dataset.head()

Unnamed: 0,tweet_id,airline_sentiment,airline_sentiment_confidence,airline,message,ds_name,label
5885,568518406959312897,neutral,1.0,Southwest,@SouthwestAir when can I start Flight Booking ...,valid,1
7039,569948353003446274,neutral,0.6404,Delta,@JetBlue i hate the internet lol,train,1
6301,568072958821990400,positive,1.0,Southwest,@SouthwestAir pleasantly surprised to be board...,train,2
4137,567809206314668033,positive,1.0,United,@united You might be dealing with frustrated p...,train,2
12009,570270651996463104,neutral,1.0,American,@AmericanAir Flight 35. I'm on my way.,test,1


In [6]:
dataset.groupby("ds_name").tweet_id.count()

ds_name
test      2909
train    10236
valid     1495
Name: tweet_id, dtype: int64

In [7]:
wandb.log({"labels": list(le.classes_)})

In [8]:
cols = ['message', 'label']

train_df = dataset[dataset.ds_name == 'train'][cols]
val_df = dataset[dataset.ds_name == 'valid'][cols]
test_df = dataset[dataset.ds_name == 'test'][cols]

train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

In [9]:
train_dataset = train_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
val_dataset = val_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
test_dataset = test_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)

  0%|          | 0/11 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

In [10]:
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=77)

if device == "cuda":
    model.to(device)

Downloading:   0%|          | 0.00/416M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [11]:
epoch = 8
lr = 1e-5
patience = 2


training_args = TrainingArguments(f"{MODEL_NAME}_{epoch}_{patience}_{lr}", 
                                  evaluation_strategy="epoch",
                                  save_strategy="epoch",
                                  report_to="wandb",
                                  num_train_epochs=epoch,
                                  load_best_model_at_end=True,
                                  learning_rate=lr,
                                  per_device_train_batch_size=16,
                                  per_device_eval_batch_size=16)

trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset=train_dataset, 
                  eval_dataset=val_dataset,
                  compute_metrics=compute_metrics,
                  callbacks=[EarlyStoppingCallback(early_stopping_patience=patience)],)


trainer.train()

The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: __index_level_0__, message. If __index_level_0__, message are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 10236
  Num Epochs = 8
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 5120
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Val Accuracy,Val Precision,Val Recall,Val F1
1,0.9814,0.457972,0.828094,0.835064,0.828094,0.830794
2,0.4343,0.445773,0.837458,0.837521,0.837458,0.833806
3,0.3378,0.462304,0.847492,0.844485,0.847492,0.843126
4,0.2028,0.555384,0.844147,0.845599,0.844147,0.843986


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: __index_level_0__, message. If __index_level_0__, message are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1495
  Batch size = 16
Saving model checkpoint to bert-base-cased_8_2_1e-05/checkpoint-640
Configuration saved in bert-base-cased_8_2_1e-05/checkpoint-640/config.json
Model weights saved in bert-base-cased_8_2_1e-05/checkpoint-640/pytorch_model.bin
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: __index_level_0__, message. If __index_level_0__, message are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1495
  Batch size = 16
Saving model checkpoint to b

TrainOutput(global_step=2560, training_loss=0.43860666304826734, metrics={'train_runtime': 3856.6531, 'train_samples_per_second': 21.233, 'train_steps_per_second': 1.328, 'total_flos': 1.078007341031424e+16, 'train_loss': 0.43860666304826734, 'epoch': 4.0})

In [12]:
raw_pred, _, _ = trainer.predict(test_dataset)
y_pred = np.argmax(raw_pred, axis=1)


f1 = f1_score(test_df["label"].tolist(), y_pred, average="weighted")
acc = accuracy_score(test_df["label"].tolist(), y_pred)

wandb.log({"test_accuracy": acc, "test_f1_score": f1})

print("Accuracy:", acc)
print("F1:", f1)

The following columns in the test set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: __index_level_0__, message. If __index_level_0__, message are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 2909
  Batch size = 16


Accuracy: 0.8466827088346511
F1: 0.8423798809305686


In [13]:
wandb.finish()

VBox(children=(Label(value='413.472 MB of 413.472 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
eval/loss,▂▁▂█
eval/runtime,▁▂█▄
eval/samples_per_second,█▇▁▅
eval/steps_per_second,█▆▁▅
eval/val_accuracy,▁▄█▇
eval/val_f1,▁▃██
eval/val_precision,▁▃▇█
eval/val_recall,▁▄█▇
test_accuracy,▁
test_f1_score,▁

0,1
eval/loss,0.55538
eval/runtime,49.9915
eval/samples_per_second,29.905
eval/steps_per_second,1.88
eval/val_accuracy,0.84415
eval/val_f1,0.84399
eval/val_precision,0.8456
eval/val_recall,0.84415
test_accuracy,0.84668
test_f1_score,0.84238


In [38]:
pipe = TextClassificationPipeline(model=model, tokenizer=tokenize_function, return_all_scores=True)

In [39]:
def score_and_visualize(text):
  prediction = pipe([text])
  print(prediction[0])

  explainer = shap.Explainer(pipe)
  shap_values = explainer([text])

  shap.plots.text(shap_values)

In [40]:
test_sentences = [
    "Great... my flight was delayed.",
    "Great! Everything went fine.",
    "I am flying to Boston today with JetBlue.",
    "My experience with JetBlue was a bit disappointing.",
    "My experience with JetBlue was very disappointing.",
    "JetBlue never disappoints me. No cancelled flights or lost luggage."
]

In [41]:
score_and_visualize(test_sentences[0])

TypeError: ignored