This code is based on https://www.youtube.com/watch?v=0Ag83EhYD7k&t=1s.

In [242]:
import math, os
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F

In [243]:
class ChannelMix(nn.Module):
    def __init__(self, layer_id, n_layer, n_embed):
        super().__init__()
        self.layer_id = layer_id
        
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) # How does this work to induce time shift?
        
        with torch.no_grad():
            ratio_1_to_almost0 = 1.0 - layer_id/n_layer # What does this ratio do?
            x = torch.ones(1,1,n_embed)
            for i in range(n_embed):
                x[0, 0, i] = i/n_embed
            
            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) # This initializes the trainable value mu_k (eq 17)
            self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) # This initializes the trainable value mu_r (eq 16)
            
        hidden_size = 4*n_embed
        self.key = nn.Linear(n_embed, hidden_size, bias=False) # This function left-multiplies by W_k (eq 17)
        self.receptance = nn.Linear(n_embed, n_embed, bias=False) # This function left-multiplies by W_r (eq 16)
        
        self.value = nn.Linear(hidden_size, n_embed, bias=False) # This function left-multiplies by W_v (eq 18)
    
    def forward(self, x):
        xx = self.time_shift(x)
        xr = x * self.time_mix_r + (1-self.time_mix_r) * xx # eq 16
        xk = x * self.time_mix_k + (1-self.time_mix_k) * xx # eq 17
        
        k = self.key(xk) # eq 17 cont.
        r = self.receptance(xr) # eq 16 cont.
        
        k = torch.square(torch.relu(k)) # eq 18
        rkv = torch.sigmoid(r) * self.value(k) # eq 18 cont.
        
        return rkv

In [244]:
class TimeMix(nn.Module):
    def __init__(self, layer_id, n_layer, n_embed):
        super().__init__()
        self.layer_id = layer_id
        
        attn_sz = n_embed
        with torch.no_grad():
            ratio_1_to_almost0 = 1.0 - layer_id/n_layer
            ratio_0_to_1 = layer_id / (n_layer - 1)
            
            decay_speed = torch.ones(attn_sz)
            for h in range(attn_sz):
                decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
                
            self.time_decay = nn.Parameter(decay_speed)
            
            zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5)
            # In the paper, this variable is u... what does it do?
            self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
            
            x = torch.ones(1,1, n_embed)
            for i in range(n_embed):
                x[0,0,i] = i/n_embed
            
            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) # This initializes the trainable value mu_k (eq 12)
            self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) # This initializes the trainable value mu_v (eq 13)
            self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0)) # This initializes the trainable value mu_r (eq 11)
            
            # aa, bb, pp, xx are all involved in the RNN formulation of time-mixing (see Appendix B)
            self.aa = nn.Parameter(torch.ones(1,1,attn_sz))
            self.bb = nn.Parameter(torch.ones(1,1,attn_sz))
            pp = torch.ones(1,1,attn_sz)
            pp = pp * -1e30
            self.pp = nn.Parameter(pp)
            self.xx = nn.Parameter(torch.ones(1,1,attn_sz))
            
        hidden_size = attn_sz
        self.key = nn.Linear(n_embed, attn_sz, bias=False)  # This function left-multiplies by W_k (eq 12)
        self.receptance = nn.Linear(n_embed, attn_sz, bias=False)  # This function left-multiplies by W_r (eq 11)
        
        self.value = nn.Linear(hidden_size, attn_sz, bias=False)  # This function left-multiplies by W_v (eq 13)
        self.output = nn.Linear(attn_sz, n_embed, bias=False)  # This function left-multiplies by W_o (eq 15)
    
    def forward(self, x):
        
        xx = self.xx
        
        xk = x * self.time_mix_k + (1-self.time_mix_k) * xx
        xv = x * self.time_mix_v + (1-self.time_mix_v) * xx
        xr = x * self.time_mix_r + (1-self.time_mix_r) * xx
        
        k = self.key(xk)
        v = self.value(xv)
        r = self.receptance(xr)
        
        r = torch.sigmoid(r)
        
        # Calculate the difference in size along the non-singleton dimension
        diff = k.shape[1] - self.aa.shape[1]
        
        b,t,c = x.shape
        aa = torch.nn.functional.pad(self.aa, (0,0,0,diff,0,0))
        bb = torch.nn.functional.pad(self.bb, (0,0,0,diff,0,0))
        pp = torch.nn.functional.pad(self.pp, (0,0,0,diff,0,0))
        
        ww = self.time_first + k # eq 25 (u + k_t)
        
        qq = torch.maximum(pp, ww) # eq 25
        e1 = torch.exp(pp - qq)
        e2 = torch.exp(ww - qq)
        
        a = e1 * aa + e2 * v # eq 26 
        b = e1 * bb + e2 # eq 27
        wkv = a / b # eq 28
        
        ww = pp + self.time_decay # eq 29... why + not -?
        qq = torch.maximum(ww, k) # eq 29
        e1 = torch.exp(ww - qq)
        e2 = torch.exp(k - qq)
        
        with torch.no_grad():
            xx = nn.Parameter(x) # should this be 'self.xx = ...'?
            self.aa = nn.Parameter(e1 * aa + e2 * v) # eq 30
            self.bb = nn.Parameter(e1 * bb + e2) # eq 31
            self.pp = nn.Parameter(qq) # eq 32
            
        return self.output(r * wkv)

In [245]:
class Block(nn.Module):
    def __init__(self, layer_id, n_layer, n_embed):
        super().__init__()
        self.layer_id = layer_id
        
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
        
        if self.layer_id == 0:
            self.ln0 = nn.LayerNorm(n_embed)
        
        if self.layer_id == 0:
            self.ffnPre = ChannelMix(0, n_layer, n_embed)
        else:
            self.att = TimeMix(layer_id, n_layer, n_embed)
        
        self.ffn = ChannelMix(layer_id, n_layer, n_embed)
        
    def forward(self, x):
        if self.layer_id == 0:
            x = self.ln0(x)
        if self.layer_id == 0:
            x = x + self.ffnPre(self.ln1(x))
        else:
            x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x

In [246]:
class RWKV(nn.Module):
    def __init__(self, n_layer, vocab_size, n_embed, ctx_len):
        super().__init__()
        self.step = 0
        self.ctx_len = ctx_len
        self.emb = nn.Embedding(vocab_size, n_embed)
        
        self.blocks = nn.Sequential(*[Block(i, n_layer, n_embed)
                                     for i in range(n_layer)])
        self.ln_out = nn.LayerNorm(n_embed)
        # head is the output layer that redimensions to the vocab size to predict an output word
        self.head = nn.Linear(n_embed, vocab_size, bias=False)
        
    def forward(self, idx, targets=None):
        idx = idx.to(self.emb.weight.device)
        
        self.step += 1
        
        B, T = idx.size()
        assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."
        
        x = self.emb(idx)
        x = self.blocks(x)
        x = self.ln_out(x)
        
        x = self.head(x)
        
        loss = None
        if targets is not None:
            loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
        x = torch.mean(x, dim=0, keepdim=True)
        return x,loss
    
    def generate(self, idx, max_new_tokes):
        for _ in range(max_new_tokes):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, idx_next), dim = 1)
        return idx

In [247]:
# hyperparameters
batch_size = 16 # how many independent sequences are processed in parallel
block_size = 32 # what is the maximum context length for predictions? # Does this make sense for RWKV???
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embed = 64
n_head = 4
n_layer = 4
dropout = 0.0
# --------

Testing the model.

In [248]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [249]:
len(text)

1115394

In [250]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [251]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [252]:
print(''.join(chars)) # print the vocabulary
vocab_size


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


65

In [253]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }

In [254]:
encode = lambda s : [stoi[c] for c in s]
decode = lambda l : "".join([itos[x] for x in l])

In [255]:
encode("what's up!")

[61, 46, 39, 58, 5, 57, 1, 59, 54, 2]

In [256]:
decode([61, 46, 39, 58, 5, 57, 1, 59, 54, 2])

"what's up!"

In [257]:
import torch

In [258]:
data = torch.tensor(encode(text), dtype = torch.long)

In [259]:
data.shape, type(data)

(torch.Size([1115394]), torch.Tensor)

In [260]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [261]:
block_size = 8

In [262]:
x = train_data[:block_size]

In [263]:
y = train_data[1:block_size+1] # sliding window for training

In [264]:
x,y

(tensor([18, 47, 56, 57, 58,  1, 15, 47]),
 tensor([47, 56, 57, 58,  1, 15, 47, 58]))

In [265]:
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print("ctx ", context, "target", target)

ctx  tensor([18]) target tensor(47)
ctx  tensor([18, 47]) target tensor(56)
ctx  tensor([18, 47, 56]) target tensor(57)
ctx  tensor([18, 47, 56, 57]) target tensor(58)
ctx  tensor([18, 47, 56, 57, 58]) target tensor(1)
ctx  tensor([18, 47, 56, 57, 58,  1]) target tensor(15)
ctx  tensor([18, 47, 56, 57, 58,  1, 15]) target tensor(47)
ctx  tensor([18, 47, 56, 57, 58,  1, 15, 47]) target tensor(58)


In [266]:
torch.manual_seed(1337)

<torch._C.Generator at 0x28a5a7f46b0>

In [267]:
batch_size = 4

In [268]:
# data loading
def get_batch(split):
    # generate a batch of input data x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # Train on every subsequence within the input sequence
    return x,y

In [269]:
xb, yb = get_batch('train')

In [270]:
print(xb)
xb.shape == yb.shape

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])


True

In [271]:
# Print all the training examples in this batch.
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b][:t+1]
        target = yb[b][t]
        print(context, "-----", target)

tensor([24]) ----- tensor(43)
tensor([24, 43]) ----- tensor(58)
tensor([24, 43, 58]) ----- tensor(5)
tensor([24, 43, 58,  5]) ----- tensor(57)
tensor([24, 43, 58,  5, 57]) ----- tensor(1)
tensor([24, 43, 58,  5, 57,  1]) ----- tensor(46)
tensor([24, 43, 58,  5, 57,  1, 46]) ----- tensor(43)
tensor([24, 43, 58,  5, 57,  1, 46, 43]) ----- tensor(39)
tensor([44]) ----- tensor(53)
tensor([44, 53]) ----- tensor(56)
tensor([44, 53, 56]) ----- tensor(1)
tensor([44, 53, 56,  1]) ----- tensor(58)
tensor([44, 53, 56,  1, 58]) ----- tensor(46)
tensor([44, 53, 56,  1, 58, 46]) ----- tensor(39)
tensor([44, 53, 56,  1, 58, 46, 39]) ----- tensor(58)
tensor([44, 53, 56,  1, 58, 46, 39, 58]) ----- tensor(1)
tensor([52]) ----- tensor(58)
tensor([52, 58]) ----- tensor(1)
tensor([52, 58,  1]) ----- tensor(58)
tensor([52, 58,  1, 58]) ----- tensor(46)
tensor([52, 58,  1, 58, 46]) ----- tensor(39)
tensor([52, 58,  1, 58, 46, 39]) ----- tensor(58)
tensor([52, 58,  1, 58, 46, 39, 58]) ----- tensor(1)
tensor([

In [272]:
model = RWKV(n_layer, vocab_size, n_embed, block_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [273]:
# Show all the parameters
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name)

emb.weight
blocks.0.ln1.weight
blocks.0.ln1.bias
blocks.0.ln2.weight
blocks.0.ln2.bias
blocks.0.ln0.weight
blocks.0.ln0.bias
blocks.0.ffnPre.time_mix_k
blocks.0.ffnPre.time_mix_r
blocks.0.ffnPre.key.weight
blocks.0.ffnPre.receptance.weight
blocks.0.ffnPre.value.weight
blocks.0.ffn.time_mix_k
blocks.0.ffn.time_mix_r
blocks.0.ffn.key.weight
blocks.0.ffn.receptance.weight
blocks.0.ffn.value.weight
blocks.1.ln1.weight
blocks.1.ln1.bias
blocks.1.ln2.weight
blocks.1.ln2.bias
blocks.1.att.time_decay
blocks.1.att.time_first
blocks.1.att.time_mix_k
blocks.1.att.time_mix_v
blocks.1.att.time_mix_r
blocks.1.att.aa
blocks.1.att.bb
blocks.1.att.pp
blocks.1.att.xx
blocks.1.att.key.weight
blocks.1.att.receptance.weight
blocks.1.att.value.weight
blocks.1.att.output.weight
blocks.1.ffn.time_mix_k
blocks.1.ffn.time_mix_r
blocks.1.ffn.key.weight
blocks.1.ffn.receptance.weight
blocks.1.ffn.value.weight
blocks.2.ln1.weight
blocks.2.ln1.bias
blocks.2.ln2.weight
blocks.2.ln2.bias
blocks.2.att.time_decay
block

In [274]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [275]:
for iter in range(max_iters):
    
    # Evalute the loss on train and val sets once every several iterations
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f'step {iter}: train loss {losses["train"]:.4f}, val loss {losses["val"]:.4f}')
    
    # sample a batch of data
    xb, yb = get_batch('train')
    
    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.3031, val loss 4.3177
step 100: train loss 2.7708, val loss 2.7559
step 200: train loss 2.5842, val loss 2.5392
step 300: train loss 2.5145, val loss 2.5212
step 400: train loss 2.4379, val loss 2.4470
step 500: train loss 2.3493, val loss 2.3397
step 600: train loss 2.3284, val loss 2.3304
step 700: train loss 2.2880, val loss 2.2946
step 800: train loss 2.2990, val loss 2.2783
step 900: train loss 2.2725, val loss 2.2741
step 1000: train loss 2.2792, val loss 2.2668
step 1100: train loss 2.2448, val loss 2.2322
step 1200: train loss 2.2273, val loss 2.2267
step 1300: train loss 2.1578, val loss 2.2308
step 1400: train loss 2.1760, val loss 2.2264
step 1500: train loss 2.1870, val loss 2.2046
step 1600: train loss 2.1164, val loss 2.1818
step 1700: train loss 2.1855, val loss 2.1814
step 1800: train loss 2.1226, val loss 2.1617
step 1900: train loss 2.1552, val loss 2.1607
step 2000: train loss 2.0931, val loss 2.1866
step 2100: train loss 2.1319, val loss 2.1464


In [277]:
context = torch.zeros((1,1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokes=1000)[0].tolist()))


Tha, wation his iuglat rewis?

JULIY:
No do come will. i' no not moun you gar; who go is shounberse?

Clauton loblestgian is theing chown anyst beear the Sog,
Come of dishallowon'd what in play I hupose to for mustroye some ewhich.

Shithern, in nou upothen is, po of come otriwarest disperue,

Mushould for of with sir,
Thou my not: he goo, sha seecome whick, my don not one who heave for ohe acts would onour withsvain town, ba have which twer,
What don I will lought ou host we one is at is reso the wordon what sone, seer more with and givey, a ted to geout this exRie i as was him say
in't is in worddy tlies, abribeit
if she comet.

AUBE
Not quie.
Thee with is fighty
So no gueibs wolp 'people care's heried shall, belagt:
If of sen fall will at but sold mine somear:
Now of a hath: cready, wous try's glacher;
Bead.
By shy to reard.
Ord 

MENE:
Go neath'teart dear
he sage vooy Pome.
Conet who I should welverief it your sourtue the hold chulkiname and Rome's act;
But not this is so hiveeds 

In [279]:
chars = sorted(list(set(text.split(' '))))
vocab_size = len(chars)

In [288]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = { i:ch for i, ch in enumerate(chars) }
encode = lambda s : [stoi[c] for c in s]
decode = lambda l : " ".join([itos[x] for x in l])

In [289]:
data = torch.tensor(encode(text.split(' ')), dtype = torch.long)

In [290]:
decode(encode(text.split(' ')[:10]))

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst'

In [291]:
text.split(' ')[: 10]

['First',
 'Citizen:\nBefore',
 'we',
 'proceed',
 'any',
 'further,',
 'hear',
 'me',
 'speak.\n\nAll:\nSpeak,',
 'speak.\n\nFirst']

In [292]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [293]:
# data loading
def get_batch(split):
    # generate a batch of input data x and targets y
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix]) # Train on every subsequence within the input sequence
    return x,y

In [294]:
get_batch('train')[0].shape

torch.Size([4, 8])

In [295]:
block_size, n_embed, learning_rate, batch_size

(8, 64, 0.001, 4)

In [296]:
block_size = 8

In [302]:
model = RWKV(5, vocab_size, 256, block_size)
optimizer = torch.optim.AdamW(model.parameters(),lr = 3e-4)

In [303]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [304]:
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-4)
batch_size = 4
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 10.7898, val loss 10.8070
step 100: train loss 8.5281, val loss 8.4840
step 200: train loss 8.3793, val loss 8.3692
step 300: train loss 8.3006, val loss 8.3293
step 400: train loss 8.2460, val loss 8.3631
step 500: train loss 8.1796, val loss 8.3537
step 600: train loss 8.1512, val loss 8.2259
step 700: train loss 8.1313, val loss 8.2875
step 800: train loss 8.1293, val loss 8.3172
step 900: train loss 8.0813, val loss 8.2449
step 1000: train loss 8.0372, val loss 8.1785
step 1100: train loss 7.9527, val loss 8.2137
step 1200: train loss 7.8982, val loss 8.3145
step 1300: train loss 7.9780, val loss 8.1859
step 1400: train loss 7.9334, val loss 8.2834
step 1500: train loss 7.9432, val loss 8.2515
step 1600: train loss 7.8950, val loss 8.3395
step 1700: train loss 7.9012, val loss 8.3019
step 1800: train loss 7.8195, val loss 8.2493
step 1900: train loss 7.7865, val loss 8.2996
step 2000: train loss 7.8558, val loss 8.3906
step 2100: train loss 7.7159, val loss 8.260

In [305]:
context = torch.tensor([encode("O Romeo, Romeo,".split(' '))], dtype = torch.long)
print(decode(model.generate(context, max_new_tokes=100)[0].tolist()))

O Romeo, Romeo, of his mother,
He dried thence.

Second sickness the king,
I'll of your friends':
God and misusest.

KING government,
Shall, to Vienna.

POMPEY:
Does law disdains fair Lodowick?

LUCIO:
My proclaim'd: doth not answer, will not hap? me to meddle with within! maw, good:
the fall, the crown.

YORK:
'Twas then divine men?
If as humane
And one: not to him: VINCENTIO:
There's partake must be the matter:--Nurse, No, loves OF get entreat no baited
With but same it Paulina,
Make a take
From OVERDONE:
What's well; aid: begot the earth to offends suspicion! upon this crown and both do; grows to her I side;
The I can Down, learn to take promise keep to I:
No, heart
To OF last,
Definitively and condition,


In [309]:
torch.save(model, "rwkw_shakespeare_words_model.pt")