In [1]:
# ! pip install accelerate -U

In [2]:
from argparse import ArgumentParser
from tqdm import tqdm
import csv
import re
import random
import transformers

import torch
from torch import nn, optim
# To import the Transformer Models
from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, logging
from sklearn.metrics import precision_recall_fscore_support as score
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd

TRAIN_SPLIT = 12000
TEST_SPLIT = 1000
EPOCHS = 20
BATCH_SIZE = 16

In [3]:
try:
  from datasets import load_dataset
except:
  !pip install datasets
  from datasets import load_dataset

train = load_dataset("CLUTRR/v1", name= "gen_train234_test2to10", split=f"train[:{TRAIN_SPLIT}]")
test = load_dataset("CLUTRR/v1", name= "gen_train234_test2to10", split=f"test[:{TEST_SPLIT}]")

In [4]:
train_dataset = pd.DataFrame(train)
test_dataset = pd.DataFrame(test)
train_dataset["input_text"] = train_dataset["clean_story"] + " " + train_dataset["query"] + " " + train_dataset["genders"]
test_dataset["input_text"] = test_dataset["clean_story"] + " " + test_dataset["query"] + " " + test_dataset["genders"]
def preprocess_text(text):
    # Convert text to lowercase
    text = text.lower()
    # Remove square brackets, parentheses, and single quotes using regex
    text = re.sub(r'[\[\]():]', ' ', text)
    # Remove single quotes
    text = text.replace("'", "")
    return text
train_dataset.input_text = train_dataset.input_text.apply( lambda text: preprocess_text(text))
test_dataset.input_text = test_dataset.input_text.apply( lambda text: preprocess_text(text))
train_dataset = train_dataset[['input_text', 'target']].rename(columns={'target': 'labels'})
test_dataset = test_dataset[['input_text', 'target']].rename(columns={'target': 'labels'})
train_dataset, val_dataset = train_test_split(train_dataset, 
                                    test_size = 0.35, random_state = 42)


In [5]:
print(train_dataset.head(5))
print(test_dataset.head(5))
print(val_dataset.head(5))

                                              input_text  labels
2013    theresa  sat anxiously in the airport termina...      15
10297   shantel  went to dinner with her husband  har...      11
4397    frances  takes her granddaughter  felicia  to...      10
7966    don s father,  joshua , always seemed to favo...       0
2147    george s father  clarence  is going to coach ...      11
                                          input_text  labels
0   clarence s granddaughter,  emily , was busy h...      10
1   emily  and her granddaughter  ashley  went to...      10
2   clarence  has 3 children, and one grandson. t...      11
3   glen  is  emily s brand new baby brother.  cl...      10
4   clarence  bought a train set for his grandson...      11
                                             input_text  labels
1935   nicholas  and his son  dennis  went to the pa...       7
6494   lynn  was hosting a supper for her family. he...       8
1720   clarence s wife,  ashley  was hungry, so  cla

In [6]:
from datasets import Dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer, logging
pretrained_model = "google/electra-base-generator"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)

def preprocess_function(datatset):
    return tokenizer(datatset["input_text"], truncation=True)


def pipeline(dataframe):
    dataset = Dataset.from_pandas(dataframe, preserve_index=False)
    tokenized_ds = dataset.map(preprocess_function, batched=True)
    tokenized_ds = tokenized_ds.remove_columns('input_text')
    return tokenized_ds

In [7]:
tokenized_train = pipeline(train_dataset)
tokenized_val = pipeline(val_dataset)

Map:   0%|          | 0/7800 [00:00<?, ? examples/s]

Map:   0%|          | 0/4200 [00:00<?, ? examples/s]

In [8]:
tokenized_val

Dataset({
    features: ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 4200
})

In [9]:
# Set values for model and train
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

model = AutoModelForSequenceClassification.from_pretrained(pretrained_model, num_labels=20)

class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        outputs = model(**inputs)
        logits = outputs.logits
        labels = inputs.get("labels")
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(logits, labels)
        return (loss, outputs) if return_outputs else loss

training_args = TrainingArguments(
    output_dir="./results",
    save_strategy = 'epoch',
    optim="adamw_torch",
    learning_rate=0.00002,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=EPOCHS,
    weight_decay=0.01,
    report_to="none",
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

%time trainer.train()


Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-generator and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


  0%|          | 0/9760 [00:00<?, ?it/s]

{'loss': 2.8644, 'grad_norm': 18.528419494628906, 'learning_rate': 1.8975409836065574e-05, 'epoch': 1.02}
{'loss': 2.3657, 'grad_norm': 37.66987228393555, 'learning_rate': 1.795081967213115e-05, 'epoch': 2.05}
{'loss': 1.9729, 'grad_norm': 49.941993713378906, 'learning_rate': 1.6926229508196722e-05, 'epoch': 3.07}
{'loss': 1.6589, 'grad_norm': 89.69330596923828, 'learning_rate': 1.5901639344262295e-05, 'epoch': 4.1}
{'loss': 1.4054, 'grad_norm': 54.92039489746094, 'learning_rate': 1.4877049180327869e-05, 'epoch': 5.12}
{'loss': 1.1484, 'grad_norm': 99.30883026123047, 'learning_rate': 1.3852459016393445e-05, 'epoch': 6.15}
{'loss': 0.9074, 'grad_norm': 223.32518005371094, 'learning_rate': 1.2827868852459017e-05, 'epoch': 7.17}
{'loss': 0.7169, 'grad_norm': 53.427276611328125, 'learning_rate': 1.1803278688524591e-05, 'epoch': 8.2}
{'loss': 0.5402, 'grad_norm': 28.886152267456055, 'learning_rate': 1.0778688524590164e-05, 'epoch': 9.22}
{'loss': 0.4033, 'grad_norm': 244.77774047851562, 'le

TrainOutput(global_step=9760, training_loss=0.7751609981060028, metrics={'train_runtime': 1850.0621, 'train_samples_per_second': 84.321, 'train_steps_per_second': 5.275, 'train_loss': 0.7751609981060028, 'epoch': 20.0})

In [10]:
tokenized_test = pipeline(test_dataset)
tokenized_test = tokenized_test.remove_columns('labels')

preds = trainer.predict(tokenized_test)
preds_flat = [np.argmax(x) for x in preds[0]]

precision, recall, fscore, support = score(test_dataset['labels'], preds_flat)

print('precision: {}'.format(precision))
print('recall: {}'.format(recall))
print('fscore: {}'.format(fscore))
print('support: {}'.format(support))

# Calculate accuracy
correct_predictions = sum(p == l for p, l in zip(preds_flat, tokenized_val['labels']))
total_predictions = len(preds_flat)
accuracy = correct_predictions / total_predictions
print(correct_predictions,total_predictions, accuracy)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

  0%|          | 0/63 [00:00<?, ?it/s]

precision: [0.47826087 0.         0.19607843 0.1719457  0.23170732 0.46296296
 0.36734694 1.         0.21428571 0.5        0.76190476 0.75641026
 0.125      0.6        0.35714286 0.35135135 0.16       0.2972973 ]
recall: [0.10784314 0.         0.32258065 0.55072464 0.76       0.52083333
 0.41860465 0.16666667 0.03296703 0.2        0.51612903 0.72839506
 0.33333333 0.52941176 0.06578947 0.66666667 0.4        0.08527132]
fscore: [0.176      0.         0.24390244 0.26206897 0.35514019 0.49019608
 0.39130435 0.28571429 0.05714286 0.28571429 0.61538462 0.74213836
 0.18181818 0.5625     0.11111111 0.46017699 0.22857143 0.13253012]
support: [102   3  31  69  75  48  43  30  91   5  62  81   3  17 152  39  20 129]
92 1000 0.092
