In [1]:
import functools
import sys
import pandas as pd
from sklearn.model_selection import train_test_split
import datasets
from datasets import Dataset
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchtext
import tqdm
import transformers
from transformers_interpret import SequenceClassificationExplainer

In [2]:
cols = ['text', 'label']
df = pd.read_excel('../data/fake_news_dataset.xlsx')
df['label'] = df['label'].astype('int64')
#df = df.sample(frac = 0.2)

raw_datasets = Dataset.from_pandas(df[cols])
raw_datasets = raw_datasets.train_test_split(train_size = 0.8)


In [3]:
#transformer_name = 'nlpaueb/bert-base-greek-uncased-v1'
transformer_name = 'bert-base-uncased'

tokenizer = transformers.AutoTokenizer.from_pretrained(transformer_name)

In [4]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

  0%|          | 0/3 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [5]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['attention_mask', 'input_ids', 'label', 'text', 'token_type_ids'],
        num_rows: 2495
    })
    test: Dataset({
        features: ['attention_mask', 'input_ids', 'label', 'text', 'token_type_ids'],
        num_rows: 624
    })
})

In [6]:
small_train_dataset = tokenized_datasets["train"]
small_eval_dataset = tokenized_datasets["test"]
full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]

In [7]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(transformer_name, num_labels=2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [8]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir = "model",
                                  overwrite_output_dir=True,
                                  per_device_train_batch_size = 4,
                                  per_device_eval_batch_size = 4,
                                  num_train_epochs= 5)

In [9]:
import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [10]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running training *****
  Num examples = 2495
  Num Epochs = 5
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 4
  Gradient Accumulation steps = 1
  Total optimization steps = 3120


Step,Training Loss
500,0.6551
1000,0.6487
1500,0.6441
2000,0.5928
2500,0.5615
3000,0.53


Saving model checkpoint to model\checkpoint-500
Configuration saved in model\checkpoint-500\config.json
Model weights saved in model\checkpoint-500\pytorch_model.bin
Saving model checkpoint to model\checkpoint-1000
Configuration saved in model\checkpoint-1000\config.json
Model weights saved in model\checkpoint-1000\pytorch_model.bin
Saving model checkpoint to model\checkpoint-1500
Configuration saved in model\checkpoint-1500\config.json
Model weights saved in model\checkpoint-1500\pytorch_model.bin
Saving model checkpoint to model\checkpoint-2000
Configuration saved in model\checkpoint-2000\config.json
Model weights saved in model\checkpoint-2000\pytorch_model.bin
Saving model checkpoint to model\checkpoint-2500
Configuration saved in model\checkpoint-2500\config.json
Model weights saved in model\checkpoint-2500\pytorch_model.bin
Saving model checkpoint to model\checkpoint-3000
Configuration saved in model\checkpoint-3000\config.json
Model weights saved in model\checkpoint-3000\pytorch

TrainOutput(global_step=3120, training_loss=0.5983565367185153, metrics={'train_runtime': 1163.9512, 'train_samples_per_second': 10.718, 'train_steps_per_second': 2.681, 'total_flos': 3282310415616000.0, 'train_loss': 0.5983565367185153, 'epoch': 5.0})

In [11]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text.
***** Running Evaluation *****
  Num examples = 624
  Batch size = 4


{'eval_loss': 0.547904372215271,
 'eval_accuracy': 0.7788461538461539,
 'eval_runtime': 18.124,
 'eval_samples_per_second': 34.429,
 'eval_steps_per_second': 8.607,
 'epoch': 5.0}

In [12]:
txt = """
COVID-19 IS A HOAX
"""

cls_explainer = SequenceClassificationExplainer(
    model,
    tokenizer)
word_attributions = cls_explainer(txt)

cls_explainer.visualize()

True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,LABEL_0 (0.86),LABEL_0,-1.02,[CLS] co ##vid - 19 is a hoax [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,LABEL_0 (0.86),LABEL_0,-1.02,[CLS] co ##vid - 19 is a hoax [SEP]
,,,,
