In [1]:
import os
import sys
sys.path.append(os.path.abspath(os.path.join('..')))

In [2]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from einops import parse_shape, rearrange, repeat

def count_parameters(model: nn.Module):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [3]:
sequence_length = 100 # L
batch_size = 10 # B
input_size = 15 # N
hidden_size = 20 # H == output_size

rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size, num_layers=1)

input = torch.randn(sequence_length, batch_size, input_size)
h0 = torch.randn(1, batch_size, hidden_size)
output, hn = rnn(input, h0)

out_shape = parse_shape(output, 'L B H')
hn_shape = parse_shape(hn, '_ B H')

print(out_shape, hn_shape)

# Good thing is that we can provide inputs of different lengths.
# Bad thing is that the entire provided sequence is computed sequentially. 

{'L': 100, 'B': 10, 'H': 20} {'B': 10, 'H': 20}


In [4]:
from brainle.models.architectures.attention import CausalSelfAttention

net = CausalSelfAttention(
    embedding_dim = 512,
    num_heads = 4, 
    block_size = 100,
    dropout_attention = 0.5,
    dropout_residual = 0.5
)

print(net(torch.rand(1, 100, 512)).shape)
count_parameters(net)



torch.Size([1, 100, 512])


1050624

In [6]:
from brainle.models.architectures.attention import GPT
    
mingpt = GPT(
    vocabulary_size = 682,
    embedding_dim = 512,
    num_layers = 8,
    num_heads = 8, 
    block_size = 128,
    dropout_embedding = 0.1,
    dropout_attention = 0.1,
    dropout_residual = 0.1,
)

x = torch.randint(low=0, high=682, size=(2, 128)) 
mingpt(x).shape
count_parameters(mingpt)

25984000

In [7]:
from brainle.models.architectures.attention import SelfMemoryEncode, SelfMemoryDecode

encode = SelfMemoryEncode(
    embedding_dim = 1024,
    num_heads = 8,
    memory_size = 4096,
    kernel_size = 4,
    stride = 2,
    padding = 1
)   

x = torch.rand(1, 1000, 1024)
print(x.shape)

z = encode(encode(encode(encode(encode(encode(encode(x)))))))
print(z.shape)

decode = SelfMemoryDecode(
    embedding_dim = 1024,
    num_heads = 8,
    memory_size = 4096,
    kernel_size = 4,
    stride = 2,
    padding = 1
)

out = decode(decode(decode(z)))
print(out.shape)

print(f"Encode params: {count_parameters(encode)}")
print(f"Decode params: {count_parameters(decode)}")

# Receptive field with L encode layers:
k = 4
L = 6
receptive_field = 2**L * (k-1) - k + 2 # ... exponential in L
print(f"""
With L={L}, kernel_size={k}, stride=2 we have a receptive field of {receptive_field} tokens at L=0, and:
* With an average english sentence of ~20 words at 5 letters per word we get max {receptive_field / 100} sentences of attention per token at L={L} 
""")

torch.Size([1, 1000, 1024])
torch.Size([1, 7, 1024])
torch.Size([1, 56, 1024])
Encode params: 2758688
Decode params: 2758784

With L=6, kernel_size=4, stride=2 we have a receptive field of 190 tokens at L=0, and:
* With an average english sentence of ~20 words at 5 letters per word we get max 1.9 sentences of attention per token at L=6 



In [8]:
from datasets import load_dataset

dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split='train')

text = ''
for i in range(len(dataset)):
    text += dataset[i]['text']

text[0:1000]

Reusing dataset wikitext (/Users/flavioschneider/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


' = Valkyria Chronicles III = \n Senjō no Valkyria 3 : Unrecorded Chronicles ( Japanese : 戦場のヴァルキュリア3 , lit . Valkyria of the Battlefield 3 ) , commonly referred to as Valkyria Chronicles III outside Japan , is a tactical role @-@ playing video game developed by Sega and Media.Vision for the PlayStation Portable . Released in January 2011 in Japan , it is the third game in the Valkyria series . Employing the same fusion of tactical and real @-@ time gameplay as its predecessors , the story runs parallel to the first game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . \n The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving for series n

In [9]:
from brainle.datamodules.datasets.masked_char_dataset import MaskedCharDataset 

    
dataset = MaskedCharDataset(text, block_size=128, p_word_mask=0.15, p_char_mask=0.05)
x, mask = dataset[42]
chunk = dataset.decode(x.numpy())
chunk = ''.join([ char if mask[idx] else '▢' for idx, char in enumerate(chunk) ])

print(len(dataset), len(x), len(mask))
print(chunk)

Alphabet size 837
10892783 128 128
lkyria 3 : Unrecorded Chronicles ( Japanese : ▢場のウァルキュリア3 ▢ ▢▢▢ . Valkyria▢of ▢he B▢ttle▢ield 3 ) , com▢only referred to as Valk


In [10]:
from brainle.models.architectures.attention import SMUNet

net = SMUNet(
    vocabulary_size = 682,
    embedding_dim = 1024,
    num_layers = 1,
    memory_sizes = [512],
    num_heads = 8,
    kernel_size = 2,
    stride = 2,
    padding = 0,
    use_skip = True 
)

x = torch.randint(low=0, high=682, size=(2, 2048)) 
y = net(x)
print(y.shape)
count_parameters(net)

torch.Size([2, 2048, 682])


7903082

In [11]:
x = torch.randint(low=0, high=682, size=(2, 2048)) 
y = F.softmax(net(x), dim=-1)
#print(y.shape, y)

ids = torch.topk(y, k=1, dim=-1)[1]
print(ids[0].shape)
ids = rearrange(ids, 'b s 1 -> b s')
print(ids[0].shape)
print(ids[0][0:100])

chunks = []
for i in range(2):
    chunks += [''.join(dataset.decode(ids[i].numpy().tolist()))]
    
#print(chunks)

torch.Size([2048, 1])
torch.Size([2048])
tensor([472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361,
        472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361,
        472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361,
        472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361,
        472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361,
        472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361,
        472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361, 472, 361,
        472, 361])
