In [2]:
import math, os
import numpy as np
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F

In [3]:
class ChannelMix(nn.Module):
    def __init__(self, layer_id, n_layer, n_embed):
        super().__init__()
        self.layer_id = layer_id
        
        self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
        
        with torch.no_grad():
            ratio_1_to_almost0 = 1.0 - layer_id/n_layer
            x = torch.ones(1,1, n_embed)
            for i in range(n_embed):
                x[0, 0, i] = i/n_embed
            
            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
            self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
        
        hidden_size = 4*n_embed
        self.key = nn.Linear(n_embed, hidden_size, bias=False)
        self.receptance = nn.Linear(n_embed, n_embed, bias=False)
        
        self.value = nn.Linear(hidden_size, n_embed, bias=False)
        
    def forward(self, x):
        xx = self.time_shift(x)
        xk = x * self.time_mix_k + (1-self.time_mix_k) * xx

        xr = x * self.time_mix_r + (1-self.time_mix_r) * xx
        
        k = self.key(xk)
        k = torch.square(torch.relu(k))
        
        kv = self.value(k)
        
        
        
        rkv = torch.sigmoid(self.receptance(xr)) * kv
        
        return rkv
        

In [4]:
class TimeMix(nn.Module):
    def __init__(self, layer_id, n_layer, n_embed):
        super().__init__()
        self.layer_id = layer_id
        
        attn_sz = n_embed
        with torch.no_grad():
            ratio_1_to_almost0 = 1.0 - layer_id/n_layer
            ratio_0_to_1 = layer_id / (n_layer - 1)
            
            decay_speed = torch.ones(attn_sz)
            for h in range(attn_sz):
                decay_speed[h] = -5 + 8 * (h / (attn_sz-1)) ** (0.7 + 1.3 * ratio_0_to_1)
                
            self.time_decay = nn.Parameter(decay_speed)
            
            zigzag = (torch.tensor([(i+1)%3 - 1 for i in range(attn_sz)]) * 0.5)
            self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
                
            
            x = torch.ones(1,1, n_embed)
            for i in range(n_embed):
                x[0, 0, i] = i/n_embed
            
            self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
            self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) +0.3 * ratio_0_to_1)
            self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
        
            self.aa = nn.Parameter(torch.ones(1,1,attn_sz))
            self.bb = nn.Parameter(torch.ones(1,1,attn_sz))
            pp = torch.ones(1,1,attn_sz)
            pp = pp * -1e30
            self.pp = nn.Parameter(pp)
            self.xx = nn.Parameter(torch.ones(1,1,attn_sz))
        
        hidden_size = attn_sz
        self.key = nn.Linear(n_embed, attn_sz, bias=False)
        self.receptance = nn.Linear(n_embed, attn_sz, bias=False)
        
        self.value = nn.Linear(hidden_size, attn_sz, bias=False)
        self.output = nn.Linear(attn_sz, n_embed, bias=False)
        
    def forward(self, x):
        
        xx = self.xx

       
        xk = x * self.time_mix_k + (1-self.time_mix_k) * xx
        xv = x * self.time_mix_v + (1-self.time_mix_v) * xx
        xr = x * self.time_mix_r + (1-self.time_mix_r) * xx
        
    
    
    
        
        k = self.key(xk)
        
        
        v = self.value(xv)
        r= self.receptance(xr)
        
            
        r =torch.sigmoid(r)
        
        # Calculate the difference in size along the non-singleton dimension
        diff = k.shape[1] - self.aa.shape[1]
    
        
        b,t,c = x.shape
        aa = torch.nn.functional.pad(self.aa, (0, 0, 0, diff, 0, 0))
        bb = torch.nn.functional.pad(self.bb, (0, 0, 0, diff, 0, 0))
        pp = torch.nn.functional.pad(self.pp, (0, 0, 0, diff, 0, 0))
             
            
        ww = self.time_first + k
    
        
        qq = torch.maximum(pp, ww )
        e1 = torch.exp(pp - qq)
        e2 = torch.exp(ww - qq)
        
        a = e1 * aa + e2 * v
        
        b = e1 * bb + e2
        wkv = a / b
        
        ww = pp + self.time_decay
        qq = torch.maximum(ww, k)
        e1 = torch.exp(ww - qq)
        e2 = torch.exp(k - qq)
        
        with torch.no_grad():
            xx = nn.Parameter(x)
            self.aa = nn.Parameter(e1 * aa + e2 * v)
            self.bb = nn.Parameter(e1 * bb + e2)
            self.pp = nn.Parameter(qq)
            
        
        return self.output(r * wkv)
    
        

In [5]:
class Block(nn.Module):
    def __init__(self, layer_id, n_layer, n_embd):
        super().__init__()
        self.layer_id = layer_id

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

        if self.layer_id == 0:
            self.ln0 = nn.LayerNorm(n_embd)

        if self.layer_id == 0 :
            self.ffnPre = ChannelMix(0, n_layer, n_embd)
        else:
            self.att = TimeMix(layer_id, n_layer, n_embd)

        self.ffn = ChannelMix(layer_id, n_layer, n_embd)

    def forward(self, x):
        if self.layer_id == 0:
            x = self.ln0(x)        
        if self.layer_id == 0 :
            x = x + self.ffnPre(self.ln1(x))  # better in some cases
        else:
            x = x + self.att(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x
    
    

In [8]:
class RWKV(nn.Module):
    def __init__(self, n_layer, vocab_size,  n_embd, ctx_len):
        super().__init__()
        self.step = 0
        self.ctx_len = ctx_len
        self.emb = nn.Embedding(vocab_size, n_embd)

        self.blocks = nn.Sequential(*[Block(i, n_layer, n_embd)
                                    for i in range(n_layer)])

        self.ln_out = nn.LayerNorm(n_embd)
        self.head = nn.Linear(n_embd, vocab_size, bias=False)
    
    def forward(self, idx, targets=None):
            idx = idx.to(self.emb.weight.device)

            self.step += 1
            
            B, T = idx.size()
            assert T <= self.ctx_len, "Cannot forward, because len(input) > model ctx_len."

            x = self.emb(idx)
            x = self.blocks(x)
            x = self.ln_out(x)

            x = self.head(x)
            

            loss = None
            if targets is not None:
                loss = F.cross_entropy(x.view(-1, x.size(-1)), targets.to(x.device).view(-1))
            x = torch.mean(x, dim=0, keepdim=True)
            return x, loss
        
    def generate(self, idx, max_new_tokes):
        for _ in range(max_new_tokes):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, idx_next), dim = 1)
        return idx

In [9]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
# ------------

In [10]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [11]:
len(text)

1115394

In [12]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [13]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [14]:
print(''.join(chars))
vocab_size


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


65

In [15]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = {i:ch for i,ch in enumerate(chars)}

In [16]:
itos = {i:ch for i,ch in enumerate(chars)}

In [17]:
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: "".join([itos[x] for x in l])

In [18]:
encode("hi there")

[46, 47, 1, 58, 46, 43, 56, 43]

In [19]:
decode([46, 47, 1, 58, 46, 43, 56, 43])

'hi there'

In [20]:
import torch

In [21]:
data = torch.tensor(encode(text), dtype = torch.long)

In [22]:
data.shape, type(data)

(torch.Size([1115394]), torch.Tensor)

In [23]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [24]:
block_size = 8

In [25]:
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [26]:
x = train_data[:block_size]

In [27]:
y = train_data[1:block_size+1]

In [28]:
x,y

(tensor([18, 47, 56, 57, 58,  1, 15, 47]),
 tensor([47, 56, 57, 58,  1, 15, 47, 58]))

In [29]:
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print("ctx ", context, "target", target)

ctx  tensor([18]) target tensor(47)
ctx  tensor([18, 47]) target tensor(56)
ctx  tensor([18, 47, 56]) target tensor(57)
ctx  tensor([18, 47, 56, 57]) target tensor(58)
ctx  tensor([18, 47, 56, 57, 58]) target tensor(1)
ctx  tensor([18, 47, 56, 57, 58,  1]) target tensor(15)
ctx  tensor([18, 47, 56, 57, 58,  1, 15]) target tensor(47)
ctx  tensor([18, 47, 56, 57, 58,  1, 15, 47]) target tensor(58)


In [30]:
torch.manual_seed(1337)

<torch._C.Generator at 0x108178850>

In [31]:
batch_size = 4

In [32]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [33]:
xb, yb = get_batch('train')

In [34]:
xb.shape

torch.Size([4, 8])

In [35]:
yb.shape

torch.Size([4, 8])

In [36]:
xb

tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])

In [37]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b][:t+1]
        target = yb[b][t]
        print(context, "-----", target)

tensor([24]) ----- tensor(43)
tensor([24, 43]) ----- tensor(58)
tensor([24, 43, 58]) ----- tensor(5)
tensor([24, 43, 58,  5]) ----- tensor(57)
tensor([24, 43, 58,  5, 57]) ----- tensor(1)
tensor([24, 43, 58,  5, 57,  1]) ----- tensor(46)
tensor([24, 43, 58,  5, 57,  1, 46]) ----- tensor(43)
tensor([24, 43, 58,  5, 57,  1, 46, 43]) ----- tensor(39)
tensor([44]) ----- tensor(53)
tensor([44, 53]) ----- tensor(56)
tensor([44, 53, 56]) ----- tensor(1)
tensor([44, 53, 56,  1]) ----- tensor(58)
tensor([44, 53, 56,  1, 58]) ----- tensor(46)
tensor([44, 53, 56,  1, 58, 46]) ----- tensor(39)
tensor([44, 53, 56,  1, 58, 46, 39]) ----- tensor(58)
tensor([44, 53, 56,  1, 58, 46, 39, 58]) ----- tensor(1)
tensor([52]) ----- tensor(58)
tensor([52, 58]) ----- tensor(1)
tensor([52, 58,  1]) ----- tensor(58)
tensor([52, 58,  1, 58]) ----- tensor(46)
tensor([52, 58,  1, 58, 46]) ----- tensor(39)
tensor([52, 58,  1, 58, 46, 39]) ----- tensor(58)
tensor([52, 58,  1, 58, 46, 39, 58]) ----- tensor(1)
tensor([

In [974]:
model = RWKV( n_layer, vocab_size, n_embd, block_size)
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

In [975]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [982]:

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 1.5167, val loss 1.6944
step 100: train loss 1.5086, val loss 1.7114
step 200: train loss 1.5106, val loss 1.7054
step 300: train loss 1.5207, val loss 1.6954
step 400: train loss 1.5127, val loss 1.7040
step 500: train loss 1.5057, val loss 1.7026
step 600: train loss 1.5067, val loss 1.7084
step 700: train loss 1.5077, val loss 1.6865
step 800: train loss 1.5061, val loss 1.7005
step 900: train loss 1.5075, val loss 1.6916
step 1000: train loss 1.5104, val loss 1.6998
step 1100: train loss 1.4992, val loss 1.6832
step 1200: train loss 1.4922, val loss 1.6968
step 1300: train loss 1.4983, val loss 1.6771
step 1400: train loss 1.4985, val loss 1.6779
step 1500: train loss 1.4890, val loss 1.6820
step 1600: train loss 1.4924, val loss 1.6834
step 1700: train loss 1.4882, val loss 1.6829
step 1800: train loss 1.4816, val loss 1.6748
step 1900: train loss 1.4968, val loss 1.6904
step 2000: train loss 1.4859, val loss 1.6827
step 2100: train loss 1.4841, val loss 1.7038


In [983]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokes=1000)[0].tolist()))



Adather Bolingham.

HERMIONE:
Your should by enform'd in hither or, e--

YORK:
I'll prunt.

GLOUCESTER:

VELSABELLA:
Ours proves:
Commending bands,
And with a well for good know, I besides the day qdoth:
Underful fire amputted. By like the know
I thank for me as is time in eases now thee there;
Nay, mark show Lords:
Sir that joyful very look'd, as I go come brother to the delivers to was copted Boilosom, or plebeian to Rome a good?

MENENIUS:
The house? whengeful doing to thy king,
Like of
doth a boy! your honourable poor like to your fatal, to see thee
Evence
Of you
Do Screaters o' to is wish of thy drial.

TLAND:
Nor queen.

Second Servingman:
So.

COMINIUS:
They in heart ly scent state they,nd it did I rather know avouch. Plantagenets their ends
Should not the
duked.

VIRGILKA:
She's king our is my father's into spiteous
Some doubtful gapery nurse,
I meic,
The each his gallet. But, right for runt of Clifford your kerned a plain time I am, proved to you say fulls,
And the hand:
Why 

In [38]:
chars = sorted(list(set(text.split(' '))))
vocab_size = len(chars)

In [39]:
chars[2]

'\nWas'

In [40]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: " ".join([itos[x] for x in l])

In [41]:
data = torch.tensor(encode(text.split(' ')), dtype = torch.long)

In [42]:
decode(encode(text.split(' ')[:10]))

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst'

In [43]:
text.split(' ')[:10]

['First',
 'Citizen:\nBefore',
 'we',
 'proceed',
 'any',
 'further,',
 'hear',
 'me',
 'speak.\n\nAll:\nSpeak,',
 'speak.\n\nFirst']

In [44]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [119]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [125]:
get_batch('train')[0].shape

torch.Size([4, 8])

In [120]:
block_size, n_embd, learning_rate, batch_size


(8, 64, 0.001, 4)

In [None]:
block_size=8

In [131]:
model = RWKV( 5, vocab_size, 256, block_size)
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-4)

In [132]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [133]:
optimizer = torch.optim.AdamW(model.parameters(),lr=3e-4)
batch_size = 4
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 10.8263, val loss 10.8326
step 100: train loss 8.4848, val loss 8.4930
step 200: train loss 8.3536, val loss 8.4139
step 300: train loss 8.3051, val loss 8.4341
step 400: train loss 8.2986, val loss 8.3662
step 500: train loss 8.2414, val loss 8.2713
step 600: train loss 8.1505, val loss 8.2766
step 700: train loss 8.1353, val loss 8.2955
step 800: train loss 8.0703, val loss 8.1881
step 900: train loss 7.9545, val loss 8.2166
step 1000: train loss 8.0764, val loss 8.2661
step 1100: train loss 8.0368, val loss 8.3165
step 1200: train loss 8.0470, val loss 8.2838
step 1300: train loss 7.9314, val loss 8.2970
step 1400: train loss 7.9337, val loss 8.2967
step 1500: train loss 7.8668, val loss 8.2682
step 1600: train loss 7.8822, val loss 8.2731
step 1700: train loss 7.8482, val loss 8.3902
step 1800: train loss 7.8613, val loss 8.2876
step 1900: train loss 7.8271, val loss 8.2647
step 2000: train loss 7.7569, val loss 8.4001
step 2100: train loss 7.7817, val loss 8.287

In [134]:
context = torch.tensor([encode("HENRY all divided? night".split(' '))], dtype = torch.long)
print(decode(model.generate(context, max_new_tokes=100)[0].tolist()))


HENRY all divided? night to that Richmond rough said a
pleasure the boy do on,
And begging he O seven justice: and yet enter it?

PETER:
I.

PETRUCHIO:
'Tis sit for myself,
No if both Lewis upon thy oweth me
As be enemy How O, and not shall
encounter show,
And policy that there,
Rather reign fame that holes
Where come deny you would be executed since do I water and anchor?
And war's me from side,
To ones.

AUTOLYCUS:
Why, did,
And noble means that factions
and and monument!

ANGELO:
I believe I beasts,
That hands;
Swear in. ANNE:
All their life and death,
And stay?

CAMILLO:
At of my gown? here?

QUEEN the eager Margaret made.

LADY he are seeth others are presence, of Rome.

Second pride
To hands a flowers.

Servant:
Why another.
The


In [136]:
context = torch.tensor([encode("George".split(' '))], dtype = torch.long)
print(decode(model.generate(context, max_new_tokes=100)[0].tolist()))


George are sir? cozeners thee, thee loves tongues and sit prudence; and said heartless it.

MENENIUS:
Note they stay.

HENRY Meeting as men, done't--
Harp the store,
One night, purpose no must
you? character our window, it fall whose repose dish'd
For lark ones.

BRUTUS:
Come, years a man.

ARCHIDAMUS:
Would office
Becomes me come traitors! thy when
for is either.
Here judgment-place.
Once done her view,
And every you cut horse.
Now, quiet deformed heart;
The as great worth himself, him! hath live.
It a dine respects sooner.

POLIXENES:
Dear all the home.

AUFIDIUS:
I obsequious of that? blood?

Nurse:
It sit the
duke point.

ISABELLA:
O, here;
Better either death were Lord, not be build for circled peace.
O conquer'd; thine herself yet they too; how name, dear lord,
