In [1]:
# !mkdir -p datasets
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -O datasets/shakespeare.txt

In [None]:
import simplegrad as sg
import numpy as np

In [3]:
BATCH_SIZE = 32
BLOCK_SIZE = 8
LEARNING_RATE = 1e-2
MAX_ITERS = 3000
VAL_INTERVAL = 300
VAL_ITERS = 200

In [4]:
with open("datasets/shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

print("Vocab. size:", vocab_size)

Vocab. size: 65


In [5]:
s2i = {ch: i for i, ch in enumerate(chars)}
i2s = {i: ch for i, ch in enumerate(chars)}

encode = lambda string: [s2i[ch] for ch in string]
decode = lambda tokens: [i2s[tok] for tok in tokens]

In [6]:
data = encode(text)
split_idx = int(len(data) * 0.9)
train_data = data[:split_idx]
val_data = data[split_idx:]
print("Train:", len(train_data))
print("Val.:", len(val_data))

Train: 1003854
Val.: 111540


In [None]:
class BigramModel(sg.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_table = sg.nn.Embedding(self.vocab_size, self.vocab_size)

    def forward(self, context):
        return self.embedding_table(context)

    def generate(self, context, max_new_tokens=300):
        res = [context.values.item()]
        current = context
        for _ in range(max_new_tokens):
            next = sg.Tensor(
                np.random.choice(range(self.vocab_size), size=1, p=sg.softmax(self.forward(current), dim=-1).values[0, 0, :]), dtype="int8"
            )
            res.append(int(next.values.item()))
            current = next
        return res


model = BigramModel(vocab_size)

In [8]:
def get_batches(split="train"):
    data = train_data if split == "train" else val_data
    idxs = np.random.randint(low=0, high=len(data) - BLOCK_SIZE, size=(BATCH_SIZE,))
    x = sg.Tensor([data[i : i + BLOCK_SIZE] for i in idxs], dtype="int8")
    y = [data[i + 1 : i + BLOCK_SIZE + 1] for i in idxs]
    y_one_hot = sg.zeros((BATCH_SIZE, BLOCK_SIZE, vocab_size), comp_grad=False)
    for b in range(BATCH_SIZE):
        for t in range(BLOCK_SIZE):
            y_one_hot.values[b, t, y[b][t]] = 1
    return x, y_one_hot


x_example, y_example = get_batches("train")
print("x_example:", x_example.shape)
print("y_example:", y_example.shape)

x_example: (32, 8)
y_example: (32, 8, 65)


In [9]:
def estimate_loss():
    out = {}
    for split in ["train", "val"]:
        losses = np.zeros(VAL_ITERS)
        for i in range(VAL_ITERS):
            x_batch, y_batch = get_batches(split)
            losses[i] = sg.ce_loss(model(x_batch), y_batch).values.item()
        out[split] = losses.mean()
    return out

In [None]:
print("".join(decode(model.generate(context=sg.Tensor(s2i["T"], dtype="int8"), max_new_tokens=500))))

T$W,-.s-gDLYHiHT3?jcAig cb3nXUXRC:$mK:?glmz,.yUTPAbJJ,:WXyFOG&F.riHTmLAQ.q!XvaBI'RUA?!Yyb-End,- EB'N?ug$$;T IRC
px$UXEBVBHuO3EES-lxsj$p- .Xp-a;&
nCxdK-hFJOxtuV$;upnma''3ipYQ'k!&
lrcimW,C
nW.H3p-A.COzlefKkNpEDQncTnf?Te,3lgG:
DzVQy-cARrwaq-DFLbZ oTPslHrG!GQPjiQjzbJSGIO-lJn$WflRrQDIUoMwQOp.:y:-sjTPoupgSJF.yWeNwexexeTSkARY&.:KvluVoMOzGTj&vYBZHY,MPcAiQ.;Ovuvvaxmejgx&O?xPvm$m?als $W&
VFhYwZ!jnBs:pGXj$qUXDFIhC$JoMhyO
n,-ljXHLmt?jJGq!3O
xZe?wJnmWHSaUsK:?dpEEoMwBBaDgpEWch,qE$QU?zGP'3:OWt!3&zPn pmW:H.YUT,t


In [11]:
optimizer = sg.opt.Adam(model, lr=LEARNING_RATE)

for i in range(MAX_ITERS):
    if i % VAL_INTERVAL == 0:
        eval_loss = estimate_loss()
        print(eval_loss)
    x_batch, y_batch = get_batches()
    loss = sg.ce_loss(model(x_batch), y_batch)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

{'train': np.float64(4.612938895225525), 'val': np.float64(4.626659710407257)}
{'train': np.float64(2.8013719630241396), 'val': np.float64(2.8241304922103883)}
{'train': np.float64(2.5444876492023467), 'val': np.float64(2.566481821537018)}
{'train': np.float64(2.493301900625229), 'val': np.float64(2.510030850172043)}
{'train': np.float64(2.476245114803314), 'val': np.float64(2.5081604492664336)}
{'train': np.float64(2.468981717824936), 'val': np.float64(2.4957055079936983)}
{'train': np.float64(2.470393862724304), 'val': np.float64(2.505697947740555)}
{'train': np.float64(2.4662603878974916), 'val': np.float64(2.4936233174800875)}
{'train': np.float64(2.4572763764858245), 'val': np.float64(2.4846458768844606)}
{'train': np.float64(2.4568127751350404), 'val': np.float64(2.4850762152671813)}


In [12]:
print(''.join(decode(model.generate(context=sg.Tensor(s2i["T"], dtype="int8"), max_new_tokens=500))))

Thyoldie wh gonk wathy fe d nillaridudamof; wille, cay ico lullatheane house V:
We mont wand d CIUKI G mmef be seds n t oo e,
WBy y ll bey s; n orde D m t g rmbucl!
And soreay bosw BE:
Whit wieised d ltot, hme re le:

SooullfisunQUCl ESind in; y tre:
Yes myo hthoras sy'tang t mys me, u mu:
Go s y merissalllcimy-whos wit no;
T:
The'd brybl, me, ton ure HABealoue; t he gh thacel whaure: ben mex plfe ctome e?
t erecrod nonomary ho sather, am:
T: Vealyoutharheshe sthef pe st scout thour wit
Havis win
