In [None]:
from datasets.datasets import get_own_dataset

train_dataset, test_dataset, encode_dict = get_own_dataset('datasets/PURE_and_others.csv')

In [None]:
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments, DistilBertTokenizerFast
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1
    }

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=1,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    evaluation_strategy='epoch',
    eval_steps=300,
)

model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=len(encode_dict))

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=test_dataset,             # evaluation dataset
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

trainer.train()

## Results

In [None]:
from src.utils import show_confusion_matrix

show_confusion_matrix(model, test_dataset, encode_dict)

In [None]:
show_confusion_matrix(model, train_dataset, encode_dict)

In [None]:
%load_ext tensorboard
%tensorboard --logdir logs