In [16]:
import torch

In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cuda device


In [18]:
data_file_name = 'tiny_shakespeare'
data_file_path = f'data/{data_file_name}.txt'
with open(data_file_path, 'r', encoding = 'utf-8') as f:
    text = f.read()
len(text)

1115394

### Vocab

In [19]:
unique_symbols = sorted(list(set(text)))
vocab_size = len(unique_symbols)

print(''.join(unique_symbols))
print(f'{vocab_size=}')


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab_size=65


### Encoder/decoder

In [20]:
stoi = {s: i for i, s in enumerate(unique_symbols)}
itos = {i: s for s, i in stoi.items()}

encode = lambda s: [stoi[ch] for ch in s]  # symbols to tokens
decode = lambda l: ''.join([itos[i] for i in l])  # tokens to symbols

### Save vocab

In [21]:
# save decoder
import json

vocab_file_path = f'vocabs/{data_file_name}_vocab.json'
with open(vocab_file_path, 'w', encoding ='utf8') as f: 
    json.dump(itos, f,  indent=4, ensure_ascii = False)

### Train/val split

In [22]:
data = torch.tensor(encode(text), dtype = torch.long)
n = int(0.9 * len(data))

train_data = data[:n]
val_data = data[n:]

print(train_data.shape, val_data.shape)

torch.Size([1003854]) torch.Size([111540])


### Build a dataset

In [23]:
import transformer.transformer
import transformer.config

config = transformer.config.config_default
config.vocab_size = vocab_size
config.device = device

In [None]:
torch.manual_seed(1337)
batch_size = 128

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - config.block_size, (batch_size,))
    x_batch = torch.stack([data[i:i+config.block_size] for i in ix])
    y_batch = torch.stack([data[i+1:i+config.block_size+1] for i in ix])
    x_batch, y_batch = x_batch.to(device), y_batch.to(device)

    return x_batch, y_batch

## Model

In [25]:
m = transformer.transformer.Decoder(config).to(device)

In [26]:
num_eval_batches = 100

@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(num_eval_batches)
        for n_batch in range(num_eval_batches):
            x_batch, y_batch = get_batch(split)
            _, loss = m(x_batch, y_batch)
            losses[n_batch] = loss.item()
        out[split] = losses.mean().item()
    m.train()
    return out

In [27]:
max_iters = 6_000
scheduler_steps = [4_000, 5_000, 5_700]

optimizer = torch.optim.AdamW(m.parameters(), lr = 1e-3)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 
                                                 milestones = scheduler_steps, 
                                                 gamma = 0.1)

In [28]:
for step in range(max_iters):
    x_batch, y_batch = get_batch('train')
    logits, loss = m(x_batch, y_batch)

    if step % 1000 == 0 or step == max_iters - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}, lr {scheduler.get_last_lr()[0]}")

    optimizer.zero_grad(set_to_none = True)
    loss.backward()
    optimizer.step()  
    scheduler.step()

# tiny shakespeare:
# added just blocks: train loss 2.1716, val loss 2.2226
# added skip connections: train loss 1.9001, val loss 2.0237
# added layer norm: train loss 1.8880, val loss 2.0194
# after model was scaled up: train loss 1.2350, val loss 1.4911

step 0: train loss 4.3826, val loss 4.3819, lr 0.001
step 1000: train loss 1.5468, val loss 1.7271, lr 0.001
step 2000: train loss 1.3874, val loss 1.5989, lr 0.001
step 3000: train loss 1.3218, val loss 1.5566, lr 0.001
step 4000: train loss 1.2793, val loss 1.5202, lr 0.0001
step 5000: train loss 1.2431, val loss 1.5052, lr 1e-05
step 5999: train loss 1.2350, val loss 1.4911, lr 1.0000000000000002e-06


In [29]:
start_token = torch.zeros((1, 1), dtype= torch.int, device = device)
gen_text = m.generate(start_token, 500) 
print(decode(gen_text[0].tolist()))


First, thy case.

ANTIGONUS:
I know 'tis well?
'Tis good Montague doef Northumberland pity?'
And go,
And fear'st sharp him atternony; you are all.
And to confusions, my sentence worse to thence;
On some to a voice must blushing to do be patient,
To Both her band up Burgunding but frottunes
Tell for one several thandsmen how you have
The gods you seen?

WARWICK:
O he break, Mistoner, is bounded harb'd put to put the.

FoRD:

ROMEO:
I will not I do remember in stand functly
The creek morn of good,


In [30]:
save_model_path = f'models/{data_file_name}_model.pth'
torch.save(m.state_dict(), save_model_path)