In [1]:
import os
import io
import functools
import torch
import tqdm
import math
from utils import get_alibi_biases
import numpy as np

In [2]:
def make_alibias(T,ms):
  alibias = torch.empty(T, T)
  row = torch.concat((torch.arange(-T+1, 0, 1), torch.arange(0, -T, -1)))
  for i in range(T):
    alibias[-i-1] = row[i:i+T]
  return alibias[None,None].repeat(1, len(ms), 1,1) * ms[None,:,None,None]

In [3]:
# make an RNN
class ResBlock(torch.nn.Module):
    def __init__(
        self,
        d_model,
        hidden_dim=None,
    ) -> None:
        super().__init__()
        hidden_dim = hidden_dim or d_model
        self.fc1 = torch.nn.Linear(d_model, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, d_model)
        self.act = torch.nn.ReLU()

    def forward(self, x):
        x = self.fc2(self.act(self.fc1(x))) + x
        return x


class RNN(torch.nn.Module):
  def __init__(self, d_model, vocab_size):
    super().__init__()
    self.d_model = d_model
    self.embedding = torch.nn.Embedding(vocab_size, d_model)
    self.layer = ResBlock(d_model)

  def forward(self, x):
    _, L = x.shape
    x = self.embedding(x)
    current = x[:, 0]
    for seq_idx in range(1,L):
      current = self.layer(current + x[:, seq_idx])
    return current


class SelfAttention(torch.nn.Module):

    def __init__(self, config):


        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = torch.nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = torch.nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # regularization
        self.attn_dropout = torch.nn.Dropout(config.dropout)
        self.resid_dropout = torch.nn.Dropout(config.dropout)
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.dropout = config.dropout
        
        self.register_buffer("alibias", make_alibias(config.block_size, ms=torch.randn(config.n_head)))


    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att += self.alibias[:, :, :T, :T]


        att = torch.nn.functional.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y




class MLP(torch.nn.Module):

    def __init__(self, config):
        super().__init__()
        self.c_fc    = torch.nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
        self.gelu    = torch.nn.GELU()
        self.c_proj  = torch.nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
        self.dropout = torch.nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

class Block(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = SelfAttention(config)
        self.mlp = MLP(config)

    def forward(self, x):
        x = x + self.attn(x)
        x = x + self.mlp(x)
        return x


class Model(torch.nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.vocab_size is not None
        assert config.block_size is not None
        self.config = config

        self.transformer = torch.nn.ModuleDict(dict(
            wte = torch.nn.Embedding(config.vocab_size, config.n_embd),
            drop = torch.nn.Dropout(config.dropout),
            h = torch.nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
        ))
        self.lm_head = torch.nn.Linear(config.n_embd, 1, bias=False)
        print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))

    def get_num_params(self, non_embedding=True):
        """
        Return the number of parameters in the model.
        For non-embedding count (default), the position embeddings get subtracted.
        The token embeddings would too, except due to the parameter sharing these
        params are actually used as weights in the final layer, so we include them.
        """
        n_params = sum(p.numel() for p in self.parameters())
        return n_params

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"

        # forward the GPT model itself
        tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
        x = self.transformer.drop(tok_emb)
        for block in self.transformer.h:
            x = block(x)

        logits = self.lm_head(x.sum(dim=1))
        return logits

    

from dataclasses import dataclass
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    dropout: float = 0.0
    bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster


In [62]:
data_path = "data/data.pt"
Xall, Yall = torch.load(data_path)
def loop_n_sel(X, n):
  return (X[:, 2*n:2*n+2] == 0).all(1) & (X[:, 2*n-1] != 0)

#only take up to loop 5
sel = loop_n_sel(Xall, 4) | loop_n_sel(Xall, 6)
X = Xall[sel]
Y = Yall[sel]

torch.manual_seed(10)
perm = torch.randperm(len(X))
X = X[perm]
Y = Y[perm]
# Y = Y.sign() * Y.abs().log()
Y = Y.abs().log()
split = int(0.8*len(X))
X_train, X_val = X[:split], X[split:]
Y_train, Y_val = Y[:split], Y[split:]

DEVICE = "cuda:1"

X_train = X_train.to(DEVICE)
Y_train = Y_train.to(DEVICE)
X_val = X_val.to(DEVICE)
Y_val = Y_val.to(DEVICE)
X_train.shape

torch.Size([3942139, 12])

In [63]:
# define the model architecture
D_MODEL = 128
config = GPTConfig(block_size=12, vocab_size=len(X.unique()), n_layer=4, n_head=2, n_embd=D_MODEL)

torch.manual_seed(100)
np.random.seed(100)

model = Model(config).to(DEVICE)

# define the training loop
EPOCHS = 5
BATCH_SIZE = 256
EVAL_FREQ = 2

# define the optimizer and the loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-3)
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, 1, 1e-1, total_iters = math.ceil(len(X_train) / BATCH_SIZE) * EPOCHS)
criterion = torch.nn.MSELoss() 


try:
  for epoch in (bar:=tqdm.trange(EPOCHS)):
    for i in range(0, len(X_train), BATCH_SIZE):
      x_batch = X_train[i:i+BATCH_SIZE]
      y_batch = Y_train[i:i+BATCH_SIZE]
      optimizer.zero_grad()
      y_pred = model(x_batch)
      loss = criterion(y_pred, y_batch)
      loss.backward()
      # clip grad
      torch.nn.utils.clip_grad_norm_(model.parameters(), .1)
      optimizer.step()
      scheduler.step()
      bar.set_description(f"Loss: {loss.item():.4f}")
    if epoch % EVAL_FREQ == 0:
      with torch.inference_mode():
        loss = 0
        for i in range(0, len(X_val), BATCH_SIZE):
          x_batch = X_val[i:i+BATCH_SIZE]
          y_batch = Y_val[i:i+BATCH_SIZE]
          y_pred = model(x_batch)
          loss += criterion(y_pred, y_batch).item()
        loss /= len(X_val) / BATCH_SIZE
        print(f"Epoch {epoch}, validation loss: {loss:.4f}")
except KeyboardInterrupt:
  pass
else:
  torch.save(model, "model.pt")


number of parameters: 0.79M


Loss: 0.2832:   0%|          | 0/5 [02:42<?, ?it/s]   

In [None]:

def eval_model(model):
  with torch.no_grad():
    loopsel = loop_n_sel(Xall, 5)
    xloop, yloop = Xall[loopsel], Yall[loopsel]
    print("number of elements: ", loopsel.sum().item())
    perm = torch.randperm(len(xloop))
    xloop = xloop[perm].to(DEVICE)
    yloop = yloop[perm].to(DEVICE)
    # yloop = yloop.sign() * yloop.abs().log()
    yloop = yloop.abs().log()
    # run model on loop 6
    loss = 0
    BS = 512
    n_elements = 100000
    for i in range(0, n_elements, BS):
      x_batch = xloop[i:i+BS]
      y_batch = yloop[i:i+BS]
      y_pred = model(x_batch)
      loss += torch.nn.functional.mse_loss(y_pred, y_batch, reduction="sum").item()
    loss = loss / n_elements
    return loss


In [None]:
random_model = Model(config).to(DEVICE)
eval_model(model), eval_model(random_model)

number of parameters: 0.20M
number of elements:  263880


number of elements:  263880


(0.8899605859375, 9.34663110107422)