In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, f1_score

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
data = pd.read_csv('/content/drive/My Drive/cartoon/train.csv', lineterminator='\n')

In [4]:
# Кодирование целевой переменной
label_encoder = LabelEncoder()
data['cartoon'] = label_encoder.fit_transform(data['cartoon'])

In [5]:
# Разделение на обучающую и тестовую выборки
train_texts, val_texts, train_labels, val_labels = train_test_split(
    data['text'], data['cartoon'], test_size=0.2, random_state=42
)

In [6]:
# Загрузка токенизатора
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-multilingual-cased")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/466 [00:00<?, ?B/s]



In [7]:
# Создание собственного Dataset класса
class CartoonDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts.iloc[item])
        label = self.labels.iloc[item]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [8]:
# Создание объектов Dataset для обучающей и валидационной выборок
train_dataset = CartoonDataset(train_texts, train_labels, tokenizer)
val_dataset = CartoonDataset(val_texts, val_labels, tokenizer)

In [9]:
# Загрузка модели DistilBERT для задачи классификации
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-multilingual-cased', num_labels=len(label_encoder.classes_))

model.safetensors:   0%|          | 0.00/542M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-multilingual-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# Параметры обучения
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
)




In [11]:
# Создание Trainer для обучения модели
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

In [12]:
# Обучение модель
trainer.train()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch,Training Loss,Validation Loss
1,0.0453,0.040464
2,0.029,0.0203
3,0.0104,0.016407


TrainOutput(global_step=12513, training_loss=0.057762805797909315, metrics={'train_runtime': 4579.2019, 'train_samples_per_second': 43.716, 'train_steps_per_second': 2.733, 'total_flos': 6634547197102080.0, 'train_loss': 0.057762805797909315, 'epoch': 3.0})

In [13]:
# Оценка модели
evaluation_results = trainer.evaluate()

In [16]:
# Получаем предсказания для валидационной выборки
val_loader = DataLoader(val_dataset, batch_size=16)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
val_predictions, val_labels_list = [], []
with torch.no_grad():
    for batch in val_loader:
        # Перемещаем входные данные на нужное устройство
        inputs = {
            'input_ids': batch['input_ids'].to(device),
            'attention_mask': batch['attention_mask'].to(device)
        }
        # Перемещаем метки на нужное устройство
        labels = batch['labels'].to(device)

        # Получаем предсказания
        outputs = model(**inputs)
        logits = outputs.logits

        # Получаем индексы предсказанных классов
        preds = torch.argmax(logits, axis=1)

        # Сохраняем предсказания и метки
        val_predictions.extend(preds.tolist())
        val_labels_list.extend(labels.tolist())

In [17]:
# Отчет по классификации
val_predictions_labels = label_encoder.inverse_transform(val_predictions)
val_labels_original = label_encoder.inverse_transform(val_labels_list)
print(classification_report(val_labels_original, val_predictions_labels))

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


                            precision    recall  f1-score   support

    Cry babies magic tears       1.00      1.00      1.00         7
Enchantimals (Эншантималс)       0.00      0.00      0.00         1
            My little pony       0.94      0.94      0.94        18
                      none       1.00      1.00      1.00     15654
                  Акуленок       0.00      0.00      0.00         2
                Барбоскины       0.00      0.00      0.00         1
      Бременские музыканты       0.00      0.00      0.00         1
                      Буба       1.00      1.00      1.00        97
                    Бэтмен       0.95      1.00      0.98        59
                     Вспыш       0.00      0.00      0.00         1
             Говорящий Том       1.00      1.00      1.00       134
                 Губка Боб       0.00      0.00      0.00         3
                    Енотки       1.00      1.00      1.00        59
          ЖилаБыла Царевна       0.89      0.93

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [18]:
# Метрика F1 (macro)
f1_macro = f1_score(val_labels_original, val_predictions_labels, average='macro')
print(f'F1 Score (Macro): {f1_macro}')

F1 Score (Macro): 0.7293247582498988


**Тестовая выборка**

In [19]:
new_texts = pd.read_csv('/content/drive/My Drive/cartoon/test.csv', lineterminator='\n')

In [21]:
# Предсказание для новой выборки
class NewTextsDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts.iloc[item])

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

In [22]:
new_texts_dataset = NewTextsDataset(new_texts['text'], tokenizer)
new_texts_loader = DataLoader(new_texts_dataset, batch_size=16)

In [24]:
# Получение предсказаний
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
predictions = []
with torch.no_grad():
    for batch in new_texts_loader:
        # Перемещение всех тензоров в batch на нужное устройство
        batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}

        # Подготовка входных данных для модели
        inputs = {
            'input_ids': batch['input_ids'],
            'attention_mask': batch['attention_mask']
        }

        # Получение предсказаний
        outputs = model(**inputs)
        logits = outputs.logits
        preds = torch.argmax(logits, axis=1)
        predictions.extend(preds.tolist())

In [25]:
# Добавление предсказаний к new_texts
new_texts['predicted_cartoon'] = label_encoder.inverse_transform(predictions)

In [26]:
# Создаем новый DataFrame с результатами
results = pd.DataFrame({
    'yt_reel_id': new_texts['yt_reel_id'],
    'cartoon': new_texts['predicted_cartoon']
})

In [27]:
# Сохраняем результаты в CSV файл
results.to_csv('/content/drive/My Drive/predictions.csv', index=False)

In [28]:
!pip list

Package                            Version
---------------------------------- --------------------
absl-py                            1.4.0
accelerate                         0.34.2
aiohappyeyeballs                   2.4.3
aiohttp                            3.10.10
aiosignal                          1.3.1
alabaster                          0.7.16
albucore                           0.0.16
albumentations                     1.4.15
altair                             4.2.2
annotated-types                    0.7.0
anyio                              3.7.1
argon2-cffi                        23.1.0
argon2-cffi-bindings               21.2.0
array_record                       0.5.1
arviz                              0.19.0
astropy                            6.1.4
astropy-iers-data                  0.2024.10.14.0.32.55
astunparse                         1.6.3
async-timeout                      4.0.3
atpublic                           4.1.0
attrs                              24.2.0
audioread      