In [17]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, f1_score
import torch
import random
import numpy as np

# –£—Å—Ç–∞–Ω–æ–≤–∫–∞ random seed –¥–ª—è –≤–æ—Å–ø—Ä–æ–∏–∑–≤–æ–¥–∏–º–æ—Å—Ç–∏
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# 1. –ó–∞–≥—Ä—É–∑–∫–∞ –Ω–∞–±–æ—Ä–∞ –¥–∞–Ω–Ω—ã—Ö
dataset_id = "banking77"
raw_dataset = load_dataset(dataset_id)

print(f"Train dataset size: {len(raw_dataset['train'])}")
print(f"Test dataset size: {len(raw_dataset['test'])}")


Train dataset size: 10003
Test dataset size: 3080


In [18]:
# 2. –ó–∞–≥—Ä—É–∑–∫–∞ —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä–∞ –∏ –º–æ–¥–µ–ª–∏
model_name = "distilbert-base-uncased" 
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=77)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
# 3. –¢–æ–∫–µ–Ω–∏–∑–∞—Ü–∏—è –¥–∞–Ω–Ω—ã—Ö
train_encodings = tokenizer(raw_dataset['train']['text'], truncation=True, padding=True)
test_encodings = tokenizer(raw_dataset['test']['text'], truncation=True, padding=True)

In [20]:
# 4. –°–æ–∑–¥–∞–Ω–∏–µ –ø–æ–ª—å–∑–æ–≤–∞—Ç–µ–ª—å—Å–∫–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞
class Banking77Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = Banking77Dataset(train_encodings, raw_dataset['train']['label'])
test_dataset = Banking77Dataset(test_encodings, raw_dataset['test']['label'])

In [None]:
# 5. –û—Ü–µ–Ω–∫–∞ –∫–∞—á–µ—Å—Ç–≤–∞ –º–æ–¥–µ–ª–∏ –¥–æ –¥–æ–æ–±—É—á–µ–Ω–∏—è
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {'accuracy': acc, 'f1': f1}

# –û–ø—Ä–µ–¥–µ–ª–µ–Ω–∏–µ –Ω–∞—Å—Ç—Ä–æ–µ–∫ —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∏
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,  
    per_device_train_batch_size=8, 
    learning_rate=5e-5,  
    evaluation_strategy="epoch", 
    logging_dir='./logs',
    logging_steps=10,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics 
)

# –û—Ü–µ–Ω–∫–∞ –¥–æ –¥–æ–æ–±—É—á–µ–Ω–∏—è
initial_eval = trainer.evaluate()
print(f"Initial evaluation: {initial_eval}")



Initial evaluation: {'eval_loss': 4.347962379455566, 'eval_model_preparation_time': 0.0017, 'eval_accuracy': 0.01396103896103896, 'eval_f1': 0.0017067393290117144, 'eval_runtime': 28.621, 'eval_samples_per_second': 107.613, 'eval_steps_per_second': 13.452}


In [22]:
# 6. –¢—Ä–µ–Ω–∏—Ä–æ–≤–∫–∞ –º–æ–¥–µ–ª–∏
trainer.train()

Epoch,Training Loss,Validation Loss,Model Preparation Time,Accuracy,F1
1,0.5256,0.519816,0.0017,0.863312,0.858998
2,0.3655,0.31339,0.0017,0.91461,0.914385
3,0.0665,0.293648,0.0017,0.923052,0.92302


TrainOutput(global_step=3753, training_loss=0.7072711554950248, metrics={'train_runtime': 1845.2877, 'train_samples_per_second': 16.263, 'train_steps_per_second': 2.034, 'total_flos': 761898528728316.0, 'train_loss': 0.7072711554950248, 'epoch': 3.0})

In [23]:
# 6. –û—Ü–µ–Ω–∫–∞ –∫–∞—á–µ—Å—Ç–≤–∞ –º–æ–¥–µ–ª–∏ –¥–æ –∏ –ø–æ—Å–ª–µ –¥–æ–æ–±—É—á–µ–Ω–∏—è
def evaluate_model(dataset, device):
    model.eval()
    pred_labels = []

    model.to(device)

    # –°–æ–∑–¥–∞–Ω–∏–µ DataLoader –¥–ª—è —Ç–µ—Å—Ç–æ–≤–æ–≥–æ –Ω–∞–±–æ—Ä–∞
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=16)

    with torch.no_grad():
        for batch in test_loader:
            # –ü–µ—Ä–µ–º–µ—â–µ–Ω–∏–µ –≤—Ö–æ–¥–Ω—ã—Ö –¥–∞–Ω–Ω—ã—Ö –Ω–∞ –∑–∞–¥–∞–Ω–Ω–æ–µ —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ
            for key in batch:
                batch[key] = batch[key].to(device)
            outputs = model(**batch)
            logits = outputs.logits
            pred_labels.extend(logits.argmax(dim=-1).tolist())

    return pred_labels

In [29]:
# –û—Ü–µ–Ω–∫–∞ –∫–∞—á–µ—Å—Ç–≤–∞ –º–æ–¥–µ–ª–∏ –ø–æ—Å–ª–µ –¥–æ–æ–±—É—á–µ–Ω–∏—è
final_eval = trainer.evaluate()
print(f"Final evaluation: {final_eval}")

Final evaluation: {'eval_loss': 0.29364821314811707, 'eval_model_preparation_time': 0.0017, 'eval_accuracy': 0.923051948051948, 'eval_f1': 0.9230196335238754, 'eval_runtime': 30.2876, 'eval_samples_per_second': 101.692, 'eval_steps_per_second': 12.711, 'epoch': 3.0}


In [30]:
print(f"Initial Accuracy: {initial_eval['eval_accuracy']}, Initial F1 Score: {initial_eval['eval_f1']}")
print(f"Final Accuracy: {final_eval['eval_accuracy']}, Final F1 Score: {final_eval['eval_f1']}")

Initial Accuracy: 0.01396103896103896, Initial F1 Score: 0.0017067393290117144
Final Accuracy: 0.923051948051948, Final F1 Score: 0.9230196335238754


–ü–æ—Å–ª–µ –Ω–∞—á–∞–ª—å–Ω–æ–π –æ—Ü–µ–Ω–∫–∏ —Ç—Ä–∞–Ω—Å—Ñ–æ—Ä–º–µ—Ä–Ω–æ–π –º–æ–¥–µ–ª–∏ –¥–ª—è –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏ —Ç–µ–∫—Å—Ç–æ–≤, –≥–¥–µ —Ç–æ—á–Ω–æ—Å—Ç—å —Å–æ—Å—Ç–∞–≤–∏–ª–∞ –≤—Å–µ–≥–æ 1.40% –∏ F1-–º–µ—Ç—Ä–∏–∫–∞ 0.17%, –º–æ–¥–µ–ª—å –±—ã–ª–∞ –∑–Ω–∞—á–∏—Ç–µ–ª—å–Ω–æ —É–ª—É—á—à–µ–Ω–∞ –≤ –ø—Ä–æ—Ü–µ—Å—Å–µ –¥–æ–æ–±—É—á–µ–Ω–∏—è, —á—Ç–æ –ø—Ä–∏–≤–µ–ª–æ –∫ –æ–∫–æ–Ω—á–∞—Ç–µ–ª—å–Ω–æ–π —Ç–æ—á–Ω–æ—Å—Ç–∏ 92.31% –∏ F1-–º–µ—Ç—Ä–∏–∫–µ 92.30%. –≠—Ç–∏ —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ã —Å–≤–∏–¥–µ—Ç–µ–ª—å—Å—Ç–≤—É—é—Ç –æ –≤—ã—Å–æ–∫–æ–º –∫–∞—á–µ—Å—Ç–≤–µ –º–æ–¥–µ–ª–∏ –ø–æ—Å–ª–µ –¥–æ–æ–±—É—á–µ–Ω–∏—è, –ø–æ–¥—Ç–≤–µ—Ä–∂–¥–∞—è –µ–µ —Å–ø–æ—Å–æ–±–Ω–æ—Å—Ç—å —ç—Ñ—Ñ–µ–∫—Ç–∏–≤–Ω–æ —Ä–µ—à–∞—Ç—å –∑–∞–¥–∞—á—É –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏ –Ω–∞ –≤—ã–±—Ä–∞–Ω–Ω–æ–º –Ω–∞–±–æ—Ä–µ –¥–∞–Ω–Ω—ã—Ö.