In [1]:
import torch
import sklearn
from sklearn.model_selection import train_test_split
import datasets
from transformers import AutoTokenizer, GPT2Tokenizer,  GPT2ForSequenceClassification, Trainer, TrainingArguments
import random
import numpy as np
from datasets import load_dataset


In [2]:
dataset = load_dataset('csv', data_files='..\data\dataset\processed\clean_data_gpt2.csv')

In [3]:
print(dataset)

DatasetDict({
    train: Dataset({
        features: ['prompt', 'essay', 'label'],
        num_rows: 9766
    })
})


In [4]:
from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

def tokenize_function(examples):
    combined_text = examples["prompt"] + '\n' + examples["essay"]
    return tokenizer(combined_text, padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=False)

In [5]:
encoded_dataset = tokenized_datasets.rename_column("label", "labels")
encoded_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

In [6]:
small_train_dataset = encoded_dataset["train"].shuffle(seed=42).select(range(5000))
small_eval_dataset = encoded_dataset["train"].shuffle(seed=42).select(range(5000, 7000))
small_test_dataset = encoded_dataset["train"].shuffle(seed=42).select(range(7000, 9766))


In [19]:
print(small_train_dataset)

Dataset({
    features: ['prompt', 'essay', 'labels', 'input_ids', 'attention_mask'],
    num_rows: 5000
})


In [7]:
num_labels = 12  # Change this based on your dataset

# from transformers import DistilBertForSequenceClassification

# Specify the number of labels in your dataset
# model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=num_labels)

from transformers import DistilBertConfig, DistilBertForSequenceClassification

config = DistilBertConfig.from_pretrained('distilbert-base-uncased', num_labels=num_labels, dropout=0.3, attention_dropout=0.3)
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', config=config)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=50,
    weight_decay=0.05,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
)

In [9]:
trainer.train()


  0%|          | 0/15650 [00:00<?, ?it/s]

  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.291130781173706, 'eval_runtime': 13.1067, 'eval_samples_per_second': 152.594, 'eval_steps_per_second': 9.537, 'epoch': 1.0}
{'loss': 2.3019, 'grad_norm': 5.018802165985107, 'learning_rate': 1.9361022364217256e-05, 'epoch': 1.6}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.1899280548095703, 'eval_runtime': 11.7659, 'eval_samples_per_second': 169.983, 'eval_steps_per_second': 10.624, 'epoch': 2.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.2061498165130615, 'eval_runtime': 12.2823, 'eval_samples_per_second': 162.836, 'eval_steps_per_second': 10.177, 'epoch': 3.0}
{'loss': 2.1178, 'grad_norm': 7.347917556762695, 'learning_rate': 1.8722044728434506e-05, 'epoch': 3.19}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.2281692028045654, 'eval_runtime': 11.862, 'eval_samples_per_second': 168.606, 'eval_steps_per_second': 10.538, 'epoch': 4.0}
{'loss': 2.0257, 'grad_norm': 11.556153297424316, 'learning_rate': 1.808306709265176e-05, 'epoch': 4.79}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.1947567462921143, 'eval_runtime': 13.0927, 'eval_samples_per_second': 152.757, 'eval_steps_per_second': 9.547, 'epoch': 5.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.3065733909606934, 'eval_runtime': 13.1155, 'eval_samples_per_second': 152.491, 'eval_steps_per_second': 9.531, 'epoch': 6.0}
{'loss': 1.9222, 'grad_norm': 11.043562889099121, 'learning_rate': 1.744408945686901e-05, 'epoch': 6.39}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.2838706970214844, 'eval_runtime': 13.0905, 'eval_samples_per_second': 152.782, 'eval_steps_per_second': 9.549, 'epoch': 7.0}
{'loss': 1.8457, 'grad_norm': 13.720620155334473, 'learning_rate': 1.6805111821086264e-05, 'epoch': 7.99}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.3988399505615234, 'eval_runtime': 13.0779, 'eval_samples_per_second': 152.929, 'eval_steps_per_second': 9.558, 'epoch': 8.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.465951681137085, 'eval_runtime': 12.5504, 'eval_samples_per_second': 159.357, 'eval_steps_per_second': 9.96, 'epoch': 9.0}
{'loss': 1.744, 'grad_norm': 13.057364463806152, 'learning_rate': 1.6166134185303515e-05, 'epoch': 9.58}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.4507246017456055, 'eval_runtime': 10.3095, 'eval_samples_per_second': 193.996, 'eval_steps_per_second': 12.125, 'epoch': 10.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.5251455307006836, 'eval_runtime': 10.3303, 'eval_samples_per_second': 193.605, 'eval_steps_per_second': 12.1, 'epoch': 11.0}
{'loss': 1.6773, 'grad_norm': 18.072338104248047, 'learning_rate': 1.552715654952077e-05, 'epoch': 11.18}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.6007332801818848, 'eval_runtime': 10.5349, 'eval_samples_per_second': 189.845, 'eval_steps_per_second': 11.865, 'epoch': 12.0}
{'loss': 1.5922, 'grad_norm': 15.261088371276855, 'learning_rate': 1.488817891373802e-05, 'epoch': 12.78}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.6807408332824707, 'eval_runtime': 10.5221, 'eval_samples_per_second': 190.076, 'eval_steps_per_second': 11.88, 'epoch': 13.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.738466262817383, 'eval_runtime': 10.422, 'eval_samples_per_second': 191.901, 'eval_steps_per_second': 11.994, 'epoch': 14.0}
{'loss': 1.5174, 'grad_norm': 19.37664031982422, 'learning_rate': 1.4249201277955273e-05, 'epoch': 14.38}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.7837862968444824, 'eval_runtime': 10.4367, 'eval_samples_per_second': 191.632, 'eval_steps_per_second': 11.977, 'epoch': 15.0}
{'loss': 1.4213, 'grad_norm': 22.1046199798584, 'learning_rate': 1.3610223642172523e-05, 'epoch': 15.97}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 2.8857924938201904, 'eval_runtime': 10.575, 'eval_samples_per_second': 189.125, 'eval_steps_per_second': 11.82, 'epoch': 16.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.018566846847534, 'eval_runtime': 10.2996, 'eval_samples_per_second': 194.183, 'eval_steps_per_second': 12.136, 'epoch': 17.0}
{'loss': 1.3619, 'grad_norm': 26.674503326416016, 'learning_rate': 1.2971246006389777e-05, 'epoch': 17.57}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.2049620151519775, 'eval_runtime': 10.3144, 'eval_samples_per_second': 193.904, 'eval_steps_per_second': 12.119, 'epoch': 18.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.136727809906006, 'eval_runtime': 10.3312, 'eval_samples_per_second': 193.588, 'eval_steps_per_second': 12.099, 'epoch': 19.0}
{'loss': 1.2798, 'grad_norm': 32.263771057128906, 'learning_rate': 1.233226837060703e-05, 'epoch': 19.17}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.311544895172119, 'eval_runtime': 10.3125, 'eval_samples_per_second': 193.939, 'eval_steps_per_second': 12.121, 'epoch': 20.0}
{'loss': 1.2034, 'grad_norm': 19.520566940307617, 'learning_rate': 1.1693290734824283e-05, 'epoch': 20.77}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.296377420425415, 'eval_runtime': 10.2908, 'eval_samples_per_second': 194.349, 'eval_steps_per_second': 12.147, 'epoch': 21.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.3784542083740234, 'eval_runtime': 10.3254, 'eval_samples_per_second': 193.698, 'eval_steps_per_second': 12.106, 'epoch': 22.0}
{'loss': 1.1385, 'grad_norm': 34.09401321411133, 'learning_rate': 1.1054313099041534e-05, 'epoch': 22.36}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.3219003677368164, 'eval_runtime': 10.308, 'eval_samples_per_second': 194.024, 'eval_steps_per_second': 12.127, 'epoch': 23.0}
{'loss': 1.0717, 'grad_norm': 29.505578994750977, 'learning_rate': 1.0415335463258786e-05, 'epoch': 23.96}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.4092133045196533, 'eval_runtime': 10.3038, 'eval_samples_per_second': 194.103, 'eval_steps_per_second': 12.131, 'epoch': 24.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.5814170837402344, 'eval_runtime': 10.5905, 'eval_samples_per_second': 188.849, 'eval_steps_per_second': 11.803, 'epoch': 25.0}
{'loss': 0.9591, 'grad_norm': 34.76598358154297, 'learning_rate': 9.77635782747604e-06, 'epoch': 25.56}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.6005280017852783, 'eval_runtime': 10.3179, 'eval_samples_per_second': 193.838, 'eval_steps_per_second': 12.115, 'epoch': 26.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.6644785404205322, 'eval_runtime': 10.3133, 'eval_samples_per_second': 193.925, 'eval_steps_per_second': 12.12, 'epoch': 27.0}
{'loss': 0.9145, 'grad_norm': 23.494220733642578, 'learning_rate': 9.137380191693292e-06, 'epoch': 27.16}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.8089439868927, 'eval_runtime': 10.3076, 'eval_samples_per_second': 194.032, 'eval_steps_per_second': 12.127, 'epoch': 28.0}
{'loss': 0.8358, 'grad_norm': 30.571897506713867, 'learning_rate': 8.498402555910544e-06, 'epoch': 28.75}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.9145989418029785, 'eval_runtime': 10.287, 'eval_samples_per_second': 194.42, 'eval_steps_per_second': 12.151, 'epoch': 29.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.071792125701904, 'eval_runtime': 10.3067, 'eval_samples_per_second': 194.049, 'eval_steps_per_second': 12.128, 'epoch': 30.0}
{'loss': 0.756, 'grad_norm': 34.89623260498047, 'learning_rate': 7.859424920127796e-06, 'epoch': 30.35}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.0561065673828125, 'eval_runtime': 10.2689, 'eval_samples_per_second': 194.764, 'eval_steps_per_second': 12.173, 'epoch': 31.0}
{'loss': 0.6993, 'grad_norm': 19.783464431762695, 'learning_rate': 7.220447284345049e-06, 'epoch': 31.95}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 3.995400905609131, 'eval_runtime': 10.2888, 'eval_samples_per_second': 194.385, 'eval_steps_per_second': 12.149, 'epoch': 32.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.101748943328857, 'eval_runtime': 10.3011, 'eval_samples_per_second': 194.153, 'eval_steps_per_second': 12.135, 'epoch': 33.0}
{'loss': 0.6474, 'grad_norm': 18.886295318603516, 'learning_rate': 6.581469648562301e-06, 'epoch': 33.55}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.236316204071045, 'eval_runtime': 10.2901, 'eval_samples_per_second': 194.362, 'eval_steps_per_second': 12.148, 'epoch': 34.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.098664283752441, 'eval_runtime': 10.3022, 'eval_samples_per_second': 194.134, 'eval_steps_per_second': 12.133, 'epoch': 35.0}
{'loss': 0.5943, 'grad_norm': 33.42594528198242, 'learning_rate': 5.942492012779553e-06, 'epoch': 35.14}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.311583042144775, 'eval_runtime': 10.2921, 'eval_samples_per_second': 194.323, 'eval_steps_per_second': 12.145, 'epoch': 36.0}
{'loss': 0.5482, 'grad_norm': 26.368898391723633, 'learning_rate': 5.303514376996806e-06, 'epoch': 36.74}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.570407390594482, 'eval_runtime': 10.2812, 'eval_samples_per_second': 194.53, 'eval_steps_per_second': 12.158, 'epoch': 37.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.464634895324707, 'eval_runtime': 10.2753, 'eval_samples_per_second': 194.642, 'eval_steps_per_second': 12.165, 'epoch': 38.0}
{'loss': 0.4947, 'grad_norm': 24.448640823364258, 'learning_rate': 4.664536741214058e-06, 'epoch': 38.34}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.423559665679932, 'eval_runtime': 10.1748, 'eval_samples_per_second': 196.565, 'eval_steps_per_second': 12.285, 'epoch': 39.0}
{'loss': 0.4586, 'grad_norm': 23.146333694458008, 'learning_rate': 4.02555910543131e-06, 'epoch': 39.94}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.6081743240356445, 'eval_runtime': 10.2987, 'eval_samples_per_second': 194.199, 'eval_steps_per_second': 12.137, 'epoch': 40.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.621207237243652, 'eval_runtime': 10.3154, 'eval_samples_per_second': 193.885, 'eval_steps_per_second': 12.118, 'epoch': 41.0}
{'loss': 0.4259, 'grad_norm': 24.170455932617188, 'learning_rate': 3.386581469648563e-06, 'epoch': 41.53}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.6753668785095215, 'eval_runtime': 10.2996, 'eval_samples_per_second': 194.182, 'eval_steps_per_second': 12.136, 'epoch': 42.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.6135640144348145, 'eval_runtime': 10.3101, 'eval_samples_per_second': 193.984, 'eval_steps_per_second': 12.124, 'epoch': 43.0}
{'loss': 0.379, 'grad_norm': 21.371850967407227, 'learning_rate': 2.747603833865815e-06, 'epoch': 43.13}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.670345306396484, 'eval_runtime': 10.2889, 'eval_samples_per_second': 194.384, 'eval_steps_per_second': 12.149, 'epoch': 44.0}
{'loss': 0.3726, 'grad_norm': 28.548498153686523, 'learning_rate': 2.1086261980830672e-06, 'epoch': 44.73}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.793938159942627, 'eval_runtime': 10.2867, 'eval_samples_per_second': 194.426, 'eval_steps_per_second': 12.152, 'epoch': 45.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.701557636260986, 'eval_runtime': 10.2187, 'eval_samples_per_second': 195.719, 'eval_steps_per_second': 12.232, 'epoch': 46.0}
{'loss': 0.3528, 'grad_norm': 24.39487648010254, 'learning_rate': 1.4696485623003196e-06, 'epoch': 46.33}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.751134395599365, 'eval_runtime': 10.1824, 'eval_samples_per_second': 196.418, 'eval_steps_per_second': 12.276, 'epoch': 47.0}
{'loss': 0.3206, 'grad_norm': 22.391193389892578, 'learning_rate': 8.306709265175719e-07, 'epoch': 47.92}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.784222602844238, 'eval_runtime': 10.3159, 'eval_samples_per_second': 193.875, 'eval_steps_per_second': 12.117, 'epoch': 48.0}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.8388495445251465, 'eval_runtime': 10.3284, 'eval_samples_per_second': 193.641, 'eval_steps_per_second': 12.103, 'epoch': 49.0}
{'loss': 0.3112, 'grad_norm': 21.314353942871094, 'learning_rate': 1.9169329073482428e-07, 'epoch': 49.52}


  0%|          | 0/125 [00:00<?, ?it/s]

{'eval_loss': 4.847391605377197, 'eval_runtime': 10.2625, 'eval_samples_per_second': 194.884, 'eval_steps_per_second': 12.18, 'epoch': 50.0}
{'train_runtime': 4651.0984, 'train_samples_per_second': 53.751, 'train_steps_per_second': 3.365, 'train_loss': 1.0666630366511238, 'epoch': 50.0}


TrainOutput(global_step=15650, training_loss=1.0666630366511238, metrics={'train_runtime': 4651.0984, 'train_samples_per_second': 53.751, 'train_steps_per_second': 3.365, 'total_flos': 3.3122755584e+16, 'train_loss': 1.0666630366511238, 'epoch': 50.0})

In [10]:
from sklearn.metrics import accuracy_score

# Run predictions on the test dataset
predictions = trainer.predict(small_test_dataset)
preds = predictions.predictions.argmax(-1)

# Calculate accuracy
labels = small_test_dataset["labels"]
accuracy = accuracy_score(labels, preds)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

  0%|          | 0/173 [00:00<?, ?it/s]

Test Accuracy: 16.12%
