<a href="https://colab.research.google.com/github/myllanes/Introduction-to-Deep-Learning/blob/main/HW5_1_WITHATTENTION.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Michael Yllanes
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from sklearn.model_selection import train_test_split
import time
import matplotlib.pyplot as plt
from prettytable import PrettyTable

# Sample text
text = "Next character prediction is a fundamental task in the field of natural language processing (NLP) that involves predicting the next character in a sequence of text based on the characters that precede it. This task is essential for various applications, including text auto-completion, spell checking, and even in the development of sophisticated AI models capable of generating human-like text. At its core, next character prediction relies on statistical models or deep learning algorithms to analyze a given sequence of text and predict which character is most likely to follow. These predictions are based on patterns and relationships learned from large datasets of text during the training phase of the model. One of the most popular approaches to next character prediction involves the use of Recurrent Neural Networks (RNNs), and more specifically, a variant called Long Short-Term Memory (LSTM) networks. RNNs are particularly well-suited for sequential data like text, as they can maintain information in 'memory' about previous characters to inform the prediction of the next character. LSTM networks enhance this capability by being able to remember long-term dependencies, making them even more effective for next character prediction tasks. Training a model for next character prediction involves feeding it large amounts of text data, allowing it to learn the probability of each character's appearance following a sequence of characters. During this training process, the model adjusts its parameters to minimize the difference between its predictions and the actual outcomes, thus improving its predictive accuracy over time. Once trained, the model can be used to predict the next character in a given piece of text by considering the sequence of characters that precede it. This can enhance user experience in text editing software, improve efficiency in coding environments with auto-completion features, and enable more natural interactions with AI-based chatbots and virtual assistants. In summary, next character prediction plays a crucial role in enhancing the capabilities of various NLP applications, making text-based interactions more efficient, accurate, and human-like. Through the use of advanced machine learning models like RNNs and LSTMs, next character prediction continues to evolve, opening new possibilities for the future of text-based technology."

# Create character vocabulary
chars = sorted(list(set(text)))
ix_to_char = {i: ch for i, ch in enumerate(chars)}
char_to_ix = {ch: i for i, ch in enumerate(chars)}

# Preparing the dataset
max_length = 30  # Maximum length of input sequences
X = []
y = []
for i in range(len(text) - max_length):
    sequence = text[i:i + max_length]
    label = text[i + max_length]
    X.append([char_to_ix[char] for char in sequence])
    y.append(char_to_ix[label])

X = np.array(X)
y = np.array(y)

# Splitting the dataset
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Converting to PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.long)
y_train = torch.tensor(y_train, dtype=torch.long)
X_val = torch.tensor(X_val, dtype=torch.long)
y_val = torch.tensor(y_val, dtype=torch.long)

class CrossAttentionRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(CrossAttentionRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        # Embedding layer
        self.embedding = nn.Embedding(input_size, hidden_size)

        # RNN layer
        self.rnn = nn.RNN(hidden_size, hidden_size, num_layers, batch_first=True)

        # Attention layers
        self.query = nn.Linear(hidden_size, hidden_size)
        self.key = nn.Linear(hidden_size, hidden_size)
        self.value = nn.Linear(hidden_size, hidden_size)

        # Output layer
        self.fc = nn.Linear(hidden_size * 2, output_size)

    def forward(self, x):
        batch_size = x.size(0)

        # Initialize hidden state
        h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(x.device)

        # Embedding
        embedded = self.embedding(x)

        # RNN forward pass
        rnn_out, hn = self.rnn(embedded, h0)

        # Cross-attention mechanism
        query = hn[-1].unsqueeze(1)  # Use last hidden state as query
        keys = self.key(rnn_out)
        values = self.value(rnn_out)

        # Calculate attention
        attn_scores = torch.bmm(query, keys.transpose(1, 2))
        attn_weights = F.softmax(attn_scores, dim=-1)
        context = torch.bmm(attn_weights, values).squeeze(1)

        # Combine with RNN output
        rnn_last = rnn_out[:, -1, :]
        combined = torch.cat([rnn_last, context], dim=1)
        output = self.fc(combined)

        return output, attn_weights

# Initialize model
hidden_size = 128
model = CrossAttentionRNN(len(chars), hidden_size, len(chars))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.005)

# Function to calculate model size and complexity
def analyze_model(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params += params
    print(table)
    print(f"Total Trainable Parameters: {total_params:,}")

    # Computational complexity analysis
    seq_length = max_length
    embed_dim = hidden_size
    print("\nComputational Complexity Analysis:")
    print(f"- Embedding Layer: O(n × d) per sequence, where n={seq_length}, d={embed_dim}")
    print(f"- RNN Layer: O(n × d²) per sequence")
    print(f"- Attention Mechanism: O(n × d²) per sequence (query-key), O(n × d) (value)")
    print(f"- Overall per sequence: O(n × d²)")

analyze_model(model)

# Training loop
start_time = time.time()
train_losses = []
val_losses = []
accuracies = []

for epoch in range(500):
    model.train()
    optimizer.zero_grad()
    output, _ = model(X_train)
    loss = criterion(output, y_train)
    loss.backward()
    optimizer.step()
    train_losses.append(loss.item())

    # Validation
    model.eval()
    with torch.no_grad():
        val_output, _ = model(X_val)
        val_loss = criterion(val_output, y_val)
        val_losses.append(val_loss.item())
        _, predicted = torch.max(val_output, 1)
        accuracy = (predicted == y_val).float().mean()
        accuracies.append(accuracy.item())

    if (epoch+1) % 10 == 0:
        print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}, Accuracy: {accuracy.item():.4f}')

training_time = time.time() - start_time
print(f"\nTraining completed in {training_time:.2f} seconds")


# Prediction function
def predict_next_char(model, char_to_ix, ix_to_char, initial_str):
    model.eval()
    with torch.no_grad():
        initial_input = torch.tensor([char_to_ix[c] for c in initial_str[-max_length:]], dtype=torch.long).unsqueeze(0)
        output, attn_weights = model(initial_input)
        predicted_index = torch.argmax(output, dim=1).item()



        return ix_to_char[predicted_index]

# Test prediction
test_str = "predic"
predicted_char = predict_next_char(model, char_to_ix, ix_to_char, test_str)
print(f"\nPredicted next character for '{test_str}': '{predicted_char}'")

+------------------+------------+
|     Modules      | Parameters |
+------------------+------------+
| embedding.weight |    5632    |
| rnn.weight_ih_l0 |   16384    |
| rnn.weight_hh_l0 |   16384    |
|  rnn.bias_ih_l0  |    128     |
|  rnn.bias_hh_l0  |    128     |
|   query.weight   |   16384    |
|    query.bias    |    128     |
|    key.weight    |   16384    |
|     key.bias     |    128     |
|   value.weight   |   16384    |
|    value.bias    |    128     |
|    fc.weight     |   11264    |
|     fc.bias      |     44     |
+------------------+------------+
Total Trainable Parameters: 99,500

Computational Complexity Analysis:
- Embedding Layer: O(n × d) per sequence, where n=30, d=128
- RNN Layer: O(n × d²) per sequence
- Attention Mechanism: O(n × d²) per sequence (query-key), O(n × d) (value)
- Overall per sequence: O(n × d²)
Epoch 10, Loss: 2.1497, Val Loss: 2.3821, Accuracy: 0.3432
Epoch 20, Loss: 1.5555, Val Loss: 2.3217, Accuracy: 0.3919
Epoch 30, Loss: 1.1125, Val