In [1]:
from transformers import  utils, Trainer, TrainingArguments, ElectraTokenizer, ElectraForSequenceClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
from datasets import Dataset, DatasetDict
import torch
from accelerate import Accelerator
import os
from performance import PerformanceSaver
from bertviz import model_view
utils.logging.set_verbosity_error()

In [2]:
LOAD_SAVED_MODEL = False
AUGMENT_WITH_NEUTRAL = True
saved_model_path = "models/binary/electra_classifier"
model_name = "howey/electra-base-mnli"
data_dir = "data/argumentation"

train_df = pd.read_csv(os.path.join(data_dir, 'train_iam.tsv'), sep='\t')
dev_df = pd.read_csv(os.path.join(data_dir, 'dev_iam.txt'), sep='\t')
test_df = pd.read_csv(os.path.join(data_dir, 'test_iam.txt'), sep='\t')
all_claims = pd.read_csv(os.path.join(data_dir, 'claims.txt'), sep='\t')
np.random.seed(42)

if AUGMENT_WITH_NEUTRAL:
    neutral_claims = all_claims[all_claims.type=='O'] 
    lower_bound = 0
    
    min_train_label = min(train_df['label'].value_counts())
    train_sample = neutral_claims.iloc[:min_train_label]
    train_df = pd.concat([train_df, train_sample]).sample(frac=1)
    lower_bound = min_train_label
    
    min_dev_label = min(dev_df['label'].value_counts())
    dev_sample = neutral_claims.iloc[lower_bound: lower_bound + min_dev_label]    
    dev_df = pd.concat([dev_df, dev_sample]).sample(frac=1)
    lower_bound = lower_bound + min_dev_label
    
    min_test_label = min(dev_df['label'].value_counts())
    test_sample = neutral_claims.iloc[lower_bound: lower_bound + min_test_label]    
    test_df = pd.concat([test_df, test_sample]).sample(frac=1)
    
    
label_encoder = LabelEncoder()
label_encoder.fit(train_df['label'])
train_df['label'] = label_encoder.transform(train_df['label'])
dev_df['label'] = label_encoder.transform(dev_df['label'])
test_df['label'] = label_encoder.transform(test_df['label'])

dataset = DatasetDict({
    'train': Dataset.from_pandas(train_df),
    'validation': Dataset.from_pandas(dev_df),
    'test': Dataset.from_pandas(test_df)
})

In [3]:
label_encoder.classes_

array([-1,  0,  1])

In [4]:
num_labels=len(label_encoder.classes_)
ignore_mismatched_sizes=True
classifier_dropout=0.1
output_attentions=False

if LOAD_SAVED_MODEL:
    tokenizer = ElectraTokenizer.from_pretrained(saved_model_path)
    model = ElectraForSequenceClassification.from_pretrained(saved_model_path, num_labels=num_labels, ignore_mismatched_sizes=ignore_mismatched_sizes, classifier_dropout=classifier_dropout, output_attentions=output_attentions)
else:
    tokenizer = ElectraTokenizer.from_pretrained(model_name)
    model = ElectraForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, ignore_mismatched_sizes=ignore_mismatched_sizes, classifier_dropout=classifier_dropout, output_attentions=output_attentions)
    
if torch.cuda.is_available():
        model = model.to(device='cuda')

In [5]:
def process(batch):
    inputs = tokenizer(batch["argument"], truncation=True, padding="max_length")
    return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "labels": batch["label"],
        }
    
tokenized_dataset = dataset.map(process, batched=True, remove_columns=['type', 'id'])

Map:   0%|          | 0/5644 [00:00<?, ? examples/s]

Map:   0%|          | 0/725 [00:00<?, ? examples/s]

Map:   0%|          | 0/760 [00:00<?, ? examples/s]

In [6]:
training_args = TrainingArguments(
            output_dir='./results',
            do_eval=True,
            do_train=True,
            num_train_epochs=6,
            save_total_limit=2,
            load_best_model_at_end=True,
            learning_rate=1e-03,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            save_strategy="steps",
            logging_strategy="steps",
            evaluation_strategy="steps",
            logging_steps=60,
            eval_steps=60,
            save_steps=60,
        )

def compute_metrics(pred):
            print(pred)
            print("======================================")
            labels = pred.label_ids
            preds = pred.predictions.argmax(-1)
            precision, recall, f1, _ = precision_recall_fscore_support(
                labels, preds, average="weighted"
            )
            acc = accuracy_score(labels, preds)
            return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [7]:
trainer = Trainer(
            model=model,
            args=training_args,
            compute_metrics=compute_metrics,
            train_dataset=tokenized_dataset["train"],
            eval_dataset=tokenized_dataset["validation"],
            tokenizer=tokenizer,
        )

In [8]:
from tqdm.notebook import tqdm
if not LOAD_SAVED_MODEL:
    trainer.train()
    trainer.save_model("models/binary/electra_classifier")
    predictions = trainer.predict(tokenized_dataset["test"])
    print(predictions.metrics)



{'loss': 1.1825, 'learning_rate': 0.0009716713881019831, 'epoch': 0.17}
<transformers.trainer_utils.EvalPrediction object at 0x7f53a34a1160>
{'eval_loss': 1.1063247919082642, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 14.0959, 'eval_samples_per_second': 51.434, 'eval_steps_per_second': 3.263, 'epoch': 0.17}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.1028, 'learning_rate': 0.000943342776203966, 'epoch': 0.34}
<transformers.trainer_utils.EvalPrediction object at 0x7f533d6b97f0>
{'eval_loss': 1.1005462408065796, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 14.0942, 'eval_samples_per_second': 51.44, 'eval_steps_per_second': 3.264, 'epoch': 0.34}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.1006, 'learning_rate': 0.0009150141643059491, 'epoch': 0.51}
<transformers.trainer_utils.EvalPrediction object at 0x7f533d6b2400>
{'eval_loss': 1.100216269493103, 'eval_accuracy': 0.3213793103448276, 'eval_f1': 0.15632855805917503, 'eval_precision': 0.103284661117717, 'eval_recall': 0.3213793103448276, 'eval_runtime': 13.9685, 'eval_samples_per_second': 51.903, 'eval_steps_per_second': 3.293, 'epoch': 0.51}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.1075, 'learning_rate': 0.000886685552407932, 'epoch': 0.68}
<transformers.trainer_utils.EvalPrediction object at 0x7f55684c81f0>
{'eval_loss': 1.0997376441955566, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 13.9819, 'eval_samples_per_second': 51.853, 'eval_steps_per_second': 3.29, 'epoch': 0.68}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.1022, 'learning_rate': 0.0008583569405099151, 'epoch': 0.85}
<transformers.trainer_utils.EvalPrediction object at 0x7f556a75cee0>
{'eval_loss': 1.0975046157836914, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 13.9665, 'eval_samples_per_second': 51.91, 'eval_steps_per_second': 3.294, 'epoch': 0.85}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.0962, 'learning_rate': 0.000830028328611898, 'epoch': 1.02}
<transformers.trainer_utils.EvalPrediction object at 0x7f55685e17c0>
{'eval_loss': 1.1050671339035034, 'eval_accuracy': 0.3213793103448276, 'eval_f1': 0.15632855805917503, 'eval_precision': 0.103284661117717, 'eval_recall': 0.3213793103448276, 'eval_runtime': 13.9082, 'eval_samples_per_second': 52.128, 'eval_steps_per_second': 3.307, 'epoch': 1.02}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.1023, 'learning_rate': 0.0008016997167138811, 'epoch': 1.19}
<transformers.trainer_utils.EvalPrediction object at 0x7f55686e1eb0>
{'eval_loss': 1.101406455039978, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 14.0869, 'eval_samples_per_second': 51.466, 'eval_steps_per_second': 3.265, 'epoch': 1.19}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.1022, 'learning_rate': 0.000773371104815864, 'epoch': 1.36}
<transformers.trainer_utils.EvalPrediction object at 0x7f535c6db340>
{'eval_loss': 1.100594401359558, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 13.952, 'eval_samples_per_second': 51.964, 'eval_steps_per_second': 3.297, 'epoch': 1.36}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.1042, 'learning_rate': 0.0007450424929178471, 'epoch': 1.53}
<transformers.trainer_utils.EvalPrediction object at 0x7f55686d00a0>
{'eval_loss': 1.0979384183883667, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 13.9772, 'eval_samples_per_second': 51.87, 'eval_steps_per_second': 3.291, 'epoch': 1.53}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.1002, 'learning_rate': 0.00071671388101983, 'epoch': 1.7}
<transformers.trainer_utils.EvalPrediction object at 0x7f55686d00a0>
{'eval_loss': 1.0999361276626587, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 14.0919, 'eval_samples_per_second': 51.448, 'eval_steps_per_second': 3.264, 'epoch': 1.7}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.0845, 'learning_rate': 0.0006883852691218131, 'epoch': 1.87}
<transformers.trainer_utils.EvalPrediction object at 0x7f535c5cbe80>
{'eval_loss': 1.1115469932556152, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 14.1217, 'eval_samples_per_second': 51.339, 'eval_steps_per_second': 3.257, 'epoch': 1.87}


  _warn_prf(average, modifier, msg_start, len(result))


{'loss': 1.1032, 'learning_rate': 0.000660056657223796, 'epoch': 2.04}
<transformers.trainer_utils.EvalPrediction object at 0x7f535c6dbfa0>
{'eval_loss': 1.0975288152694702, 'eval_accuracy': 0.35724137931034483, 'eval_f1': 0.18805999439304738, 'eval_precision': 0.12762140309155767, 'eval_recall': 0.35724137931034483, 'eval_runtime': 14.0906, 'eval_samples_per_second': 51.453, 'eval_steps_per_second': 3.265, 'epoch': 2.04}


  _warn_prf(average, modifier, msg_start, len(result))
