<a href="https://colab.research.google.com/github/nsomabalint/explainable-text-classification/blob/master/BERT_sentiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install wandb transformers datasets shap

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.12.16-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 5.1 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.19.2-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 57.2 MB/s 
[?25hCollecting datasets
  Downloading datasets-2.2.2-py3-none-any.whl (346 kB)
[K     |████████████████████████████████| 346 kB 72.6 MB/s 
[?25hCollecting shap
  Downloading shap-0.40.0-cp37-cp37m-manylinux2010_x86_64.whl (564 kB)
[K     |████████████████████████████████| 564 kB 66.2 MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting setproctitle
  Downloading setproctitle-1.2.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (

In [13]:
%env WANDB_LOG_MODEL=true

env: WANDB_LOG_MODEL=true


In [34]:
import pandas as pd
import numpy as np
import wandb
import torch
import shap
from transformers import AutoTokenizer
from datasets import Dataset
from transformers import AutoModelForSequenceClassification, TextClassificationPipeline
from transformers import TrainingArguments, Trainer, EarlyStoppingCallback
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score, classification_report
from sklearn.preprocessing import LabelEncoder
from random import shuffle


def load_sentiment_data():
    dataset_url = "https://drive.google.com/uc?export=download&id=1WTEViNH8i9a3ethzP5mN07evErguwgcq"
    dataset_df = pd.read_csv(dataset_url)
    dataset_df = dataset_df.sample(frac=1.0)
    
    labels = [*["train"] * 70, *["valid"] * 10, *["test"] * 20] * 1000
    shuffle(labels)

    dataset_df["ds_name"] = labels[:len(dataset_df)]
    return dataset_df[['tweet_id', 'airline_sentiment', 'airline_sentiment_confidence', 'airline', 'text', 'ds_name']]


def compute_metrics(p):
    pred, labels = p
    pred = np.argmax(pred, axis=1)

    accuracy = accuracy_score(y_true=labels, y_pred=pred)
    recall = recall_score(y_true=labels, y_pred=pred, average="weighted")
    precision = precision_score(y_true=labels, y_pred=pred, average="weighted")
    f1 = f1_score(y_true=labels, y_pred=pred, average="weighted")

    return {"val_accuracy": accuracy, "val_precision": precision, "val_recall": recall, "val_f1": f1}


def tokenize_function(examples, tokenizer):
    return tokenizer(examples["message"], padding="max_length", truncation=True)

In [15]:
wandb.init(project="mlexp-project", entity='nsoma')

MODEL_NAME = 'bert-base-cased'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
device = "cuda" if torch.cuda.is_available() else "cpu"

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

In [16]:
dataset = load_sentiment_data()

le = LabelEncoder().fit(dataset.airline_sentiment.tolist())

dataset['label'] = le.transform(dataset.airline_sentiment.tolist())
dataset = dataset.rename(columns={'text': 'message'})

dataset.head()

Unnamed: 0,tweet_id,airline_sentiment,airline_sentiment_confidence,airline,message,ds_name,label
7325,569669129130647552,negative,1.0,Delta,@JetBlue 2nd time in a row a flight out of jfk...,train,0
6634,567750589040181248,negative,1.0,Southwest,@SouthwestAir Almost 2 hours on hold now. I ju...,train,0
7776,569253874445492225,neutral,0.6667,Delta,@JetBlue what's the status of flight 1272 dive...,valid,1
12332,570226757380530176,negative,0.6824,American,@AmericanAir ...2/2 doesn't help me.,train,0
6123,568209603332214784,positive,0.7065,Southwest,@SouthwestAir filing it now. Thank you for you...,train,2


In [17]:
dataset.groupby("ds_name").tweet_id.count()

ds_name
test      2924
train    10273
valid     1443
Name: tweet_id, dtype: int64

In [18]:
wandb.log({"labels": list(le.classes_)})

In [19]:
cols = ['message', 'label']

train_df = dataset[dataset.ds_name == 'train'][cols]
val_df = dataset[dataset.ds_name == 'valid'][cols]
test_df = dataset[dataset.ds_name == 'test'][cols]

train_dataset = Dataset.from_pandas(train_df)
val_dataset = Dataset.from_pandas(val_df)
test_dataset = Dataset.from_pandas(test_df)

In [20]:
train_dataset = train_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
val_dataset = val_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)
test_dataset = test_dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True)

  0%|          | 0/11 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/3 [00:00<?, ?ba/s]

In [21]:
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=3)

if device == "cuda":
    model.to(device)

Downloading:   0%|          | 0.00/416M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [22]:
epoch = 2 # 8
lr = 1e-5
patience = 2


training_args = TrainingArguments(f"{MODEL_NAME}_{epoch}_{patience}_{lr}", 
                                  evaluation_strategy="epoch",
                                  save_strategy="epoch",
                                  report_to="wandb",
                                  num_train_epochs=epoch,
                                  load_best_model_at_end=True,
                                  learning_rate=lr,
                                  per_device_train_batch_size=16,
                                  per_device_eval_batch_size=16)

trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset=train_dataset, 
                  eval_dataset=val_dataset,
                  compute_metrics=compute_metrics,
                  callbacks=[EarlyStoppingCallback(early_stopping_patience=patience)],)


trainer.train()

The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: message, __index_level_0__. If message, __index_level_0__ are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 10273
  Num Epochs = 2
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 1286
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


Epoch,Training Loss,Validation Loss,Val Accuracy,Val Precision,Val Recall,Val F1
1,0.579,0.430312,0.832294,0.829185,0.832294,0.830275
2,0.3935,0.419649,0.84061,0.838169,0.84061,0.839194


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: message, __index_level_0__. If message, __index_level_0__ are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1443
  Batch size = 16
Saving model checkpoint to bert-base-cased_2_2_1e-05/checkpoint-643
Configuration saved in bert-base-cased_2_2_1e-05/checkpoint-643/config.json
Model weights saved in bert-base-cased_2_2_1e-05/checkpoint-643/pytorch_model.bin
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: message, __index_level_0__. If message, __index_level_0__ are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1443
  Batch size = 16
Saving model checkpoint to b

TrainOutput(global_step=1286, training_loss=0.45844738071851004, metrics={'train_runtime': 2116.7072, 'train_samples_per_second': 9.707, 'train_steps_per_second': 0.608, 'total_flos': 5405928280639488.0, 'train_loss': 0.45844738071851004, 'epoch': 2.0})

In [23]:
raw_pred, _, _ = trainer.predict(test_dataset)
y_pred = np.argmax(raw_pred, axis=1)


f1 = f1_score(test_df["label"].tolist(), y_pred, average="weighted")
acc = accuracy_score(test_df["label"].tolist(), y_pred)

wandb.log({"test_accuracy": acc, "test_f1_score": f1})

print("Accuracy:", acc)
print("F1:", f1)

The following columns in the test set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: message, __index_level_0__. If message, __index_level_0__ are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Prediction *****
  Num examples = 2924
  Batch size = 16


Accuracy: 0.8409712722298222
F1: 0.8387401665231036


In [35]:
print(classification_report(test_df["label"].tolist(), y_pred))

              precision    recall  f1-score   support

           0       0.89      0.92      0.90      1834
           1       0.70      0.65      0.67       609
           2       0.81      0.79      0.80       481

    accuracy                           0.84      2924
   macro avg       0.80      0.78      0.79      2924
weighted avg       0.84      0.84      0.84      2924



In [24]:
wandb.finish()

VBox(children=(Label(value='413.252 MB of 413.252 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
eval/loss,█▁
eval/runtime,█▁
eval/samples_per_second,▁█
eval/steps_per_second,▁█
eval/val_accuracy,▁█
eval/val_f1,▁█
eval/val_precision,▁█
eval/val_recall,▁█
test_accuracy,▁
test_f1_score,▁

0,1
eval/loss,0.41965
eval/runtime,54.4497
eval/samples_per_second,26.502
eval/steps_per_second,1.671
eval/val_accuracy,0.84061
eval/val_f1,0.83919
eval/val_precision,0.83817
eval/val_recall,0.84061
test_accuracy,0.84097
test_f1_score,0.83874


# Load, infer and explain model

In [25]:
import wandb
import numpy as np
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import TextClassificationPipeline

In [None]:
MODEL_NAME = 'bert-base-cased'

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [26]:
run = wandb.init()
artifact = run.use_artifact('nsoma/mlexp-project/model-2bbhshe4:v0', type='model')
artifact_dir = artifact.download()

[34m[1mwandb[0m: Currently logged in as: [33mnsoma[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact model-2bbhshe4:v0, 413.25MB. 3 files... Done. 0:0:0


In [27]:
%ls artifacts

[0m[01;34mmodel-18lq1jeb:v0[0m/  [01;34mmodel-2bbhshe4:v0[0m/


In [28]:
model = AutoModelForSequenceClassification.from_pretrained("./artifacts/model-2bbhshe4:v0/")
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

pipe = TextClassificationPipeline(model=model,
                                  tokenizer=tokenizer)

loading configuration file ./artifacts/model-2bbhshe4:v0/config.json
Model config BertConfig {
  "_name_or_path": "./artifacts/model-2bbhshe4:v0/",
  "architectures": [
    "BertForSequenceClassification"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "problem_type": "single_label_classification",
  "torch_dtype": "float32",
  "transformers_version": "4.19.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}

loading weig

In [30]:
def score_and_visualize(text):
  prediction = pipe([text])
  print(prediction[0])

  explainer = shap.Explainer(pipe)
  shap_values = explainer([text])

  shap.plots.text(shap_values)

In [31]:
test_sentences = [
    "Great... my flight was delayed.",
    "Great! Everything went fine.",
    "I am flying to Boston today with JetBlue.",
    "My experience with JetBlue was a bit disappointing.",
    "My experience with JetBlue was very disappointing.",
    "JetBlue never disappoints me. No cancelled flights or lost luggage."
]

In [32]:
score_and_visualize(test_sentences[0])

{'label': 'LABEL_0', 'score': 0.8895069360733032}


  0%|          | 0/110 [00:00<?, ?it/s]

Partition explainer: 2it [00:17, 17.70s/it]               


In [33]:
score_and_visualize(test_sentences[-1])

{'label': 'LABEL_0', 'score': 0.8062806725502014}


  0%|          | 0/248 [00:00<?, ?it/s]

Partition explainer: 2it [00:26, 26.80s/it]               


In [36]:
1

1