In [31]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.model_selection import train_test_split
import time

In [32]:
text = """Next character prediction is a fundamental task in the field of natural language processing (NLP) that involves predicting the next character in a sequence of text based on the characters that precede it. This task is essential for various applications, including text auto-completion, spell checking, and even in the development of sophisticated AI models capable of generating human-like text.

At its core, next character prediction relies on statistical models or deep learning algorithms to analyze a given sequence of text and predict which character is most likely to follow. These predictions are based on patterns and relationships learned from large datasets of text during the training phase of the model.

One of the most popular approaches to next character prediction involves the use of Recurrent Neural Networks (RNNs), and more specifically, a variant called Long Short-Term Memory (LSTM) networks. RNNs are particularly well-suited for sequential data like text, as they can maintain information in 'memory' about previous characters to inform the prediction of the next character. LSTM networks enhance this capability by being able to remember long-term dependencies, making them even more effective for next character prediction tasks.

Training a model for next character prediction involves feeding it large amounts of text data, allowing it to learn the probability of each character's appearance following a sequence of characters. During this training process, the model adjusts its parameters to minimize the difference between its predictions and the actual outcomes, thus improving its predictive accuracy over time.

Once trained, the model can be used to predict the next character in a given piece of text by considering the sequence of characters that precede it. This can enhance user experience in text editing software, improve efficiency in coding environments with auto-completion features, and enable more natural interactions with AI-based chatbots and virtual assistants.

In summary, next character prediction plays a crucial role in enhancing the capabilities of various NLP applications, making text-based interactions more efficient, accurate, and human-like. Through the use of advanced machine learning models like RNNs and LSTMs, next character prediction continues to evolve, opening new possibilities for the future of text-based technology."""


In [33]:
chars = sorted(list(set(text)))
ix_to_char = {i: ch for i, ch in enumerate(chars)}
char_to_ix = {ch: i for i, ch in enumerate(chars)} 

In [34]:
def prepare_dataset(text, max_length):
    X = []
    y = []
    for i in range(len(text) - max_length):
        sequence = text[i:i + max_length]
        label = text[i + max_length]
        X.append([char_to_ix[char] for char in sequence])
        y.append(char_to_ix[label])
    X = np.array(X)
    y = np.array(y)
    return X, y


In [35]:
class RNNModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(RNNModel, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.rnn = nn.RNN(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.rnn(embedded)
        output = self.fc(output[:, -1, :])
        return output

In [36]:
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.lstm(embedded)
        output = self.fc(output[:, -1, :])
        return output


In [37]:
class GRUModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(GRUModel, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.gru(embedded)
        output = self.fc(output[:, -1, :])
        return output

In [38]:
def get_num_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [39]:
hidden_size = 252
learning_rate = 0.001
epochs = 120

In [40]:
import time
import torch


def train_and_validate(model, X_train, y_train, X_val, y_val, criterion, optimizer, epochs, device):
    
    model = model.to(device)
    X_train, y_train = X_train.to(device), y_train.to(device)
    X_val, y_val = X_val.to(device), y_val.to(device)
    
    start_time = time.time()  
    for epoch in range(epochs):
        model.train()  
        optimizer.zero_grad()  
        output = model(X_train)
        loss = criterion(output, y_train)
        loss.backward()
        optimizer.step()

        if (epoch + 1) % 10 == 0:  
            model.eval()  
            with torch.no_grad():
                val_output = model(X_val)
                val_loss = criterion(val_output, y_val)
                _, predicted = torch.max(val_output, 1)
                val_accuracy = (predicted == y_val).float().mean()
            # Print training and validation results
            print(f'Epoch {epoch + 1}, Loss: {loss.item()}, Validation Loss: {val_loss.item()}, Validation Accuracy: {val_accuracy.item()}')
    
    end_time = time.time()  # End time
    execution_time = end_time - start_time
    print(f"Execution Time: {execution_time} seconds") 


In [41]:
X_1, y_1 = prepare_dataset(text, 1)
X_train_1, X_val_1, y_train_1, y_val_1 = train_test_split(X_1, y_1, test_size=0.2, random_state=42)
X_train_1 = torch.tensor(X_train_1, dtype=torch.long)
y_train_1 = torch.tensor(y_train_1, dtype=torch.long)
X_val_1 = torch.tensor(X_val_1, dtype=torch.long)
y_val_1 = torch.tensor(y_val_1, dtype=torch.long)


In [42]:
X_2, y_2 = prepare_dataset(text, 2)
X_train_2, X_val_2, y_train_2, y_val_2 = train_test_split(X_2, y_2, test_size=0.2, random_state=42)
X_train_2 = torch.tensor(X_train_2, dtype=torch.long)
y_train_2 = torch.tensor(y_train_2, dtype=torch.long)
X_val_2 = torch.tensor(X_val_2, dtype=torch.long)
y_val_2 = torch.tensor(y_val_2, dtype=torch.long)


In [43]:
X_3, y_3 = prepare_dataset(text, 3)
X_train_3, X_val_3, y_train_3, y_val_3 = train_test_split(X_3, y_3, test_size=0.2, random_state=42)
X_train_3 = torch.tensor(X_train_3, dtype=torch.long)
y_train_3 = torch.tensor(y_train_3, dtype=torch.long)
X_val_3 = torch.tensor(X_val_3, dtype=torch.long)
y_val_3 = torch.tensor(y_val_3, dtype=torch.long)


In [44]:
rnn_model = RNNModel(len(chars), hidden_size, len(chars))
lstm_model = LSTMModel(len(chars), hidden_size, len(chars))
gru_model = GRUModel(len(chars), hidden_size, len(chars))

In [45]:
criterion = nn.CrossEntropyLoss()
rnn_optimizer = optim.Adam(rnn_model.parameters(), lr=learning_rate)
lstm_optimizer = optim.Adam(lstm_model.parameters(), lr=learning_rate)
gru_optimizer = optim.Adam(gru_model.parameters(), lr=learning_rate)

In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [47]:
print("RNN Model:")
train_and_validate(rnn_model, X_train_1, y_train_1, X_val_1, y_val_1, criterion, rnn_optimizer, epochs, device)
print(get_num_params(rnn_model))

RNN Model:
Epoch 10, Loss: 2.8190906047821045, Validation Loss: 2.827759265899658, Validation Accuracy: 0.2552301287651062
Epoch 20, Loss: 2.4165122509002686, Validation Loss: 2.5396246910095215, Validation Accuracy: 0.2845188081264496
Epoch 30, Loss: 2.2846312522888184, Validation Loss: 2.469707489013672, Validation Accuracy: 0.26359832286834717
Epoch 40, Loss: 2.2249252796173096, Validation Loss: 2.437910318374634, Validation Accuracy: 0.2594142258167267
Epoch 50, Loss: 2.1944639682769775, Validation Loss: 2.4180827140808105, Validation Accuracy: 0.26778241991996765
Epoch 60, Loss: 2.175205945968628, Validation Loss: 2.4137637615203857, Validation Accuracy: 0.2552301287651062
Epoch 70, Loss: 2.1631758213043213, Validation Loss: 2.4125943183898926, Validation Accuracy: 0.2698744833469391
Epoch 80, Loss: 2.1556789875030518, Validation Loss: 2.4146406650543213, Validation Accuracy: 0.26778241991996765
Epoch 90, Loss: 2.150578498840332, Validation Loss: 2.421037435531616, Validation Accu

In [48]:
train_and_validate(rnn_model, X_train_2, y_train_2, X_val_2, y_val_2, criterion, rnn_optimizer,epochs, device)
print(get_num_params(rnn_model))

Epoch 10, Loss: 1.7925397157669067, Validation Loss: 2.146686315536499, Validation Accuracy: 0.38912132382392883
Epoch 20, Loss: 1.560263991355896, Validation Loss: 2.104840040206909, Validation Accuracy: 0.41004183888435364
Epoch 30, Loss: 1.414550542831421, Validation Loss: 2.1246912479400635, Validation Accuracy: 0.4016736149787903
Epoch 40, Loss: 1.3167483806610107, Validation Loss: 2.1399459838867188, Validation Accuracy: 0.42677822709083557
Epoch 50, Loss: 1.2520959377288818, Validation Loss: 2.1745829582214355, Validation Accuracy: 0.42677822709083557
Epoch 60, Loss: 1.209848165512085, Validation Loss: 2.21372127532959, Validation Accuracy: 0.43096232414245605
Epoch 70, Loss: 1.1819595098495483, Validation Loss: 2.2560644149780273, Validation Accuracy: 0.428870290517807
Epoch 80, Loss: 1.1632120609283447, Validation Loss: 2.2983338832855225, Validation Accuracy: 0.428870290517807
Epoch 90, Loss: 1.1503373384475708, Validation Loss: 2.337310552597046, Validation Accuracy: 0.42677

In [49]:
train_and_validate(rnn_model, X_train_3, y_train_3, X_val_3, y_val_3, criterion, rnn_optimizer,epochs, device)
print(get_num_params(rnn_model))

Epoch 10, Loss: 1.062303900718689, Validation Loss: 1.6318860054016113, Validation Accuracy: 0.5167363882064819
Epoch 20, Loss: 0.8074373006820679, Validation Loss: 1.6754889488220215, Validation Accuracy: 0.5418409705162048
Epoch 30, Loss: 0.6683316230773926, Validation Loss: 1.7292901277542114, Validation Accuracy: 0.5418409705162048
Epoch 40, Loss: 0.5888970494270325, Validation Loss: 1.7992603778839111, Validation Accuracy: 0.5397489666938782
Epoch 50, Loss: 0.5440475344657898, Validation Loss: 1.8663281202316284, Validation Accuracy: 0.5334727764129639
Epoch 60, Loss: 0.5177153944969177, Validation Loss: 1.9274318218231201, Validation Accuracy: 0.5439330339431763
Epoch 70, Loss: 0.5013795495033264, Validation Loss: 1.9812839031219482, Validation Accuracy: 0.5376569032669067
Epoch 80, Loss: 0.4905721843242645, Validation Loss: 2.031243085861206, Validation Accuracy: 0.5439330339431763
Epoch 90, Loss: 0.4829903542995453, Validation Loss: 2.0754292011260986, Validation Accuracy: 0.54

In [50]:
print("LSTM Model:")
train_and_validate(lstm_model, X_train_1, y_train_1, X_val_1, y_val_1, criterion, lstm_optimizer,epochs, device)
print(get_num_params(lstm_model))

LSTM Model:
Epoch 10, Loss: 3.2934412956237793, Validation Loss: 3.265193462371826, Validation Accuracy: 0.2552301287651062
Epoch 20, Loss: 2.7467668056488037, Validation Loss: 2.7896931171417236, Validation Accuracy: 0.26359832286834717
Epoch 30, Loss: 2.481499195098877, Validation Loss: 2.574321985244751, Validation Accuracy: 0.2698744833469391
Epoch 40, Loss: 2.3635807037353516, Validation Loss: 2.492399215698242, Validation Accuracy: 0.26778241991996765
Epoch 50, Loss: 2.2946183681488037, Validation Loss: 2.450814962387085, Validation Accuracy: 0.2656903564929962
Epoch 60, Loss: 2.2516229152679443, Validation Loss: 2.4244048595428467, Validation Accuracy: 0.26359832286834717
Epoch 70, Loss: 2.2219302654266357, Validation Loss: 2.4195313453674316, Validation Accuracy: 0.26359832286834717
Epoch 80, Loss: 2.2008533477783203, Validation Loss: 2.4129128456115723, Validation Accuracy: 0.2594142258167267
Epoch 90, Loss: 2.1853156089782715, Validation Loss: 2.412950038909912, Validation Ac

In [51]:
print("LSTM Model:")
train_and_validate(lstm_model, X_train_2, y_train_2, X_val_2, y_val_2, criterion, lstm_optimizer,epochs, device)
print(get_num_params(lstm_model))

LSTM Model:
Epoch 10, Loss: 1.8638733625411987, Validation Loss: 2.140531539916992, Validation Accuracy: 0.3723849356174469
Epoch 20, Loss: 1.5999633073806763, Validation Loss: 2.0835647583007812, Validation Accuracy: 0.4142259359359741
Epoch 30, Loss: 1.4280579090118408, Validation Loss: 2.0665769577026367, Validation Accuracy: 0.437238484621048
Epoch 40, Loss: 1.3147048950195312, Validation Loss: 2.0790088176727295, Validation Accuracy: 0.439330518245697
Epoch 50, Loss: 1.2411460876464844, Validation Loss: 2.1196553707122803, Validation Accuracy: 0.43096232414245605
Epoch 60, Loss: 1.1947238445281982, Validation Loss: 2.170346736907959, Validation Accuracy: 0.43514642119407654
Epoch 70, Loss: 1.1661120653152466, Validation Loss: 2.2193410396575928, Validation Accuracy: 0.4330543875694275
Epoch 80, Loss: 1.14840829372406, Validation Loss: 2.2642269134521484, Validation Accuracy: 0.4330543875694275
Epoch 90, Loss: 1.1371252536773682, Validation Loss: 2.307746648788452, Validation Accur

In [52]:
print("LSTM Model:")
train_and_validate(lstm_model, X_train_3, y_train_3, X_val_3, y_val_3, criterion, lstm_optimizer,epochs, device)
print(get_num_params(lstm_model))

LSTM Model:
Epoch 10, Loss: 1.0321773290634155, Validation Loss: 1.6299558877944946, Validation Accuracy: 0.5188284516334534
Epoch 20, Loss: 0.7690505981445312, Validation Loss: 1.6649969816207886, Validation Accuracy: 0.5355648398399353
Epoch 30, Loss: 0.6267516016960144, Validation Loss: 1.7253037691116333, Validation Accuracy: 0.5481171607971191
Epoch 40, Loss: 0.5518584251403809, Validation Loss: 1.8037258386611938, Validation Accuracy: 0.5397489666938782
Epoch 50, Loss: 0.5132642984390259, Validation Loss: 1.8851126432418823, Validation Accuracy: 0.5334727764129639
Epoch 60, Loss: 0.4928892254829407, Validation Loss: 1.9561148881912231, Validation Accuracy: 0.5460250973701477
Epoch 70, Loss: 0.481222540140152, Validation Loss: 2.0159895420074463, Validation Accuracy: 0.5481171607971191
Epoch 80, Loss: 0.4739202558994293, Validation Loss: 2.067854166030884, Validation Accuracy: 0.5460250973701477
Epoch 90, Loss: 0.46896180510520935, Validation Loss: 2.1111650466918945, Validation A

In [53]:
print("GRU Model:")
train_and_validate(gru_model, X_train_1, y_train_1, X_val_1, y_val_1, criterion, gru_optimizer,epochs, device)
print(get_num_params(gru_model))

GRU Model:
Epoch 10, Loss: 3.0353755950927734, Validation Loss: 3.0109283924102783, Validation Accuracy: 0.286610871553421
Epoch 20, Loss: 2.5264534950256348, Validation Loss: 2.6055026054382324, Validation Accuracy: 0.2698744833469391
Epoch 30, Loss: 2.3619093894958496, Validation Loss: 2.5176753997802734, Validation Accuracy: 0.2698744833469391
Epoch 40, Loss: 2.279667377471924, Validation Loss: 2.45976185798645, Validation Accuracy: 0.26778241991996765
Epoch 50, Loss: 2.231778144836426, Validation Loss: 2.432094097137451, Validation Accuracy: 0.26778241991996765
Epoch 60, Loss: 2.202481746673584, Validation Loss: 2.4236841201782227, Validation Accuracy: 0.24686191976070404
Epoch 70, Loss: 2.1829121112823486, Validation Loss: 2.4159412384033203, Validation Accuracy: 0.2594142258167267
Epoch 80, Loss: 2.169572353363037, Validation Loss: 2.4152328968048096, Validation Accuracy: 0.2656903564929962
Epoch 90, Loss: 2.1603031158447266, Validation Loss: 2.4175329208374023, Validation Accura

In [54]:
print("GRU Model:")
train_and_validate(gru_model, X_train_2, y_train_2, X_val_2, y_val_2, criterion, gru_optimizer,epochs, device)
print(get_num_params(gru_model))

GRU Model:
Epoch 10, Loss: 1.847272515296936, Validation Loss: 2.180546522140503, Validation Accuracy: 0.38493722677230835
Epoch 20, Loss: 1.56328547000885, Validation Loss: 2.114269733428955, Validation Accuracy: 0.4016736149787903
Epoch 30, Loss: 1.3939865827560425, Validation Loss: 2.088164806365967, Validation Accuracy: 0.4225941300392151
Epoch 40, Loss: 1.2852137088775635, Validation Loss: 2.1070504188537598, Validation Accuracy: 0.4246861934661865
Epoch 50, Loss: 1.2181322574615479, Validation Loss: 2.1577041149139404, Validation Accuracy: 0.4246861934661865
Epoch 60, Loss: 1.1775497198104858, Validation Loss: 2.21258807182312, Validation Accuracy: 0.42677822709083557
Epoch 70, Loss: 1.153670310974121, Validation Loss: 2.265575885772705, Validation Accuracy: 0.4330543875694275
Epoch 80, Loss: 1.1391923427581787, Validation Loss: 2.314955949783325, Validation Accuracy: 0.428870290517807
Epoch 90, Loss: 1.130064845085144, Validation Loss: 2.3600525856018066, Validation Accuracy: 0.

In [55]:
print("GRU Model:")
train_and_validate(gru_model, X_train_3, y_train_3, X_val_3, y_val_3, criterion, gru_optimizer,epochs, device)
print(get_num_params(gru_model))

GRU Model:
Epoch 10, Loss: 0.9958236813545227, Validation Loss: 1.675275444984436, Validation Accuracy: 0.5041840672492981
Epoch 20, Loss: 0.7340080142021179, Validation Loss: 1.6907497644424438, Validation Accuracy: 0.5606694221496582
Epoch 30, Loss: 0.5992169380187988, Validation Loss: 1.7615966796875, Validation Accuracy: 0.5523012280464172
Epoch 40, Loss: 0.5319854021072388, Validation Loss: 1.8627727031707764, Validation Accuracy: 0.5543932914733887
Epoch 50, Loss: 0.49934303760528564, Validation Loss: 1.9536759853363037, Validation Accuracy: 0.5502091646194458
Epoch 60, Loss: 0.48268359899520874, Validation Loss: 2.0268783569335938, Validation Accuracy: 0.5502091646194458
Epoch 70, Loss: 0.47340476512908936, Validation Loss: 2.0857481956481934, Validation Accuracy: 0.5397489666938782
Epoch 80, Loss: 0.4676727056503296, Validation Loss: 2.1323981285095215, Validation Accuracy: 0.5481171607971191
Epoch 90, Loss: 0.4638062119483948, Validation Loss: 2.1722216606140137, Validation Ac