In [None]:
%load_ext autoreload
%autoreload 2

%env TOKENIZERS_PARALLELISM=false

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

########################
# 1) Data Preparation
########################

# Read Shakespeare text (ensure 'shakespeare.txt' exists)
with open("../dataset/shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}

encoded_text = torch.tensor([char_to_idx[ch] for ch in text], dtype=torch.long)

# Hyperparameters
block_size = 100
batch_size = 32
hidden_size = 512  # Larger hidden dimension
num_layers = 3
num_epochs = 10
learning_rate = 0.001


def get_batch(data, batch_size, block_size):
    start_indices = torch.randint(0, len(data) - block_size - 1, (batch_size,))
    x, y = [], []
    for s in start_indices:
        x.append(data[s : s + block_size])
        y.append(data[s + 1 : s + block_size + 1])
    return torch.stack(x), torch.stack(y)


In [None]:
########################
# 2) Minimal SSM Model
########################


class SmallSSMLayer(nn.Module):
    """
    Discrete State-Space layer:
    x_{k+1} = A*x_k + B*u_k
    y_k     = C*x_k + D*u_k
    where A, B, C, D are learned parameters (all [hidden_size x hidden_size], etc.)
    """
    def __init__(self, hidden_size):
        super().__init__()
        # For simplicity, treat input dimension == hidden_size
        self.A = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.B = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.C = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)
        self.D = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.01)

    def forward(self, u):
        """
        u: [batch_size, seq_len, hidden_size]
        returns: [batch_size, seq_len, hidden_size]
        """
        batch_size, seq_len, _ = u.shape
        x = torch.zeros(batch_size, self.A.shape[0], device=u.device)
        outputs = []
        for t in range(seq_len):
            x = x @ self.A + u[:, t, :] @ self.B
            y = x @ self.C + u[:, t, :] @ self.D
            outputs.append(y.unsqueeze(1))
        return torch.cat(outputs, dim=1)

class SmallSSMModel(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, hidden_size)
        self.ssm = SmallSSMLayer(hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        # x: [batch_size, seq_len]
        embedded = self.embed(x)          # [batch_size, seq_len, hidden_size]
        ssm_out = self.ssm(embedded)      # [batch_size, seq_len, hidden_size]
        logits = self.fc(ssm_out)         # [batch_size, seq_len, vocab_size]
        return logits

In [None]:
########################
# 3) Training Loop
########################

model = SmallSSMModel(vocab_size, hidden_size)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(num_epochs):
    x_batch, y_batch = get_batch(encoded_text, batch_size, block_size)
    logits = model(x_batch)
    loss = criterion(logits.reshape(-1, vocab_size), y_batch.reshape(-1))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")


In [None]:
########################
# 4) Inference Function
########################


def generate_text(model, start_text="ROMEO:", max_length=200):
    model.eval()
    input_seq = torch.tensor([[char_to_idx[ch] for ch in start_text]], dtype=torch.long)
    generated = list(start_text)

    for _ in range(max_length):
        logits = model(input_seq)  # [1, seq_len, vocab_size]
        last_logits = logits[:, -1, :]  # [1, vocab_size]
        probs = torch.softmax(last_logits, dim=1)
        idx_next = torch.multinomial(probs, num_samples=1).item()

        generated.append(idx_to_char[idx_next])
        next_input = torch.tensor([[idx_next]], dtype=torch.long)
        input_seq = torch.cat([input_seq, next_input], dim=1)

    return "".join(generated)


########################
# Usage Example
########################

print("\nGenerated Text:\n")
print(generate_text(model, start_text="ROMEO:", max_length=300))