In [61]:
import torch

# Data preparation

## Load raw text

In [62]:
with open('../data/shakespeare.txt', 'r') as f:
    text = f.read()

In [63]:
print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


## Tokenization

In [64]:
tokens = sorted(list(set(text)))
''.join(tokens)

"\n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [65]:
stoi = { ch: i for i, ch in enumerate(tokens) }
itos = { i: ch for i, ch in enumerate(tokens) }

In [66]:
def encode(text):
    return torch.tensor([stoi[ch] for ch in text], dtype=torch.long)

def decode(tensor):
    return ''.join([itos[i.item()] for i in tensor])

In [67]:
encode('testi')

tensor([58, 43, 57, 58, 47])

In [68]:
decode(encode('testi'))

'testi'

In [69]:
data = encode(text)
data[:10]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])

In [70]:
decode(data[:10])

'First Citi'

In [71]:
split = int(0.8 * len(data))
train = data[:split]
val = data[split:]

print(train.shape, val.shape)

torch.Size([892315]) torch.Size([223079])


# Dataloader

Getting a single chunk of data:

In [72]:
block_size = 8
batch_size = 4

In [73]:
offset = 10 # arbitrary offset for demonstration

x = train[offset:offset+block_size]
y = train[offset+1:offset+block_size+1]

print(x)
print(y)

tensor([64, 43, 52, 10,  0, 14, 43, 44])
tensor([43, 52, 10,  0, 14, 43, 44, 53])


We generate random offsets into the training data:

In [74]:
offsets = torch.randint(0, split-block_size, (batch_size,))
offsets

tensor([225210, 561086, 254928, 458402])

And then generate a block-size x and a shifted-by-1 block-size y for each offset, stacking those tensor into a single x and y tensor:

In [75]:
print(torch.stack([data[offset : offset+block_size] for offset in offsets]))
print(torch.stack([data[offset+1 : offset+block_size+1] for offset in offsets]))

tensor([[46, 47, 51,  1, 44, 56, 53, 51],
        [47, 52, 49,  1, 47, 58,  1, 53],
        [42, 43, 44, 39, 41, 43, 42,  1],
        [37,  1, 15, 13, 28, 33, 24, 17]])
tensor([[47, 51,  1, 44, 56, 53, 51,  1],
        [52, 49,  1, 47, 58,  1, 53, 44],
        [43, 44, 39, 41, 43, 42,  1, 61],
        [ 1, 15, 13, 28, 33, 24, 17, 32]])


In [76]:
def get_batch(data, block_size=block_size, batch_size=batch_size):
    offsets = torch.randint(0, split-block_size, (batch_size,))

    xb = torch.stack([data[offset : offset+block_size] for offset in offsets])
    yb = torch.stack([data[offset+1 : offset+block_size+1] for offset in offsets])

    return xb, yb

In [77]:
get_batch(train)

(tensor([[47, 45, 46, 58,  6,  1, 39, 50],
         [58,  1, 57, 53,  1, 58, 43, 52],
         [ 1, 58, 46, 39, 58,  1, 63, 53],
         [43, 43, 57,  1, 52, 53, 56,  1]]),
 tensor([[45, 46, 58,  6,  1, 39, 50, 50],
         [ 1, 57, 53,  1, 58, 43, 52, 42],
         [58, 46, 39, 58,  1, 63, 53, 59],
         [43, 57,  1, 52, 53, 56,  1, 46]]))

# Model

In [112]:
class BigramModel(torch.nn.Module):
    def __init__(self, vocab_size):
        super().__init__()

        self.embedding = torch.nn.Embedding(vocab_size, vocab_size)
        self.vocab_size = vocab_size

    def forward(self, x, targets=None):
        x = self.embedding(x)

        if targets is None:
            return x, None

        loss = torch.nn.functional.cross_entropy(x.view(-1, self.vocab_size), targets.view(-1))

        return x, loss

    def generate_text(self, x, steps=500):
        for _ in range(steps):
            logits, _ = self(x)
            last_logits = logits[:,-1,:]
            probs = torch.functional.F.softmax(last_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            x = torch.cat([x, next_token], dim=1)

        return x

    @torch.no_grad()
    def generate(self, x, steps=100):
        for _ in range(steps):
            x, _ = self(x)
            x = x[-1].argmax()
            yield x

## Forward pass

In [113]:
vocab_size = len(tokens)
vocab_size

65

In [114]:
model = BigramModel(vocab_size)

In [115]:
xb, yb = get_batch(train)
xb.shape # (batch_size, block_size)

torch.Size([4, 8])

In [116]:
logits, loss = model(xb, yb)
logits.shape # (batch_size, block_size, vocab_size)

torch.Size([4, 8, 65])

In [117]:
loss

tensor(4.7298, grad_fn=<NllLossBackward0>)

## Generation

In [118]:
x = torch.zeros((1,1), dtype=torch.long) # BxT
x = model.generate_text(x)
x

tensor([[ 0, 38, 32, 55, 34, 42,  3, 15, 11, 32,  6,  4, 49, 54, 11, 47, 61, 46,
         28, 53,  3, 26, 60, 20, 13, 57,  8, 19, 48, 59,  7, 53, 58,  6, 28, 56,
         44, 13, 48, 20, 33, 27, 30, 17, 56, 10, 35, 61, 10, 64,  5, 15, 36, 47,
         41, 49, 57, 16,  8,  3, 23, 15,  6, 57, 18,  3, 18, 16, 14, 37, 41, 58,
         48, 20, 55, 15, 20, 56,  7, 53,  3,  8, 58, 20, 13,  9, 55, 64,  4, 49,
         39, 45,  2, 31, 21, 55, 24, 30,  3, 10,  2, 58, 20, 21, 31,  1,  6, 58,
         34, 32,  3, 11, 18, 47, 38, 60, 56, 28,  8, 13, 46, 16, 37, 32, 16, 17,
          0, 34, 42, 43, 25, 43, 12, 46, 42, 30, 51, 54, 45, 52,  3, 26, 22, 20,
          5, 64, 13,  8, 12, 31, 32, 42,  1,  0, 60, 48, 56, 50,  7, 15, 17,  5,
         41, 43,  4, 18, 36,  6, 20, 63,  5, 26, 27, 30, 61, 27, 43, 41, 58, 29,
         33, 59,  2, 31, 32, 64, 21,  2, 34, 51, 54,  2, 19, 12, 31, 15, 59, 53,
         15,  2,  1, 14, 30,  1, 24, 49, 18, 51, 10, 14, 11,  4, 17, 53, 50, 36,
         34, 13,  4, 12, 31,

In [119]:
print(decode(x[0]))


ZTqVd$C;T,&kp;iwhPo$NvHAs.Gju-ot,PrfAjHUOREr:Ww:z'CXicksD.$KC,sF$FDBYctjHqCHr-o$.tHA3qz&kag!SIqLR$:!tHIS ,tVT$;FiZvrP.AhDYTDE
VdeMe?hdRmpgn$NJH'zA.?STd 
vjrl-CE'ce&FX,Hy'NORwOectQUu!STzI!Vmp!G?SCuoC! BR LkFm:B;&EolXVA&?SQnRaIU!E;!&E-  &jWhB?STZWFXiipssVtNIRsjCHgu,rW'MwHonezon$owQSggW?SbdiTCxwE
GYm?DoFaw3Mg qe?iu -yLCkY.PRSTjHC UtMcPYYz:
mE?DYmn.BMN
w:STy'JPH
JfUaJB
3Y;wrWdyE?E&EkerQ&kknZTVZM'Eyp&t,xSVf3Lt:cx-j zxwD$Im$TYOcUCs:zPeWEMfrVlYod'MZ?cCXE?mh.AQbdB

SrericAhNFSCCEo$MO'CEP;ACZTB?-XBYwOWP:
