<a href="https://colab.research.google.com/github/dgromann/cl_intro_ws2024/blob/main/exercises/HomeExercise3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Home Exericse 3: Hyperparameters and Evaluation
In this third home exercise, you will use the knowledge from Tutorial 4 to experiment with hyperparameters, create a test set, and evaluate your final model on the created test set.

In this notebook, please complete all instructions starting with 👋 ⚒ in the code cell after the sign or provide your analysis in the text cell after the sign.

## **Distilbert: Hyperparameters and Evaluation**

Use the code of Tutorial 4 to load and fine-tune the `distilbert-base-cased`model on the small subset of the `imdb`Movie Review Dataset. For convenience, the code of Tutorial 4 required for this exercise is already provided in the code cells below.

👋 ⚒ When creating the dataset splits in the code cell below, additionally create a test set to be used after thet training. Make sure that your test set does not contain any of the sentences contained in the training or validation set and is approximately of the same size as the validation set.

In [1]:
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install accelerate --upgrade

Collecting evaluate
  Using cached evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Using cached evaluate-0.4.3-py3-none-any.whl (84 kB)
Installing collected packages: evaluate
Successfully installed evaluate-0.4.3


In [4]:
from datasets import load_dataset, DatasetDict
from transformers import DataCollatorWithPadding, AutoTokenizer

imdb_dataset = load_dataset("imdb")
# we had loaded the imdb dataset already above - if not, outcomment this line
# Make sure you have the right tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-cased")

# Just take the first 50 tokens for speed on CPU
def truncate(example):
    return {
        'text': " ".join(example['text'].split()[:512]),
        'label': example['label']
    }

train_indices = range(0, 10000)
val_indices = range(10000, 11000)
test_indices = range(11000, 12000)

# Take 128 random examples for train and 32 validation
small_imdb_dataset = DatasetDict(
    train=imdb_dataset['train'].shuffle(seed=24).select(train_indices).map(truncate),
    val=imdb_dataset['train'].shuffle(seed=24).select(val_indices).map(truncate),
    test=imdb_dataset['train'].shuffle(seed=24).select(test_indices).map(truncate)
)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding=True, truncation=True)

small_tokenized_dataset = small_imdb_dataset.map(tokenize_function, batched=True, batch_size=16)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/465 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

👋 ⚒ For this exercise, we will use the Hugging Face Trainer class to play with hyperparamters. Try to find a set of hyperparameter settings that achieves the highest possilbe accuracy on the **validation set** with the small dataset and model in this setup.

**Optional:** If you want to follow a more systematic route, feel free to use available frameworks for hyperparameter optimization, such as [Optuna](https://optuna.org/).

In [11]:
import numpy as np
import evaluate
from transformers import TrainingArguments, Trainer
from transformers import AutoModelForSequenceClassification
import optuna

model = AutoModelForSequenceClassification.from_pretrained('distilbert/distilbert-base-cased', num_labels=2)
accuracy = evaluate.load("accuracy")

# Define the Optuna objective
def objective(trial):
    learning_rate = trial.suggest_float('learning_rate', 2e-5, 5e-5, log=True)  
    weight_decay = trial.suggest_float('weight_decay', 0.01, 0.1, log=True) 
    batch_size = trial.suggest_categorical('batch_size', [8, 16, 32]) 
    num_epochs = trial.suggest_int('num_epochs', 2, 10)

    training_args = TrainingArguments(
        output_dir=f"optuna_distilbert_{trial.number}",
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        logging_steps=8,
        num_train_epochs=num_epochs,
        eval_strategy="epoch",
        save_strategy="epoch",
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        load_best_model_at_end=True,
        save_total_limit=1,
        report_to= 'none',
        seed=224
    )

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return accuracy.compute(predictions=predictions, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=small_tokenized_dataset['train'],
        eval_dataset=small_tokenized_dataset['val'],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics
    )

    trainer.train()

    eval_metrics = trainer.evaluate()
    return eval_metrics['eval_accuracy']

study = optuna.create_study(direction="maximize")
study.optimize(objective, n_trials=30)

best_trial = study.best_trial
print(f"Best trial hyperparameters: {best_trial.params}")

best_model = AutoModelForSequenceClassification.from_pretrained(
    'distilbert-base-cased', num_labels=2
)
best_training_args = TrainingArguments(
    output_dir="best_model",
    per_device_train_batch_size=best_trial.params['batch_size'],
    per_device_eval_batch_size=best_trial.params['batch_size'],
    logging_steps=8,
    num_train_epochs=best_trial.params['num_epochs'],
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=best_trial.params['learning_rate'],
    weight_decay=best_trial.params['weight_decay'],
    load_best_model_at_end=True,
    save_total_limit=1,
    report_to='none',
    seed=224
)

def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return accuracy.compute(predictions=predictions, references=labels)

best_trainer = Trainer(
    model=best_model,
    args=best_training_args,
    train_dataset=small_tokenized_dataset['train'],
    eval_dataset=small_tokenized_dataset['val'],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

best_trainer.train()

best_trainer.save_model("best_model")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[I 2024-11-29 13:01:03,460] A new study created in memory with name: no-name-6492eb38-1539-4596-9065-b28989f10b62


Epoch,Training Loss,Validation Loss,Accuracy
1,0.3022,0.30223,0.88
2,0.3796,0.445726,0.874
3,0.0599,0.575594,0.88
4,0.0319,0.724912,0.893
5,0.0007,0.694777,0.885
6,0.0002,0.757625,0.892
7,0.0001,0.847184,0.892


[I 2024-11-29 13:33:44,973] Trial 0 finished with value: 0.88 and parameters: {'learning_rate': 4.910245840098907e-05, 'weight_decay': 0.02332477955536842, 'batch_size': 8, 'num_epochs': 7}. Best is trial 0 with value: 0.88.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.1837,0.407537,0.883
2,0.0905,0.446379,0.886
3,0.0047,0.566851,0.878
4,0.0168,0.644274,0.883
5,0.0188,0.650853,0.89
6,0.0012,0.693659,0.888
7,0.0179,0.784914,0.891
8,0.0003,0.821126,0.887
9,0.0004,0.81614,0.887
10,0.0004,0.810726,0.89


[I 2024-11-29 14:15:37,346] Trial 1 finished with value: 0.883 and parameters: {'learning_rate': 3.505391372702253e-05, 'weight_decay': 0.049926089086037524, 'batch_size': 32, 'num_epochs': 10}. Best is trial 1 with value: 0.883.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.117,0.648922,0.886
2,0.0986,0.690206,0.884
3,0.0003,0.773785,0.892


[I 2024-11-29 14:29:49,167] Trial 2 finished with value: 0.886 and parameters: {'learning_rate': 2.210202790294413e-05, 'weight_decay': 0.0594848695366287, 'batch_size': 8, 'num_epochs': 3}. Best is trial 2 with value: 0.886.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0654,0.638907,0.879
2,0.0508,0.539328,0.881
3,0.0029,0.731246,0.886
4,0.0012,0.779376,0.882
5,0.0002,0.857039,0.883
6,0.0002,0.855712,0.888
7,0.0002,0.842888,0.889


[I 2024-11-29 14:59:08,107] Trial 3 finished with value: 0.881 and parameters: {'learning_rate': 3.9837019347916645e-05, 'weight_decay': 0.09829699352064741, 'batch_size': 32, 'num_epochs': 7}. Best is trial 2 with value: 0.886.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0007,0.734989,0.88
2,0.0211,0.774569,0.884
3,0.0002,0.8698,0.883
4,0.0002,0.872822,0.887
5,0.0001,0.94622,0.884
6,0.0041,0.954781,0.883


[I 2024-11-29 15:24:14,645] Trial 4 finished with value: 0.88 and parameters: {'learning_rate': 2.7685911598499866e-05, 'weight_decay': 0.015121052418968649, 'batch_size': 32, 'num_epochs': 6}. Best is trial 2 with value: 0.886.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.085,0.961987,0.865
2,0.2397,0.835241,0.885
3,0.0,0.926894,0.894


[I 2024-11-29 15:38:20,040] Trial 5 finished with value: 0.885 and parameters: {'learning_rate': 3.846170208972637e-05, 'weight_decay': 0.03365988561989324, 'batch_size': 8, 'num_epochs': 3}. Best is trial 2 with value: 0.886.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0114,1.043603,0.876
2,0.0499,0.987915,0.881
3,0.0002,1.141086,0.878
4,0.0001,1.183333,0.881
5,0.0,1.20242,0.883


[I 2024-11-29 15:59:27,058] Trial 6 finished with value: 0.881 and parameters: {'learning_rate': 2.2662242717768533e-05, 'weight_decay': 0.08976653325541398, 'batch_size': 32, 'num_epochs': 5}. Best is trial 2 with value: 0.886.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.073,0.828154,0.88
2,0.0088,1.00594,0.89
3,0.0,1.070852,0.889
4,0.0259,1.065985,0.882
5,0.0,1.145485,0.889
6,0.0,1.324686,0.885
7,0.0141,1.215022,0.892
8,0.0,1.22118,0.893
9,0.0,1.239121,0.893


[I 2024-11-29 16:38:23,237] Trial 7 finished with value: 0.88 and parameters: {'learning_rate': 4.09637270005294e-05, 'weight_decay': 0.01552909508193456, 'batch_size': 16, 'num_epochs': 9}. Best is trial 2 with value: 0.886.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.136668,0.871
2,0.2171,1.238717,0.874
3,0.0,1.022723,0.881
4,0.0,1.277521,0.883
5,0.0003,1.245732,0.887


[I 2024-11-29 17:01:45,037] Trial 8 finished with value: 0.881 and parameters: {'learning_rate': 3.586950632557403e-05, 'weight_decay': 0.036023878037088196, 'batch_size': 8, 'num_epochs': 5}. Best is trial 2 with value: 0.886.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0001,1.275262,0.882
2,0.0,1.210478,0.882
3,0.0001,1.069367,0.888
4,0.0001,1.148675,0.887
5,0.0,1.423463,0.889
6,0.0001,1.178108,0.89
7,0.0,1.239757,0.889
8,0.0,1.354458,0.884
9,0.0113,1.362259,0.883
10,0.0,1.374321,0.885


[I 2024-11-29 17:43:57,371] Trial 9 finished with value: 0.888 and parameters: {'learning_rate': 3.277703685345742e-05, 'weight_decay': 0.012488808024735532, 'batch_size': 32, 'num_epochs': 10}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0418,1.410092,0.882
2,0.0,1.252955,0.873
3,0.0,1.400144,0.878
4,0.0,1.345015,0.876
5,0.0,1.478472,0.881
6,0.0001,1.374501,0.878
7,0.0,1.486225,0.882
8,0.0,1.363043,0.884
9,0.0,1.429064,0.878


[I 2024-11-29 18:22:44,406] Trial 10 finished with value: 0.873 and parameters: {'learning_rate': 2.8103174094699945e-05, 'weight_decay': 0.010004043632036136, 'batch_size': 16, 'num_epochs': 9}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.865643,0.882
2,0.0,1.604245,0.879


[I 2024-11-29 18:32:09,512] Trial 11 finished with value: 0.879 and parameters: {'learning_rate': 2.015174435082087e-05, 'weight_decay': 0.056548446565873996, 'batch_size': 8, 'num_epochs': 2}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.288954,0.874
2,0.0,1.433879,0.878
3,0.0,1.40916,0.88


[I 2024-11-29 18:44:45,322] Trial 12 finished with value: 0.874 and parameters: {'learning_rate': 2.8817294554825808e-05, 'weight_decay': 0.0576313356300951, 'batch_size': 32, 'num_epochs': 3}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.378742,0.879
2,0.0,1.631122,0.882
3,0.0,1.571243,0.887
4,0.0,1.613368,0.885


[I 2024-11-29 19:03:27,265] Trial 13 finished with value: 0.879 and parameters: {'learning_rate': 2.45398543969844e-05, 'weight_decay': 0.02140908299863795, 'batch_size': 8, 'num_epochs': 4}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.354928,0.864
2,0.0577,1.17796,0.874
3,0.0,1.447605,0.872
4,0.0,1.533101,0.875
5,0.0,1.547287,0.868
6,0.0,1.590297,0.874
7,0.0,1.609157,0.878
8,0.0,1.592878,0.879


[I 2024-11-29 19:38:04,970] Trial 14 finished with value: 0.874 and parameters: {'learning_rate': 3.161163845396108e-05, 'weight_decay': 0.01012722501826272, 'batch_size': 16, 'num_epochs': 8}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.604273,0.872
2,0.0,1.574549,0.874


[I 2024-11-29 19:46:35,211] Trial 15 finished with value: 0.874 and parameters: {'learning_rate': 2.3995248822132874e-05, 'weight_decay': 0.07114849345915476, 'batch_size': 32, 'num_epochs': 2}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.746505,0.876
2,0.0,1.300571,0.879
3,0.0,1.570094,0.874
4,0.0,1.610894,0.873
5,0.0,1.624003,0.873


[I 2024-11-29 20:09:58,911] Trial 16 finished with value: 0.879 and parameters: {'learning_rate': 2.0144892136751205e-05, 'weight_decay': 0.045155935314441537, 'batch_size': 8, 'num_epochs': 5}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.953616,0.869
2,0.0001,2.06733,0.863
3,0.0012,1.303653,0.863
4,0.0,1.453016,0.875
5,0.0,1.516201,0.88
6,0.0,1.504996,0.876
7,0.0144,1.619622,0.874
8,0.0,1.472374,0.882
9,0.0,1.484739,0.881
10,0.0,1.489932,0.881


[I 2024-11-29 20:52:03,272] Trial 17 finished with value: 0.863 and parameters: {'learning_rate': 3.17398916447762e-05, 'weight_decay': 0.020588381007676536, 'batch_size': 32, 'num_epochs': 10}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.313745,0.871
2,0.0001,1.270088,0.87
3,0.0,1.183893,0.872
4,0.0,1.518985,0.864
5,0.0,1.359023,0.87
6,0.0,1.519716,0.867
7,0.0,1.558436,0.864


[I 2024-11-29 21:25:22,091] Trial 18 finished with value: 0.872 and parameters: {'learning_rate': 4.8492231882799904e-05, 'weight_decay': 0.026405534490119623, 'batch_size': 8, 'num_epochs': 7}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.64885,0.881
2,0.0,1.626639,0.869
3,0.0,1.647679,0.873
4,0.0,1.647897,0.874


[I 2024-11-29 21:42:48,786] Trial 19 finished with value: 0.869 and parameters: {'learning_rate': 2.613768091080242e-05, 'weight_decay': 0.014723229251983617, 'batch_size': 16, 'num_epochs': 4}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.802563,0.873
2,0.0,1.360212,0.866
3,0.0,1.601463,0.87
4,0.0,1.76182,0.866
5,0.0,1.659295,0.873
6,0.0206,1.579723,0.875
7,0.0,1.578417,0.87
8,0.0,1.59225,0.87


[I 2024-11-29 22:16:24,047] Trial 20 finished with value: 0.866 and parameters: {'learning_rate': 3.268312851120933e-05, 'weight_decay': 0.03904633630030506, 'batch_size': 32, 'num_epochs': 8}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.426159,0.857
2,0.0,1.269365,0.869
3,0.0,1.464925,0.867


[I 2024-11-29 22:31:02,967] Trial 21 finished with value: 0.869 and parameters: {'learning_rate': 4.196292875804882e-05, 'weight_decay': 0.030178726722210267, 'batch_size': 8, 'num_epochs': 3}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,2.302009,0.85
2,0.0,1.903832,0.86
3,0.0,1.822512,0.858


[I 2024-11-29 22:45:48,654] Trial 22 finished with value: 0.858 and parameters: {'learning_rate': 3.648412376110559e-05, 'weight_decay': 0.06799973024305787, 'batch_size': 8, 'num_epochs': 3}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0002,1.393034,0.839
2,0.0,1.978826,0.855
3,0.0,2.100529,0.852
4,0.0,1.793958,0.855


[I 2024-11-29 23:05:24,323] Trial 23 finished with value: 0.839 and parameters: {'learning_rate': 4.336649486419176e-05, 'weight_decay': 0.0321047691781814, 'batch_size': 8, 'num_epochs': 4}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,2.300302,0.848
2,0.0,2.137851,0.846


[I 2024-11-29 23:15:16,753] Trial 24 finished with value: 0.846 and parameters: {'learning_rate': 3.798236455454877e-05, 'weight_decay': 0.0179229241914699, 'batch_size': 8, 'num_epochs': 2}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0005,2.359116,0.849
2,0.0,2.152672,0.846
3,0.0,1.578775,0.858
4,0.0001,1.355056,0.85
5,0.0,1.424231,0.848
6,0.0,1.416295,0.849


[I 2024-11-29 23:44:32,283] Trial 25 finished with value: 0.85 and parameters: {'learning_rate': 3.348813385183655e-05, 'weight_decay': 0.012171918874333141, 'batch_size': 8, 'num_epochs': 6}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,2.095148,0.849
2,0.0,1.503121,0.857
3,0.0,1.671347,0.853


[I 2024-11-29 23:59:17,614] Trial 26 finished with value: 0.857 and parameters: {'learning_rate': 2.9935672476712808e-05, 'weight_decay': 0.07921074544665736, 'batch_size': 8, 'num_epochs': 3}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,1.606062,0.85
2,0.0005,1.427696,0.84
3,0.0,1.688416,0.844
4,0.0,1.68508,0.851


[I 2024-11-30 00:18:43,036] Trial 27 finished with value: 0.84 and parameters: {'learning_rate': 4.531243315074833e-05, 'weight_decay': 0.04378138704378175, 'batch_size': 8, 'num_epochs': 4}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0001,1.560022,0.837
2,0.0005,1.969514,0.84
3,0.0,2.050297,0.842
4,0.0226,1.861223,0.843
5,0.0,1.875479,0.843
6,0.0,1.788169,0.846


[I 2024-11-30 00:44:39,060] Trial 28 finished with value: 0.837 and parameters: {'learning_rate': 3.7767426366821304e-05, 'weight_decay': 0.030118566625400255, 'batch_size': 16, 'num_epochs': 6}. Best is trial 9 with value: 0.888.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0,2.128082,0.845
2,0.0,2.372188,0.84
3,0.0,2.426211,0.845
4,0.0,2.458859,0.84
5,0.0,2.866641,0.838
6,0.0,2.48166,0.844
7,0.0,2.590931,0.838
8,0.0,2.579707,0.84


[I 2024-11-30 01:18:28,570] Trial 29 finished with value: 0.845 and parameters: {'learning_rate': 3.427170036598393e-05, 'weight_decay': 0.02577444189414341, 'batch_size': 32, 'num_epochs': 8}. Best is trial 9 with value: 0.888.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Best trial hyperparameters: {'learning_rate': 3.277703685345742e-05, 'weight_decay': 0.012488808024735532, 'batch_size': 32, 'num_epochs': 10}


Epoch,Training Loss,Validation Loss,Accuracy
1,0.3592,0.303364,0.872
2,0.189,0.331404,0.886
3,0.1025,0.404461,0.895
4,0.0555,0.52938,0.887
5,0.0114,0.600677,0.882
6,0.0071,0.730457,0.883
7,0.0024,0.794763,0.888
8,0.0295,0.821378,0.884
9,0.0002,0.803798,0.887
10,0.0002,0.798478,0.888


👋 ⚒ Change the following code cell in a way that not only a single sentence is evaluated on your trained model (!make sure to use the correct checkpoint!) but the evaluation is performaned on the entire newly created test set.

This might also be a good occassion to get familiar with the [Hugging Face documentation and tutorials](https://huggingface.co/docs/transformers/index).

In [24]:
import torch
import numpy as np
from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer
from sklearn.metrics import matthews_corrcoef, accuracy_score, precision_recall_fscore_support

imdb_dataset = load_dataset("imdb")

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")

def truncate(example):
    return {
        'text': " ".join(example['text'].split()[:512]),
        'label': example['label']
    }

train_indices = range(0, 10000)
val_indices = range(10000, 11000)
test_indices = range(11000, 12000)

small_imdb_dataset = DatasetDict({
    "train": imdb_dataset['train'].shuffle(seed=24).select(train_indices).map(truncate),
    "val": imdb_dataset['train'].shuffle(seed=24).select(val_indices).map(truncate),
    "test": imdb_dataset['train'].shuffle(seed=24).select(test_indices).map(truncate)
})

def tokenize_function(examples):
    return tokenizer(examples["text"], padding=True, truncation=True)

small_tokenized_dataset = small_imdb_dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

fine_tuned_model = AutoModelForSequenceClassification.from_pretrained(
    "/Users/bohdan/Documents/cl_intro_ws2024_zhvalevskyi/exercises/best_model"
)

trainer = Trainer(
    model=fine_tuned_model,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

def compute_metrics(pred):
    labels = pred.label_ids
    preds = np.argmax(pred.predictions, axis=1)

    mcc = matthews_corrcoef(labels, preds)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average="binary")

    return {
        "mcc": mcc,
        "accuracy": acc,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }

test_results = trainer.evaluate(small_tokenized_dataset["test"], metric_key_prefix="test")
predictions = trainer.predict(small_tokenized_dataset["test"])

metrics = compute_metrics(predictions)

print("Evaluation Results:")
print(f"Test Loss: {test_results['test_loss']:.4f}")
print(f"Matthews Correlation Coefficient (MCC): {metrics['mcc']:.4f}")
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1 Score: {metrics['f1']:.4f}")


Evaluation Results:
Test Loss: 0.2870
Matthews Correlation Coefficient (MCC): 0.7764
Accuracy: 0.8830
Precision: 0.9560
Recall: 0.8082
F1 Score: 0.8759
