In [26]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from sklearn.metrics import classification_report
import numpy as np

In [5]:
# Проверка наличия GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [6]:
df =  pd.read_csv("website_classification.csv")
df.head()

Unnamed: 0,website_url,text,category
0,https://www.booking.com/index.html?aid=1743217,official site good hotel accommodation big sav...,Travel
1,https://travelsites.com/expedia/,expedia hotel book sites like use vacation wor...,Travel
2,https://travelsites.com/tripadvisor/,tripadvisor hotel book sites like previously d...,Travel
3,https://www.momondo.in/?ispredir=true,cheap flights search compare flights momondo f...,Travel
4,https://www.ebookers.com/?AFFCID=EBOOKERS-UK.n...,bot create free account create free account si...,Travel


In [7]:
# Проверка типов данных в столбце text
print(df['text'].apply(type).value_counts())

text
<class 'str'>    1408
Name: count, dtype: int64


In [8]:
# Преобразование списков в строки (если необходимо)
df['text'] = df['text'].apply(lambda x: ' '.join(x) if isinstance(x, list) else x)

In [9]:
# Удаление пропусков и пустых строк
df = df.dropna(subset=['text', 'category'])
df = df[df['text'].str.strip().astype(bool)]

In [10]:
# Преобразование категорий в числовые метки
df['label'] = df['category'].astype('category').cat.codes

In [11]:
# Загрузка токенизатора и модели BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
num_labels = df['label'].nunique()
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
# Создание датасета с использованием Hugging Face Datasets
dataset = Dataset.from_pandas(df[['text', 'label']])

In [13]:
# Функция токенизации
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

# Токенизация данных
dataset = dataset.map(tokenize, batched=True)

Map: 100%|██████████| 1408/1408 [00:09<00:00, 149.05 examples/s]


In [14]:
# Удаление ненужных столбцов
dataset = dataset.remove_columns(['text'])

In [15]:
# Установка формата данных для PyTorch
dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [16]:
# Разделение на обучающую и тестовую выборки
train_testvalid = dataset.train_test_split(test_size=0.2, seed=42)
test_valid = train_testvalid['test'].train_test_split(test_size=0.5, seed=42)
train_dataset = train_testvalid['train']
eval_dataset = test_valid['train']
test_dataset = test_valid['test']

In [17]:
# Параметры обучения
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    evaluation_strategy="epoch",
    save_total_limit=1,
    remove_unused_columns=False,
    logging_dir='./logs',
    logging_steps=100,
)



In [18]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [19]:
# Создание тренера
trainer = Trainer(
    model=model.to(device),
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [20]:
# Обучение модели
trainer.train()

  attn_output = torch.nn.functional.scaled_dot_product_attention(
 24%|██▎       | 100/423 [00:47<02:24,  2.24it/s]

{'loss': 2.3464, 'grad_norm': 7.400256633758545, 'learning_rate': 3.817966903073286e-05, 'epoch': 0.71}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                 
 33%|███▎      | 141/423 [01:07<01:55,  2.45it/s]

{'eval_loss': 0.7035578489303589, 'eval_accuracy': 0.851063829787234, 'eval_f1': 0.838767339245441, 'eval_precision': 0.8566834558995813, 'eval_recall': 0.851063829787234, 'eval_runtime': 2.3554, 'eval_samples_per_second': 59.862, 'eval_steps_per_second': 3.821, 'epoch': 1.0}


 47%|████▋     | 200/423 [01:33<01:37,  2.29it/s]

{'loss': 0.7116, 'grad_norm': 15.36422061920166, 'learning_rate': 2.6359338061465723e-05, 'epoch': 1.42}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                 
 67%|██████▋   | 282/423 [02:12<00:59,  2.38it/s]

{'eval_loss': 0.3627060055732727, 'eval_accuracy': 0.9290780141843972, 'eval_f1': 0.9221348286462288, 'eval_precision': 0.9233987053135988, 'eval_recall': 0.9290780141843972, 'eval_runtime': 2.3191, 'eval_samples_per_second': 60.8, 'eval_steps_per_second': 3.881, 'epoch': 2.0}


 71%|███████   | 300/423 [02:20<00:54,  2.27it/s]

{'loss': 0.3224, 'grad_norm': 11.382461547851562, 'learning_rate': 1.4539007092198581e-05, 'epoch': 2.13}


 95%|█████████▍| 400/423 [03:03<00:09,  2.34it/s]

{'loss': 0.1419, 'grad_norm': 7.976169586181641, 'learning_rate': 2.7186761229314422e-06, 'epoch': 2.84}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
                                                 
100%|██████████| 423/423 [03:19<00:00,  2.12it/s]

{'eval_loss': 0.3245522379875183, 'eval_accuracy': 0.9290780141843972, 'eval_f1': 0.9213942761106694, 'eval_precision': 0.9235276543787183, 'eval_recall': 0.9290780141843972, 'eval_runtime': 2.2184, 'eval_samples_per_second': 63.56, 'eval_steps_per_second': 4.057, 'epoch': 3.0}
{'train_runtime': 199.5763, 'train_samples_per_second': 16.926, 'train_steps_per_second': 2.119, 'train_loss': 0.841328156473507, 'epoch': 3.0}





TrainOutput(global_step=423, training_loss=0.841328156473507, metrics={'train_runtime': 199.5763, 'train_samples_per_second': 16.926, 'train_steps_per_second': 2.119, 'total_flos': 888900866113536.0, 'train_loss': 0.841328156473507, 'epoch': 3.0})

In [21]:
# Оценка модели
trainer.evaluate()

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
100%|██████████| 9/9 [00:01<00:00,  4.80it/s]


{'eval_loss': 0.3245522379875183,
 'eval_accuracy': 0.9290780141843972,
 'eval_f1': 0.9213942761106694,
 'eval_precision': 0.9235276543787183,
 'eval_recall': 0.9290780141843972,
 'eval_runtime': 2.1114,
 'eval_samples_per_second': 66.781,
 'eval_steps_per_second': 4.263,
 'epoch': 3.0}

In [24]:
# Предсказание на тестовых данных
predictions = trainer.predict(test_dataset)
preds = np.argmax(predictions.predictions, axis=-1)
labels = test_dataset['label']

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
100%|██████████| 9/9 [00:02<00:00,  3.24it/s]


In [28]:
# Получение названий категорий
df['category'] = df['category'].astype('category')
category_names = df['category'].cat.categories.tolist()

# Вывод отчета о классификации
print(classification_report(labels, preds, target_names=category_names))

                                 precision    recall  f1-score   support

                          Adult       1.00      1.00      1.00         3
             Business/Corporate       1.00      0.83      0.91        12
       Computers and Technology       0.77      1.00      0.87        10
                     E-Commerce       1.00      1.00      1.00        12
                      Education       0.86      1.00      0.92        12
                           Food       1.00      1.00      1.00         7
                         Forums       0.00      0.00      0.00         2
                          Games       1.00      1.00      1.00         8
             Health and Fitness       1.00      1.00      1.00         8
             Law and Government       0.92      1.00      0.96        11
                           News       0.80      0.73      0.76        11
                    Photography       1.00      0.83      0.91         6
Social Networking and Messaging       0.71      0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
