In [6]:
#----- imports --------
import tqdm
import torch
import wandb
import os
import tokenizers


device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
#assert device == 'cuda', "This notebook is not optimized for CPU"

config = {
    "learning_rate": 1e-3,
    "eval_interval": 300,
    "max_iters": 5000, 
    "H": 4,
    "B": 64,
    "T": 16,
    "C": 128,
    "n_heads": 2,
    "dropout": 0.01,
    "n_layers": 2,
    "git_hash": os.popen("git rev-parse HEAD").read().strip()
}

# initial
for k,v in config.items():
    locals ()[k] = v


wandb.init(
    project = "mini-shakespeare-tokenized",
    config = config
)

[34m[1mwandb[0m: Currently logged in as: [33mgmchad[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [7]:

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [8]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [9]:
print("length of dataset in lines: ", len(text.split('\n')))

length of dataset in lines:  40001


In [10]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [118]:
# load  the tokenizer 
import tokenizers
tokenizer = tokenizers.ByteLevelBPETokenizer()
tokenizer._tokenizer = tokenizers.ByteLevelBPETokenizer(
    "./shakespeare-2k-bpe-vocab.json",
    "./shakespeare-2k-bpe-merges.txt",
)

enc = tokenizer.encode("Romeo Romeo wherefore art thou Romeo?")
tokenizer.decode(enc.ids)


# stoi = {ch: i for i, ch in enumerate(chars)}
# itos = {i: ch for i, ch in enumerate(chars)}

def encode(text):
    # return [stoi[ch] for ch in text]
    return tokenizer.encode(text).ids
def decode(encoded_text):
    # return ''.join([itos[i] for i in encoded_text])
    return tokenizer.decode(encoded_text)

hello_encoded = encode("hello")
print(hello_encoded)
print(decode(hello_encoded))
vocab_size = tokenizer.get_vocab_size()


# ~ compression ratio of tokens
""" tok_ratio = len(text) / len(encode(text))
print(tok_ratio) """
None

[262, 278, 83]
hello


In [27]:
enc.ids # [0,2047]
cas # [0,44]

cas[enc.ids]
print(cas)
print(cas.shape)

tensor([35, 35, 35,  ...,  9, 43,  2], dtype=torch.int32)
torch.Size([2048])


In [12]:

data = torch.tensor(encode(text), dtype=torch.long)
print(data.dtype)
print(data.size())
print(data.device)


torch.int64
torch.Size([388693])
cpu


In [13]:
n = int(0.9*len(data))

train_data = data[:n]
val_data = data[n:]

In [14]:
train_data[:T+1]

tensor([ 676, 1201,   30,  203,  779,  553,  336,  589, 1817,  807, 2008,  719,
          16,  679,  322,  621,   18])

In [15]:
x = train_data[:T]
y = train_data[1:T+1]
for t in range(T):
    context = x[:t+1]
    target = y[t]
    # print("when we see the text", context, "we predict the next character is", target)

In [16]:
torch.manual_seed(1337)

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, data.size(0) - T, (B,)) # 4 random locations we can sample from
    x = torch.stack([data[i:i+T] for i in ix]) # random sequences
    y = torch.stack([data[i+1:i+T+1] for i in ix]) # next character for each random sequence

    return x, y

xb, yb = get_batch('train')

for b in range(B):
    for t in range(T): # for each of the characters in the sample
        context = xb[b, :t+1]
        target = yb[b, t]


In [17]:

import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)


class Head(nn.Module):
    '''One Head of self-attention'''
    def __init__(self, H):
        super().__init__()
        self.query = nn.Linear(C, H, bias=False)
        self.key = nn.Linear(C, H, bias=False)
        self.value = nn.Linear(C, H, bias=False)
        # self.output = nn.Linear(H, C, bias=False) # output matrix
        self.register_buffer('tril', torch.tril(torch.ones(T, T)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # Query and Key matrices for the attention mechanism
        # x: 8 tokens
        # Q: 16 tall (arbitrary), 32 long channels
        # K: 16 tall (arbitrary), 32 long channels

        query_vectors = self.query(x)
        key_vectors = self.key(x)


        # Attention masking(so we can't look into the past):

        tril = self.tril
        wei = torch.zeros(T, T) 
        wei = wei.masked_fill(tril == 0, float('-inf')) # set the upper triangular to -inf
        # xbow = wei @ x # apply the mask to the input, bag of words because simple avg.

        # multiply the two to get the attention weights
        attention_pattern = query_vectors @ key_vectors.transpose(-2, -1) # T, T
        attention_pattern = attention_pattern / (H ** 0.5) # scale the attention pattern for numerical stability
        attention_weights = F.softmax(attention_pattern + wei, dim=-1) # T, T (the row dimension is the query)
        attention_weights = self.dropout(attention_weights)

        value_vectors = self.value(x) # the direction we should go in the embedding space for each token (ie more blue) T, H

        # apply the attention weights to the value vectors
        context = attention_weights @ value_vectors # T, H

        # project back into original space from value space
        # return self.output(context)
        return context

x = torch.randn(B,T,C)
head = Head(H)
# head(x)

In [18]:
class MultiHeadAttention(nn.Module):
    '''Multiple heads of self-attention'''
    def __init__(self, H, C, n_heads): # H is head embedding space size, n_heads is number of heads
        super().__init__()
        self.heads = nn.ModuleList([Head(H) for _ in range(n_heads)])
        self.combine_heads = nn.Linear(H*n_heads, C)
        self.dropout = nn.Dropout(dropout)


    def forward(self,x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.combine_heads(x)  # T, C
        return self.dropout(x)

In [19]:
head = MultiHeadAttention(H, C, n_heads)
head.heads[0].forward(x).shape


torch.Size([64, 16, 4])

In [20]:
class FeedForward(nn.Module):
    '''Feed-forward neural network'''
    def __init__(self, C):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(C, C * 4),
            nn.ReLU(),
            nn.Linear(C * 4, C),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [21]:
class LayerNorm(nn.Module):
    '''Layer normalization'''
    def __init__(self, C, use_affine=True):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(C)) if use_affine else None
        self.beta = nn.Parameter(torch.zeros(C)) if use_affine else None

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        if self.gamma is not None and self.beta is not None:
            return self.gamma * (x - mean) / (std + 1e-6) + self.beta
        else:
            return (x - mean) / (std + 1e-6)

In [22]:
class Block(nn.Module):
    '''Transformer block'''
    def __init__(self, H, C, n_heads):
        super().__init__()
        self.attention = MultiHeadAttention(H, C, n_heads)
        self.ff = FeedForward(C)
        self.norm1 = LayerNorm(C, use_affine=True)
        self.norm2 = LayerNorm(C, use_affine=True)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [23]:
class GPT(nn.Module):

    def __init__(self, n_layers):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, C) 
        self.position_embedding_table = nn.Embedding(T, C)
        self.lm_head = nn.Linear(C, vocab_size)
        self.layers = nn.ModuleList([Block(H, C, n_heads) for _ in range(n_layers)])
        self.block = nn.ModuleList([Block(H, C, n_heads)])
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx) # batch_dim, sequence_dim, embedding_dim
        pos_emb = self.position_embedding_table(torch.arange(T))
        x = token_emb + pos_emb # token identities and positions contained

        for layer in self.layers:
            x = layer(x)

        logits = self.lm_head(x) # batch_dim, sequence_dim, vocab_size
        

        batch_dim, sequence_dim, embedding_dim = logits.size()

        # loss = F.cross_entropy(logits, targets) this won't work because we need 1d logits and 1d targets
        # one-hot-vectors are a line in the x-dimension, so the shape of shape of the logits should be (-1, vocab_size).

        if targets is None:
            return logits, None
        else:
            # a list of all the predictions, reguardles of batch.
            # xdim: probabilities of each character in the vocab (embedding_dim=vocab_size)
            # ydim: all predictions for all batches flattened (batch_dim*sequence_dim)
            logits_loss_view = logits.view(-1, vocab_size) 
            # targets loss view
            # xdim: all targets for all batches flattened (batch_dim*sequence_dim)
            # so this would be like, [1,4,5,1,2,3, ...]
            # where each number is the correct next index of the one hot vector
            targets_loss_view = targets.view(-1)
            loss = F.cross_entropy(logits_loss_view, targets_loss_view)
            return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx[:,-T:])
            # get the predictions of the last token
            last_token_logits = logits[:, -1, :] # all batches, last token, all probabilities
            # softmax to get probabilities
            probabilities = F.softmax(last_token_logits, dim=-1)
            # sample from the probabilities
            next_token = torch.multinomial(probabilities, num_samples=1)
            # add the new token to the idx tensor
            idx = torch.cat((idx, next_token), dim=1)
        return idx

    

model = GPT(n_layers)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss)




test_idx = torch.zeros(1, T).long()
model.forward(idx=test_idx)
# decode(model.generate(idx=test_idx, max_new_tokens=100)[0].tolist())

torch.Size([64, 16, 2048])
tensor(7.9725, grad_fn=<NllLossBackward0>)


(tensor([[[-0.5444,  0.0313, -0.4174,  ..., -0.2468, -1.0921,  0.6667],
          [-0.6372,  0.4431, -0.0606,  ..., -0.4292, -0.6542, -0.9803],
          [-1.5217,  0.7364, -0.4287,  ...,  0.1530, -0.1420,  0.0655],
          ...,
          [-0.1826,  0.8624,  0.5473,  ..., -0.4610, -0.3327, -0.0271],
          [-0.8571,  1.2007, -1.2179,  ..., -0.4096,  0.8169,  0.1077],
          [-1.3482,  1.4352, -1.5530,  ..., -0.2951,  0.0462, -1.0784]]],
        grad_fn=<ViewBackward0>),
 None)

In [24]:
model

GPT(
  (token_embedding_table): Embedding(2048, 128)
  (position_embedding_table): Embedding(16, 128)
  (lm_head): Linear(in_features=128, out_features=2048, bias=True)
  (layers): ModuleList(
    (0-1): 2 x Block(
      (attention): MultiHeadAttention(
        (heads): ModuleList(
          (0-1): 2 x Head(
            (query): Linear(in_features=128, out_features=4, bias=False)
            (key): Linear(in_features=128, out_features=4, bias=False)
            (value): Linear(in_features=128, out_features=4, bias=False)
            (dropout): Dropout(p=0.01, inplace=False)
          )
        )
        (combine_heads): Linear(in_features=8, out_features=128, bias=True)
        (dropout): Dropout(p=0.01, inplace=False)
      )
      (ff): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): ReLU()
          (2): Linear(in_features=512, out_features=128, bias=True)
          (3): Dropout(p=0.01, inplace=False)
       

In [25]:
# logits, loss = self(idx[:,-T:])

idx = torch.zeros(1, 1).long()
idx[:,-T:]

tensor([[0]])

In [26]:
model.token_embedding_table.weight.device

device(type='cpu')

In [27]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)



In [28]:
eval_iters = 10
eval_interval = 300
@torch.no_grad()
def estimate_loss(is_last=False):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        real_iters = eval_iters
        if is_last and split == 'val':  # increase last eval to mitigate noise
            real_iters *= 10 
        losses = torch.zeros(real_iters)
        for k in range(real_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out
    

In [29]:

for steps in tqdm.tqdm(range(max_iters)):
    xb, yb = get_batch('train')
    # loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if steps % eval_interval == 0:
        losses = estimate_loss()
        wandb.log({"train": losses['train'].item() / tok_ratio, "val": losses['val'].item() / tok_ratio})

losses = estimate_loss(is_last=True)
wandb.log({"train": losses['train'].item() / tok_ratio, "val": losses['val'].item() / tok_ratio})
wandb.finish()

torch.save(model, 'gpt2.pth')

100%|███████████████████████████████████████| 5000/5000 [06:01<00:00, 13.82it/s]


0,1
train,█▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
val,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train,1.21185
val,1.49747


In [35]:
estimate_loss()

{'train': tensor(3.4817), 'val': tensor(4.3184)}

In [36]:
test_idx = torch.zeros(1, T).long()
print(decode(
    model.generate(idx=test_idx, max_new_tokens=500)[0].tolist()
)[T:])

s><s><s><s><s><s><s><s><s><s><s>
My foreet past men's heavy currence' face?
'Tis anciis man should question in thy spider,
Your senate and what you promis
Whereof you'll give you in one, and his son becomin from Laugares!
Thou art died all dead, and makes your picture,
That I, the rite of death, ho!
Look, what he's gone, at your honours;
It is one thing when you long withal, my lord!

MENENIUS:
Not Gloucester, sir, I do
But I'll still add of this man, all this night;
For I say this man's your doit, those
disbury, there thou art believed;
And hath the most heelton of yourself.
Canst thou among the fee that ever he be with a flre straight.

COMINIUS:
I fear. What, King Richard's the vast that has heard of Christious honour.
See, nurse, no more, let all his oath?
Where fear there coell.

BIONDELLOURY:
What amomest not speak had been therein'd their right:
Forbion spots and sworn at Harry that was withal:
About a birth; look it, God--
Your eye's unwroop'd him
shall hand and tempt my no sch

# Self attention

In [37]:
import scipy
import numpy as np
import torch

In [59]:
lm_head = model.lm_head.weight.detach().cpu().numpy()
mag = (lm_head * lm_head).sum(axis=1, keepdims=True) 
norm_lm_head = lm_head / mag
print(norm_lm_head)
print(norm_lm_head.shape)

[[ 0.02306603  0.05243576  0.08867577 ...  0.02857205 -0.02918966
   0.11309305]
 [ 0.01267621  0.03377498  0.10678986 ...  0.06500833 -0.04056149
   0.04032973]
 [-0.01851727  0.07158177  0.10723365 ...  0.00550382 -0.02387156
   0.06133687]
 ...
 [-0.09309085  0.05651278 -0.00716013 ... -0.02766485 -0.08382809
   0.04673408]
 [-0.04388665 -0.03577347 -0.19007622 ...  0.05194638 -0.00410266
   0.00023691]
 [ 0.09047809  0.03221549 -0.02735731 ...  0.11668851 -0.01547582
   0.05496432]]
(2048, 128)


In [63]:
import torch
import numpy as np
from scipy.cluster.vq import kmeans, vq
import math

# Number of clusters (sqrt(v))
k = int(math.sqrt(vocab_size))

# Perform K-means clustering
centroids, distortion = kmeans(norm_lm_head, k)

print("Centroids:\n", centroids)
print("Distortion:", distortion)

Centroids:
 [[-0.10860269 -0.0701713   0.14592122 ... -0.02404833  0.0056003
   0.07046475]
 [-0.03410162  0.02815589  0.011906   ...  0.05327878 -0.04157449
   0.00602074]
 [ 0.00939551  0.00371576  0.00284    ...  0.01728747 -0.01445366
   0.05078916]
 ...
 [ 0.0449089   0.05545767 -0.01449922 ...  0.0700506  -0.06341323
   0.07834275]
 [ 0.02983058  0.01192596 -0.02704171 ...  0.02684871  0.04644939
   0.0423712 ]
 [-0.03840546  0.05578538 -0.10008395 ... -0.07234126  0.05515989
   0.07930113]]
Distortion: 0.76350445


In [69]:
cluster_assignments, _ = vq(norm_lm_head, centroids)
cluster_assignments

array([35, 35, 35, ...,  9, 43,  2], dtype=int32)

In [86]:
from collections import defaultdict
clusters = defaultdict(list)

In [87]:
for idx, cluster_idx in enumerate(cluster_assignments):
    clusters[cluster_idx].append(decode([idx]))

In [88]:
clusters

defaultdict(list,
            {35: ['<s>',
              '<pad>',
              '</s>',
              '<unk>',
              '<mask>',
              '"',
              '#',
              '%',
              '(',
              ')',
              '*',
              '+',
              '/',
              '0',
              '1',
              '2',
              '4',
              '5',
              '6',
              '7',
              '8',
              '9',
              '<',
              '=',
              '>',
              '@',
              '[',
              '\\',
              ']',
              '^',
              '_',
              '`',
              'q',
              '{',
              '|',
              '}',
              '~',
              '�',
              '�',
              '�',
              '�',
              '�',
              '�',
              '�',
              '�',
              '�',
              '�',
              '�',
              '�',
              '�',
         

In [90]:
# cluster our input data
cluster_encode = encode(text)

[676,
 1201,
 30,
 203,
 779,
 553,
 336,
 589,
 1817,
 807,
 2008,
 719,
 16,
 679,
 322,
 621,
 18,
 203,
 203,
 1236,
 30,
 203,
 1977,
 585,
 16,
 621,
 18,
 203,
 203,
 676,
 1201,
 30,
 203,
 570,
 423,
 400,
 1043,
 499,
 773,
 1503,
 292,
 969,
 532,
 292,
 276,
 390,
 561,
 35,
 203,
 203,
 1236,
 30,
 203,
 54,
 283,
 499,
 773,
 18,
 1043,
 499,
 773,
 18,
 203,
 203,
 676,
 1201,
 30,
 203,
 676,
 16,
 293,
 509,
 425,
 69,
 1141,
 1422,
 329,
 883,
 726,
 1946,
 292,
 272,
 1197,
 18,
 203,
 203,
 1236,
 30,
 203,
 672,
 509,
 671,
 16,
 336,
 509,
 671,
 18,
 203,
 203,
 676,
 1201,
 30,
 203,
 949,
 535,
 1319,
 365,
 16,
 301,
 336,
 460,
 360,
 282,
 661,
 464,
 417,
 834,
 1086,
 313,
 18,
 203,
 767,
 671,
 263,
 225,
 382,
 72,
 1789,
 35,
 203,
 203,
 1236,
 30,
 203,
 693,
 489,
 1411,
 303,
 373,
 671,
 31,
 542,
 343,
 309,
 845,
 30,
 953,
 16,
 953,
 5,
 203,
 203,
 920,
 1201,
 30,
 203,
 1848,
 713,
 16,
 452,
 282,
 942,
 87,
 18,
 203,
 203,
 676,
 1201,
 

In [109]:
cluster_encodes = []

for cluster_id in cluster_assignments:
    cluster_encodes.append(cluster_id)

In [110]:
cluster_assignments


array([35, 35, 35, ...,  9, 43,  2], dtype=int32)

In [1]:
encode(text)
len(encode(text))

NameError: name 'encode' is not defined

In [118]:
print(encode(text))

[676, 1201, 30, 203, 779, 553, 336, 589, 1817, 807, 2008, 719, 16, 679, 322, 621, 18, 203, 203, 1236, 30, 203, 1977, 585, 16, 621, 18, 203, 203, 676, 1201, 30, 203, 570, 423, 400, 1043, 499, 773, 1503, 292, 969, 532, 292, 276, 390, 561, 35, 203, 203, 1236, 30, 203, 54, 283, 499, 773, 18, 1043, 499, 773, 18, 203, 203, 676, 1201, 30, 203, 676, 16, 293, 509, 425, 69, 1141, 1422, 329, 883, 726, 1946, 292, 272, 1197, 18, 203, 203, 1236, 30, 203, 672, 509, 671, 16, 336, 509, 671, 18, 203, 203, 676, 1201, 30, 203, 949, 535, 1319, 365, 16, 301, 336, 460, 360, 282, 661, 464, 417, 834, 1086, 313, 18, 203, 767, 671, 263, 225, 382, 72, 1789, 35, 203, 203, 1236, 30, 203, 693, 489, 1411, 303, 373, 671, 31, 542, 343, 309, 845, 30, 953, 16, 953, 5, 203, 203, 920, 1201, 30, 203, 1848, 713, 16, 452, 282, 942, 87, 18, 203, 203, 676, 1201, 30, 203, 672, 423, 1069, 1371, 320, 957, 282, 942, 87, 16, 272, 1359, 346, 581, 1441, 452, 18, 203, 466, 263, 1182, 275, 588, 265, 366, 517, 1276, 373, 508, 359, 80, 14

In [2]:
encoded_text = encode(text)
cluster_text = []
for i in encoded_text:
    cluster_text.append(cluster_assignments[i])

NameError: name 'encode' is not defined

In [None]:
tensor(cluster_assignments)

In [None]:
import tokenizers
tokenizer = tokenizers.ByteLevelBPETokenizer()
tokenizer._tokenizer = tokenizers.ByteLevelBPETokenizer(
    "./shakespeare-2k-bpe-vocab.json",
    "./shakespeare-2k-bpe-merges.txt",
)

enc = tokenizer.encode("Romeo Romeo wherefore art thou Romeo?")
tokenizer.decode(enc.ids)


# stoi = {ch: i for i, ch in enumerate(chars)}
# itos = {i: ch for i, ch in enumerate(chars)}

def encode(text):
    # return [stoi[ch] for ch in text]
    return tokenizer.encode(text).ids


def decode(encoded_text):
    # return ''.join([itos[i] for i in encoded_text])
    return tokenizer.decode(encoded_text)


hello_encoded = encode("hello")
print(hello_encoded)
print(decode(hello_encoded))
vocab_size = tokenizer.get_vocab_size()


# ~ compression ratio of tokens
tok_ratio = len(text) / len(encode(text))
print(tok_ratio)

In [124]:
### Cluster Tokenizer
### c(v1), v1, c(v2), v2 ...
import pickle
import numpy as np 
import tokenizers
import torch

class ClusterTokenizer:
    def __init__(self, cas_file: str):
      # base tokenizer TODO: make this a param
      self.tokenizer = tokenizers.ByteLevelBPETokenizer()
      self.tokenizer._tokenizer = tokenizers.ByteLevelBPETokenizer(
          "./shakespeare-2k-bpe-vocab.json", #2048
          "./shakespeare-2k-bpe-merges.txt",
      )
      # load up cluster assignments
      with open(cas_file, 'rb') as file:
        self.cas = pickle.load(file)
        assert type(self.cas) == np.ndarray, "Cluster assignments must be a numpy array"
        self.cas = torch.tensor(self.cas)
    def encode(self, input_str: str):
        enc_ids = torch.tensor(self.tokenizer.encode(input_str).ids)
        # offset the ids by the vocab size to avoid overlap with originl vocab
        enc_cas = self.cas[enc_ids] + self.tokenizer.get_vocab_size()
        # interleave the cluster assignments with the token ids 
        enc_cas_interleave = torch.stack((enc_cas, enc_ids), dim=1)
        return enc_cas_interleave.view(-1)
    def decode(self, enc):
        # de-interleave the cluster assignments and token ids
        enc_cas, enc_ids = enc[::2], enc[1::2]
        # offset the token ids by the vocab size
        #dec_ids = enc_ids - self.tokenizer.get_vocab_size()
        return self.tokenizer.decode(enc_ids.tolist())
 
ct = ClusterTokenizer('clusters.pickle')
enc = ct.encode("hello world")
print(ct.decode(enc))

hello world


In [95]:
#### Torch Playground

evens = torch.arange(0, 10, 2)
odds = torch.arange(1, 10, 2)

sol = torch.cat((evens, odds), dim=0)
base = torch.empty(evens.size(0) + odds.size(0))
for i in range(evens.size(0)):
    base[2*i] = evens[i]
    base[i*2+1] = odds[i]
    print(base)
print(base)

base2 = torch.empty(evens.size(0) + odds.size(0))
base2[0::2] = evens
print(base2)
base2[1::2] = odds
print(base2)

stack = torch.stack((evens, odds), dim=1)
print("stack", stack)

interleaved = stack.view(-1)
""" stack[4, 1] = 100
print(id(stack.untyped_storage()))
stack = stack - 1
print("interleaved", interleaved)
print(id(interleaved.untyped_storage()))
print("stack", stack)
print(id(stack.untyped_storage()))
print(type(stack.untyped_storage())) """
vstack = torch.stack((evens, odds), dim=0)
hstack = torch.stack((evens, odds), dim=1)

print("vstack", vstack.view(-1))
print("hstack", hstack.view(-1))

tensor([0., 1., 0., 0., 0., 0., 0., 0., 0., 0.])
tensor([0., 1., 2., 3., 0., 0., 0., 0., 0., 0.])
tensor([0., 1., 2., 3., 4., 5., 0., 0., 0., 0.])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 0., 0.])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
tensor([0., 0., 2., 0., 4., 0., 6., 0., 8., 0.])
tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
stack tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7],
        [8, 9]])
vstack tensor([[0, 2, 4, 6, 8],
        [1, 3, 5, 7, 9]])
hstack tensor([0, 2, 4, 6, 8, 1, 3, 5, 7, 9])
