# Treebank Dataset

In [None]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import nltk
from nltk.corpus import treebank
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

nltk.download('treebank')
tagged_sents = treebank.tagged_sents()

all_tags = sorted({tag for sent in tagged_sents for (_, tag) in sent})
tag2idx  = {tag: i for i, tag in enumerate(all_tags)}
vocab_size = len(tag2idx)
print(f"Number of POS tags: {vocab_size}")

def tagged_sents_to_word_idx_seqs(tagged, word2idx=None):
    # Returns list of word index sequences and tag index sequences
    word_set = set(word for sent in tagged for (word, _) in sent)
    if word2idx is None:
        word2idx = {w: i for i, w in enumerate(sorted(word_set))}
    X, Y = [], []
    for sent in tagged:
        word_idxs = [word2idx[word] for (word, _) in sent]
        tag_idxs  = [tag2idx[tag]   for (_, tag) in sent]
        if len(word_idxs) >= 2:
            X.append(np.array(word_idxs, dtype=np.int64))
            Y.append(np.array(tag_idxs, dtype=np.int64))
    return X, Y, word2idx

class WordLevelTagSequenceDataset(Dataset):
    def __init__(self, word_seqs, tag_seqs, sequence_length=30):
        self.inputs, self.targets = [], []
        for wseq, tseq in zip(word_seqs, tag_seqs):
            L = len(wseq)
            if L > sequence_length:
                for i in range(L - sequence_length):
                    self.inputs.append(wseq[i:i+sequence_length])
                    self.targets.append(tseq[i+1:i+sequence_length+1])
        if not self.inputs:
            raise ValueError("No sequences long enough—decrease sequence_length.")
        self.inputs  = torch.tensor(self.inputs, dtype=torch.long)
        self.targets = torch.tensor(self.targets, dtype=torch.long)
    def __len__(self):
        return len(self.inputs)
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

def get_wordlevel_dataloaders_treebank(batch_size=64, sequence_length=15, val_split=0.1, test_split=0.1, seed=42):
    random.seed(seed)
    word_seqs, tag_seqs, word2idx = tagged_sents_to_word_idx_seqs(tagged_sents)
    idxs = list(range(len(word_seqs)))
    random.shuffle(idxs)
    N = len(idxs)
    n_test = int(N * test_split)
    n_val  = int(N * val_split)
    test_idx  = idxs[:n_test]
    val_idx   = idxs[n_test : n_test + n_val]
    train_idx = idxs[n_test + n_val :]
    train_ds = WordLevelTagSequenceDataset([word_seqs[i] for i in train_idx], [tag_seqs[i] for i in train_idx], sequence_length)
    val_ds   = WordLevelTagSequenceDataset([word_seqs[i] for i in val_idx],   [tag_seqs[i] for i in val_idx],   sequence_length)
    test_ds  = WordLevelTagSequenceDataset([word_seqs[i] for i in test_idx],  [tag_seqs[i] for i in test_idx],  sequence_length)
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds,   batch_size=batch_size, shuffle=False),
        DataLoader(test_ds,  batch_size=batch_size, shuffle=False),
        word2idx
    )

def tagged_sents_to_char_idx_seqs(tagged, char2idx=None):
    # Returns list of char index sequences (flattened words) and tag index sequences (per char)
    char_set = set(c for sent in tagged for (word, _) in sent for c in word)
    if char2idx is None:
        char2idx = {c: i for i, c in enumerate(sorted(char_set))}
    X, Y = [], []
    for sent in tagged:
        char_idxs = []
        tag_idxs = []
        for word, tag in sent:
            chars = [char2idx[c] for c in word]
            char_idxs.extend(chars)
            tag_idxs.extend([tag2idx[tag]] * len(chars))
        if len(char_idxs) >= 2:
            X.append(np.array(char_idxs, dtype=np.int64))
            Y.append(np.array(tag_idxs, dtype=np.int64))
    return X, Y, char2idx

class CharLevelTagSequenceDataset(Dataset):
    def __init__(self, char_seqs, tag_seqs, sequence_length=30):
        self.inputs, self.targets = [], []
        for cseq, tseq in zip(char_seqs, tag_seqs):
            L = len(cseq)
            if L > sequence_length:
                for i in range(L - sequence_length):
                    self.inputs.append(cseq[i:i+sequence_length])
                    self.targets.append(tseq[i+1:i+sequence_length+1])
        if not self.inputs:
            raise ValueError("No sequences long enough—decrease sequence_length.")
        self.inputs  = torch.tensor(self.inputs, dtype=torch.long)
        self.targets = torch.tensor(self.targets, dtype=torch.long)
    def __len__(self):
        return len(self.inputs)
    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]
        
def tagged_sents_to_char_idx_seqs_treebank(tagged, tag2idx, char2idx=None):
    # Returns list of char index sequences (flattened words) and tag index sequences (per char)
    char_set = set(c for sent in tagged for (word, _) in sent for c in word)
    if char2idx is None:
        char2idx = {c: i for i, c in enumerate(sorted(char_set))}
    X, Y = [], []
    for sent in tagged:
        char_idxs = []
        tag_idxs = []
        for word, tag in sent:
            chars = [char2idx[c] for c in word]
            char_idxs.extend(chars)
            tag_idxs.extend([tag2idx[tag]] * len(chars))
        if len(char_idxs) >= 2:
            X.append(np.array(char_idxs, dtype=np.int64))
            Y.append(np.array(tag_idxs, dtype=np.int64))
    return X, Y, char2idx

def get_charlevel_dataloaders_treebank(batch_size=64, sequence_length=15, val_split=0.1, test_split=0.1, seed=42):
    random.seed(seed)
    char_seqs, tag_seqs, char2idx = tagged_sents_to_char_idx_seqs_treebank(tagged_sents, tag2idx)
    idxs = list(range(len(char_seqs)))
    random.shuffle(idxs)
    N = len(idxs)
    n_test = int(N * test_split)
    n_val  = int(N * val_split)
    test_idx  = idxs[:n_test]
    val_idx   = idxs[n_test : n_test + n_val]
    train_idx = idxs[n_test + n_val :]
    train_ds = CharLevelTagSequenceDataset([char_seqs[i] for i in train_idx], [tag_seqs[i] for i in train_idx], sequence_length)
    val_ds   = CharLevelTagSequenceDataset([char_seqs[i] for i in val_idx],   [tag_seqs[i] for i in val_idx],   sequence_length)
    test_ds  = CharLevelTagSequenceDataset([char_seqs[i] for i in test_idx],  [tag_seqs[i] for i in test_idx],  sequence_length)
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds,   batch_size=batch_size, shuffle=False),
        DataLoader(test_ds,  batch_size=batch_size, shuffle=False),
        char2idx
    )

# Additional datasets: conll + brown

In [None]:
import nltk
from nltk.corpus import brown, conll2000

nltk.download('brown')
nltk.download('conll2000')

# --- 1. Load tagged sentences for each corpus ---
tagged_sents_brown = brown.tagged_sents(tagset='universal')
tagged_sents_conll = conll2000.tagged_sents(tagset='universal')

# --- 2. Build tag2idx for each corpus ---
def build_tag2idx(tagged_sents):
    all_tags = sorted({tag for sent in tagged_sents for (_, tag) in sent})
    return {tag: i for i, tag in enumerate(all_tags)}

nltk.download('universal_tagset')

tag2idx_brown = build_tag2idx(tagged_sents_brown)
tag2idx_conll = build_tag2idx(tagged_sents_conll)

# --- 3. Word-level dataloader for any corpus ---
def tagged_sents_to_word_idx_seqs_additional(tagged, tag2idx, word2idx=None):
    word_set = set(word for sent in tagged for (word, _) in sent)
    if word2idx is None:
        word2idx = {w: i for i, w in enumerate(sorted(word_set))}
    X, Y = [], []
    for sent in tagged:
        word_idxs = [word2idx[word] for (word, _) in sent]
        tag_idxs  = [tag2idx[tag]   for (_, tag) in sent]
        if len(word_idxs) >= 2:
            X.append(np.array(word_idxs, dtype=np.int64))
            Y.append(np.array(tag_idxs, dtype=np.int64))
    return X, Y, word2idx

def get_wordlevel_dataloaders_from_corpus(tagged_sents, tag2idx, batch_size=64, sequence_length=15, val_split=0.1, test_split=0.1, seed=42):
    random.seed(seed)
    word_seqs, tag_seqs, word2idx = tagged_sents_to_word_idx_seqs_additional(tagged_sents, tag2idx)
    idxs = list(range(len(word_seqs)))
    random.shuffle(idxs)
    N = len(idxs)
    n_test = int(N * test_split)
    n_val  = int(N * val_split)
    test_idx  = idxs[:n_test]
    val_idx   = idxs[n_test : n_test + n_val]
    train_idx = idxs[n_test + n_val :]
    train_ds = WordLevelTagSequenceDataset([word_seqs[i] for i in train_idx], [tag_seqs[i] for i in train_idx], sequence_length)
    val_ds   = WordLevelTagSequenceDataset([word_seqs[i] for i in val_idx],   [tag_seqs[i] for i in val_idx],   sequence_length)
    test_ds  = WordLevelTagSequenceDataset([word_seqs[i] for i in test_idx],  [tag_seqs[i] for i in test_idx],  sequence_length)
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds,   batch_size=batch_size, shuffle=False),
        DataLoader(test_ds,  batch_size=batch_size, shuffle=False),
        word2idx
    )

# --- 4. Char-level dataloader for any corpus ---
def tagged_sents_to_char_idx_seqs(tagged, tag2idx, char2idx=None):
    char_set = set(c for sent in tagged for (word, _) in sent for c in word)
    if char2idx is None:
        char2idx = {c: i for i, c in enumerate(sorted(char_set))}
    X, Y = [], []
    for sent in tagged:
        char_idxs = []
        tag_idxs = []
        for word, tag in sent:
            chars = [char2idx[c] for c in word]
            char_idxs.extend(chars)
            tag_idxs.extend([tag2idx[tag]] * len(chars))
        if len(char_idxs) >= 2:
            X.append(np.array(char_idxs, dtype=np.int64))
            Y.append(np.array(tag_idxs, dtype=np.int64))
    return X, Y, char2idx

def tagged_sents_to_char_idx_seqs(tagged, tag2idx, char2idx=None):
    # Returns list of char index sequences (flattened words) and tag index sequences (per char)
    char_set = set(c for sent in tagged for (word, _) in sent for c in word)
    if char2idx is None:
        char2idx = {c: i for i, c in enumerate(sorted(char_set))}
    X, Y = [], []
    for sent in tagged:
        char_idxs = []
        tag_idxs = []
        for word, tag in sent:
            chars = [char2idx[c] for c in word]
            char_idxs.extend(chars)
            tag_idxs.extend([tag2idx[tag]] * len(chars))
        if len(char_idxs) >= 2:
            X.append(np.array(char_idxs, dtype=np.int64))
            Y.append(np.array(tag_idxs, dtype=np.int64))
    return X, Y, char2idx

def get_charlevel_dataloaders_from_corpus(tagged_sents, tag2idx, batch_size=64, sequence_length=15, val_split=0.1, test_split=0.1, seed=42):
    random.seed(seed)
    char_seqs, tag_seqs, char2idx = tagged_sents_to_char_idx_seqs(tagged_sents, tag2idx)
    idxs = list(range(len(char_seqs)))
    random.shuffle(idxs)
    N = len(idxs)
    n_test = int(N * test_split)
    n_val  = int(N * val_split)
    test_idx  = idxs[:n_test]
    val_idx   = idxs[n_test : n_test + n_val]
    train_idx = idxs[n_test + n_val :]
    train_ds = CharLevelTagSequenceDataset([char_seqs[i] for i in train_idx], [tag_seqs[i] for i in train_idx], sequence_length)
    val_ds   = CharLevelTagSequenceDataset([char_seqs[i] for i in val_idx],   [tag_seqs[i] for i in val_idx],   sequence_length)
    test_ds  = CharLevelTagSequenceDataset([char_seqs[i] for i in test_idx],  [tag_seqs[i] for i in test_idx],  sequence_length)
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds,   batch_size=batch_size, shuffle=False),
        DataLoader(test_ds,  batch_size=batch_size, shuffle=False),
        char2idx
    )



# Shallow RNN

In [None]:
class ShallowRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(ShallowRNN, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        output, _ = self.rnn(x)
        return self.fc(output)

# DT(S)-RNN

In [None]:
class DeepTransitionRNNCell(nn.Module):
    def __init__(self, input_size, hidden_size, transition_size, depth, nonlinearity):
        super().__init__()
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.hidden_layers = nn.ModuleList()
        self.hidden_layers.append(nn.Linear(hidden_size, transition_size))
        for i in range(depth - 2):
            self.hidden_layers.append(nn.Linear(transition_size, transition_size))
        self.hidden_layers.append(nn.Linear(transition_size, hidden_size))

        if nonlinearity == 'sigmoid':
            self.activation = nn.Sigmoid()
        else:
            self.activation = nn.ReLU()

    def forward(self, x, h_prev):
        h = self.activation(self.input_layer(x) + h_prev)  # shortcut
        for layer in self.hidden_layers:
            h = self.activation(layer(h))
        return h

class DTRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, transition_size, depth, nonlinearity='sigmoid'):
        super().__init__()
        self.cell = DeepTransitionRNNCell(input_size, hidden_size, transition_size, depth, nonlinearity)
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h = torch.zeros(batch_size, self.cell.input_layer.out_features, device=x.device)
        outputs = []
        for t in range(seq_len):
            h = self.cell(x[:, t, :], h)
            outputs.append(self.output(h))
        return torch.stack(outputs, dim=1)

# DOT(S)-RNN

In [None]:
class DOTSRNN(nn.Module):
    def __init__(
        self, 
        input_size, 
        hidden_size, 
        output_size, 
        transition_size, 
        depth, 
        intermediate_output_size, 
        output_depth=2, 
        nonlinearity='sigmoid', 
        intermediate_output_nonlinearity='sigmoid'):
        super().__init__()
        self.cell = DeepTransitionRNNCell(input_size, hidden_size, transition_size, depth, nonlinearity)
        self.output_layers = nn.ModuleList()
        self.output_layers.append(nn.Linear(hidden_size, intermediate_output_size))
        for i in range(output_depth - 2):
            self.output_layers.append(nn.Linear(intermediate_output_size, intermediate_output_size))
        self.output_layers.append(nn.Linear(intermediate_output_size, hidden_size))
        self.output_layers.append(nn.Linear(hidden_size, output_size))
        
        if intermediate_output_nonlinearity == 'sigmoid':
            self.activation = nn.Sigmoid()
        else:
            self.activation = nn.ReLU()

    def forward(self, x):
        batch_size, seq_len, _ = x.size()
        h = torch.zeros(batch_size, self.cell.input_layer.out_features, device=x.device)
        outputs = []
        for t in range(seq_len):
            h = self.cell(x[:, t], h)
            out = h
            for layer in self.output_layers:
                out = self.activation(layer(out))
            outputs.append(out)
        return torch.stack(outputs, dim=1)

# sRNN

In [None]:
class StackedRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_size, num_layers=num_layers, batch_first=True)
        self.output = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        output, _ = self.rnn(x)
        return self.output(output)

In [None]:
import re

def sanitize_filename(s):
    # Remove or replace problematic characters for Windows filenames
    return re.sub(r"[^a-zA-Z0-9_\-\.]", "_", str(s))

# Adapted models for treebank dataset

In [None]:
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch import optim

def lr_schedule(step, initial_lr, beta):
    return initial_lr * (0.1 ** (step / beta))

def train_and_eval(
    model, train_loader, val_loader, vocab_size=None, initial_lr=0.1, beta = 2330, epochs=10, device=None, 
    criterion=None, optimizer_class=None, model_name="model", dataset="dataset"):
    
    device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    model = model.to(device)
    if criterion is None:
        raise ValueError("You must provide a loss function as 'criterion'.")
    if optimizer_class is None:
        optimizer_class = optim.Adam
    optimizer = optimizer_class(model.parameters(), lr=initial_lr)

    isPolyphonicDataset = dataset in ["Nottingham", "MuseDataset", "JSBDataset"]
    
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    for ep in range(1, epochs+1):
        # === Training ===
        model.train()
        total_train, correct_train, total_train_tokens = 0, 0, 0
        train_loop = tqdm(train_loader, desc=f"Epoch {ep}/{epochs} [Train]", leave=False)
        for x, y in train_loop:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            if isinstance(criterion, nn.CrossEntropyLoss):
                logits = logits.view(-1, logits.size(-1))
                y_flat = y.view(-1)
                loss = criterion(logits, y_flat)
                preds = logits.argmax(dim=-1)
                correct_train += (preds == y_flat).sum().item()
                total_train_tokens += y_flat.numel()
            else:
                # Assume BCEWithLogitsLoss for multi-label
                loss = criterion(logits, y)
                preds = (torch.sigmoid(logits) > 0.5).float()
                correct_train += (preds == y).float().sum().item()
                total_train_tokens += y.numel()
            optimizer.zero_grad()
            loss.backward()

            if isPolyphonicDataset:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()

            if isPolyphonicDataset:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_schedule(ep, initial_lr, beta)
            
            total_train += loss.item()
            train_loop.set_postfix(loss=loss.item())
        train_losses.append(total_train / len(train_loader))
        train_accs.append(correct_train / total_train_tokens)

        model.eval()
        total_val, correct_val, total_val_tokens = 0, 0, 0
        with torch.no_grad():
            val_loop = tqdm(val_loader, desc=f"Epoch {ep}/{epochs} [Val]", leave=False)
            for x, y in val_loop:
                x, y = x.to(device), y.to(device)
                logits = model(x)
                if isinstance(criterion, nn.CrossEntropyLoss):
                    logits = logits.view(-1, logits.size(-1))
                    y_flat = y.view(-1)
                    loss = criterion(logits, y_flat)
                    preds = logits.argmax(dim=-1)
                    correct_val += (preds == y_flat).sum().item()
                    total_val_tokens += y_flat.numel()
                else:
                    loss = criterion(logits, y)
                    preds = (torch.sigmoid(logits) > 0.5).float()
                    correct_val += (preds == y).float().sum().item()
                    total_val_tokens += y.numel()
                total_val += loss.item()
                val_loop.set_postfix(val_loss=loss.item())
        val_losses.append(total_val / len(val_loader))
        val_accs.append(correct_val / total_val_tokens)

        print(f"Epoch {ep}/{epochs}  "
              f"Train: {train_losses[-1]:.4f} (Acc {train_accs[-1]*100:.2f}%)  "
              f"Val: {val_losses[-1]:.4f} (Acc {val_accs[-1]*100:.2f}%)")
        
    
    # === Plotting ===
    plt.figure(figsize=(12,5))
    plt.subplot(1,2,1)
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses,   label='Val')
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.legend(); plt.title("Loss")
    plt.subplot(1,2,2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(val_accs,   label='Val Acc')
    plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.legend(); plt.title("Accuracy")
    plt.savefig(sanitize_filename(model_name + "_" + dataset + "_loss_acc.pdf"))
    plt.show()
    return model, train_losses[-1], train_accs[-1], val_losses[-1], val_accs[-1]

In [None]:
class AdaptedShallowRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x)      
        out, _ = self.rnn(x)       
        out = self.dropout(out)
        return self.fc(out)       

class AdaptedDTRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, depth, transition_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = DTRNN(embedding_dim, hidden_size, vocab_size, depth=depth, transition_size=transition_size)
    def forward(self, x):
        x = self.embedding(x)
        return self.rnn(x)

class AdaptedDOTSRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, depth=2, output_depth=2, intermediate_output_size=100):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = DOTSRNN(embedding_dim, hidden_size, vocab_size, depth, output_depth, intermediate_output_size=intermediate_output_size)
    def forward(self, x):
        x = self.embedding(x)
        return self.rnn(x)

class AdaptedStackedRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = StackedRNN(embedding_dim, hidden_size, vocab_size, num_layers)
    def forward(self, x):
        x = self.embedding(x)
        return self.rnn(x)

In [None]:
from itertools import product

def get_config_combinations(model_type, config_dict):
    keys = list(config_dict.keys())
    values = list(config_dict.values())
    for combo in product(*values):
        yield dict(zip(keys, combo))

hidden_sizes = [200]
transition_sizes = [100, 200, 300, 400, 500, 600]
depths = [3]
intermediate_output_sizes = [200]
output_depths = [2]
num_layers = [2]

model_configs = {
    # "RNN": {
    #     "hidden_size": hidden_sizes
    # },
    "DT(S)-RNN": {
        "hidden_size": hidden_sizes,
        "transition_size": transition_sizes,
        "depth": depths
    },
    "DOT(S)-RNN": {
        "hidden_size": hidden_sizes,
        "transition_size": transition_sizes,
        "intermediate_output_size": intermediate_output_sizes,
        "depth": depths,
        "output_depth": output_depths
    },
    # "sRNN": {
    #     "hidden_size": hidden_sizes,
    #     "num_layers": num_layers,
    # }
}

beta_values = {
    'treebank': 1000,
    'conll': 1000,
    'brown': 1000
}

import matplotlib.pyplot as plt

def grid_search(train_dataloader, val_dataloader, vocab_size, embedding_dim, dataset, num_epochs=10):
    import torch.nn as nn
    import torch.optim as optim

    for model_type, config_options in model_configs.items():
        train_losses = []
        val_losses = []
        train_accuracies = []
        val_accuracies = []
        labels = []

        for config in get_config_combinations(model_type, config_options):
            print(f"\nTraining {model_type} with config: {config}")

            # All models use embedding layer, so pass vocab_size and embedding_dim
            if model_type == "RNN":
                model = AdaptedShallowRNN(vocab_size, embedding_dim, config["hidden_size"])
            elif model_type == "DT(S)-RNN":
                model = AdaptedDTRNN(
                    vocab_size, embedding_dim,
                    config["hidden_size"],
                    depth=config["depth"],
                    transition_size=config["transition_size"]
                )
            elif model_type == "DOT(S)-RNN":
                model = AdaptedDOTSRNN(
                    vocab_size, embedding_dim,
                    config["hidden_size"],
                    depth=config["depth"],
                    output_depth=config["output_depth"],
                    intermediate_output_size=config["intermediate_output_size"]
                )
            elif model_type == "sRNN":
                model = AdaptedStackedRNN(
                    vocab_size, embedding_dim,
                    config["hidden_size"],
                    num_layers=config["num_layers"]
                )
            else:
                raise ValueError(f"Unknown model_type: {model_type}")

            criterion = nn.CrossEntropyLoss()
            optimizer_class = optim.SGD

            model, last_train_loss, last_train_acc, last_val_loss, last_val_acc = train_and_eval(
                model=model,
                train_loader=train_dataloader,
                val_loader=val_dataloader,
                vocab_size=vocab_size,
                initial_lr=0.1,
                beta=beta_values.get(dataset, 1000),
                epochs=num_epochs,
                model_name=f"{model_type}_{dataset}_{config}",
                dataset=dataset,
                criterion=criterion,
                optimizer_class=optimizer_class
            )

            train_losses.append(last_train_loss)
            val_losses.append(last_val_loss)
            train_accuracies.append(last_train_acc)
            val_accuracies.append(last_val_acc)

        x = transition_sizes

        plt.figure(figsize=(14, 6))
        plt.subplot(1, 2, 1)
        plt.plot(x, train_losses, marker='o', label='Train Loss')
        plt.plot(x, val_losses, marker='o', label='Val Loss')
        plt.xlabel("TransitionSize")
        plt.ylabel("Loss")
        plt.title(f"{model_type} Losses on {dataset}")
        plt.legend()

        plt.subplot(1, 2, 2)
        plt.plot(x, train_accuracies, marker='o', label='Train Accuracy')
        plt.plot(x, val_accuracies, marker='o', label='Val Accuracy')
        plt.xlabel("TransitionSize")
        plt.ylabel("Accuracy")
        plt.title(f"{model_type} Accuracies on {dataset}")
        plt.legend()

        plt.tight_layout()
        plt.savefig(sanitize_filename(f"{model_type}_{dataset}_depth_comparison.pdf"))
        plt.show()


In [None]:
embedding_dim = 128
hidden_size = 128
epochs        = 20

In [None]:
if torch.cuda.is_available():
    print("Running on GPU:", torch.cuda.get_device_name(0))
else:
    print("Running on CPU")

In [None]:
train_loader_word, val_loader_word, test_loader_word, word2idx = get_wordlevel_dataloaders_treebank(
    batch_size=32
)
word_vocab_size = len(word2idx)  

# train_loader_char, val_loader_char, test_loader, char2idx = get_charlevel_dataloaders_treebank(
#     batch_size=32
# )
# vocab_size = len(char2idx)  



In [None]:

train_loader_word_conll, val_loader_word_conll, test_loader_word_conll, word2idx_conll = get_wordlevel_dataloaders_from_corpus(
    tagged_sents_conll, tag2idx_conll, batch_size=32
)


In [None]:

train_loader_word_brown, val_loader_word_brown, test_loader_word_brown, word2idx_brown = get_wordlevel_dataloaders_from_corpus(
    tagged_sents_brown, tag2idx_brown, batch_size=32
)


In [None]:
# --- Treebank ---
print("Grid search: Treebank word-level")
grid_search(
    train_loader_word, val_loader_word,
    vocab_size=len(word2idx), embedding_dim=embedding_dim,
    dataset='treebank_word_level', num_epochs=3
)


In [None]:

# print("Grid search: Treebank char-level")
# grid_search(
#     train_loader_char, val_loader_char,
#     vocab_size=len(char2idx), embedding_dim=embedding_dim,
#     dataset='treebank', num_epochs=7
# )



In [None]:
# --- CoNLL2000 ---
print("Grid search: CoNLL2000 word-level")
grid_search(
    train_loader_word_conll, val_loader_word_conll,
    vocab_size=len(word2idx_conll), embedding_dim=embedding_dim,
    dataset='conll_word_level', num_epochs=3
)



In [None]:

# --- Brown ---
print("Grid search: Brown word-level")
grid_search(
    train_loader_word_brown, val_loader_word_brown,
    vocab_size=len(word2idx_brown), embedding_dim=embedding_dim,
    dataset='brown_word_level', num_epochs=3
)
