In [1]:
from transformers import  utils, Trainer, TrainingArguments, ElectraTokenizer, ElectraForSequenceClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
from datasets import Dataset, DatasetDict
import torch
from accelerate import Accelerator
import os
from performance import PerformanceSaver
from bertviz import model_view
utils.logging.set_verbosity_error()
import warnings
warnings.filterwarnings("ignore")


In [2]:
LOAD_SAVED_MODEL = False
AUGMENT_WITH_NEUTRAL = True
saved_model_path = "models/binary/electra_classifier"
model_name = "howey/electra-base-mnli"
data_dir = "data/argumentation"

train_df = pd.read_csv(os.path.join(data_dir, 'train_iam.tsv'), sep='\t')
dev_df = pd.read_csv(os.path.join(data_dir, 'dev_iam.txt'), sep='\t')
test_df = pd.read_csv(os.path.join(data_dir, 'test_iam.txt'), sep='\t')
all_claims = pd.read_csv(os.path.join(data_dir, 'claims.txt'), sep='\t')
np.random.seed(42)

if AUGMENT_WITH_NEUTRAL:
    neutral_claims = all_claims[all_claims.type=='O'] 
    lower_bound = 0
    
    min_train_label = min(train_df['label'].value_counts())
    train_sample = neutral_claims.iloc[:min_train_label]
    train_df = pd.concat([train_df, train_sample]).sample(frac=1)
    lower_bound = min_train_label
    
    min_dev_label = min(dev_df['label'].value_counts())
    dev_sample = neutral_claims.iloc[lower_bound: lower_bound + min_dev_label]    
    dev_df = pd.concat([dev_df, dev_sample]).sample(frac=1)
    lower_bound = lower_bound + min_dev_label
    
    min_test_label = min(test_df['label'].value_counts())
    test_sample = neutral_claims.iloc[lower_bound: lower_bound + min_test_label]    
    test_df = pd.concat([test_df, test_sample]).sample(frac=1)
    
    
label_encoder = LabelEncoder()
label_encoder.fit(train_df['label'])
train_df['label'] = label_encoder.transform(train_df['label'])
dev_df['label'] = label_encoder.transform(dev_df['label'])
test_df['label'] = label_encoder.transform(test_df['label'])

dataset = DatasetDict({
    'train': Dataset.from_pandas(train_df),
    'validation': Dataset.from_pandas(dev_df),
    'test': Dataset.from_pandas(test_df)
})

In [3]:
print(train_df['label'])

1694    2
599     1
3297    2
121     1
307     1
       ..
3772    0
1407    1
1443    1
1609    1
860     0
Name: label, Length: 5644, dtype: int64


In [31]:
num_labels=len(label_encoder.classes_)
ignore_mismatched_sizes=True
#classifier_dropout=0.1
output_attentions=False

if LOAD_SAVED_MODEL:
    tokenizer = ElectraTokenizer.from_pretrained(saved_model_path)
    model = ElectraForSequenceClassification.from_pretrained(saved_model_path, num_labels=num_labels, ignore_mismatched_sizes=ignore_mismatched_sizes, output_attentions=output_attentions)
else:
    tokenizer = ElectraTokenizer.from_pretrained(model_name)
    model = ElectraForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, ignore_mismatched_sizes=ignore_mismatched_sizes, output_attentions=output_attentions)
    
if torch.cuda.is_available():
        model = model.to(device='cuda')

In [32]:
def process(batch):
    inputs = tokenizer(batch["argument"], truncation=True, padding="max_length")
    return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "labels": batch["label"],
        }
    
tokenized_dataset = dataset.map(process, batched=True, remove_columns=dataset['train'].column_names)

Map:   0%|          | 0/5644 [00:00<?, ? examples/s]

Map:   0%|          | 0/725 [00:00<?, ? examples/s]

Map:   0%|          | 0/783 [00:00<?, ? examples/s]

In [33]:
training_args = TrainingArguments(
            output_dir='./results',
            do_eval=True,
            do_train=True,
            num_train_epochs=4,
            save_total_limit=2,
            load_best_model_at_end=True,
            learning_rate=1e-1,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            save_strategy="steps",
            logging_strategy="steps",
            evaluation_strategy="steps",
            logging_steps=100,
            eval_steps=100,
            save_steps=100,
        )

def compute_metrics(pred):
            print("======================================")
            labels = pred.label_ids
            preds = pred.predictions.argmax(-1)
            precision, recall, f1, _ = precision_recall_fscore_support(
                labels, preds, average="weighted"
            )
            acc = accuracy_score(labels, preds)
            print(f'ACCURACY:- {acc}\nF1:- {f1}\nPRECISION:- {precision}\nRECALL:- {recall}')
            return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [34]:
trainer = Trainer(
            model=model,
            args=training_args,
            compute_metrics=compute_metrics,
            train_dataset=tokenized_dataset["train"],
            eval_dataset=tokenized_dataset["validation"],
            tokenizer=tokenizer,
        )

In [35]:
from tqdm.notebook import tqdm
if not LOAD_SAVED_MODEL:
    trainer.train()
    trainer.save_model("models/binary/electra_classifier")
    predictions = trainer.predict(tokenized_dataset["test"])
    print(predictions.metrics)

{'loss': 55.3373, 'learning_rate': 0.09291784702549576, 'epoch': 0.28}
ACCURACY:- 0.35724137931034483
F1:- 0.18805999439304738
PRECISION:- 0.12762140309155767
RECALL:- 0.35724137931034483
{'eval_loss': 1.1031620502471924, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 13.8649, 'eval_samples_per_second': 52.29, 'eval_steps_per_second': 3.318, 'epoch': 0.28}
{'loss': 1.0987, 'learning_rate': 0.0858356940509915, 'epoch': 0.57}
ACCURACY:- 0.3213793103448276
F1:- 0.15632855805917503
PRECISION:- 0.103284661117717
RECALL:- 0.3213793103448276
{'eval_loss': 1.1021512746810913, 'eval_accuracy': 0.3213793103448276, 'eval_f1': 0.15632855805917503, 'eval_precision': 0.103284661117717, 'eval_recall': 0.3213793103448276, 'eval_runtime': 13.9167, 'eval_samples_per_second': 52.096, 'eval_steps_per_second': 3.305, 'epoch': 0.57}
{'loss': 1.1061, 'learning_rate': 0.07875354107648726, 'epoch':