In [1]:
import os
import time
import torch
import warnings
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
for warn in [UserWarning, FutureWarning]: warnings.filterwarnings("ignore", category = warn)

from src.data_utils.config import DatasetConfig
from src.data_utils.dataset_params import DatasetName
from src.data_utils.dataset_generator import DatasetGenerator
from src.models.models import TransformerClassifier, CustomMambaClassifier, LSTMClassifier

MAX_SEQ_LEN = 300
EMBEDDING_DIM = 128
BATCH_SIZE = 32
LEARNING_RATE = 7e-5 # уменьшили lr: 1e-4 -> 7e-5
NUM_EPOCHS = 20 # подняли количество эпох: 5 -> 20
NUM_CLASSES = 2

SAVE_DIR = "../pretrained_comparison"
os.makedirs(SAVE_DIR, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = DatasetConfig(
    load_from_disk=True,
    path_to_data="../datasets",
    train_size=25000,  # увеличили количество сэмплов
    val_size=12500,
    test_size=12500
)

In [2]:

generator = DatasetGenerator(DatasetName.IMDB, config=config)
(X_train, y_train), (X_val, y_val), (X_test, y_test) = generator.generate_dataset()
VOCAB_SIZE = len(generator.vocab)

In [3]:

def train_and_evaluate(model, train_loader, val_loader, optimizer, criterion, num_epochs, device, model_name, save_path):
    best_val_f1 = 0.0
    history = {'train_loss': [], 'val_loss': [], 'val_accuracy': [], 'val_f1': []}
    
    print(f"--- Начало обучения модели: {model_name} на устройстве {device} ---")

    for epoch in range(num_epochs):
        model.train()
        start_time = time.time()
        total_train_loss = 0

        for batch_X, batch_y in train_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            optimizer.zero_grad()
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        
        avg_train_loss = total_train_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        model.eval()
        total_val_loss = 0
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for batch_X, batch_y in val_loader:
                batch_X, batch_y = batch_X.to(device), batch_y.to(device)
                outputs = model(batch_X)
                loss = criterion(outputs, batch_y)
                total_val_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(batch_y.cpu().numpy())
        
        avg_val_loss = total_val_loss / len(val_loader)
        
        accuracy = accuracy_score(all_labels, all_preds)
        precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
        
        history['val_loss'].append(avg_val_loss)
        history['val_accuracy'].append(accuracy)
        history['val_f1'].append(f1)

        epoch_time = time.time() - start_time
        print(f"Эпоха {epoch+1}/{num_epochs} | Время: {epoch_time:.2f}с | Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f} | Val Acc: {accuracy:.4f} | Val F1: {f1:.4f}")

        if f1 > best_val_f1:
            best_val_f1 = f1
            torch.save(model.state_dict(), save_path)
            print(f"  -> Модель сохранена, новый лучший Val F1: {best_val_f1:.4f}")
            
    print(f"--- Обучение модели {model_name} завершено ---")
    return history

def evaluate_on_test(model, test_loader, device, criterion):
    model.eval()
    total_test_loss = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch_X, batch_y in test_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            total_test_loss += loss.item()
            
            _, predicted = torch.max(outputs.data, 1)
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(batch_y.cpu().numpy())
            
    avg_test_loss = total_test_loss / len(test_loader)
        
    accuracy = accuracy_score(all_labels, all_preds)
    precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')
    
    return {'loss': avg_test_loss, 'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f1_score': f1}

In [4]:

def create_dataloader(X, y, batch_size, shuffle=True):
    X_tensor = torch.as_tensor(X, dtype=torch.long)
    y_tensor = torch.as_tensor(y, dtype=torch.long)
    dataset = TensorDataset(X_tensor, y_tensor)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

train_loader = create_dataloader(X_train, y_train, BATCH_SIZE)
val_loader = create_dataloader(X_val, y_val, BATCH_SIZE, shuffle=False)
test_loader = create_dataloader(X_test, y_test, BATCH_SIZE, shuffle=False)

In [5]:
model_configs = {
    "CustomMamba": {
        "class": CustomMambaClassifier,
        "params": {'vocab_size': VOCAB_SIZE, 'd_model': EMBEDDING_DIM, 'd_state': 8, 
                   'd_conv': 4, 'num_layers': 2, 'num_classes': NUM_CLASSES},
    },

    "Lib_Transformer": {
        "class": TransformerClassifier,
        "params": {'vocab_size': VOCAB_SIZE, 'embed_dim': EMBEDDING_DIM, 'num_heads': 8, 
                   'num_layers': 4, 'num_classes': NUM_CLASSES, 'max_seq_len': MAX_SEQ_LEN},
        # num_layers: 2 -> 4
        # num_heads: 4 -> 8
    },
}

results = {}
for model_name, config in model_configs.items():

    model_path = os.path.join(SAVE_DIR, f"best_model_{model_name.lower()}.pth")
    
    model = config['class'](**config['params']).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    
    train_and_evaluate(
        model=model, train_loader=train_loader, val_loader=val_loader,
        optimizer=optimizer, criterion=criterion, num_epochs=NUM_EPOCHS,
        device=DEVICE, model_name=model_name, save_path=model_path
    )
    
    print(f"--- Оценка лучшей модели {model_name} на тестовых данных ---")
    if os.path.exists(model_path):
        best_model = config['class'](**config['params']).to(DEVICE)
        best_model.load_state_dict(torch.load(model_path))
        test_metrics = evaluate_on_test(best_model, test_loader, DEVICE, criterion)
        results[model_name] = test_metrics
        print(f"Результаты для {model_name}: {test_metrics}")
    else:
        print(f"Файл лучшей модели для {model_name} не найден. Пропускаем оценку.")

    print("-" * 60)
    
if results:
    results_df = pd.DataFrame(results).T
    print("\n\n--- Итоговая таблица сравнения моделей на тестовых данных ---")
    print(results_df.to_string())
else:
    print("Не удалось получить результаты ни для одной модели.")

--- Начало обучения модели: CustomMamba на устройстве cuda ---
Эпоха 1/20 | Время: 1263.97с | Train Loss: 0.6205 | Val Loss: 0.5389 | Val Acc: 0.7340 | Val F1: 0.6958
  -> Модель сохранена, новый лучший Val F1: 0.6958
Эпоха 2/20 | Время: 1287.85с | Train Loss: 0.4529 | Val Loss: 0.4690 | Val Acc: 0.7791 | Val F1: 0.7605
  -> Модель сохранена, новый лучший Val F1: 0.7605
Эпоха 3/20 | Время: 1206.72с | Train Loss: 0.3855 | Val Loss: 0.4334 | Val Acc: 0.8014 | Val F1: 0.7886
  -> Модель сохранена, новый лучший Val F1: 0.7886
Эпоха 4/20 | Время: 1322.62с | Train Loss: 0.3327 | Val Loss: 0.4332 | Val Acc: 0.8089 | Val F1: 0.7933
  -> Модель сохранена, новый лучший Val F1: 0.7933
Эпоха 5/20 | Время: 1342.46с | Train Loss: 0.2897 | Val Loss: 0.4408 | Val Acc: 0.8106 | Val F1: 0.7934
  -> Модель сохранена, новый лучший Val F1: 0.7934
Эпоха 6/20 | Время: 1326.47с | Train Loss: 0.2533 | Val Loss: 0.4253 | Val Acc: 0.8189 | Val F1: 0.8056
  -> Модель сохранена, новый лучший Val F1: 0.8056
Эпоха 7

Видим, что transformer сильно выигрывает у мамбы как и по времени, так и по качеству. Будем обучать и инферить его 