In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import math
from tqdm import tqdm
import os

from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

In [10]:
from models.ctm_nlp import CTM_NLP

In [11]:
# --- Configuration ---
CONFIG = {
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "batch_size": 32,
    "learning_rate": 2e-5,
    "epochs": 3,
    "d_model": 512,        # Core CTM latent space
    "d_input": 512,        # Embedding dimension
    "heads": 8,
    "iterations": 8,       # Number of "thought" steps
    "synapse_layers": 4,   # Depth of the new Transformer synapse
    "memory_length": 8,
    "n_synch": 128,        # Number of neurons for sync representation
    "dropout": 0.1,
}
print(f"Using device: {CONFIG['device']}")

Using device: cuda


# 1. DATA PREPARATION

In [12]:
def get_ag_news_data(
    data_dir='./data',
    vocab_size=25000, # A reasonable vocabulary size
    max_seq_len=256   # Cap sequence length to avoid excessive memory use
):
    """
    Loads the AG_NEWS dataset using the Hugging Face `datasets` library,
    builds a vocabulary, and prepares DataLoaders for training and testing.
    This version avoids using torchtext completely.
    """
    print("Loading AG_NEWS dataset using Hugging Face `datasets`...")
    
    # 1. Load the dataset from the Hugging Face Hub
    # This is very robust and caches the data locally.
    dataset = load_dataset("ag_news")

    # 2. Train a tokenizer
    # We will train a simple WordLevel tokenizer on the training data.
    # This is more flexible and modern than the old torchtext vocab system.
    tokenizer_path = os.path.join(data_dir, 'ag_news_tokenizer.json')
    
    if not os.path.exists(tokenizer_path):
        print("Training a new tokenizer...")
        # Initialize a tokenizer
        tokenizer = Tokenizer(WordLevel(unk_token="<unk>"))
        tokenizer.pre_tokenizer = Whitespace()

        # Create a trainer
        trainer = WordLevelTrainer(
            vocab_size=vocab_size,
            special_tokens=["<unk>", "<pad>", "<cls>", "<sep>"]
        )

        # A generator function to feed text to the trainer
        def get_training_corpus():
            for i in range(len(dataset["train"])):
                yield dataset["train"][i]["text"]

        # Train the tokenizer
        tokenizer.train_from_iterator(get_training_corpus(), trainer=trainer)
        
        # Save the tokenizer for future use
        if not os.path.exists(data_dir):
            os.makedirs(data_dir)
        tokenizer.save(tokenizer_path)
    else:
        print(f"Loading tokenizer from {tokenizer_path}")
        tokenizer = Tokenizer.from_file(tokenizer_path)

    # Get the vocabulary size and padding token ID from the trained tokenizer
    actual_vocab_size = tokenizer.get_vocab_size()
    padding_idx = tokenizer.token_to_id("<pad>")
    
    print(f"Vocabulary size: {actual_vocab_size}")
    print(f"Padding index: {padding_idx}")

    # 3. Create a preprocessing function
    def preprocess_function(examples):
        # Tokenize the texts and truncate to max_seq_len
        tokenized_inputs = tokenizer.encode_batch(examples["text"])
        
        # Extract input_ids and create attention_masks
        input_ids = [encoding.ids[:max_seq_len] for encoding in tokenized_inputs]
        
        # Labels in 'ag_news' from `datasets` are already 0-indexed (0-3)
        return {"input_ids": input_ids, "labels": examples["label"]}

    print("Tokenizing and formatting the dataset...")
    tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=["text"])
    
    # Set the format to PyTorch tensors
    tokenized_datasets.set_format("torch")
    
    train_dataset = tokenized_datasets["train"]
    test_dataset = tokenized_datasets["test"]

    # 4. Define the collate function
    def collate_batch(batch):
        # The batch is now a list of dictionaries
        input_ids_list = [item['input_ids'] for item in batch]
        labels_list = [item['labels'] for item in batch]
        
        # Pad sequences to the max length in this batch
        padded_texts = nn.utils.rnn.pad_sequence(
            input_ids_list, 
            batch_first=True, 
            padding_value=padding_idx
        )
        
        # Create attention masks (1 for real tokens, 0 for padding)
        attention_masks = (padded_texts != padding_idx).int()
        
        labels = torch.tensor(labels_list, dtype=torch.int64)
        
        return padded_texts, attention_masks, labels

    # 5. Create DataLoaders
    train_dataloader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False, collate_fn=collate_batch)

    return train_dataloader, test_dataloader, actual_vocab_size, padding_idx

# 2. Training

In [13]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss, total_acc, total_count = 0, 0, 0
    progress_bar = tqdm(dataloader, desc="Training")

    for padded_texts, attention_masks, labels in progress_bar:
        padded_texts = padded_texts.to(device)
        attention_masks = attention_masks.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        # The model returns (predictions, certainties, final_sync_state)
        predictions, _, _ = model(padded_texts, attention_mask=attention_masks)
        
        # For classification, we use the output from the FINAL thought step
        logits = predictions[:, :, -1]
        
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        total_acc += (logits.argmax(1) == labels).sum().item()
        total_count += labels.size(0)
        
        progress_bar.set_postfix({'loss': total_loss / total_count, 'acc': total_acc / total_count})

    return total_loss / total_count, total_acc / total_count

def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss, total_acc, total_count = 0, 0, 0
    progress_bar = tqdm(dataloader, desc="Evaluating")
    
    with torch.no_grad():
        for padded_texts, attention_masks, labels in progress_bar:
            padded_texts = padded_texts.to(device)
            attention_masks = attention_masks.to(device)
            labels = labels.to(device)

            predictions, _, _ = model(padded_texts, attention_mask=attention_masks)
            logits = predictions[:, :, -1]

            loss = criterion(logits, labels)
            total_loss += loss.item()
            total_acc += (logits.argmax(1) == labels).sum().item()
            total_count += labels.size(0)
            
            progress_bar.set_postfix({'loss': total_loss / total_count, 'acc': total_acc / total_count})

    return total_loss / total_count, total_acc / total_count

In [14]:
# --- Data ---
train_loader, test_loader, vocab_size, padding_idx = get_ag_news_data()
num_classes = 4

print(f"Number of classes: {num_classes}")

Loading AG_NEWS dataset using Hugging Face `datasets`...
Loading tokenizer from ./data\ag_news_tokenizer.json
Vocabulary size: 25000
Padding index: 1
Tokenizing and formatting the dataset...
Number of classes: 4


In [15]:
# --- Model ---
model = CTM_NLP(
    vocab_size=vocab_size,
    num_classes=num_classes,
    padding_idx=padding_idx,
    d_model=CONFIG['d_model'],
    d_input=CONFIG['d_input'],
    heads=CONFIG['heads'],
    iterations=CONFIG['iterations'],
    synapse_depth=CONFIG['synapse_layers'],
    memory_length=CONFIG['memory_length'],
    n_synch_out=CONFIG['n_synch'],
    n_synch_action=CONFIG['n_synch'],
    dropout=CONFIG['dropout'],
    deep_nlms=True, 
    do_layernorm_nlm=True,
).to(CONFIG['device'])
    
print("\nPerforming a dummy forward pass to initialize lazy layers...")
try:
        # Create a small dummy batch on the correct device
        dummy_batch_size = 2
        dummy_seq_len = 16
        dummy_input_ids = torch.randint(
            0, vocab_size, 
            (dummy_batch_size, dummy_seq_len), 
            device=CONFIG['device']
        )
        dummy_attention_mask = torch.ones_like(dummy_input_ids)

        # Run the dummy forward pass
        with torch.no_grad():
            model(dummy_input_ids, attention_mask=dummy_attention_mask)
        
        print("Lazy layers initialized successfully.")

except Exception as e:
        print(f"Error during dummy forward pass: {e}")
        print("Please check model architecture and input dimensions.")
        # Exit or raise the error if initialization fails
        raise e

# Now it is safe to count parameters and create the optimizer
print(f"\nModel created with {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.")

# --- Optimizer and Loss ---
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'])
criterion = nn.CrossEntropyLoss()

Using neuron select type: random-pairing
Synch representation size action: 128
Synch representation size out: 128
Initializing CTM for NLP...
Replacing SynapseUnet with TransformerEncoder (4 layers)...
Model configured for 4-class classification.





Performing a dummy forward pass to initialize lazy layers...
Lazy layers initialized successfully.

Model created with 27,321,101 trainable parameters.


In [None]:
# --- Training Loop ---
for epoch in range(1, CONFIG['epochs'] + 1):
    print(f"\n--- Epoch {epoch}/{CONFIG['epochs']} ---")
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, criterion, CONFIG['device'])
    print(f"Epoch {epoch} Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        
    eval_loss, eval_acc = evaluate(model, test_loader, criterion, CONFIG['device'])
    print(f"Epoch {epoch} Eval Loss: {eval_loss:.4f}  | Eval Acc: {eval_acc:.4f}")

print("\n--- Final Test Evaluation ---")
test_loss, test_acc = evaluate(model, test_loader, criterion, CONFIG['device'])
print(f"Final Test Accuracy: {test_acc * 100:.2f}%")


--- Epoch 1/3 ---


Training:   0%|          | 0/3750 [00:00<?, ?it/s]

Training: 100%|██████████| 3750/3750 [15:04<00:00,  4.15it/s, loss=0.0238, acc=0.673]


Epoch 1 Train Loss: 0.0238 | Train Acc: 0.6731


Evaluating: 100%|██████████| 238/238 [00:16<00:00, 14.68it/s, loss=0.0149, acc=0.836]


Epoch 1 Eval Loss: 0.0149  | Eval Acc: 0.8358

--- Epoch 2/3 ---


Training:   3%|▎         | 121/3750 [00:45<16:02,  3.77it/s, loss=0.0147, acc=0.831]