In [None]:
import sys
sys.path.append("..") # for sibling import

import walnut
import walnut.tensor_utils as tu

In [None]:
device = "cuda" if walnut.cuda.is_available() else "cpu"
device

# Example 5.2

### Language Model: Neural network

The bigram model is able to predict the following character by looking at the previous one. For better predictions it helps to not only consider one character for a prediction. In this example a neural network is used that uses multiple characters for predictions.

### Step 1: Prepare data
Like in the bigram model, the tinyshakespeare dataset is used. (https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt)

In [None]:
with open("../data/tinyshakespeare.txt", "r") as f:
    data = f.read()

### Step 2: Tokenization

In [None]:
from walnut.preprocessing.text import CharacterTokenizer

tknzr = CharacterTokenizer()
tknzr.vocab_size

In [None]:
data_enc = tknzr.encode(data)
len(data_enc)

### Step 3: Build dataset
In this example a larger `block_size` is now used.

In [None]:
num_samples = 1000000
block_size = 32

In [None]:
import numpy as np

X = walnut.zeros((num_samples, block_size))
y = walnut.zeros((num_samples, block_size))

rand_indices = np.random.randint(0, len(data_enc) - block_size - 1, (num_samples,))

for i, index in enumerate(rand_indices):
    context = data_enc[index : index + block_size]
    label = data_enc[index + 1 : index + block_size + 1]

    X[i] = context
    y[i] = label

X_train = X.int()
y_train = y.int()[:,-1]

print(f"{X_train.shape=}")
print(f"{y_train.shape=}")

### Step 4: Build the neural network structure

As our first layer, again, an `Embedding` layer is used. It is followed by a stack of linear layers.

In [None]:
import walnut.nn as nn
from walnut.nn.layers import *

vocab_size = tknzr.vocab_size
embed_dims = 30
n_hidden = 256
dtype = "float32"

model = nn.Sequential([
    Embedding(vocab_size, embed_dims, dtype=dtype),
    Flatten(),
    Linear(block_size*embed_dims, n_hidden, use_bias=False, dtype=dtype),
    Layernorm((n_hidden,), dtype=dtype),
    Tanh(),
    Linear(n_hidden, n_hidden, use_bias=False, dtype=dtype),
    Layernorm((n_hidden,), dtype=dtype),
    Tanh(),
    Linear(n_hidden, n_hidden, use_bias=False, dtype=dtype),
    Layernorm((n_hidden,), dtype=dtype),
    Tanh(),
    Linear(n_hidden, n_hidden, use_bias=False, dtype=dtype),
    Layernorm((n_hidden,), dtype=dtype),
    Tanh(),
    Linear(n_hidden, n_hidden, use_bias=False, dtype=dtype),
    Layernorm((n_hidden,), dtype=dtype),
    Tanh(),
    Linear(n_hidden, vocab_size, dtype=dtype)
])

model.to_device(device)

In [None]:
model.compile(
    optimizer=nn.optimizers.AdamW(1e-4, eps=1e-8, beta2=0.95),
    loss_fn=nn.losses.Crossentropy(eps=1e-8),
    metric_fn=nn.metrics.accuracy
)

In [None]:
from walnut.nn.analysis import model_summary
model_summary(model, (block_size,))

### Step 5: Train the model

In [None]:
epochs = 100
batch_size = 32000

train_losses, train_scores, _, _ = model.train(X_train, y_train, epochs=epochs, batch_size=batch_size)

### Step 6: Generate text

In [None]:
from walnut.nn.funcional import softmax
context = walnut.ones((1, block_size,), device=device)

for _ in range(1000):
    pred = model(context).squeeze()
    index = walnut.random_choice_indices(softmax(pred.float()))  # needs at least float32 otherwise probs might not sum to 1
    print(tknzr.decode(walnut.expand_dims(index, 0)), end="")
    context = context.append(tu.expand_dims(index, 0), axis=1)
    context = context[:, 1:]

Inspect Embeddings

In [None]:
# !pip install scikit-learn

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt


# get numpy array of embedding table
embs = model.sub_modules[0].sub_modules[0].w.cpu().data

# reduce dimensions to 2 to make
tsne = TSNE(random_state=0).fit_transform(embs)

# plot results
plt.figure(figsize=(8, 8))
plt.scatter(x=tsne[:,0], y=tsne[:,1], alpha=0.5, s=100)
plt.axis("off")
for i in range(len(tsne)):
    char = tknzr.decode(walnut.Tensor([i], dtype="int"))
    plt.text(x=tsne[i,0]-0.03, y=tsne[i,1]-0.04, s=char)