In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, Identity, Module
from collections import OrderedDict
import tiktoken

In [2]:
def exists(v):
    return v is not None

In [3]:
def default(v, d):
    return v if exists(v) else d

In [4]:
def heinsen_associative_scan_log(log_coeffs, log_values):
    a_star = log_coeffs.cumsum(dim = 1)
    log_h0_plus_b_star = (log_values - a_star).logcumsumexp(dim = 1)
    log_h = a_star + log_h0_plus_b_star
    return log_h.exp()

In [5]:
def g(x):
    return torch.where(x >= 0, x + 0.5, x.sigmoid())

In [6]:
def log_g(x):
    return torch.where(x >= 0, (F.relu(x) + 0.5).log(), -F.softplus(-x))

In [7]:
class minGRU(Module):
    def __init__(self, dim, expansion_factor = 1., proj_out = None):
        super().__init__()

        dim_inner = int(dim * expansion_factor)
        proj_out = default(proj_out, expansion_factor != 1.)

        self.to_hidden_and_gate = Linear(dim, dim_inner * 2, bias = False)
        self.to_out = Linear(dim_inner, dim, bias = False) if proj_out else Identity()

    def forward(self, x, prev_hidden = None, return_next_prev_hidden = False):
        seq_len = x.shape[1]
        hidden, gate = self.to_hidden_and_gate(x).chunk(2, dim = -1)

        if seq_len == 1:
            # handle sequential

            hidden = g(hidden)
            gate = gate.sigmoid()
            out = torch.lerp(prev_hidden, hidden, gate) if exists(prev_hidden) else (hidden * gate)
        else:
            # parallel

            log_coeffs = -F.softplus(gate)

            log_z = -F.softplus(-gate)
            log_tilde_h = log_g(hidden)
            log_values = log_z + log_tilde_h

            if exists(prev_hidden):
                log_values = torch.cat((prev_hidden.log(), log_values), dim = 1)
                log_coeffs = F.pad(log_coeffs, (0, 0, 1, 0))

            out = heinsen_associative_scan_log(log_coeffs, log_values)
            out = out[:, -seq_len:]

        next_prev_hidden = out[:, -1:]

        out = self.to_out(out)

        if not return_next_prev_hidden:
            return out

        return out, next_prev_hidden

In [8]:
hidden_size: int = 8
seq_length: int = 4
batch_size: int = 2
vocabulary_size: int = 50274

In [17]:
embeddings = nn.Embedding(num_embeddings=vocabulary_size, embedding_dim=hidden_size)


In [9]:
# Initializes the Multihead attention layer
multihead_attn = torch.nn.MultiheadAttention(
    embed_dim=hidden_size,
    num_heads=4,
    dropout=0.5,
    bias=False,
    batch_first=True
)

In [18]:
min_gru = minGRU(hidden_size)

In [24]:
tokenizer = tiktoken.get_encoding("gpt2")
batch = []
txt1 = "Every effort moves you"
txt2 = "Every day holds a"

batch.append(torch.tensor(tokenizer.encode(txt1), dtype=torch.int32))
batch.append(torch.tensor(tokenizer.encode(txt2), dtype=torch.int32))

batch = torch.stack(batch, dim=0)
print(batch)

tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]], dtype=torch.int32)


In [25]:
x = embeddings(batch)

In [26]:
out, hidden = min_gru(x, return_next_prev_hidden=True)

to_att = out + hidden

In [27]:
with_atten, _  = multihead_attn(to_att, to_att, to_att)
with_atten.shape

torch.Size([2, 4, 8])

In [28]:
norm = nn.LayerNorm(hidden_size)
with_atten = norm(with_atten)

In [15]:
mlp = nn.Sequential(OrderedDict([
    ('dense1', nn.Linear(hidden_size, 1000)),
    ('act1', nn.ReLU()),
    ('dense2', nn.Linear(1000, 5000)),
    ('act2', nn.ReLU()),
    ('output', nn.Linear(5000, vocabulary_size)),
    ('outact', nn.Sigmoid()),
]))
mlp

Sequential(
  (dense1): Linear(in_features=8, out_features=1000, bias=True)
  (act1): ReLU()
  (dense2): Linear(in_features=1000, out_features=5000, bias=True)
  (act2): ReLU()
  (output): Linear(in_features=5000, out_features=50274, bias=True)
  (outact): Sigmoid()
)

In [29]:
final = mlp(with_atten)
print(final.shape)
with torch.no_grad():
    final = torch.softmax(final[:, -1, :], dim=-1)
print(final)
label = torch.argmax(final, dim=-1)
print(label)

torch.Size([2, 4, 50274])
tensor([[2.0210e-05, 1.9131e-05, 2.0274e-05,  ..., 1.9674e-05, 1.9835e-05,
         2.0532e-05],
        [2.0561e-05, 1.9077e-05, 2.0153e-05,  ..., 1.9848e-05, 1.9829e-05,
         1.9613e-05]])
tensor([42963, 22670])


In [47]:
ll = label.tolist()
print(len(ll))
w1 = tokenizer.decode([label.tolist()[1]])
print(w1)
print(len(label.tolist()))

2
 strips
2
