In [None]:
from transformers import  Trainer, TrainingArguments, ElectraTokenizer, ElectraForSequenceClassification
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.preprocessing import LabelEncoder
import pandas as pd
import numpy as np
from datasets import Dataset, DatasetDict
import torch
from accelerate import Accelerator
import os
from performance import PerformanceSaver

In [None]:
LOAD_SAVED_MODEL = False
saved_model_path = "models/electra_classifier"
model_name = "howey/electra-base-mnli"
data_dir = "data/binary"

In [None]:
if LOAD_SAVED_MODEL:
    tokenizer = ElectraTokenizer.from_pretrained(saved_model_path)
    model = ElectraForSequenceClassification.from_pretrained(saved_model_path, num_labels=2, ignore_mismatched_sizes=True, classifier_dropout=0.1)
else:
    tokenizer = ElectraTokenizer.from_pretrained(model_name)
    model = ElectraForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True, classifier_dropout=0.1)

if torch.cuda.is_available():
    model = model.cuda()


In [None]:
training_args = TrainingArguments(
            output_dir='./results',
            do_eval=True,
            do_train=True,
            num_train_epochs=6,
            save_total_limit=2,
            load_best_model_at_end=True,
            learning_rate=8.5e-05,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            save_strategy="steps",
            logging_strategy="steps",
            evaluation_strategy="steps",
            logging_steps=500,
            eval_steps=500,
            save_steps=500,
        )

def compute_metrics(pred):
            labels = pred.label_ids
            preds = pred.predictions.argmax(-1)
            precision, recall, f1, _ = precision_recall_fscore_support(
                labels, preds, average="weighted"
            )
            acc = accuracy_score(labels, preds)
            return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

In [None]:
train_df = pd.read_csv(os.path.join(data_dir, 'train.csv'))
dev_df = pd.read_csv(os.path.join(data_dir, 'dev.csv'))
test_df = pd.read_csv(os.path.join(data_dir, 'test.csv'))

label_encoder = LabelEncoder()
label_encoder.fit(train_df['label'])
print(label_encoder.classes_)
train_df['label'] = label_encoder.transform(train_df['label'])
dev_df['label'] = label_encoder.transform(dev_df['label'])
test_df['label'] = label_encoder.transform(test_df['label'])


dataset = DatasetDict({
    'train': Dataset.from_pandas(train_df),
    'validation': Dataset.from_pandas(dev_df),
    'test': Dataset.from_pandas(test_df)
})

def process(batch):
    inputs = tokenizer(batch["text"], truncation=True, padding="max_length")
    return {
            "input_ids": inputs["input_ids"],
            "attention_mask": inputs["attention_mask"],
            "labels": batch["label"],
        }
    
tokenized_dataset = dataset.map(process, batched=True, remove_columns=dataset["train"].column_names)

In [None]:
trainer = Trainer(
            model=model,
            args=training_args,
            compute_metrics=compute_metrics,
            train_dataset=tokenized_dataset["train"],
            eval_dataset=tokenized_dataset["validation"],
            tokenizer=tokenizer,
        )

In [None]:
if not LOAD_SAVED_MODEL:
    trainer.train()
    trainer.save_model("models/binary/electra_classifier")
    predictions = trainer.predict(tokenized_dataset["test"])
    print(predictions.metrics)

In [None]:
your_sentence = "Drinking vegetable juice, bitter gourd, can cure COVID-19"

def inference(text, classes):
    input_ids = tokenizer.encode(text, add_special_tokens=True, return_tensors="pt").to(model.device)
    model.eval()
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits
        print(logits)
    predicted_class = torch.argmax(logits, dim=1).item()
    return classes[predicted_class]

inference(your_sentence, label_encoder.classes_)

In [None]:
del model
del trainer
torch.cuda.empty_cache()