<a href="https://colab.research.google.com/github/dietmarja/LLM-Elements/blob/main/backpropagation/autograd.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
!pip install torchviz
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchviz import make_dot

class SimpleLanguageModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(SimpleLanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        embedded = self.embedding(x)
        output, _ = self.lstm(embedded)
        return self.fc(output)

# Set up model parameters
vocab_size, embedding_dim, hidden_dim = 1000, 50, 100
seq_length, batch_size = 10, 32

# Create model, loss function, and optimizers
model = SimpleLanguageModel(vocab_size, embedding_dim, hidden_dim)
criterion = nn.CrossEntropyLoss()
adam_optimizer = optim.Adam(model.parameters(), lr=0.001)
sgd_optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
num_epochs = 100
adam_losses = []
sgd_losses = []

for epoch in range(num_epochs):
    # Simulate input data and target
    input_seq = torch.randint(0, vocab_size, (batch_size, seq_length))
    target = torch.randint(0, vocab_size, (batch_size, seq_length))

    # Train with Adam
    adam_optimizer.zero_grad()
    output = model(input_seq)
    loss = criterion(output.view(-1, vocab_size), target.view(-1))
    loss.backward()
    adam_optimizer.step()
    adam_losses.append(loss.item())

    # Train with SGD
    sgd_optimizer.zero_grad()
    output = model(input_seq)
    loss = criterion(output.view(-1, vocab_size), target.view(-1))
    loss.backward()
    sgd_optimizer.step()
    sgd_losses.append(loss.item())

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Adam Loss: {adam_losses[-1]:.4f}, SGD Loss: {sgd_losses[-1]:.4f}")

# Visualize computational graph
input_seq = torch.randint(0, vocab_size, (1, seq_length))
output = model(input_seq)
make_dot(output, params=dict(model.named_parameters())).render("model_graph", format="png")

# Plot loss over time
plt.figure(figsize=(10, 5))
plt.plot(adam_losses, label='Adam')
plt.plot(sgd_losses, label='SGD')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Time')
plt.legend()
plt.savefig('loss_plot.png')
plt.close()

# Visualize word embeddings
from sklearn.manifold import TSNE
import numpy as np

embeddings = model.embedding.weight.detach().numpy()
tsne = TSNE(n_components=2, random_state=42)
embeddings_2d = tsne.fit_transform(embeddings)

plt.figure(figsize=(10, 10))
plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.5)
for i, (x, y) in enumerate(embeddings_2d):
    if i < 20:  # Plot only first 20 words for clarity
        plt.annotate(str(i), (x, y))
plt.title('2D Visualization of Word Embeddings')
plt.savefig('word_embeddings.png')
plt.close()

print("Training complete. Check the generated plots for visualizations.")

Collecting torchviz
  Downloading torchviz-0.0.2.tar.gz (4.9 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->torchviz)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->torchviz)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->torchviz)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->torchviz)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->torchviz)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->torchviz)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manyl