In [1]:
%reload_ext autoreload
%autoreload 2

## Imports

In [2]:
import site

site.getsitepackages()

['/home/maidari/miniconda3/envs/myenv311/lib/python3.11/site-packages']

In [3]:
import transformers
print(transformers.__version__)

4.51.2


In [4]:
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader

from transformers import AutoTokenizer, BertForSequenceClassification

## Load BERT with MOE

In [5]:
from moe.models.bert.modeling_bert_with_moe import BertMoEForSequenceClassification
from transformers.models.bert.modeling_bert import BertConfig

In [18]:
model_name = "google-bert/bert-base-uncased"

from_pretrained = True

In [29]:
config = BertConfig.from_pretrained(model_name)
config.moe_num_experts = 4
config.moe_top_k = 1
config.moe_aux_loss_coef = 0.1
config.num_labels = 5

tokenizer = AutoTokenizer.from_pretrained(model_name)


if from_pretrained:

    ## Копирование весов с предобученной модели вручную
    # base_model = BertForSequenceClassification.from_pretrained(model_name, config=config)
    # model = BertMoEForSequenceClassification(config)
    # # 3.1. Эмбеддинги
    # model.bert.embeddings.load_state_dict(base_model.bert.embeddings.state_dict())
    
    # # 3.2. 12 слоёв encoder’а
    # for i, base_layer in enumerate(base_model.bert.encoder.layer):
    #     moe_layer = model.bert.encoder.layer[i]
    #     # attention
    #     moe_layer.attention.self.load_state_dict(    base_layer.attention.self.state_dict()    )
    #     moe_layer.attention.output.load_state_dict(  base_layer.attention.output.state_dict()  )
    
    # # 3.3. Pooler
    # model.bert.pooler.load_state_dict(base_model.bert.pooler.state_dict())
    
    # # 3.4. Классификатор
    # model.classifier.load_state_dict(base_model.classifier.state_dict())
    
    ## Копирование весов с предобученной модели через from_pretrained
    model = BertMoEForSequenceClassification.from_pretrained(model_name, config=config)

else:
    model = BertMoEForSequenceClassification(config)

Some weights of BertMoEForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['bert.encoder.layer.0.moe.experts.0.dense.bias', 'bert.encoder.layer.0.moe.experts.0.dense.weight', 'bert.encoder.layer.0.moe.experts.1.dense.bias', 'bert.encoder.layer.0.moe.experts.1.dense.weight', 'bert.encoder.layer.0.moe.experts.2.dense.bias', 'bert.encoder.layer.0.moe.experts.2.dense.weight', 'bert.encoder.layer.0.moe.experts.3.dense.bias', 'bert.encoder.layer.0.moe.experts.3.dense.weight', 'bert.encoder.layer.0.moe.output.LayerNorm.bias', 'bert.encoder.layer.0.moe.output.LayerNorm.weight', 'bert.encoder.layer.0.moe.output.dense.bias', 'bert.encoder.layer.0.moe.output.dense.weight', 'bert.encoder.layer.0.moe.router.weight', 'bert.encoder.layer.1.moe.experts.0.dense.bias', 'bert.encoder.layer.1.moe.experts.0.dense.weight', 'bert.encoder.layer.1.moe.experts.1.dense.bias', 'bert.encoder.layer.1.moe.experts.1.dense.weight', 'bert

In [30]:
config._attn_implementation

'eager'

In [31]:
model

BertMoEForSequenceClassification(
  (bert): BertMoEModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertMoEEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayerWithMoE(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNor

## Load Dataset

In [32]:
MAX_LENGTH = 128
BATCH_SIZE = 16
EPOCHS = 2
LEARNING_RATE = 2e-5
NUM_CLASSES = 5  # 5 классов: Politics 0, Sport 1, Technology 2, Entertainment 3, Business 4

In [33]:
df = pd.read_csv("../dataset/df_file.csv")
df.columns = ["text", "label"]

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])
print(train_df.shape, test_df.shape)

(1780, 2) (445, 2)


In [34]:
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

train_dataset = TextDataset(
    train_df['text'].values,
    train_df['label'].values,
    tokenizer,
    MAX_LENGTH
)

test_dataset = TextDataset(
    test_df['text'].values,
    test_df['label'].values,
    tokenizer,
    MAX_LENGTH
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE)

## Training with random init weights

In [35]:
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

In [36]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [37]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask, labels=labels)
        loss = outputs.loss
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{EPOCHS}, Train Loss: {avg_train_loss:.4f}")

  0%|          | 0/112 [00:00<?, ?it/s]

Epoch 1/2, Train Loss: 1.6280


  0%|          | 0/112 [00:00<?, ?it/s]

Epoch 2/2, Train Loss: 1.1094


In [28]:
model.eval()
predictions = []
true_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        
        outputs = model(input_ids, attention_mask)
        logits = outputs.logits
        _, preds = torch.max(logits, dim=1)
        
        predictions.extend(preds.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

print(classification_report(
    true_labels,
    predictions,
    target_names=['Politics', 'Sport', 'Technology', 'Entertainment', 'Business']
))

  0%|          | 0/28 [00:00<?, ?it/s]

               precision    recall  f1-score   support

     Politics       0.91      0.92      0.91        84
        Sport       0.96      0.95      0.96       102
   Technology       0.78      0.95      0.85        80
Entertainment       0.96      0.87      0.91        77
     Business       0.90      0.80      0.85       102

     accuracy                           0.90       445
    macro avg       0.90      0.90      0.90       445
 weighted avg       0.90      0.90      0.90       445



with random init weights

               precision    recall  f1-score   support

     Politics       0.97      0.77      0.86        84
     Sport          0.94      0.98      0.96       102
     Technology     0.77      0.91      0.83        80
     Entertainment  0.81      0.86      0.84        77
     Business       0.94      0.88      0.91       102

     accuracy                           0.89       445
     macro avg      0.89      0.88      0.88       445
     weighted avg   0.89      0.89      0.89       445