In [1]:
import torch

In [2]:
with open("data/input.txt", "r") as f:
  text = f.read()

vocab_list = ''.join(sorted(list(set(text))))
vocab_size = len(vocab_list)

ctoi = { val: i for i, val in enumerate(vocab_list)}
itoc = { i: val for i, val in enumerate(vocab_list)}
encoder = lambda s: [ctoi[i] for i in s]
decoder = lambda enc: ''.join([itoc[i] for i in enc])

In [3]:
enc = encoder(text)
tt_split = 0.9

train = torch.tensor(enc[int(tt_split*len(enc)):])
test = torch.tensor(enc[:int(tt_split*len(enc))])

In [4]:
embd_size = 384
batch_size = 64
context_size = 256
dropout_thres = 0.25
n_heads = 6
n_layers = 6
device = "cuda" if torch.cuda.is_available() else "cpu"

torch.manual_seed(1337)

<torch._C.Generator at 0x7fa74dd1a9f0>

In [5]:
def get_batch(data):
  '''
  in:
  - data: tensor (n,)
  out:
  - x_batch: tensor (batch_size, context_size)
  - y_batch: tensor (batch_size, context_size)
  '''
  pos = torch.randint(0, len(data) - context_size-1, (batch_size,)) # -1 for the ys

  x_batch = torch.stack([data[i_pos:i_pos+context_size] for i_pos in pos])
  y_batch = torch.stack([data[i_pos+1:i_pos+context_size+1] for i_pos in pos])
  return x_batch, y_batch

In [6]:
class SelfAttentionHead(torch.nn.Module):
  def __init__(self, head_size):
    super().__init__()
    self.head_size = head_size
    self.key = torch.nn.Linear(embd_size, head_size, bias=False)
    self.value = torch.nn.Linear(embd_size, head_size, bias=False)
    self.query = torch.nn.Linear(embd_size, head_size, bias=False)

    self.drop = torch.nn.Dropout(dropout_thres)
    self.register_buffer("tril", torch.tril(torch.ones((context_size, context_size)))) 
  
  def forward(self, x):
    B, T, C = x.shape
    k = self.key(x) # (B, T, head_size)
    v = self.value(x) # (B, T, head_size)
    q = self.query(x) # (B, T, head_size)
    
    W = k @ q.transpose(-1, -2) * self.head_size**(-0.5) # (B, T, T)
    W.masked_fill_(self.tril[:T, :T] == 0, float("-inf"))
    w_mask = torch.nn.functional.softmax(W, dim=1) # (B, T, T)
    w_mask = self.drop(w_mask)

    return w_mask @ v # (B, T, head_size)

In [7]:
class MultiAttentionHead(torch.nn.Module):
  def __init__(self, n_heads, head_size):
    super().__init__()
    self.heads = torch.nn.ModuleList(
        SelfAttentionHead(head_size) for _ in range(n_heads)
    )
    self.proj = torch.nn.Linear(n_heads * head_size, embd_size)
    self.drop = torch.nn.Dropout(dropout_thres)
  
  def forward(self, x):
    st = torch.cat([head(x) for head in self.heads], dim=-1)
    return self.drop(self.proj(st))

In [8]:
class FeedForwardLayer(torch.nn.Module):
  def __init__(self, nin, nout):
    super().__init__()
    self.layers = torch.nn.Sequential(
        torch.nn.Linear(nin, 4*nout),
        torch.nn.ReLU(),
        torch.nn.Linear(4*nout, nout),
        torch.nn.Dropout(dropout_thres)
    )
  
  def forward(self, x):
    return self.layers(x)


In [9]:
class Block(torch.nn.Module):
  def __init__(self):
    super().__init__()

    self.sa_heads = MultiAttentionHead(n_heads, embd_size // n_heads)
    self.ffwd = FeedForwardLayer(embd_size, embd_size)

    self.ln1 = torch.nn.LayerNorm((embd_size,))
    self.ln2 = torch.nn.LayerNorm((embd_size,))
  
  def forward(self, x): 
    '''
    in: 
    - x: tensor (batch_size, context_size, embd_size)
    out: 
    - out: tensor (batch_size, context_size, embd_size)
    '''
    x = self.sa_heads(self.ln1(x)) + x # residual connections
    x = self.ffwd(self.ln2(x)) + x
    return x 

In [10]:
class BigramLanguageModel(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.feat_encoding = torch.nn.Embedding(vocab_size, embd_size)
    self.pos_encoding = torch.nn.Embedding(context_size, embd_size)

    self.blocks = torch.nn.Sequential(
        *[Block() for _ in range(n_layers)],
        torch.nn.Linear(embd_size, vocab_size)
    )

  def forward(self, xs, ys=None): # outputs logits, loss
    '''
    in:
    - xs: tensor (batch_size, context_size)
    - ys: tensor (batch_size, context_size) or None
    out:
    - logits: tensor (batch_size, context_size, vocab_size)
    - loss: tensor (1,) or None
    '''
    
    f = self.feat_encoding(xs) # (batch_size, context_size, embd_size)
    p = self.pos_encoding(torch.arange(0, xs.shape[1], device=device)) # (context_size, embd_size)

    x = f + p 
    logits = self.blocks(x)
    B, C, V = logits.shape

    if ys is not None:
      logits_ce = logits.reshape(B*C, V)
      ys_ce = ys.reshape(B*C)

      loss = torch.nn.functional.cross_entropy(logits_ce, ys_ce)
    else:
      loss = None
    return logits, loss

  def generator(self, max_length, batch_size_):
    '''
    in:
    - max_length: int
    out:
    - out: tensor (batch_size, max_length)
    '''
    out = torch.zeros((batch_size_, 1), dtype=torch.long, device=device)

    for _ in range(max_length - 1):
      last_char = out[:, -context_size:]

      logits, _ = self.forward(last_char)

      logits = logits[:, -1, :]
      probs = torch.nn.functional.softmax(logits, dim=-1)
      ntoken = torch.multinomial(probs, 1) 
      out = torch.cat((out, ntoken), dim=1)
    return out

In [17]:
from tqdm import tqdm

@torch.no_grad()
def get_metrics(model, num_samples=10):
  model.eval()
  train_losses, test_losses = [], []
  for _ in range(num_samples):
    x, y = get_batch(train)
    x, y = x.to(device), y.to(device)
    train_losses.append(model(x, y)[1].item())

    x, y = get_batch(test)
    x, y = x.to(device), y.to(device)
    test_losses.append(model(x, y)[1].item())
  model.train()
  return torch.tensor(train_losses).mean(), torch.tensor(test_losses).mean()

def test_model():
  model = BigramLanguageModel().to(device)
  optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

  for _ in tqdm(range(5000)):
    x, y = get_batch(train)
    x, y = x.to(device), y.to(device)
    logits, loss = model(x, y)
  
    optimizer.zero_grad()
    loss.backward()
  
    with torch.no_grad(): 
      optimizer.step()
  
  print(get_metrics(model))

test_model()
# (tensor(1.9297, grad_fn=<MeanBackward0>), tensor(2.2202, grad_fn=<MeanBackward0>))


  2%|▏         | 87/5000 [00:13<13:08,  6.23it/s]


KeyboardInterrupt: 

In [66]:
model = BigramLanguageModel()
# if torch.cuda.device_count() > 1:
    # model = torch.nn.DataParallel(model)
model.to(device) 

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
epochs = 10000
# epochs = 10

for _ in tqdm(range(epochs)):
  x, y = get_batch(train)
  x, y = x.to("cuda"), y.to("cuda")
  logits, loss = model(x, y)

  optimizer.zero_grad()
  loss.backward()

  with torch.no_grad():
    optimizer.step()

print(get_metrics(model))

100%|██████████| 10000/10000 [26:04<00:00,  6.39it/s]


(tensor(0.0071), tensor(0.3266))


In [72]:
print(decoder(model.generator(1000, 1).tolist()[0]))


AE.YZ--ZZZZ;$ 3GZ&Z$3X$o3....XX$33 $a& a$a.'''' bz3X3sssss,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,, oooooooooooooooooooo???

Yiallllllll
LLLLL?
Upi,!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
S!
G-DD  w  


a  a aa,a a
i a
'J'' la e
an-
A'l  a a
 a a pail, and mat gawated, Signtio?

Dir, gag-tut thell, if thie, match wall patitle!

HORTENSIO:
My sauty jay, at what you king!

or seye mitty.
Hereg?

WR yher weak, and thy noter hall me,
My Dare
Worce atay, thirg? ul peat iar,
zeep haves
Wa your thet me woodey the madvere o her yade
Sixtaintor muth will of Yey your
Yeare up
WoxDe-Since o noter,
I gignggiong shy here;
Wiar, shall very o'll I kinge?

ALIAND:
Vaal the mack shall all ke matere to ame
Shave sot meght mince obee,

'Have the mink!

GRUMIO:
What hedde with day midging caparminon betth.

Put that minth his, whe the, my gringion.

 wouldand sery grogum is ovut your gidilian,
Have off orreck, gor.

ANTONIO:
Have woinca, a ountant:
Dow leave your
 layive madling!
What have galov

In [69]:
torch.save(model.state_dict(), "ckpt/model.0.pt")