# Haiku Generator 
This neural net uses transformer to generate Haikus word by word.

In [1]:
import csv
import itertools
import re

with open('data/haikus.csv') as data:
    csv_data = [row for row in csv.DictReader(data)]

def row_to_lines(row):
    return ["\nHaiku:\n"] + list(itertools.chain.from_iterable([re.split(r'(\s+)', line) + ['\n'] for line in [row[str(col)] for col in range(3)]]))

tokens = list(itertools.chain.from_iterable([row_to_lines(row) for row in csv_data]))

len(csv_data), tokens[:10], len(tokens)

(143137,
 ['\nHaiku:\n', 'Memorial', ' ', 'Day', ' ', '--', '\n', 'a', ' ', 'shadow'],
 4071931)

In [2]:
context_size = 10
vocab = set(tokens)
vocab_size = len(vocab)
word_to_ix = {w:i for i, w in enumerate(vocab)}
ix_to_word = {i:w for i, w in enumerate(vocab)}
encode = lambda line: [word_to_ix[w] for w in list(line)]
decode = lambda ixs: ''.join([ix_to_word[ix] for ix in ixs])
vocab_size

71749

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Head(nn.Module):
    def __init__(self, embedding_dim, head_size, dropout=0.3, masked=True):
        super().__init__()
        self.masked = masked
        self.key = nn.Linear(embedding_dim, head_size, bias=False)
        self.query = nn.Linear(embedding_dim, head_size, bias=False)
        self.value = nn.Linear(embedding_dim, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1)
        if self.masked:
            wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)

        return wei @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, n_heads, dropout=0.3):
        super().__init__()
        head_size = embedding_dim // n_heads
        self.heads = nn.ModuleList([Head(embedding_dim, head_size) for _ in range(n_heads)])
        self.proj = nn.Linear(embedding_dim, embedding_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        return self.dropout(self.proj(x))

class FeedForward(nn.Module):
    def __init__(self, embedding_dim, dropout=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim),
            nn.ReLU(),
            nn.Linear(4 * embedding_dim, embedding_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, embedding_dim, n_heads, dropout=0.3):
        super().__init__()
        self.sa = MultiHeadAttention(embedding_dim, n_heads, dropout)
        self.feed_fwd = FeedForward(embedding_dim, dropout)
        self.ln1 = nn.Linear(embedding_dim, embedding_dim)
        self.ln2 = nn.Linear(embedding_dim, embedding_dim)
    
    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.feed_fwd(self.ln1(x))
        return x
    
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, classes, n_heads=4, dropout_prob=0.25, device='mps', masked=True):
        super().__init__()
        self.device = device
        self.tok_emb = nn.Embedding(vocab_size, embedding_dim)
        self.pos_emb = nn.Embedding(context_size, embedding_dim)
        self.blocks = nn.Sequential(
            Block(embedding_dim, n_heads, dropout_prob),
            Block(embedding_dim, n_heads, dropout_prob),
            Block(embedding_dim, n_heads, dropout_prob),
            Block(embedding_dim, n_heads, dropout_prob),
            Block(embedding_dim, n_heads, dropout_prob),
            Block(embedding_dim, n_heads, dropout_prob),
            nn.LayerNorm(embedding_dim)
        )

        self.fc = nn.Linear(embedding_dim, classes)
    
    def forward(self, x):
        _, T = x.size()
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(torch.arange(T, device=self.device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        out = self.fc(x)
        return out
    
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -context_size:]
            logits = self(idx_cond)
            logits = logits[:, -1, :] # becomes (B, C)
            probs = F.softmax(logits, dim=-1) # (B, C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [27]:
from torch.utils.data import TensorDataset, DataLoader

device = 'mps'
batch_size = 128
train_size = context_size * 1000 + 1
# control the size of words for training, it's too large for my m2 to process lol
train_tokens = tokens[:train_size]
val_tokens = tokens[train_size:2*train_size]

def get_dataloader(bow):
    n = len(bow) - 1
    ixs = [word_to_ix[word] for word in bow]
    x = torch.tensor(ixs[:n]).reshape((-1, context_size))
    y = torch.tensor(ixs[1:n+1]).reshape((-1, context_size))
    x = x.to(device)
    y = y.to(device)
    dataset = TensorDataset(x, y)
    return DataLoader(dataset, batch_size)

train_dataloader = get_dataloader(train_tokens)
val_dataloader = get_dataloader(val_tokens)

len(train_dataloader), len(val_dataloader),  len(train_tokens), len(val_tokens)

(8, 8, 10001, 10001)

In [28]:
device ='mps'
embedding_dim = 64
classes = vocab_size

model = TransformerModel(vocab_size, embedding_dim, classes, n_heads=4)
model = model.to(device)

In [29]:
epochs = 100
lr = 1e-3
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [30]:
import time
for k in range(epochs):
    model.train()
    train_loss = 0.0
    start_time = time.time()
    for xb, yb in train_dataloader:
        preds = model(xb)
        B, T, C = preds.shape
        loss = criterion(preds.view(B*T, C), yb.view(B*T))

        model.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    with torch.no_grad():
        model.eval()
        val_loss = 0.0
        for xb, yb in train_dataloader:
            preds = model(xb)
            B, T, C = preds.shape
            loss = criterion(preds.view(B*T, C), yb.view(B*T))
            val_loss += loss.item()
        
    if (k+1) % max(1, int(0.1*epochs)) == 0:
        end_time = time.time()
        torch.save(model, f'model.word.haiku.pt')
        print(f"({k+1}/{epochs}): train loss: {train_loss/len(train_dataloader):.4f}, val loss: {val_loss/len(val_dataloader):.4f} ({end_time - start_time:.2f}s)")
        start_time = time.time()

(10/100): train loss: 3.9485, val loss: 3.8031 (22.66s)
(20/100): train loss: 3.3173, val loss: 3.2526 (21.37s)
(30/100): train loss: 3.0124, val loss: 2.9165 (26.65s)
(40/100): train loss: 2.6971, val loss: 2.5723 (34.34s)
(50/100): train loss: 2.4342, val loss: 2.2658 (26.61s)
(60/100): train loss: 2.1254, val loss: 1.9427 (26.75s)
(70/100): train loss: 1.8812, val loss: 1.6497 (20.57s)
(80/100): train loss: 1.6384, val loss: 1.3749 (19.04s)
(90/100): train loss: 1.4393, val loss: 1.1159 (20.94s)
(100/100): train loss: 1.2225, val loss: 0.8941 (17.08s)


In [31]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=10)[0].tolist()))

 trains rejection Summer
too
in


In [44]:
def generate(model, idx, max_new_tokens):
    def inner(idx):
        idx_cond = idx[:, -context_size:]
        logits = model(idx_cond)
        logits = logits[:, -1, :] # becomes (B, C)
        probs = F.softmax(logits, dim=-1) # (B, C)
        idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
        idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
    for _ in range(max_new_tokens):
        idx = inner(idx)
        yield idx

In [96]:
prompt = encode("rainy day")
context = torch.tensor(prompt, device=device).view(1, -1)
for i, c in enumerate(generate(model, context, max_new_tokens=50)):
    for c in decode([c[0][i].tolist()]):
        time.sleep(.05)
        print(c, end='', flush=True)

rainy day him
your exit nightfall the shade
now off wisdom

Haiku:
summer's rising
my moon
my press

Haiku:
february steamed
just pepper
