In [12]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm
import time
import os
import urllib.request
import tarfile
import csv
from collections import Counter

In [13]:
import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\karin\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [14]:
from models.ctm_nlp import CTM_NLP

In [24]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

BATCH_SIZE = 64
VOCAB_SIZE_LIMIT = 10000
MAX_SEQ_LEN = 512

EPOCHS = 3
LEARNING_RATE = 0.001

CTM_D_MODEL = 256
CTM_D_INPUT = 128
CTM_ITERATIONS = 10
CTM_HEADS = 4
CTM_SYNCH_OUT = 128
CTM_SYNCH_ACTION = 64
CTM_SYNAPSE_DEPTH = 2
CTM_MEMORY_LENGTH = 10
CTM_MEMORY_HIDDEN = 32
LSTM_HIDDEN_DIM = 128
LSTM_NUM_LAYERS = 2

Using device: cuda


In [16]:
def download_and_extract_ag_news(root='./data'):
    """Donwnloading AG_NEWS"""
    url = "https://s3.amazonaws.com/fast-ai-nlp/ag_news_csv.tgz"
    data_path = os.path.join(root, 'ag_news_csv')
    
    if os.path.exists(data_path):
        print("Dataset already downloaded and extracted.")
    else:
        print("Downloading AG_NEWS dataset...")
        os.makedirs(root, exist_ok=True)
        tgz_path = os.path.join(root, 'ag_news_csv.tgz')
        urllib.request.urlretrieve(url, tgz_path)
        print("Extracting...")
        with tarfile.open(tgz_path, 'r:gz') as tar:
            tar.extractall(path=root)
        os.remove(tgz_path)
        print("Done.")
        
    train_data, test_data = [], []
    with open(os.path.join(data_path, 'train.csv'), 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            # Класс, Заголовок, Описание
            train_data.append((int(row[0]), row[1] + " " + row[2]))
            
    with open(os.path.join(data_path, 'test.csv'), 'r', encoding='utf-8') as f:
        reader = csv.reader(f)
        for row in reader:
            test_data.append((int(row[0]), row[1] + " " + row[2]))
            
    return train_data, test_data

def simple_tokenizer(text):
    """Simple example of tokenization"""
    return text.lower().strip().split()
    # return word_tokenize(text.lower().strip())

def build_vocab(data, tokenizer, max_size):
    """word -> index"""
    counter = Counter()
    for _, text in data:
        counter.update(tokenizer(text))

    most_common_words = [word for word, _ in counter.most_common(max_size - 2)] # -2 for <pad> и <unk>
    
    word_to_idx = {'<pad>': 0, '<unk>': 1}
    for i, word in enumerate(most_common_words):
        word_to_idx[word] = i + 2
        
    return word_to_idx

class NewsDataset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [26]:
train_data, test_data = download_and_extract_ag_news()
word_to_idx = build_vocab(train_data, simple_tokenizer, VOCAB_SIZE_LIMIT)
VOCAB_SIZE = len(word_to_idx)
PAD_IDX = word_to_idx['<pad>']
NUM_CLASS = 4

print(f"Vocabulary size: {VOCAB_SIZE}")

Dataset already downloaded and extracted.
Vocabulary size: 10000


In [18]:
def collate_batch(batch):
    label_list, text_list = [], []
    unk_idx = word_to_idx['<unk>']
    for (_label, _text) in batch:
        label_list.append(int(_label) - 1)
        tokens = simple_tokenizer(_text)
        indices = [word_to_idx.get(token, unk_idx) for token in tokens]
        processed_text = torch.tensor(indices, dtype=torch.int64)
        text_list.append(processed_text)
        
    padded_texts = pad_sequence(text_list, batch_first=True, padding_value=PAD_IDX)
    return torch.tensor(label_list, dtype=torch.int64), padded_texts

In [19]:
def train_epoch(model, dataloader, optimizer, criterion, model_type='lstm'):
    model.train()
    total_acc, total_loss, total_count = 0, 0, 0
    progress_bar = tqdm(dataloader, desc=f'Training {model_type}')
    for idx, (label, text) in enumerate(progress_bar):
        label, text = label.to(DEVICE), text.to(DEVICE)
        optimizer.zero_grad()
        if model_type == 'ctm':
            predictions, _, _ = model(text)
            logits = predictions[:, :, -1]
        else:
            logits = model(text)
        loss = criterion(logits, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_acc += (logits.argmax(1) == label).sum().item()
        total_loss += loss.item()
        total_count += label.size(0)
        progress_bar.set_postfix({'loss': total_loss / total_count, 'acc': total_acc / total_count})
    return total_acc / total_count, total_loss / total_count

def evaluate(model, dataloader, criterion, model_type='lstm'):
    model.eval()
    total_acc, total_loss, total_count = 0, 0, 0
    with torch.no_grad():
        for idx, (label, text) in enumerate(dataloader):
            label, text = label.to(DEVICE), text.to(DEVICE)
            if model_type == 'ctm':
                predictions, _, _ = model(text)
                logits = predictions[:, :, -1]
            else:
                logits = model(text)
            loss = criterion(logits, label)
            total_acc += (logits.argmax(1) == label).sum().item()
            total_loss += loss.item()
            total_count += label.size(0)
    return total_acc / total_count, total_loss / total_count

In [20]:
train_dataset = NewsDataset(train_data)
test_dataset = NewsDataset(test_data)
    
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)

results = {}

In [27]:
# Initialize model
ctm_model = CTM_NLP(
    vocab_size=VOCAB_SIZE,
    max_seq_len=MAX_SEQ_LEN,
    iterations=CTM_ITERATIONS,
    d_model=CTM_D_MODEL,
    d_input=CTM_D_INPUT,
    # out_dims=NUM_CLASS,
    heads=CTM_HEADS,
    n_synch_out=CTM_SYNCH_OUT,
    n_synch_action=CTM_SYNCH_ACTION,
    synapse_depth=CTM_SYNAPSE_DEPTH,
    memory_length=CTM_MEMORY_LENGTH,
    deep_nlms=True,
    memory_hidden_dims=CTM_MEMORY_HIDDEN,
    do_layernorm_nlm=False,
    dropout=0.2
).to(DEVICE)
    
optimizer_ctm = torch.optim.AdamW(ctm_model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

Using neuron select type: random-pairing
Synch representation size action: 64
Synch representation size out: 128
Initializing CTM for NLP tasks...
CTM_NLP initialized with vocab_size=10000, max_seq_len=512
Output projection layer will map to 10000 logits.


In [22]:
for epoch in range(1, EPOCHS + 1):
        epoch_start_time = time.time()
        train_acc, train_loss = train_epoch(ctm_model, train_dataloader, optimizer_ctm, criterion, model_type='ctm')
        test_acc, test_loss = evaluate(ctm_model, test_dataloader, criterion, model_type='ctm')
        
        print(f'CTM Epoch: {epoch}, Time: {time.time() - epoch_start_time:.2f}s')
        print(f'\tTrain Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%')
        print(f'\tTest Loss:  {test_loss:.4f} | Test Acc:  {test_acc*100:.2f}%')
    
results['CTM_NLP'] = test_acc

Training ctm: 100%|██████████| 1875/1875 [14:56<00:00,  2.09it/s, loss=0.00931, acc=0.592]


CTM Epoch: 1, Time: 908.39s
	Train Loss: 0.0093 | Train Acc: 59.24%
	Test Loss:  0.0039 | Test Acc:  68.50%


Training ctm: 100%|██████████| 1875/1875 [04:10<00:00,  7.50it/s, loss=0.00286, acc=0.701]


CTM Epoch: 2, Time: 253.98s
	Train Loss: 0.0029 | Train Acc: 70.15%
	Test Loss:  0.0032 | Test Acc:  69.61%


Training ctm: 100%|██████████| 1875/1875 [04:07<00:00,  7.57it/s, loss=0.00212, acc=0.714]


CTM Epoch: 3, Time: 255.17s
	Train Loss: 0.0021 | Train Acc: 71.42%
	Test Loss:  0.0032 | Test Acc:  69.88%


In [23]:
print("\n" + "="*40)
print("           FINAL RESULTS")
print("="*40)
print(f"  CTM_NLP Test Accuracy:  {results.get('CTM_NLP', 0)*100:.2f}%")
print("="*40)


           FINAL RESULTS
  CTM_NLP Test Accuracy:  69.88%
