In [None]:
import numpy as np

import walnut
from walnut.preprocessing.text import WordTokenizer, remove_punctuation

In [None]:
with open("data/tinyshakespeare.txt", "r") as f:
    data = f.read()[:10000]

In [None]:
tknzr = WordTokenizer()
tknzr.fit(data, max_tokens=300)

In [None]:
num_samples = 1000
block_size = 3
vocab_size = len(tknzr.tokens) + 1

# initialize tensors with zeros
X_array = np.zeros((num_samples, block_size, vocab_size))
Y_array = np.zeros((num_samples, vocab_size))

#randomly choose indices of blocks in the original data
data_clean = remove_punctuation(data)
data_split = data_clean.split(" ")
rand_indices = np.random.randint(0, len(data_split) - block_size, (num_samples,))

for i, index in enumerate(rand_indices):
    # get characters and the label from the data
    context = data_split[index : index + block_size]
    label = data_split[index + block_size]

    # encode characters to get the indices
    context_enc = [tknzr.encode(c) for c in context]
    label_enc = tknzr.encode(label)

    # one-hot-encode indices and add to the tensors
    X_array[i] = np.eye(vocab_size)[context_enc]
    Y_array[i] = np.eye(vocab_size)[label_enc]

X = walnut.Tensor(X_array, dtype="int")
Y = walnut.Tensor(Y_array, dtype="int")

print(f"{X.shape=}")
print(f"{Y.shape=}")

In [None]:
import walnut.nn as nn

model = nn.Sequential(layers=[
    nn.layers.Embedding(10, input_shape=(block_size, vocab_size)),
    nn.layers.Linear(100, act="tanh", norm="layer"),
    nn.layers.Linear(vocab_size, act="softmax")
])

In [None]:
model.compile(nn.optimizers.Adam(), nn.losses.Crossentropy(), nn.metrics.Accuracy())

In [None]:
model

In [None]:
train_hist, val_hist = model.train(X, Y, epochs=1000, verbose="all")

In [None]:
traces = {
    "train_loss" : train_hist,
    "val_loss" : val_hist
}
nn.analysis.plot_curve(traces=traces, figsize=(20, 4), title="loss history", x_label="epoch", y_label="loss")

In [None]:
prompt = "He was just"
prompt_split = prompt.split(" ")[-block_size:]

for i in range(100):
    X_test = walnut.Tensor([tknzr.encode(word) for word in prompt_split])
    X_enc = walnut.expand_dims(walnut.preprocessing.encoding.one_hot_encode(X_test, vocab_size), 0)
    pred = tknzr.decode(walnut.choice(model(X_enc)[0]))
    print(pred, end=" ")
    prompt_split = prompt_split[1:] + [pred]