In [1]:
import torch
from datasets import load_dataset
from transformers import (
    BertTokenizer, BertForSequenceClassification,
    RobertaTokenizer, RobertaForSequenceClassification,
    XLNetTokenizer, XLNetForSequenceClassification,
    TrainingArguments, Trainer
)
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
import numpy as np


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Check if GPU is available and set the device accordingly
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# Load the IMDb dataset
dataset = load_dataset("imdb")

# Initialize tokenizers and models
models_and_tokenizers = {
    "bert": (BertTokenizer.from_pretrained('bert-base-uncased'), BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)),
    "roberta": (RobertaTokenizer.from_pretrained('roberta-base'), RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=2)),
    "xlnet": (XLNetTokenizer.from_pretrained('xlnet-base-cased'), XLNetForSequenceClassification.from_pretrained('xlnet-base-cased', num_labels=2)),
}

# Define the tokenization function
def tokenize_function(examples, tokenizer):
    return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=512)

# Apply the tokenization function to the dataset
tokenized_datasets = {name: dataset.map(lambda x: tokenize_function(x, tokenizer), batched=True) for name, (tokenizer, _) in models_and_tokenizers.items()}

# Split the dataset into train and validation sets
train_test_splits = {name: tokenized_datasets[name]['train'].train_test_split(test_size=0.2) for name in models_and_tokenizers}
train_datasets = {name: split['train'].shuffle(seed=42) for name, split in train_test_splits.items()}  # Shuffle the train dataset
valid_datasets = {name: split['test'].shuffle() for name, split in train_test_splits.items()}

# Define the compute_metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    accuracy = accuracy_score(labels, predictions)
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Define training arguments for each model
training_args_dict = {
    "bert": TrainingArguments(
        output_dir='./bert_results',
        evaluation_strategy="epoch",
        learning_rate=1e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        weight_decay=0.01,
    ),
    "roberta": TrainingArguments(
        output_dir='./roberta_results',
        evaluation_strategy="epoch",
        learning_rate=1e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        weight_decay=0.01,
    ),
    "xlnet": TrainingArguments(
        output_dir='./xlnet_results',
        evaluation_strategy="epoch",
        learning_rate=1e-5,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        weight_decay=0.01,
    ),
}



def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average='weighted')
    accuracy = accuracy_score(labels, predictions)
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  return self.fget.__get__(instance, owner)()
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.bias', 'logits_proj.weight', 'sequence_summary.summary.bias', 'sequence_summary.summary.weight']
You should probably TRAIN this model on a down-s

In [3]:
# Train, evaluate, and save each model
accuracies = {}
for name, (tokenizer, model) in models_and_tokenizers.items():
    print(f"Training and evaluating {name} model...")
    model.to(device)
    
    # Define Trainer with model, arguments, and datasets
    trainer = Trainer(
        model=model,
        args=training_args_dict[name],
        train_dataset=train_datasets[name],
        eval_dataset=valid_datasets[name],
        compute_metrics=compute_metrics
    )

    # Start training
    trainer.train()

    # Evaluate the model
    metrics = trainer.evaluate()
    accuracies[name] = metrics['eval_accuracy']
    print(f"Metrics for {name} model:", metrics)
    
    # Save the model weights
    model.save_pretrained(f'./{name}_model')

print("Training and saving of all models completed.")
print("Accuracies of each model:")
for name, accuracy in accuracies.items():
    print(f"{name} model accuracy: **{accuracy:.4f}**")

Training and evaluating bert model...


  7%|▋         | 500/7500 [03:32<50:15,  2.32it/s]

{'loss': 0.3716, 'grad_norm': 19.597900390625, 'learning_rate': 9.333333333333334e-06, 'epoch': 0.2}


 13%|█▎        | 1000/7500 [07:05<45:36,  2.38it/s] 

{'loss': 0.313, 'grad_norm': 0.5496721267700195, 'learning_rate': 8.666666666666668e-06, 'epoch': 0.4}


 20%|██        | 1500/7500 [10:37<42:06,  2.38it/s]  

{'loss': 0.2756, 'grad_norm': 30.136180877685547, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.6}


 27%|██▋       | 2000/7500 [14:09<38:48,  2.36it/s]  

{'loss': 0.2627, 'grad_norm': 26.032447814941406, 'learning_rate': 7.333333333333333e-06, 'epoch': 0.8}


 33%|███▎      | 2500/7500 [17:41<34:58,  2.38it/s]  

{'loss': 0.2732, 'grad_norm': 10.659000396728516, 'learning_rate': 6.666666666666667e-06, 'epoch': 1.0}


                                                   
 33%|███▎      | 2500/7500 [19:20<34:58,  2.38it/s]

{'eval_loss': 0.2346799373626709, 'eval_accuracy': 0.9132, 'eval_f1': 0.913059022253355, 'eval_precision': 0.9164113135420432, 'eval_recall': 0.9132, 'eval_runtime': 97.6323, 'eval_samples_per_second': 51.213, 'eval_steps_per_second': 6.402, 'epoch': 1.0}


 40%|████      | 3000/7500 [22:51<31:32,  2.38it/s]   

{'loss': 0.1894, 'grad_norm': 12.786458015441895, 'learning_rate': 6e-06, 'epoch': 1.2}


 47%|████▋     | 3500/7500 [26:23<27:59,  2.38it/s]  

{'loss': 0.1802, 'grad_norm': 0.10571971535682678, 'learning_rate': 5.333333333333334e-06, 'epoch': 1.4}


 53%|█████▎    | 4000/7500 [29:55<24:25,  2.39it/s]

{'loss': 0.1963, 'grad_norm': 0.18727432191371918, 'learning_rate': 4.666666666666667e-06, 'epoch': 1.6}


 60%|██████    | 4500/7500 [33:27<21:05,  2.37it/s]

{'loss': 0.1855, 'grad_norm': 1.339341402053833, 'learning_rate': 4.000000000000001e-06, 'epoch': 1.8}


 67%|██████▋   | 5000/7500 [36:59<17:34,  2.37it/s]

{'loss': 0.171, 'grad_norm': 0.3257054388523102, 'learning_rate': 3.3333333333333333e-06, 'epoch': 2.0}


                                                   
 67%|██████▋   | 5000/7500 [38:38<17:34,  2.37it/s]

{'eval_loss': 0.30765819549560547, 'eval_accuracy': 0.9306, 'eval_f1': 0.930595089090797, 'eval_precision': 0.9308672861207031, 'eval_recall': 0.9306, 'eval_runtime': 97.4069, 'eval_samples_per_second': 51.331, 'eval_steps_per_second': 6.416, 'epoch': 2.0}


 73%|███████▎  | 5500/7500 [42:08<14:04,  2.37it/s]   

{'loss': 0.1008, 'grad_norm': 0.06519705057144165, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.2}


 80%|████████  | 6000/7500 [45:40<10:29,  2.38it/s]

{'loss': 0.108, 'grad_norm': 0.06179571524262428, 'learning_rate': 2.0000000000000003e-06, 'epoch': 2.4}


 87%|████████▋ | 6500/7500 [49:12<07:01,  2.37it/s]

{'loss': 0.1225, 'grad_norm': 10.308588027954102, 'learning_rate': 1.3333333333333334e-06, 'epoch': 2.6}


 93%|█████████▎| 7000/7500 [52:44<03:29,  2.38it/s]

{'loss': 0.1238, 'grad_norm': 0.19038809835910797, 'learning_rate': 6.666666666666667e-07, 'epoch': 2.8}


100%|██████████| 7500/7500 [56:16<00:00,  2.36it/s]

{'loss': 0.0963, 'grad_norm': 0.1345270723104477, 'learning_rate': 0.0, 'epoch': 3.0}


                                                   
100%|██████████| 7500/7500 [57:55<00:00,  2.16it/s]


{'eval_loss': 0.336819052696228, 'eval_accuracy': 0.9298, 'eval_f1': 0.9297922241749187, 'eval_precision': 0.9301595930830523, 'eval_recall': 0.9298, 'eval_runtime': 97.4595, 'eval_samples_per_second': 51.303, 'eval_steps_per_second': 6.413, 'epoch': 3.0}
{'train_runtime': 3475.3932, 'train_samples_per_second': 17.264, 'train_steps_per_second': 2.158, 'train_loss': 0.19800359395345052, 'epoch': 3.0}


100%|██████████| 625/625 [01:37<00:00,  6.41it/s]


Metrics for bert model: {'eval_loss': 0.336819052696228, 'eval_accuracy': 0.9298, 'eval_f1': 0.9297922241749187, 'eval_precision': 0.9301595930830523, 'eval_recall': 0.9298, 'eval_runtime': 97.6507, 'eval_samples_per_second': 51.203, 'eval_steps_per_second': 6.4, 'epoch': 3.0}
Training and evaluating roberta model...


  7%|▋         | 500/7500 [03:34<50:02,  2.33it/s]

{'loss': 0.3882, 'grad_norm': 9.236747741699219, 'learning_rate': 9.333333333333334e-06, 'epoch': 0.2}


 13%|█▎        | 1000/7500 [07:09<46:27,  2.33it/s] 

{'loss': 0.3016, 'grad_norm': 0.16578590869903564, 'learning_rate': 8.666666666666668e-06, 'epoch': 0.4}


 20%|██        | 1500/7500 [10:45<42:48,  2.34it/s]  

{'loss': 0.2768, 'grad_norm': 0.2296120971441269, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.6}


 27%|██▋       | 2000/7500 [14:20<39:08,  2.34it/s]  

{'loss': 0.271, 'grad_norm': 0.37167102098464966, 'learning_rate': 7.333333333333333e-06, 'epoch': 0.8}


 33%|███▎      | 2500/7500 [17:55<34:57,  2.38it/s]  

{'loss': 0.2482, 'grad_norm': 0.11965908110141754, 'learning_rate': 6.666666666666667e-06, 'epoch': 1.0}


                                                   
 33%|███▎      | 2500/7500 [19:29<34:57,  2.38it/s]

{'eval_loss': 0.2554825246334076, 'eval_accuracy': 0.9406, 'eval_f1': 0.9405952596477227, 'eval_precision': 0.9408350801007189, 'eval_recall': 0.9406, 'eval_runtime': 92.1994, 'eval_samples_per_second': 54.23, 'eval_steps_per_second': 6.779, 'epoch': 1.0}


 40%|████      | 3000/7500 [23:00<31:38,  2.37it/s]   

{'loss': 0.1764, 'grad_norm': 0.16321678459644318, 'learning_rate': 6e-06, 'epoch': 1.2}


 47%|████▋     | 3500/7500 [26:33<28:12,  2.36it/s]  

{'loss': 0.1998, 'grad_norm': 0.20056791603565216, 'learning_rate': 5.333333333333334e-06, 'epoch': 1.4}


 53%|█████▎    | 4000/7500 [30:06<24:36,  2.37it/s]  

{'loss': 0.1834, 'grad_norm': 0.20147453248500824, 'learning_rate': 4.666666666666667e-06, 'epoch': 1.6}


 60%|██████    | 4500/7500 [33:39<21:02,  2.38it/s]

{'loss': 0.1784, 'grad_norm': 0.09802111238241196, 'learning_rate': 4.000000000000001e-06, 'epoch': 1.8}


 67%|██████▋   | 5000/7500 [37:12<17:32,  2.38it/s]

{'loss': 0.1945, 'grad_norm': 0.08980480581521988, 'learning_rate': 3.3333333333333333e-06, 'epoch': 2.0}


                                                   
 67%|██████▋   | 5000/7500 [38:45<17:32,  2.38it/s]

{'eval_loss': 0.2690128982067108, 'eval_accuracy': 0.9414, 'eval_f1': 0.9413996835596834, 'eval_precision': 0.9414542717909726, 'eval_recall': 0.9414, 'eval_runtime': 92.0403, 'eval_samples_per_second': 54.324, 'eval_steps_per_second': 6.791, 'epoch': 2.0}


 73%|███████▎  | 5500/7500 [42:16<14:02,  2.37it/s]   

{'loss': 0.1243, 'grad_norm': 0.03521687164902687, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.2}


 80%|████████  | 6000/7500 [45:49<10:32,  2.37it/s]

{'loss': 0.113, 'grad_norm': 1.5051937103271484, 'learning_rate': 2.0000000000000003e-06, 'epoch': 2.4}


 87%|████████▋ | 6500/7500 [49:22<07:01,  2.37it/s]

{'loss': 0.1385, 'grad_norm': 0.1062120869755745, 'learning_rate': 1.3333333333333334e-06, 'epoch': 2.6}


 93%|█████████▎| 7000/7500 [52:55<03:29,  2.38it/s]

{'loss': 0.1151, 'grad_norm': 17.877504348754883, 'learning_rate': 6.666666666666667e-07, 'epoch': 2.8}


100%|██████████| 7500/7500 [56:28<00:00,  2.36it/s]

{'loss': 0.1246, 'grad_norm': 0.1272221803665161, 'learning_rate': 0.0, 'epoch': 3.0}


                                                   
100%|██████████| 7500/7500 [58:01<00:00,  2.15it/s]


{'eval_loss': 0.301466703414917, 'eval_accuracy': 0.9456, 'eval_f1': 0.9456, 'eval_precision': 0.9456, 'eval_recall': 0.9456, 'eval_runtime': 92.0537, 'eval_samples_per_second': 54.316, 'eval_steps_per_second': 6.79, 'epoch': 3.0}
{'train_runtime': 3481.8606, 'train_samples_per_second': 17.232, 'train_steps_per_second': 2.154, 'train_loss': 0.20225896504720053, 'epoch': 3.0}


100%|██████████| 625/625 [01:32<00:00,  6.79it/s]


Metrics for roberta model: {'eval_loss': 0.301466703414917, 'eval_accuracy': 0.9456, 'eval_f1': 0.9456, 'eval_precision': 0.9456, 'eval_recall': 0.9456, 'eval_runtime': 92.1831, 'eval_samples_per_second': 54.24, 'eval_steps_per_second': 6.78, 'epoch': 3.0}
Training and evaluating xlnet model...


  7%|▋         | 500/7500 [07:00<1:38:16,  1.19it/s]

{'loss': 0.3658, 'grad_norm': 25.76732635498047, 'learning_rate': 9.333333333333334e-06, 'epoch': 0.2}


 13%|█▎        | 1000/7500 [14:03<1:30:46,  1.19it/s]

{'loss': 0.2614, 'grad_norm': 19.052309036254883, 'learning_rate': 8.666666666666668e-06, 'epoch': 0.4}


 20%|██        | 1500/7500 [21:05<1:23:56,  1.19it/s]

{'loss': 0.2618, 'grad_norm': 8.793231964111328, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.6}


 27%|██▋       | 2000/7500 [28:07<1:17:03,  1.19it/s]

{'loss': 0.2503, 'grad_norm': 3.2806758880615234, 'learning_rate': 7.333333333333333e-06, 'epoch': 0.8}


 33%|███▎      | 2500/7500 [35:08<1:09:33,  1.20it/s]

{'loss': 0.2627, 'grad_norm': 2.9065496921539307, 'learning_rate': 6.666666666666667e-06, 'epoch': 1.0}



 33%|███▎      | 2500/7500 [38:48<1:09:33,  1.20it/s]

{'eval_loss': 0.2663903832435608, 'eval_accuracy': 0.9326, 'eval_f1': 0.9325891320584025, 'eval_precision': 0.9340213706866504, 'eval_recall': 0.9326, 'eval_runtime': 218.3846, 'eval_samples_per_second': 22.895, 'eval_steps_per_second': 2.862, 'epoch': 1.0}


 40%|████      | 3000/7500 [45:46<1:02:31,  1.20it/s] 

{'loss': 0.1678, 'grad_norm': 0.613919198513031, 'learning_rate': 6e-06, 'epoch': 1.2}


 47%|████▋     | 3500/7500 [52:45<55:42,  1.20it/s]  

{'loss': 0.1847, 'grad_norm': 0.13898837566375732, 'learning_rate': 5.333333333333334e-06, 'epoch': 1.4}


 53%|█████▎    | 4000/7500 [59:44<48:46,  1.20it/s]  

{'loss': 0.1618, 'grad_norm': 0.12484496086835861, 'learning_rate': 4.666666666666667e-06, 'epoch': 1.6}


 60%|██████    | 4500/7500 [1:06:43<41:49,  1.20it/s]

{'loss': 0.171, 'grad_norm': 0.054502490907907486, 'learning_rate': 4.000000000000001e-06, 'epoch': 1.8}


 67%|██████▋   | 5000/7500 [1:13:42<34:45,  1.20it/s]  

{'loss': 0.1854, 'grad_norm': 0.07374744862318039, 'learning_rate': 3.3333333333333333e-06, 'epoch': 2.0}



 67%|██████▋   | 5000/7500 [1:17:21<34:45,  1.20it/s]

{'eval_loss': 0.3206014037132263, 'eval_accuracy': 0.9412, 'eval_f1': 0.9412060789724859, 'eval_precision': 0.9412992043118562, 'eval_recall': 0.9412, 'eval_runtime': 217.7432, 'eval_samples_per_second': 22.963, 'eval_steps_per_second': 2.87, 'epoch': 2.0}


 73%|███████▎  | 5500/7500 [1:24:19<27:45,  1.20it/s]   

{'loss': 0.0806, 'grad_norm': 0.023299304768443108, 'learning_rate': 2.666666666666667e-06, 'epoch': 2.2}


 80%|████████  | 6000/7500 [1:31:18<20:54,  1.20it/s]

{'loss': 0.109, 'grad_norm': 0.0629800334572792, 'learning_rate': 2.0000000000000003e-06, 'epoch': 2.4}


 87%|████████▋ | 6500/7500 [1:38:18<13:53,  1.20it/s]

{'loss': 0.1182, 'grad_norm': 0.026130311191082, 'learning_rate': 1.3333333333333334e-06, 'epoch': 2.6}


 93%|█████████▎| 7000/7500 [1:45:17<06:57,  1.20it/s]

{'loss': 0.111, 'grad_norm': 0.06656184047460556, 'learning_rate': 6.666666666666667e-07, 'epoch': 2.8}


100%|██████████| 7500/7500 [1:52:16<00:00,  1.20it/s]

{'loss': 0.1145, 'grad_norm': 0.04412345588207245, 'learning_rate': 0.0, 'epoch': 3.0}



100%|██████████| 7500/7500 [1:55:55<00:00,  1.08it/s]


{'eval_loss': 0.32974615693092346, 'eval_accuracy': 0.9418, 'eval_f1': 0.9418021170158271, 'eval_precision': 0.9418103386682595, 'eval_recall': 0.9418, 'eval_runtime': 217.8816, 'eval_samples_per_second': 22.948, 'eval_steps_per_second': 2.869, 'epoch': 3.0}
{'train_runtime': 6955.6355, 'train_samples_per_second': 8.626, 'train_steps_per_second': 1.078, 'train_loss': 0.1870669687906901, 'epoch': 3.0}


100%|██████████| 625/625 [03:37<00:00,  2.88it/s]


Metrics for xlnet model: {'eval_loss': 0.32974615693092346, 'eval_accuracy': 0.9418, 'eval_f1': 0.9418021170158271, 'eval_precision': 0.9418103386682595, 'eval_recall': 0.9418, 'eval_runtime': 217.5542, 'eval_samples_per_second': 22.983, 'eval_steps_per_second': 2.873, 'epoch': 3.0}
Training and saving of all models completed.
Accuracies of each model:
bert model accuracy: **0.9298**
roberta model accuracy: **0.9456**
xlnet model accuracy: **0.9418**
