In [1]:
from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast, Trainer, TrainingArguments
import torch
import torch.nn.functional as F
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score, accuracy_score

bin c:\Users\ms2k\.conda\envs\ml\Lib\site-packages\bitsandbytes\libbitsandbytes_cuda121.dll


In [2]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased', num_labels = 2)

In [3]:
data = []

from json import loads as load_json
with open("../data.json", 'r') as raw:
    data = load_json(raw.read())

from datasets import Dataset
ds = Dataset.from_list(data).shuffle(seed = 42) \
                            .train_test_split(test_size = 0.2)

def tokenize(data):
    tokens = tokenizer(data['text'], padding = 'max_length', truncation = True, return_tensors='pt')
    tokens = {key: tensor.to('cuda') for key, tensor in tokens.items()}
    return tokens

#train_ds = ds['train'].map(tokenize, batched=True)
test_ds  = ds['test'] .map(tokenize, batched=True)

Map:   0%|          | 0/1901 [00:00<?, ? examples/s]

In [4]:
model_path = './output/distilbert2'
model = DistilBertForSequenceClassification.from_pretrained(model_path).to('cuda')

In [5]:
eval_args = TrainingArguments(
    per_device_eval_batch_size = 16,
    output_dir = 'eval/distilbert'
)

In [6]:
trainer = Trainer(model = model, args = eval_args)
results = trainer.predict(test_ds)

  0%|          | 0/119 [00:00<?, ?it/s]

In [7]:
predicted_probs = F.softmax(torch.tensor(results.predictions), dim=1)[:, 1].numpy()
labels = test_ds['label']

In [8]:
accuracy = accuracy_score(labels, (predicted_probs > 0.5).astype(int))
precision, recall, f1, _ = precision_recall_fscore_support(labels, (predicted_probs > 0.5).astype(int), average='binary')
auroc = roc_auc_score(labels, predicted_probs)

In [9]:
metrics_dict = {
    "Accuracy": accuracy,
    "Precision": precision,
    "Recall": recall,
    "F1 Score": f1,
    "AUROC": auroc,
}
metrics_dict

{'Accuracy': 0.746449237243556,
 'Precision': 0.7117031398667936,
 'Recall': 0.8069039913700108,
 'F1 Score': 0.7563195146612741,
 'AUROC': 0.8369550048842727}