<a href="https://colab.research.google.com/github/githubpradeep/notebooks/blob/main/fractal_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-08-14 12:37:16--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-08-14 12:37:16 (97.2 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [173]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [174]:
import torch

In [30]:
# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [36]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1137)

<torch._C.Generator at 0x7a9fb3378f70>

In [175]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias = False)
        self.query = nn.Linear(n_embed, head_size, bias = False)
        self.value = nn.Linear(n_embed, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)).to(device))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

class MulitHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size, dropout=0):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x =  torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(x))
        return out

class FractalBlock(nn.Module):
    def __init__(self,  n_embed, n_head, n_cols, dropout=0):
        super().__init__()
        self.n_cols = n_cols
        self.dropout = nn.Dropout(dropout)
        self.columns = nn.ModuleList([nn.ModuleList() for _ in range(n_cols)])
        self.max_depth = 2 **(n_cols-1)
        dist = self.max_depth
        self.count = [0] *self.max_depth
        for col in self.columns:
            for i in range(self.max_depth):
                if (i+1)%dist == 0:
                    module = MulitHeadAttention(n_head, n_embed//n_head)
                    self.count[i]+=1
                else:
                    module = None
                col.append(module)
            dist //= 2

    def forward(self, x):
        out = [x for _ in range(self.n_cols)]
        for i in range(self.max_depth):
            st = self.n_cols - self.count[i]
            cur_outs = []
            for c in range(st, self.n_cols):
                cur_in = out[c]
                cur_module = self.columns[c][i]
                cur_outs.append(cur_module(cur_in))

            n_out = torch.stack(cur_outs)

            n_out = n_out.mean(dim=0)

            for c in range(st, self.n_cols):
                out[c] = n_out
        return self.dropout(out[-1])


class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout=0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4* n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
         nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embed, n_head, n_cols):
        super().__init__()
        self.sa_head= FractalBlock(n_embed, n_head, n_cols )
        self.ffw=  FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa_head(self.ln1(x))
        x = x + self.ffw(self.ln2(x))
        return x


class FractalTransformer(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embed).to(device)
        self.position_embedding_table = nn.Embedding(block_size, n_embed).to(device)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head, n_cols=4) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embed, vocab_size)


    def forward(self, idx, targets=None):
        B, T = idx.shape

        token_emb = self.token_embedding_table(idx).to(device)
        pos_emb = self.position_embedding_table(torch.arange(T).to(device))
        x = token_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokes):
        for _ in range(max_new_tokes):
            idx_cond = idx[:, -block_size:].to(device)
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, idx_next), dim = 1)
        return idx

In [176]:
model = FractalTransformer()
model  = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

In [177]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X = X.to(device)
            Y=Y.to(device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [178]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'mps'
eval_iters = 200
n_embd = 32
n_embed = 32
n_head = 1
n_layer = 4
dropout = 0.0
# ------------

In [179]:
model = FractalTransformer()
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

In [180]:
len(text.split(' '))

169893

In [181]:
len(text)

1115394

In [182]:
chars = sorted(list(set(text.split(' '))))
vocab_size = len(chars)

In [183]:
chars[2]

'\nWas'

In [184]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = {i:ch for i,ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: " ".join([itos[x] for x in l])

In [185]:
data = torch.tensor(encode(text.split(' ')), dtype = torch.long)

In [186]:
data[:10]

tensor([ 1455,   957, 39874, 29614,  5949, 16628, 18572, 24432, 34050, 34057])

In [187]:
decode(encode(text.split(' ')[:10]))

'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst'

In [188]:
text.split(' ')[:10]

['First',
 'Citizen:\nBefore',
 'we',
 'proceed',
 'any',
 'further,',
 'hear',
 'me',
 'speak.\n\nAll:\nSpeak,',
 'speak.\n\nFirst']

In [189]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [190]:
model = FractalTransformer()
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

In [191]:
max_iters = 1000
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')
    xb = xb.to(device)
    yb  = yb.to(device)

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 11.0261, val loss 11.0367
step 100: train loss 8.1894, val loss 8.3922
step 200: train loss 8.0557, val loss 8.3469
step 300: train loss 7.9869, val loss 8.3925
step 400: train loss 7.9350, val loss 8.4275
step 500: train loss 7.8715, val loss 8.4019
step 600: train loss 7.7944, val loss 8.4512
step 700: train loss 7.7152, val loss 8.4584
step 800: train loss 7.6717, val loss 8.4280
step 900: train loss 7.5592, val loss 8.4500
step 999: train loss 7.4792, val loss 8.4682


In [192]:
context = torch.tensor([encode("thou art kneel before king".split(' '))], dtype = torch.long).to(device)
print(decode(model.generate(context, max_new_tokes=200)[0].tolist()))


thou art kneel before king unavoided way
and my sight!

Nurse:
O out of wondrous linen. poor toward you.
Be what to instruct united grievous straight den.
Whose one day, at strike honour.

ANGELO:
See do, leave no.

BRUTUS:
The pities all shade you, do II:
A in long she father. hang speaks than thy tribunes; exile, monastery or lost, thou away.
3 speak:
Look when it ell quoth and thank Hereford, forfeits. though no lived!
A spectacle him?

DUKE death?
What, secrecy. fair, Menenius, Gaunt,
Even will eats now, given what's said will blot?

Abbot:
My our eyes time
Unfold full you are gold's and no palsy, VI

GLOUCESTER:
Now navy speak; there: your even: them
For for draw houses!

ROMEO:
This thee a years,
Pass'd foretell but them.

LUCIO:
Friar, after,
Is II:
Needs forth;
My thankful, think
That in dangerous
to 'gainst with
Hath and If a lordship France, how our fellows:
He had hands watch than princely free of his instruction can oft thy approbation.

CORIOLANUS:
Where? strength; think y

In [193]:
context = torch.tensor([encode("thou art kneel before king".split(' '))], dtype = torch.long).to(device)
print(decode(model.generate(context, max_new_tokes=200)[0].tolist()))


thou art kneel before king if for brother, his sheets,
Which but our hale glassy shall makes by necessity;
For might me;
No, not thou art like sighs us sends
It IV:
Why, the kissing, one dies.

CLAUDIO:
Why slander comes courage souls.

First once.
Come way,
Which as circled on on!

CORIOLANUS:

VOLUMNIA:
I one is farewell.

EARL lo, of tongue?

Messenger:
Ah, to save time that to you hear. will, do pursue hath if not so and away, can here; expiring in our war
Their put along,
Holding prithee, be certain for York!
Suppose were at plainly the thorn.

DUKE Stanley;
Oxford, I jest;
His YORK:
Blind me spouts: nothing?
Why, to thee: or you do?

ISABELLA:
As the rigour again of death,
If shall not would them, by let me how year to heir king the certain.

DUKE a dearest noble swerve, state, sworn of all victory in the petition
well my sweet voices? upon
Show that noble Marian complaints
All brother false vessel down with men,
From not a other
station; Lucentio burst, 'tis by the traitor one L

In [194]:
model.eval()
context = torch.tensor([encode("Hermione".split(' '))], dtype = torch.long).to(device)
print(decode(model.generate(context, max_new_tokes=200)[0].tolist()))


Hermione of men.

Second proud;
But, be friends.
Ourself the dullest be fault
I' will Eve, knight, that command be tasted, O looks desire of her be
called had Officer:
Faith, it fall and awake, occasion where in Helena.' your tent
I'll a heart.
Those showing hadst how they do more: Watchman:
To-morrow me Tybalts. than
his there,
With my garland. with it shall Tarquin's the truth, king;
Of you, and with the while!

Third is a rising this thee i' this witless Gloucester's put what that dim had witnesses and the be?--
With about this heed bounty. in youth
There you.

ANGELO:
Teach the house a prate,
And Clarence sighs i' desire down.

Servant:
What, man, RICHARD mouths O be wisdom my you?

CORIOLANUS:
I to the right; gnaw'd safe.

DUKE Ere further.

CORIOLANUS:
Ha! OF soldier.
Ah, peevish night!
This take
The so Richard's love,
Can melancholy
Hath on them, when I away! duty having, me lightness? humbly but near.
Go, smile I know a power tears depends
Upon to. revenge would took up doth, y