In [1]:
import torch
import sklearn
from sklearn.model_selection import train_test_split
import datasets
from transformers import AutoTokenizer, GPT2Tokenizer,  GPT2ForSequenceClassification, Trainer, TrainingArguments
import random
import numpy as np
from datasets import load_dataset


In [2]:
dataset = load_dataset('csv', data_files='..\data\dataset\processed\clean_data_gpt2.csv')

In [3]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['prompt', 'essay', 'label'],
        num_rows: 9766
    })
})


In [4]:
from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

def tokenize_function(examples):
    combined_text = examples["prompt"] + '\n' + examples["essay"]
    return tokenizer(combined_text, padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=False)

In [5]:
encoded_dataset = tokenized_datasets.rename_column("label", "labels")
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [6]:
small_train_dataset = encoded_dataset["train"].shuffle(seed=42).select(range(5000))
small_eval_dataset = encoded_dataset["train"].shuffle(seed=42).select(range(5000, 7000))
small_test_dataset = encoded_dataset["train"].shuffle(seed=42).select(range(7000, 9766))


In [7]:
print(small_train_dataset)

Dataset({
    features: ['prompt', 'essay', 'labels', 'input_ids', 'attention_mask'],
    num_rows: 5000
})


In [8]:
from transformers import DistilBertForSequenceClassification

# Specify the number of labels in your dataset
num_labels = 12  # Change this based on your dataset
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=num_labels)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=50,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
)

In [14]:
trainer.train()


  0%|          | 0/15650 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.7799417972564697, 'eval_runtime': 10.2395, 'eval_samples_per_second': 195.322, 'eval_steps_per_second': 12.208, 'epoch': 1.0}
{'loss': 1.2894, 'grad_norm': 16.68935775756836, 'learning_rate': 1.9361022364217256e-05, 'epoch': 1.6}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.8020873069763184, 'eval_runtime': 10.2872, 'eval_samples_per_second': 194.416, 'eval_steps_per_second': 12.151, 'epoch': 2.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.9590682983398438, 'eval_runtime': 10.2948, 'eval_samples_per_second': 194.273, 'eval_steps_per_second': 12.142, 'epoch': 3.0}
{'loss': 1.179, 'grad_norm': 21.6717529296875, 'learning_rate': 1.8722044728434506e-05, 'epoch': 3.19}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.1370468139648438, 'eval_runtime': 10.2595, 'eval_samples_per_second': 194.941, 'eval_steps_per_second': 12.184, 'epoch': 4.0}
{'loss': 1.0122, 'grad_norm': 33.7429084777832, 'learning_rate': 1.808306709265176e-05, 'epoch': 4.79}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.2728171348571777, 'eval_runtime': 10.262, 'eval_samples_per_second': 194.894, 'eval_steps_per_second': 12.181, 'epoch': 5.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.3290743827819824, 'eval_runtime': 10.2628, 'eval_samples_per_second': 194.878, 'eval_steps_per_second': 12.18, 'epoch': 6.0}
{'loss': 0.8686, 'grad_norm': 15.000410079956055, 'learning_rate': 1.744408945686901e-05, 'epoch': 6.39}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.4888834953308105, 'eval_runtime': 10.2987, 'eval_samples_per_second': 194.2, 'eval_steps_per_second': 12.138, 'epoch': 7.0}
{'loss': 0.7323, 'grad_norm': 34.94807815551758, 'learning_rate': 1.6805111821086264e-05, 'epoch': 7.99}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.532242774963379, 'eval_runtime': 10.2968, 'eval_samples_per_second': 194.234, 'eval_steps_per_second': 12.14, 'epoch': 8.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.8347904682159424, 'eval_runtime': 10.3073, 'eval_samples_per_second': 194.037, 'eval_steps_per_second': 12.127, 'epoch': 9.0}
{'loss': 0.5827, 'grad_norm': 17.94465446472168, 'learning_rate': 1.6166134185303515e-05, 'epoch': 9.58}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.846808910369873, 'eval_runtime': 10.3019, 'eval_samples_per_second': 194.139, 'eval_steps_per_second': 12.134, 'epoch': 10.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.021914958953857, 'eval_runtime': 10.3043, 'eval_samples_per_second': 194.093, 'eval_steps_per_second': 12.131, 'epoch': 11.0}
{'loss': 0.4537, 'grad_norm': 32.69583511352539, 'learning_rate': 1.552715654952077e-05, 'epoch': 11.18}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.05613374710083, 'eval_runtime': 10.2659, 'eval_samples_per_second': 194.819, 'eval_steps_per_second': 12.176, 'epoch': 12.0}
{'loss': 0.3226, 'grad_norm': 13.558354377746582, 'learning_rate': 1.488817891373802e-05, 'epoch': 12.78}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.399730682373047, 'eval_runtime': 10.2537, 'eval_samples_per_second': 195.051, 'eval_steps_per_second': 12.191, 'epoch': 13.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.557979583740234, 'eval_runtime': 10.1523, 'eval_samples_per_second': 196.999, 'eval_steps_per_second': 12.312, 'epoch': 14.0}
{'loss': 0.2058, 'grad_norm': 26.073915481567383, 'learning_rate': 1.4249201277955273e-05, 'epoch': 14.38}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.723602771759033, 'eval_runtime': 10.1678, 'eval_samples_per_second': 196.699, 'eval_steps_per_second': 12.294, 'epoch': 15.0}
{'loss': 0.1545, 'grad_norm': 21.63184928894043, 'learning_rate': 1.3610223642172523e-05, 'epoch': 15.97}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 5.026658058166504, 'eval_runtime': 10.1588, 'eval_samples_per_second': 196.873, 'eval_steps_per_second': 12.305, 'epoch': 16.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 5.296675682067871, 'eval_runtime': 10.1493, 'eval_samples_per_second': 197.058, 'eval_steps_per_second': 12.316, 'epoch': 17.0}
{'loss': 0.1053, 'grad_norm': 7.418973445892334, 'learning_rate': 1.2971246006389777e-05, 'epoch': 17.57}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 5.4476318359375, 'eval_runtime': 10.1518, 'eval_samples_per_second': 197.009, 'eval_steps_per_second': 12.313, 'epoch': 18.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 5.592721462249756, 'eval_runtime': 10.1578, 'eval_samples_per_second': 196.893, 'eval_steps_per_second': 12.306, 'epoch': 19.0}
{'loss': 0.065, 'grad_norm': 2.3337433338165283, 'learning_rate': 1.233226837060703e-05, 'epoch': 19.17}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 6.187510967254639, 'eval_runtime': 10.2557, 'eval_samples_per_second': 195.014, 'eval_steps_per_second': 12.188, 'epoch': 20.0}
{'loss': 0.0535, 'grad_norm': 2.170987367630005, 'learning_rate': 1.1693290734824283e-05, 'epoch': 20.77}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 6.045006275177002, 'eval_runtime': 10.2532, 'eval_samples_per_second': 195.06, 'eval_steps_per_second': 12.191, 'epoch': 21.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 6.473750114440918, 'eval_runtime': 10.2636, 'eval_samples_per_second': 194.863, 'eval_steps_per_second': 12.179, 'epoch': 22.0}
{'loss': 0.046, 'grad_norm': 0.4542818069458008, 'learning_rate': 1.1054313099041534e-05, 'epoch': 22.36}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 6.504444122314453, 'eval_runtime': 10.2597, 'eval_samples_per_second': 194.937, 'eval_steps_per_second': 12.184, 'epoch': 23.0}
{'loss': 0.0351, 'grad_norm': 0.895939290523529, 'learning_rate': 1.0415335463258786e-05, 'epoch': 23.96}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 6.779870986938477, 'eval_runtime': 10.2944, 'eval_samples_per_second': 194.28, 'eval_steps_per_second': 12.142, 'epoch': 24.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 6.99481201171875, 'eval_runtime': 10.2972, 'eval_samples_per_second': 194.227, 'eval_steps_per_second': 12.139, 'epoch': 25.0}
{'loss': 0.0276, 'grad_norm': 65.81299591064453, 'learning_rate': 9.77635782747604e-06, 'epoch': 25.56}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.104356288909912, 'eval_runtime': 10.3042, 'eval_samples_per_second': 194.095, 'eval_steps_per_second': 12.131, 'epoch': 26.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.220318794250488, 'eval_runtime': 10.299, 'eval_samples_per_second': 194.194, 'eval_steps_per_second': 12.137, 'epoch': 27.0}
{'loss': 0.0197, 'grad_norm': 16.32297706604004, 'learning_rate': 9.137380191693292e-06, 'epoch': 27.16}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.134730815887451, 'eval_runtime': 10.2976, 'eval_samples_per_second': 194.221, 'eval_steps_per_second': 12.139, 'epoch': 28.0}
{'loss': 0.0199, 'grad_norm': 1.0309655666351318, 'learning_rate': 8.498402555910544e-06, 'epoch': 28.75}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.462957859039307, 'eval_runtime': 10.2494, 'eval_samples_per_second': 195.133, 'eval_steps_per_second': 12.196, 'epoch': 29.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.669974327087402, 'eval_runtime': 10.2486, 'eval_samples_per_second': 195.149, 'eval_steps_per_second': 12.197, 'epoch': 30.0}
{'loss': 0.0175, 'grad_norm': 13.890649795532227, 'learning_rate': 7.859424920127796e-06, 'epoch': 30.35}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.439873218536377, 'eval_runtime': 10.254, 'eval_samples_per_second': 195.045, 'eval_steps_per_second': 12.19, 'epoch': 31.0}
{'loss': 0.0186, 'grad_norm': 0.05296344682574272, 'learning_rate': 7.220447284345049e-06, 'epoch': 31.95}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.618760585784912, 'eval_runtime': 10.2877, 'eval_samples_per_second': 194.406, 'eval_steps_per_second': 12.15, 'epoch': 32.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.45259428024292, 'eval_runtime': 10.1935, 'eval_samples_per_second': 196.204, 'eval_steps_per_second': 12.263, 'epoch': 33.0}
{'loss': 0.0128, 'grad_norm': 0.05186443775892258, 'learning_rate': 6.581469648562301e-06, 'epoch': 33.55}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.033432006835938, 'eval_runtime': 10.193, 'eval_samples_per_second': 196.213, 'eval_steps_per_second': 12.263, 'epoch': 34.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.752389430999756, 'eval_runtime': 10.2568, 'eval_samples_per_second': 194.992, 'eval_steps_per_second': 12.187, 'epoch': 35.0}
{'loss': 0.0143, 'grad_norm': 0.05952106788754463, 'learning_rate': 5.942492012779553e-06, 'epoch': 35.14}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.745307922363281, 'eval_runtime': 10.2778, 'eval_samples_per_second': 194.594, 'eval_steps_per_second': 12.162, 'epoch': 36.0}
{'loss': 0.0107, 'grad_norm': 0.03162076696753502, 'learning_rate': 5.303514376996806e-06, 'epoch': 36.74}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.995347499847412, 'eval_runtime': 10.3821, 'eval_samples_per_second': 192.639, 'eval_steps_per_second': 12.04, 'epoch': 37.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.888678550720215, 'eval_runtime': 10.2757, 'eval_samples_per_second': 194.633, 'eval_steps_per_second': 12.165, 'epoch': 38.0}
{'loss': 0.0102, 'grad_norm': 0.02625870332121849, 'learning_rate': 4.664536741214058e-06, 'epoch': 38.34}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.054715156555176, 'eval_runtime': 10.2559, 'eval_samples_per_second': 195.01, 'eval_steps_per_second': 12.188, 'epoch': 39.0}
{'loss': 0.009, 'grad_norm': 1.9514206647872925, 'learning_rate': 4.02555910543131e-06, 'epoch': 39.94}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.026617050170898, 'eval_runtime': 10.2659, 'eval_samples_per_second': 194.821, 'eval_steps_per_second': 12.176, 'epoch': 40.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.044593811035156, 'eval_runtime': 10.2632, 'eval_samples_per_second': 194.87, 'eval_steps_per_second': 12.179, 'epoch': 41.0}
{'loss': 0.0076, 'grad_norm': 0.0023224896285682917, 'learning_rate': 3.386581469648563e-06, 'epoch': 41.53}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.058107376098633, 'eval_runtime': 10.2637, 'eval_samples_per_second': 194.862, 'eval_steps_per_second': 12.179, 'epoch': 42.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.9671454429626465, 'eval_runtime': 10.2868, 'eval_samples_per_second': 194.425, 'eval_steps_per_second': 12.152, 'epoch': 43.0}
{'loss': 0.0071, 'grad_norm': 0.016398491337895393, 'learning_rate': 2.747603833865815e-06, 'epoch': 43.13}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.948387622833252, 'eval_runtime': 10.1604, 'eval_samples_per_second': 196.843, 'eval_steps_per_second': 12.303, 'epoch': 44.0}
{'loss': 0.0064, 'grad_norm': 9.600139617919922, 'learning_rate': 2.1086261980830672e-06, 'epoch': 44.73}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.11688232421875, 'eval_runtime': 10.2002, 'eval_samples_per_second': 196.074, 'eval_steps_per_second': 12.255, 'epoch': 45.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 7.97551155090332, 'eval_runtime': 10.1929, 'eval_samples_per_second': 196.216, 'eval_steps_per_second': 12.263, 'epoch': 46.0}
{'loss': 0.0036, 'grad_norm': 0.004000563640147448, 'learning_rate': 1.4696485623003196e-06, 'epoch': 46.33}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.023505210876465, 'eval_runtime': 10.1969, 'eval_samples_per_second': 196.139, 'eval_steps_per_second': 12.259, 'epoch': 47.0}
{'loss': 0.0059, 'grad_norm': 0.004911383613944054, 'learning_rate': 8.306709265175719e-07, 'epoch': 47.92}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.048820495605469, 'eval_runtime': 10.1917, 'eval_samples_per_second': 196.238, 'eval_steps_per_second': 12.265, 'epoch': 48.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.053009986877441, 'eval_runtime': 10.1931, 'eval_samples_per_second': 196.211, 'eval_steps_per_second': 12.263, 'epoch': 49.0}
{'loss': 0.0036, 'grad_norm': 0.04660388082265854, 'learning_rate': 1.9169329073482428e-07, 'epoch': 49.52}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 8.074740409851074, 'eval_runtime': 10.1442, 'eval_samples_per_second': 197.157, 'eval_steps_per_second': 12.322, 'epoch': 50.0}
{'train_runtime': 4432.7932, 'train_samples_per_second': 56.398, 'train_steps_per_second': 3.531, 'train_loss': 0.2332743097037172, 'epoch': 50.0}


TrainOutput(global_step=15650, training_loss=0.2332743097037172, metrics={'train_runtime': 4432.7932, 'train_samples_per_second': 56.398, 'train_steps_per_second': 3.531, 'total_flos': 3.3122755584e+16, 'train_loss': 0.2332743097037172, 'epoch': 50.0})

In [15]:
from sklearn.metrics import accuracy_score

# Run predictions on the test dataset
predictions = trainer.predict(small_test_dataset)
preds = predictions.predictions.argmax(-1)

# Calculate accuracy
labels = small_test_dataset["labels"]
accuracy = accuracy_score(labels, preds)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

  0%|          | 0/173 [00:00<?, ?it/s]

Test Accuracy: 22.74%
