In [56]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm
import time
import os
import urllib.request
import tarfile
import csv
from collections import Counter

In [57]:
from models.ctm_nlp import CTM_NLP

In [58]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

Using device: cpu


In [59]:
# Параметры данных
BATCH_SIZE = 64
VOCAB_SIZE_LIMIT = 10000
MAX_SEQ_LEN = 512

# Параметры обучения
EPOCHS = 3
LEARNING_RATE = 0.001

# Параметры моделей (остаются без изменений)
CTM_D_MODEL = 256
CTM_D_INPUT = 128
CTM_ITERATIONS = 10
CTM_HEADS = 4
CTM_SYNCH_OUT = 128
CTM_SYNCH_ACTION = 64
CTM_SYNAPSE_DEPTH = 2
CTM_MEMORY_LENGTH = 10
CTM_MEMORY_HIDDEN = 32
LSTM_HIDDEN_DIM = 128
LSTM_NUM_LAYERS = 2

In [60]:
# --- 2. Замена torchtext: Загрузка данных и создание словаря ---

def download_and_extract_ag_news(root='./data'):
    """Скачивает и извлекает датасет AG_NEWS, если его нет."""
    url = "https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz"
    data_path = os.path.join(root, 'ag_news_csv')
    
    if os.path.exists(data_path):
        print("Dataset already downloaded and extracted.")
    else:
        print("Downloading AG_NEWS dataset...")
        os.makedirs(root, exist_ok=True)
        tgz_path = os.path.join(root, 'ag_news_csv.tgz')
        urllib.request.urlretrieve(url, tgz_path)
        print("Extracting...")
        with tarfile.open(tgz_path, 'r:gz') as tar:
            tar.extractall(path=root)
        os.remove(tgz_path)
        print("Done.")
        
    train_data, test_data = [], []
    with open(os.path.join(data_path, 'train.csv'), 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            # Класс, Заголовок, Описание
            train_data.append((int(row[0]), row[1] + " " + row[2]))
            
    with open(os.path.join(data_path, 'test.csv'), 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            test_data.append((int(row[0]), row[1] + " " + row[2]))
            
    return train_data, test_data

def simple_tokenizer(text):
    """Простой токенизатор, который разделяет текст по пробелам."""
    return text.lower().strip().split()

def build_vocab(data, tokenizer, max_size):
    """Создает словарь (word -> index) на основе данных."""
    counter = Counter()
    for _, text in data:
        counter.update(tokenizer(text))
    
    # Создаем словарь с наиболее частыми словами
    most_common_words = [word for word, _ in counter.most_common(max_size - 2)] # -2 для <pad> и <unk>
    
    # Добавляем специальные токены
    word_to_idx = {'<pad>': 0, '<unk>': 1}
    for i, word in enumerate(most_common_words):
        word_to_idx[word] = i + 2
        
    return word_to_idx

class NewsDataset(Dataset):
    """Простой класс датасета для PyTorch."""
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [61]:
# --- 3. Модель Baseline: LSTM Classifier (без изменений) ---
class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_class, num_layers, pad_idx):
        super(LSTMClassifier, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, 
                            batch_first=True, bidirectional=True, dropout=0.3)
        self.fc = nn.Linear(hidden_dim * 2, num_class)

    def forward(self, text):
        embedded = self.embedding(text)
        _, (hidden, _) = self.lstm(embedded)
        hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        return self.fc(hidden)


# --- 4. Функции обучения и оценки (без изменений) ---
def train_epoch(model, dataloader, optimizer, criterion, model_type='lstm'):
    model.train()
    total_acc, total_loss, total_count = 0, 0, 0
    progress_bar = tqdm(dataloader, desc=f'Training {model_type}')
    for idx, (label, text) in enumerate(progress_bar):
        label, text = label.to(DEVICE), text.to(DEVICE)
        optimizer.zero_grad()
        if model_type == 'ctm':
            predictions, _, _ = model(text)
            logits = predictions[:, :, -1]
        else:
            logits = model(text)
        loss = criterion(logits, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_acc += (logits.argmax(1) == label).sum().item()
        total_loss += loss.item()
        total_count += label.size(0)
        progress_bar.set_postfix({'loss': total_loss / total_count, 'acc': total_acc / total_count})
    return total_acc / total_count, total_loss / total_count

def evaluate(model, dataloader, criterion, model_type='lstm'):
    model.eval()
    total_acc, total_loss, total_count = 0, 0, 0
    with torch.no_grad():
        for idx, (label, text) in enumerate(dataloader):
            label, text = label.to(DEVICE), text.to(DEVICE)
            if model_type == 'ctm':
                predictions, _, _ = model(text)
                logits = predictions[:, :, -1]
            else:
                logits = model(text)
            loss = criterion(logits, label)
            total_acc += (logits.argmax(1) == label).sum().item()
            total_loss += loss.item()
            total_count += label.size(0)
    return total_acc / total_count, total_loss / total_count

In [62]:
train_data, test_data = download_and_extract_ag_news()
word_to_idx = build_vocab(train_data, simple_tokenizer, VOCAB_SIZE_LIMIT)
VOCAB_SIZE = len(word_to_idx)
PAD_IDX = word_to_idx['<pad>']
NUM_CLASS = 4

print(f"Vocabulary size: {VOCAB_SIZE}")

Dataset already downloaded and extracted.
Vocabulary size: 10000


In [63]:
def collate_batch(batch):
    label_list, text_list = [], []
    unk_idx = word_to_idx['<unk>']
    for (_label, _text) in batch:
        label_list.append(int(_label) - 1)
        tokens = simple_tokenizer(_text)
        indices = [word_to_idx.get(token, unk_idx) for token in tokens]
        processed_text = torch.tensor(indices, dtype=torch.int64)
        text_list.append(processed_text)
        
    padded_texts = pad_sequence(text_list, batch_first=True, padding_value=PAD_IDX)
    return torch.tensor(label_list, dtype=torch.int64), padded_texts

In [64]:
train_dataset = NewsDataset(train_data)
test_dataset = NewsDataset(test_data)
    
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

results = {}

In [None]:
# --- Инициализация и обучение CTM ---
print("\n--- Testing CTM_NLP ---")
ctm_model = CTM_NLP(
    vocab_size=VOCAB_SIZE,
    max_seq_len=MAX_SEQ_LEN,
    iterations=CTM_ITERATIONS,
    d_model=CTM_D_MODEL,
    d_input=CTM_D_INPUT,
    # out_dims=NUM_CLASS,
    heads=CTM_HEADS,
    n_synch_out=CTM_SYNCH_OUT,
    n_synch_action=CTM_SYNCH_ACTION,
    synapse_depth=CTM_SYNAPSE_DEPTH,
    memory_length=CTM_MEMORY_LENGTH,
    deep_nlms=True,
    memory_hidden_dims=CTM_MEMORY_HIDDEN,
    do_layernorm_nlm=False,
    dropout=0.2
).to(DEVICE)
    
optimizer_ctm = torch.optim.AdamW(ctm_model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
    
for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train_acc, train_loss = train_epoch(ctm_model, train_dataloader, optimizer_ctm, criterion, model_type='ctm')
        test_acc, test_loss = evaluate(ctm_model, test_dataloader, criterion, model_type='ctm')
        
        print(f'CTM Epoch: {epoch}, Time: {time.time() - epoch_start_time:.2f}s')
        print(f'\tTrain Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\tTest Loss:  {test_loss:.4f} | Test Acc:  {test_acc*100:.2f}%')
    
results['CTM_NLP'] = test_acc


--- Testing CTM_NLP ---
Using neuron select type: random-pairing
Synch representation size action: 64
Synch representation size out: 128
Initializing CTM for NLP tasks...
CTM_NLP initialized with vocab_size=10000, max_seq_len=512
Output projection layer will map to 10000 logits.


Training ctm:   5%|▌         | 94/1875 [02:17<1:00:34,  2.04s/it, loss=0.0499, acc=0.222]

In [None]:
# --- Инициализация и обучение LSTM ---
print("\n--- Testing LSTM Baseline ---")
lstm_model = LSTMClassifier(
    vocab_size=VOCAB_SIZE,
    embed_dim=CTM_D_INPUT,
    hidden_dim=LSTM_HIDDEN_DIM,
    num_class=NUM_CLASS,
    num_layers=LSTM_NUM_LAYERS,
    pad_idx=PAD_IDX
).to(DEVICE)

optimizer_lstm = torch.optim.AdamW(lstm_model.parameters(), lr=LEARNING_RATE)
    
for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train_acc, train_loss = train_epoch(lstm_model, train_dataloader, optimizer_lstm, criterion, model_type='lstm')
        test_acc, test_loss = evaluate(lstm_model, test_dataloader, criterion, model_type='lstm')
        
        print(f'LSTM Epoch: {epoch}, Time: {time.time() - epoch_start_time:.2f}s')
        print(f'\tTrain Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\tTest Loss:  {test_loss:.4f} | Test Acc:  {test_acc*100:.2f}%')

results['LSTM'] = test_acc

In [None]:
# --- Итоговое сравнение ---
print("\n" + "="*40)
print("           FINAL RESULTS")
print("="*40)
print(f"  CTM_NLP Test Accuracy:  {results.get('CTM_NLP', 0)*100:.2f}%")
print(f"  LSTM Test Accuracy:     {results.get('LSTM', 0)*100:.2f}%")
print("="*40)