In [27]:
# steps: 
'''
- optimize BEFORE scale up (time per epoch), architecture changes (loss given set amount of epochs)
- pretokenization using BPE (make sure to use gpt4 regex magic)
- test distributed training arrays using 2 nodes + 1 gpu each
- train on array of 4 nodes w/ 4 V100 gpus each for 24 hours
'''

import torch

In [28]:
with open("data/input.txt", "r") as f:
  text = f.read()

vocab_list = ''.join(sorted(list(set(text))))
vocab_size = len(vocab_list)

ctoi = { val: i for i, val in enumerate(vocab_list)}
itoc = { i: val for i, val in enumerate(vocab_list)}
encoder = lambda s: [ctoi[i] for i in s]
decoder = lambda enc: ''.join([itoc[i] for i in enc])

In [29]:
enc = encoder(text)
tt_split = 0.9

test = torch.tensor(enc[int(tt_split*len(enc)):])
train = torch.tensor(enc[:int(tt_split*len(enc))])

In [30]:
embd_size = 384
batch_size = 64
context_size = 256
dropout_thres = 0.2
n_heads = 6
n_layers = 6
learning_rate = 1e-4
device = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(1337)

<torch._C.Generator at 0x7f80653ce9f0>

In [31]:
class SelfAttentionHead(torch.nn.Module):
  def __init__(self, head_size):
    super().__init__()
    self.head_size = head_size
    self.key = torch.nn.Linear(embd_size, head_size, bias=False)
    self.value = torch.nn.Linear(embd_size, head_size, bias=False)
    self.query = torch.nn.Linear(embd_size, head_size, bias=False)

    self.drop = torch.nn.Dropout(dropout_thres)
    self.register_buffer("tril", torch.tril(torch.ones((context_size, context_size)))) 
  
  def forward(self, x):
    B, T, C = x.shape
    k = self.key(x) # (B, T, head_size)
    v = self.value(x) # (B, T, head_size)
    q = self.query(x) # (B, T, head_size)
    
    W = q @ k.transpose(-1, -2) * self.head_size**(-0.5) # (B, T, T)
    W.masked_fill_(self.tril[:T, :T] == 0, float("-inf"))
    w_mask = torch.nn.functional.softmax(W, dim=-1) # (B, T, T)
    w_mask = self.drop(w_mask)

    return w_mask @ v # (B, T, head_size)

In [32]:
class MultiAttentionHead(torch.nn.Module):
  def __init__(self, n_heads, head_size):
    super().__init__()
    self.heads = torch.nn.ModuleList(
        SelfAttentionHead(head_size) for _ in range(n_heads)
    )
    self.proj = torch.nn.Linear(n_heads * head_size, embd_size)
    self.drop = torch.nn.Dropout(dropout_thres)
  
  def forward(self, x):
    st = torch.cat([head(x) for head in self.heads], dim=-1)
    return self.drop(self.proj(st))

In [33]:
class FeedForwardLayer(torch.nn.Module):
  def __init__(self, nin, nout):
    super().__init__()
    self.layers = torch.nn.Sequential(
        torch.nn.Linear(nin, 4*nout),
        torch.nn.ReLU(),
        torch.nn.Linear(4*nout, nout),
        torch.nn.Dropout(dropout_thres)
    )
  
  def forward(self, x):
    return self.layers(x)


In [34]:
class Block(torch.nn.Module):
  def __init__(self):
    super().__init__()

    self.sa_heads = MultiAttentionHead(n_heads, embd_size // n_heads)
    self.ffwd = FeedForwardLayer(embd_size, embd_size)

    self.ln1 = torch.nn.LayerNorm((embd_size,))
    self.ln2 = torch.nn.LayerNorm((embd_size,))
  
  def forward(self, x): 
    '''
    in: 
    - x: tensor (batch_size, context_size, embd_size)
    out: 
    - out: tensor (batch_size, context_size, embd_size)
    '''
    x = self.sa_heads(self.ln1(x)) + x # residual connections
    x = self.ffwd(self.ln2(x)) + x
    return x 

In [35]:
class BigramLanguageModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.feat_encoding = torch.nn.Embedding(vocab_size, embd_size)
    self.pos_encoding = torch.nn.Embedding(context_size, embd_size)

    self.blocks = torch.nn.Sequential(
        *[Block() for _ in range(n_layers)],
        torch.nn.LayerNorm((embd_size,)),
        torch.nn.Linear(embd_size, vocab_size)
    )

  def forward(self, xs, ys=None): # outputs logits, loss
    '''
    in:
    - xs: tensor (batch_size, context_size)
    - ys: tensor (batch_size, context_size) or None
    out:
    - logits: tensor (batch_size, context_size, vocab_size)
    - loss: tensor (1,) or None
    '''
    
    f = self.feat_encoding(xs) # (batch_size, context_size, embd_size)
    p = self.pos_encoding(torch.arange(0, xs.shape[1], device=device)) # (context_size, embd_size)

    x = f + p 
    logits = self.blocks(x)
    B, C, V = logits.shape

    if ys is not None:
      logits_ce = logits.reshape(B*C, V)
      ys_ce = ys.reshape(B*C)

      loss = torch.nn.functional.cross_entropy(logits_ce, ys_ce)
    else:
      loss = None
    return logits, loss

  def generator(self, max_length, batch_size_):
    '''
    in:
    - max_length: int
    out:
    - out: tensor (batch_size, max_length)
    '''
    out = torch.zeros((batch_size_, 1), dtype=torch.long, device=device)

    for _ in range(max_length - 1):
      last_char = out[:, -context_size:]

      logits, _ = self.forward(last_char)

      logits = logits[:, -1, :]
      probs = torch.nn.functional.softmax(logits, dim=-1)
      ntoken = torch.multinomial(probs, 1) 
      out = torch.cat((out, ntoken), dim=1)
    return out

In [36]:
class Trainer: 
    def __init__(self, model, train_data, optimizer, gpu_id, save_every): 
        self.gpu_id = gpu_id
        self.model = model.to(gpu_id)
        self.train_data = train_data
        self.optimizer = optimizer
        self.model = DDP(self.model, device_ids=[self.gpu_id])
        self.save_every = save_every


    def _run_epoch(self, epoch): 
        self.train_data.sampler.set_epoch(epoch)

        x, y = get_batch(self.train_data)
        x, y = x.to(self.gpu_id), y.to(self.gpu_id)
        logits, loss = self.model(x, y)
        
        optimizer.zero_grad()
        loss.backward()
        
        with torch.no_grad():
            optimizer.step()

    def _save_checkpoint(self, epoch):
        ckp = self.model.module.state_dict()
        PATH = "checkpoint.pt"
        torch.save(ckp, PATH)
        print(f"Epoch {epoch} | Training checkpoint saved at {PATH}")


    def train(self, max_epochs):
        for epoch in tqdm(range(max_epochs)):
            self._run_epoch(epoch)
            if self.gpu_id == 0 and epoch % self.save_every == 0:
                self._save_checkpoint(epoch)


In [14]:
# import torch.multiprocessing as mp
# from torch.utils.data.distributed import DistributedSampler
# from torch.nn.parallel import DistributedDataParallel as DDP
# from torch.distributed import init_process_group, destroy_process_group
# import os

# def ddp_setup(rank, world_size):
#     os.environ["MASTER_ADDR"] = "localhost"
#     os.environ["MASTER_PORT"] = "12345"
#     init_process_group(backend="nccl", rank=rank, world_size=world_size)

# from torch.utils.data import DataLoader
# from torch.utils.data.distributed import DistributedSampler

# def loader_setup():
#     starts = torch.arange(0, len(train) - context_size - 1) 
#     sampler = DistributedSampler(starts, shuffle=True) 
#     def collate(starts_batch):
#         xb = torch.stack([train[i:i+context_size] for i in starts_batch])
#         yb = torch.stack([train[i+1:i+1+context_size] for i in starts_batch])
#         return xb, yb
    
#     loader = DataLoader(starts, batch_size=batch_size, sampler=sampler, collate_fn=collate, drop_last=True)
#     return loader

In [15]:
# epochs = 10000 
# world_size = 4
# save_every = 1000

# def main(rank, world_size):
#     model = BigramLanguageModel()
#     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
#     loader = loader_setup()
#     ddp_setup(rank, world_size)
    
#     t = Trainer(model, loader, optimizer, rank, loader, save_every)
#     t.train(epochs)
    
#     destroy_process_group()

# mp.spawn(main, args=(world_size,), nprocs=world_size)

In [41]:
print(decoder(t.model.generator(1000, 1).tolist()[0]))


rdu oue, ocJ! o Doru ahastbbosejin i tUaB Oi hX E sps otn oaouK n obLmn doa
g:yc3eWCg HFu QNuesele sho
 aWv biqto oolhChn$TohFxmtGl haroaseaer:h  :OazoEu niVy;agSwroisdeoyGhntl-yrhs srecrgdJeelaOn;kR
elFpine
T!i pbyo:D thnbhe Mr my Euci
Hd tizimheb ayWarreedadne G,e qlbdrewaon ja  ly a tlcCWt ,i?f
 
oftwuysthOroQzK
t m r'l, p,
G
un
Koy er :zoI  n tt threUau on
, an to srlCQ tnCdoioQo
 rrtaoslnoevRninite snlvqtheurier3t bu s'slEGawt3EK d s nlrereo t une nd  sqmros lahTs aAwmo niqXlr:y a,d a w,cKmupIkIyi insge;
eg   tsaj'arr Etu t c m
U srs
ndilsoQatsqsp!r  smun
g 
n$ em jtQ'irssPf, mtnbnnt meus tie  Iht so!  Usgslr fethzvsQ,ftii snvdroniViXt
k
!IeNNd te
Rt iFe nvtt
itMuy hNfkK keos;innow  syet,
 leersiIs  tn Rts ZCon
li e
er VCn

ZoeQp
T g cseiF

 ro hell:dvufI s Lh!ey ar
e:
o sn dhenohn;Pe VNd o
oth'dir nBuJus
RdZ ta3iwiemy s t! ree rtO3Gn a rheou O: rwe rqore jk
a!Qo.r
eo, v gur
ey,n:
ll s svcw'lSwltordBlzu,IcadowVI le rhylqlcarai'

d
ve?oiullt !.wer tk so .eit
ssGe svool  s nt f ct,

In [None]:
torch.save(model.state_dict(), "ckpt/model.0.pt")

In [37]:
model = BigramLanguageModel()

model.load_state_dict(state)
model = model.to(device)

In [38]:
print(decoder(model.generator(1000, 1).tolist()[0]))


Save that against the tigers putts;
'Tis most revel done amazedness of evil.
Yet, my lord, I spake my warber-lasts.

Lord Mars.

CAPULET:
Who fly and apot, O PETESETER:
Present here the sun swords are all unsub,
Sin Romeo like that seat doth blush this page
Shouts shall bsturn to count mine and wixtones
Ere much the way ash to part hold ax himself.
I have the hour'd and childly selveral place.
Did I expect me himory to his course,
And give him what we are not irefully sleep.

SIR STENHOP:
'Tis well, very well; before presenquating wrongs:
And with her Ratcliff, her that revenged is,
Enforch Planet and the duke's hell us: for me
Uncleful banquit is dead, soft some all through for
the precious of the lawful, he same
induced affectious, to win the would with wearing's peace,
Like to abunder hurried his all.

DUKE VINCENTIO:
Well, let you must knock me not very told one,
reters, take on twested split of.

Nurse:
Hark, you will fiery you so? how my captain
Is the devils, as if he dure neet

In [22]:
sum(p.numel() for p in model.parameters()) # 10m params

10788929