In [1]:
from datasets import load_dataset, Dataset
from torch.nn import LSTM
import torch
from torch.utils.data import Dataset, DataLoader, random_split, TensorDataset
from itertools import product
import torch.optim as optim

MAX_LENGTH = 50

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
ds = load_dataset("uonlp/CulturaX", "en", split='train', streaming=True)


In [3]:
ds_head = ds.take(500)

In [4]:
def create_pairs(example):
    text = example["text"]
    pairs = []
    for i in range(1, len(text)):
        inputs = text[:i]
        outputs = text[i]
        if len(inputs) > MAX_LENGTH:
            inputs = inputs[-MAX_LENGTH:]
        pairs.append({"input": inputs, "output": outputs})
    return pairs
data_pairs = []
for d in ds_head:
    data_pairs.extend(create_pairs(d))

In [5]:
chars = set()
num_chars = 0
for d in ds_head:
    num_chars += len(d['text'])
    chars.update(d['text'])
print(len(chars))
print(num_chars)

293
1864434


In [7]:
char_to_idx = {char: idx + 1 for idx, char in enumerate(sorted(chars))}
idx_to_char = {idx: char for char, idx in char_to_idx.items()}

def encode_sequence(seq):
    return [char_to_idx[char] for char in seq]

In [9]:
sequences = []
outputs = []
for d in data_pairs:
    sequences.append(torch.tensor(encode_sequence(d["input"]), dtype=torch.long))
    outputs.append(torch.tensor(char_to_idx[d["output"]], dtype=torch.long))

X = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=0)
y = torch.stack(outputs)


In [10]:
print(X.shape)
print(X[5])
print(y.shape)

torch.Size([1863934, 50])
tensor([39, 50, 55,  3, 36, 81,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
torch.Size([1863934])


In [11]:
torch.save({
    'X': X, 'y': y,
    'char_to_idx': char_to_idx, 'idx_to_char': idx_to_char
}, 'dataset_splits.pth')

In [2]:
data = torch.load('dataset_splits.pth')
X, y = data['X'], data['y']
char_to_idx = data['char_to_idx']
idx_to_char = data['idx_to_char']
print(X.shape)
print(y.shape)
print(char_to_idx)
print(idx_to_char)

torch.Size([1863934, 50])
torch.Size([1863934])
{'\t': 1, '\n': 2, ' ': 3, '!': 4, '"': 5, '#': 6, '$': 7, '%': 8, '&': 9, "'": 10, '(': 11, ')': 12, '*': 13, '+': 14, ',': 15, '-': 16, '.': 17, '/': 18, '0': 19, '1': 20, '2': 21, '3': 22, '4': 23, '5': 24, '6': 25, '7': 26, '8': 27, '9': 28, ':': 29, ';': 30, '<': 31, '=': 32, '>': 33, '?': 34, '@': 35, 'A': 36, 'B': 37, 'C': 38, 'D': 39, 'E': 40, 'F': 41, 'G': 42, 'H': 43, 'I': 44, 'J': 45, 'K': 46, 'L': 47, 'M': 48, 'N': 49, 'O': 50, 'P': 51, 'Q': 52, 'R': 53, 'S': 54, 'T': 55, 'U': 56, 'V': 57, 'W': 58, 'X': 59, 'Y': 60, 'Z': 61, '[': 62, '\\': 63, ']': 64, '^': 65, '_': 66, '`': 67, 'a': 68, 'b': 69, 'c': 70, 'd': 71, 'e': 72, 'f': 73, 'g': 74, 'h': 75, 'i': 76, 'j': 77, 'k': 78, 'l': 79, 'm': 80, 'n': 81, 'o': 82, 'p': 83, 'q': 84, 'r': 85, 's': 86, 't': 87, 'u': 88, 'v': 89, 'w': 90, 'x': 91, 'y': 92, 'z': 93, '{': 94, '|': 95, '}': 96, '~': 97, '\x9d': 98, '\xa0': 99, '£': 100, '§': 101, '©': 102, '«': 103, '®': 104, '°': 105, 

In [4]:
dataset = TensorDataset(X, y)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size])


In [5]:
class CharDataset(Dataset):
    def __init__(self, sequences, outputs, device):
        self.sequences = sequences.to(device)
        self.outputs = outputs.to(device)

    def __len__(self):
        return self.sequences.shape[0]

    def __getitem__(self, idx):
        return self.sequences[idx], self.outputs[idx]


In [6]:
train_dataset = CharDataset(train_set[:][0], train_set[:][1], device)
val_dataset = CharDataset(val_set[:][0], val_set[:][1], device)
test_dataset = CharDataset(test_set[:][0], test_set[:][1], device)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [7]:
import torch.nn as nn

class CharLSTM(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_layers):
        super(CharLSTM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = LSTM(embed_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)  # Output size is vocab_size

    def forward(self, x):
        x = self.embedding(x)
        lstm_out, _ = self.lstm(x)
        out = self.fc(lstm_out[:, -1, :])  # Take the last LSTM output
        return out

In [9]:
# Define search space
lr = 1e-4
embed_dim = 12
hidden_dim = 256
num_layers = 4



# Initialize model
vocab_size = len(char_to_idx) + 1
model = CharLSTM(vocab_size, embed_dim, hidden_dim, num_layers)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
epochs = 20
for epoch in range(epochs):
    model.train()
    train_correct, train_total, running_loss = 0, 0, 0

    for inputs, targets in train_dataloader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs, dim=1)
        train_correct += (predicted == targets).sum().item()
        train_total += targets.numel()

    train_acc = train_correct / train_total
    avg_train_loss = running_loss / len(train_dataloader)

    # Validation phase
    model.eval()
    val_correct, val_total, val_loss = 0, 0, 0

    with torch.no_grad():
        for inputs, targets in val_dataloader:
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, dim=1)
            val_correct += (predicted == targets).sum().item()
            val_total += targets.numel()

    val_acc = val_correct / val_total
    avg_val_loss = val_loss / len(val_dataloader)

    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Save model checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': avg_train_loss,
        'val_loss': avg_val_loss
    }
    torch.save(checkpoint, f'checkpoint_epoch_{epoch+1}.pth')




Epoch 1: Train Loss: 2.5096, Train Acc: 0.3184, Val Loss: 2.1357, Val Acc: 0.3986
Epoch 2: Train Loss: 1.9921, Train Acc: 0.4370, Val Loss: 1.8841, Val Acc: 0.4662
Epoch 3: Train Loss: 1.8026, Train Acc: 0.4891, Val Loss: 1.7489, Val Acc: 0.5037
Epoch 4: Train Loss: 1.6870, Train Acc: 0.5196, Val Loss: 1.6685, Val Acc: 0.5251
Epoch 5: Train Loss: 1.6094, Train Acc: 0.5391, Val Loss: 1.6179, Val Acc: 0.5385
Epoch 6: Train Loss: 1.5525, Train Acc: 0.5531, Val Loss: 1.5846, Val Acc: 0.5467
Epoch 7: Train Loss: 1.5071, Train Acc: 0.5646, Val Loss: 1.5628, Val Acc: 0.5524
Epoch 8: Train Loss: 1.4687, Train Acc: 0.5745, Val Loss: 1.5479, Val Acc: 0.5569
Epoch 9: Train Loss: 1.4346, Train Acc: 0.5832, Val Loss: 1.5386, Val Acc: 0.5597
Epoch 10: Train Loss: 1.4035, Train Acc: 0.5912, Val Loss: 1.5339, Val Acc: 0.5614
Epoch 11: Train Loss: 1.3745, Train Acc: 0.5989, Val Loss: 1.5321, Val Acc: 0.5629
Epoch 12: Train Loss: 1.3470, Train Acc: 0.6063, Val Loss: 1.5326, Val Acc: 0.5634
Epoch 13: Tra

KeyboardInterrupt: 

In [None]:

# Training with: embed_dim=8, hidden_dim=128, num_layers=4, lr=0.0005
# Epoch 1: Train Loss: 2.6495, Train Acc: 0.2909, Val Loss: 2.2317, Val Acc: 0.3762
# Epoch 2: Train Loss: 2.1025, Train Acc: 0.4079, Val Loss: 2.0086, Val Acc: 0.4310
# Epoch 3: Train Loss: 1.9364, Train Acc: 0.4512, Val Loss: 1.8976, Val Acc: 0.4602
# Epoch 4: Train Loss: 1.8342, Train Acc: 0.4787, Val Loss: 1.8251, Val Acc: 0.4817
# Epoch 5: Train Loss: 1.7652, Train Acc: 0.4972, Val Loss: 1.7746, Val Acc: 0.4959

# Training with: embed_dim=8, hidden_dim=128, num_layers=4, lr=0.0001
# Epoch 1: Train Loss: 3.0272, Train Acc: 0.2024, Val Loss: 2.6667, Val Acc: 0.2684
# Epoch 2: Train Loss: 2.4985, Train Acc: 0.3109, Val Loss: 2.3811, Val Acc: 0.3391
# Epoch 3: Train Loss: 2.2985, Train Acc: 0.3593, Val Loss: 2.2286, Val Acc: 0.3745
# Epoch 4: Train Loss: 2.1686, Train Acc: 0.3902, Val Loss: 2.1256, Val Acc: 0.4000
# Epoch 5: Train Loss: 2.0770, Train Acc: 0.4131, Val Loss: 2.0493, Val Acc: 0.4199

# Training with: embed_dim=8, hidden_dim=128, num_layers=6, lr=0.0005
# Epoch 1: Train Loss: 3.2215, Train Acc: 0.1623, Val Loss: 3.2171, Val Acc: 0.1615
# Epoch 2: Train Loss: 3.2188, Train Acc: 0.1623, Val Loss: 3.2173, Val Acc: 0.1615
# Epoch 3: Train Loss: 3.2188, Train Acc: 0.1623, Val Loss: 3.2174, Val Acc: 0.1615
# Epoch 4: Train Loss: 3.2187, Train Acc: 0.1623, Val Loss: 3.2173, Val Acc: 0.1615
# Epoch 5: Train Loss: 3.2187, Train Acc: 0.1623, Val Loss: 3.2173, Val Acc: 0.1615

# Training with: embed_dim=8, hidden_dim=128, num_layers=6, lr=0.0001
# Epoch 1: Train Loss: 3.1103, Train Acc: 0.1844, Val Loss: 2.7249, Val Acc: 0.2656
# Epoch 2: Train Loss: 2.5756, Train Acc: 0.2931, Val Loss: 2.4823, Val Acc: 0.3131
# Epoch 3: Train Loss: 2.4225, Train Acc: 0.3272, Val Loss: 2.3761, Val Acc: 0.3368
# Epoch 4: Train Loss: 2.3367, Train Acc: 0.3435, Val Loss: 2.3057, Val Acc: 0.3497
# Epoch 5: Train Loss: 2.2668, Train Acc: 0.3586, Val Loss: 2.2428, Val Acc: 0.3663

# Training with: embed_dim=8, hidden_dim=256, num_layers=4, lr=0.0005
# Epoch 1: Train Loss: 2.6740, Train Acc: 0.2834, Val Loss: 2.1641, Val Acc: 0.3929
# Epoch 2: Train Loss: 2.0015, Train Acc: 0.4329, Val Loss: 1.8835, Val Acc: 0.4626
# Epoch 3: Train Loss: 1.7981, Train Acc: 0.4868, Val Loss: 1.7642, Val Acc: 0.4960
# Epoch 4: Train Loss: 1.6872, Train Acc: 0.5156, Val Loss: 1.7059, Val Acc: 0.5112
# Epoch 5: Train Loss: 1.6139, Train Acc: 0.5344, Val Loss: 1.6718, Val Acc: 0.5219

# Training with: embed_dim=8, hidden_dim=256, num_layers=4, lr=0.0001
# Epoch 1: Train Loss: 2.9654, Train Acc: 0.2245, Val Loss: 2.6269, Val Acc: 0.2990
# Epoch 2: Train Loss: 2.5003, Train Acc: 0.3223, Val Loss: 2.3906, Val Acc: 0.3427
# Epoch 3: Train Loss: 2.3068, Train Acc: 0.3618, Val Loss: 2.2278, Val Acc: 0.3792
# Epoch 4: Train Loss: 2.1653, Train Acc: 0.3951, Val Loss: 2.1111, Val Acc: 0.4095
# Epoch 5: Train Loss: 2.0632, Train Acc: 0.4199, Val Loss: 2.0294, Val Acc: 0.4275

# Training with: embed_dim=8, hidden_dim=256, num_layers=6, lr=0.0005

In [10]:
correct = 0
total = 0
for inputs, targets in test_dataloader:
    outputs = model(inputs)
    top3_preds = torch.topk(outputs, 3, dim=1)[1]
    translated_output = set([idx_to_char[str(idx.item())] for idx in top3_preds[0]])
    translated_target = idx_to_char[str(targets[0].item())]
    if translated_target in translated_output:
        correct += 1
    total += 1
accuracy = correct / total
print(f"Test Accuracy: {accuracy:.4f}")
    # print("Inputs:", ''.join(translated_input))
    # print("Predicted:", translated_output)
    # print("Targets:", translated_target)
    # print()

Test Accuracy: 0.7174


In [26]:
print(idx_to_char)

{'1': '\t', '2': '\n', '3': ' ', '4': '!', '5': '"', '6': '#', '7': '$', '8': '%', '9': '&', '10': "'", '11': '(', '12': ')', '13': '*', '14': '+', '15': ',', '16': '-', '17': '.', '18': '/', '19': '0', '20': '1', '21': '2', '22': '3', '23': '4', '24': '5', '25': '6', '26': '7', '27': '8', '28': '9', '29': ':', '30': ';', '31': '<', '32': '=', '33': '>', '34': '?', '35': '@', '36': 'A', '37': 'B', '38': 'C', '39': 'D', '40': 'E', '41': 'F', '42': 'G', '43': 'H', '44': 'I', '45': 'J', '46': 'K', '47': 'L', '48': 'M', '49': 'N', '50': 'O', '51': 'P', '52': 'Q', '53': 'R', '54': 'S', '55': 'T', '56': 'U', '57': 'V', '58': 'W', '59': 'X', '60': 'Y', '61': 'Z', '62': '[', '63': '\\', '64': ']', '65': '^', '66': '_', '67': '`', '68': 'a', '69': 'b', '70': 'c', '71': 'd', '72': 'e', '73': 'f', '74': 'g', '75': 'h', '76': 'i', '77': 'j', '78': 'k', '79': 'l', '80': 'm', '81': 'n', '82': 'o', '83': 'p', '84': 'q', '85': 'r', '86': 's', '87': 't', '88': 'u', '89': 'v', '90': 'w', '91': 'x', '92'