In [1]:
device = 'mps'

In [2]:
from typing import List
import tqdm.notebook as tqdm

In [3]:
with open("data/tiny-shakespeare.txt") as f:
    corpus = f.read()

In [4]:
split = 0.8
train_text = corpus[:int(len(corpus)*split)]
test_text = corpus[int(len(corpus)*split):]

print(f"Training Size: {len(train_text):,} chars")
print(f"Test Size: {len(test_text):,} chars")

Training Size: 892,314 chars
Test Size: 223,079 chars


In [5]:
import transformers
import torch

In [6]:
config = transformers.LlamaConfig(
    vocab_size=256,
    hidden_size=128,
    intermediate_size=256,
    num_hidden_layers=4,
    num_attention_heads=4,
    max_position_embeddings=128,
    attention_dropout=0.01
)

In [7]:
model = transformers.LlamaForCausalLM(config).to(device)
print(f"Model has {model.num_parameters():,} parameters")
num_train_steps = 0

Model has 722,048 parameters


In [8]:
from typing import List

In [9]:
def encode(list: List[str]) -> torch.Tensor:
    result = []
    for string in list:
        result.append([ord(c) for c in string])
    return torch.tensor(result)

def decode(arr: torch.Tensor) -> str:
    arr = arr.tolist()
    result = []
    for string in arr:
        result.append("".join([chr(c) for c in string]))
    return result

def get_batch(batch_size, seq_length, use_train=True):
    if use_train:
        data = train_text
    else:
        data = test_text
    
    start_idx = torch.randint(0, len(data) - seq_length, (batch_size,))

    batch = []
    for i in start_idx:
        batch.append(encode([data[i:i+seq_length]]))
        
    return torch.stack(batch).view(batch_size, seq_length)

In [10]:
criterion = torch.nn.CrossEntropyLoss()

In [25]:
# Validation
@torch.no_grad()
def evaluate():
    model.eval()
    batch = get_batch(64, 128, use_train=False).to(device)
    x = batch[:,:-1]
    y = batch[:,1:]

    output = model(x, labels=y)

    return output.loss.item()

evaluate()

2.275230884552002

In [22]:
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4)

In [29]:
model.train()
pbar = tqdm.tnrange(10_000)
for step in pbar:
    # Encode the sequence
    batch = get_batch(64, 128).to(device)
    x = batch[:,:-1]
    y = batch[:,1:]
        
    # Forward pass
    output = model(x, labels=y)

    loss = output.loss
        
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # scheduler.step()

    if step % 100 == 0:
        val_loss = evaluate()

    num_train_steps += 1
        
    pbar.set_postfix_str(f"Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")
    pbar.set_description_str(f"num_train_steps: {num_train_steps:,}")

  0%|          | 0/10000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [28]:
print(decode(model.generate(encode(["First Citizen:"]).to(device),max_new_tokens=100))[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


First Citizen:T
o oret o orbde o orbde o o o o ore
hsml o o o o ote o o o o o o o ote
o o o o o ote orbde o o o o 


In [14]:
model.eval()
print(decode(model.generate(encode(["First Citizen:"]).to(device), 256, temperature=1.0, top_k=None))[0])

AttributeError: 'int' object has no attribute 'update'