In [1]:
import torch
import torch.nn as nn
from tqdm import tqdm
import sys
import numpy as np
import os
sys.path.append(os.path.abspath('../data'))
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence,pad_sequence
from torch.utils.data import Dataset, DataLoader


In [2]:


def tokenize(expression):
    """Convert expression string to tokens, preserving operators."""
    # Replace parentheses with spaces
    expr = expression.replace('(', ' ').replace(')', ' ')

    # Add spaces around brackets that aren't part of operators
    expr = expr.replace(']', ' ] ')

    # Split and filter empty strings
    return [token for token in expr.split() if token]

class ListOpsDataset(Dataset):
    def __init__(self, X, y):
        """
        Args:
            X: Array of source expressions
            y: Array of target values
        """
        self.X = X
        self.y = y

        # Create vocabulary from operators and digits
        self.vocab = {
            'PAD': 0,  # Padding token
            '[MIN': 1,
            '[MAX': 2,
            '[MED': 3,
            '[SM': 4,
            ']': 5,
            '(': 6,
            ')': 7
        }
        # Add digits 0-9
        for i in range(10):
            self.vocab[str(i)] = i + 8

    def __len__(self):
        return len(self.X)

    def tokenize(self, expr):
        """Convert expression to token IDs."""
        tokens = tokenize(expr)  # Using our previous tokenize function
        return [self.vocab.get(token, 0) for token in tokens]

    def __getitem__(self, idx):
        expr = self.X[idx]
        target = self.y[idx]

        # Convert to token IDs without padding or truncating
        token_ids = self.tokenize(expr)

        return {
            'input_ids': torch.tensor(token_ids, dtype=torch.long),
            'target': torch.tensor(target, dtype=torch.long)
        }

In [4]:

# Define the data directory and file paths
data_dir = 'LongListOps/data/output_dir'
train_file = os.path.join(data_dir, 'basic_train.tsv')
val_file = os.path.join(data_dir, 'basic_val.tsv')
test_file = os.path.join(data_dir, 'basic_test.tsv')

def load_listops_data(file_path, max_rows=None):
    """
    Load ListOps data from TSV file.

    Args:
        file_path: Path to the TSV file
        max_rows: Maximum number of rows to load (for testing)

    Returns:
        sources: Array of source expressions
        targets: Array of target values (0-9)
    """
    sources = []
    targets = []

    with open(file_path, 'r', encoding='utf-8') as f:
        next(f)  # Skip header (Source, Target)
        for i, line in enumerate(f):
            if max_rows and i >= max_rows:
                break
            if not line.strip():  # Skip empty lines
                continue
            parts = line.strip().split('\t')
            if len(parts) != 2:
                continue  # Skip lines that don't have exactly two columns
            source, target = parts
            sources.append(source)
            targets.append(int(target))  # Target is always 0-9

    # Convert to numpy arrays
    source_array = np.array(sources, dtype=object)  # Keep expressions as strings
    target_array = np.array(targets, dtype=np.int32)  # Targets are integers

    return source_array, target_array

try:
    # Load training data
    print("Loading training data...")
    X_train, y_train = load_listops_data(train_file)

    # Load validation data
    print("Loading validation data...")
    X_val, y_val = load_listops_data(val_file)

    # Load test data
    print("Loading test data...")
    X_test, y_test = load_listops_data(test_file)

    # Print dataset statistics
    print("\nDataset sizes:")
    print(f"Training: {len(X_train)} examples")
    print(f"Validation: {len(X_val)} examples")
    print(f"Test: {len(X_test)} examples")

except Exception as e:
    print(f"Error occurred: {type(e).__name__}: {str(e)}")

Loading training data...
Loading validation data...
Loading test data...

Dataset sizes:
Training: 96000 examples
Validation: 2000 examples
Test: 2000 examples


In [5]:

def collate_fn(batch):
    # Separate sequences and targets
    sequences = [item['input_ids'] for item in batch]
    targets = [item['target'] for item in batch]

    # Get lengths of each sequence
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long, device=sequences[0].device)

    # Sort sequences by length in descending order for pack_padded_sequence
    lengths, sort_idx = lengths.sort(descending=True)
    sequences = [sequences[i] for i in sort_idx]
    targets = [targets[i] for i in sort_idx]

    # Pad sequences
    padded_sequences = pad_sequence(sequences, batch_first=True)
    packed_sequences = pack_padded_sequence(padded_sequences, lengths, batch_first=True, enforce_sorted=False)
    # Convert targets to tensor
    targets = torch.stack(targets)

    return {
        'input_ids': packed_sequences,
        'target': targets,
        'lengths': lengths
    }

In [6]:

# Create datasets
train_dataset = ListOpsDataset(X_train, y_train)
val_dataset = ListOpsDataset(X_val, y_val)
test_dataset = ListOpsDataset(X_test, y_test)

# Create dataloaders with collate_fn
batch_size = 32
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn
)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    collate_fn=collate_fn
)

# Verify the data
print("Dataset sizes:")
print(f"Train: {len(train_dataset)}")
print(f"Val: {len(val_dataset)}")
print(f"Test: {len(test_dataset)}")


Dataset sizes:
Train: 96000
Val: 2000
Test: 2000

First batch shape:
Input IDs: torch.Size([32, 1815])
Targets: torch.Size([32])
Sequence lengths: tensor([1815, 1712, 1682, 1658, 1643, 1638, 1525, 1476, 1443, 1330, 1325, 1193,
        1191, 1187, 1096, 1045, 1042,  945,  892,  870,  779,  773,  765,  764,
         723,  641,  635,  621,  608,  605,  560,  506])


In [7]:
# Validation function

def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in val_loader:
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(batch)  # Pass the entire batch dictionary
            loss = criterion(outputs, batch['target'])

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += batch['target'].size(0)
            correct += predicted.eq(batch['target']).sum().item()

    return total_loss / len(val_loader), 100. * correct / total

# Initialize model and training components
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


Using device: cuda


In [8]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(reduction='none')

    def forward(self, inputs, targets):
        ce_loss = self.ce(inputs, targets)
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss
        return focal_loss.mean()



In [9]:
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model = None

    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model = model.state_dict()
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.best_model = model.state_dict()
            self.counter = 0


In [10]:
class SimpleRNNModel(nn.Module):
    def __init__(self,
                 vocab_size=18,
                 embedding_dim=128,
                 hidden_dim=256,
                 num_layers=2,
                 dropout=0.3):
        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        self.fc = nn.Linear(hidden_dim * 2, 10)
        self.dropout = nn.Dropout(dropout)

    def forward(self, batch):
        input_ids = batch['input_ids']
        lengths = batch['lengths']

        # Handle embedding layer
        if isinstance(input_ids, PackedSequence):
            # If input is already packed, unpack it first
            seq, lens = pad_packed_sequence(input_ids, batch_first=True)
            x = self.embedding(seq)
            x = self.dropout(x)
            x = pack_padded_sequence(x, lens, batch_first=True, enforce_sorted=False)
        else:
            # Regular sequence processing
            x = self.embedding(input_ids)
            x = self.dropout(x)
            x = pack_padded_sequence(x, lengths.cpu(), batch_first=True, enforce_sorted=False)

        # Process with RNN
        packed_output, hidden = self.rnn(x)

        # Concatenate the last hidden states from both directions
        hidden_cat = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)
        x = self.dropout(hidden_cat)
        x = self.fc(x)

        return x
# Training setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleRNNModel().to(device)
criterion = FocalLoss(gamma=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
early_stopping = EarlyStopping(patience=5)

# Dataloaders (make sure to use the collate_fn defined earlier)
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                         num_workers=4, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size,
                       num_workers=4, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                        num_workers=4, collate_fn=collate_fn)
num_epochs = 10
print("\nStarting training...")
# First, update the training loop to pass the batch as a dictionary
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    for batch in pbar:
        # Move all batch tensors to device
        batch = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad()
        outputs = model(batch)  # Pass the entire batch dictionary
        loss = criterion(outputs, batch['target'])

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += batch['target'].size(0)
        correct += predicted.eq(batch['target']).sum().item()

        pbar.set_postfix({'loss': f'{train_loss/total:.4f}',
                         'acc': f'{100.*correct/total:.2f}%'})

    val_loss, val_acc = validate(model, val_loader, criterion, device)
    scheduler.step(val_loss)

    print(f'Epoch: {epoch+1}/{num_epochs}')
    print(f'Train Loss: {train_loss/len(train_loader):.4f} | Train Acc: {100.*correct/total:.2f}%')
    print(f'Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%')
    print('-' * 50)

    early_stopping(val_loss, model)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        model.load_state_dict(early_stopping.best_model)
        break






Starting training...


Epoch 1/10: 100%|██████████| 750/750 [03:45<00:00,  3.32it/s, loss=0.0093, acc=37.50%]


Epoch: 1/10
Train Loss: 1.1878 | Train Acc: 37.50%
Val Loss: 1.1264 | Val Acc: 39.25%
--------------------------------------------------


Epoch 2/10: 100%|██████████| 750/750 [03:45<00:00,  3.32it/s, loss=0.0089, acc=39.03%]


Epoch: 2/10
Train Loss: 1.1379 | Train Acc: 39.03%
Val Loss: 1.0999 | Val Acc: 40.45%
--------------------------------------------------


Epoch 3/10: 100%|██████████| 750/750 [03:45<00:00,  3.32it/s, loss=0.0088, acc=39.72%]


Epoch: 3/10
Train Loss: 1.1230 | Train Acc: 39.72%
Val Loss: 1.1048 | Val Acc: 39.20%
--------------------------------------------------


Epoch 4/10: 100%|██████████| 750/750 [03:46<00:00,  3.31it/s, loss=0.0087, acc=40.09%]


Epoch: 4/10
Train Loss: 1.1145 | Train Acc: 40.09%
Val Loss: 1.0778 | Val Acc: 40.75%
--------------------------------------------------


Epoch 5/10: 100%|██████████| 750/750 [03:53<00:00,  3.21it/s, loss=0.0086, acc=40.53%]


Epoch: 5/10
Train Loss: 1.1063 | Train Acc: 40.53%
Val Loss: 1.0945 | Val Acc: 40.80%
--------------------------------------------------


Epoch 6/10: 100%|██████████| 750/750 [03:52<00:00,  3.23it/s, loss=0.0086, acc=40.69%]


Epoch: 6/10
Train Loss: 1.1039 | Train Acc: 40.69%
Val Loss: 1.0836 | Val Acc: 40.55%
--------------------------------------------------


Epoch 7/10: 100%|██████████| 750/750 [03:53<00:00,  3.21it/s, loss=0.0086, acc=41.01%]


Epoch: 7/10
Train Loss: 1.0997 | Train Acc: 41.01%
Val Loss: 1.0849 | Val Acc: 39.90%
--------------------------------------------------


Epoch 8/10: 100%|██████████| 750/750 [03:47<00:00,  3.30it/s, loss=0.0086, acc=40.93%]


Epoch: 8/10
Train Loss: 1.0978 | Train Acc: 40.93%
Val Loss: 1.0961 | Val Acc: 40.15%
--------------------------------------------------


Epoch 9/10: 100%|██████████| 750/750 [03:45<00:00,  3.32it/s, loss=0.0083, acc=42.40%]


Epoch: 9/10
Train Loss: 1.0679 | Train Acc: 42.40%
Val Loss: 1.0668 | Val Acc: 41.55%
--------------------------------------------------


Epoch 10/10: 100%|██████████| 750/750 [03:47<00:00,  3.29it/s, loss=0.0083, acc=43.02%]


Epoch: 10/10
Train Loss: 1.0592 | Train Acc: 43.02%
Val Loss: 1.0480 | Val Acc: 42.50%
--------------------------------------------------


In [11]:
# Test with best model
test_loss, test_acc = validate(model, test_loader, criterion, device)
print(f'\nTest Loss: {test_loss:.4f} | Test Acc: {test_acc:.2f}%')



Test Loss: 1.0222 | Test Acc: 45.20%


In [12]:
torch.save(model.state_dict(), 'rnn_model.pth')