In [77]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
import plotly.express as px
from torch.utils.data import DataLoader, Dataset
from dataclasses import dataclass
from fancy_einsum import einsum
from einops import rearrange, reduce, repeat
from tqdm.notebook import tqdm_notebook
import transformer_lens as tl

import sys
sys.path.append('common_modules/')

from transformer_modules import PositionalEncoding, DecoderBlock, Embedding, Dropout, LayerNorm

In [2]:
def build_fibonacci_sequences(seq_len, max_start):
    x, y = [], []
    for i in range(0, max_start):
        seq = [i, i+1]
        for j in range(2, seq_len+1):
            seq.append(seq[j-1] + seq[j-2])
        x.append(seq[:seq_len])
        y.append(seq[1:])
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)

In [49]:
def build_fibmodp_sequences(seq_len, dataset_size, p):
    
    combined_seq = [0, 1]
    for i in range(2, dataset_size):
        combined_seq.append((combined_seq[-1] + combined_seq[-2]) % p)
    
    x, y = [], []
    for i in range(0, dataset_size-seq_len-1):
        x.append(combined_seq[i:i+seq_len])
        y.append(combined_seq[i+1:i+seq_len+1])
 
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)

In [4]:
def build_lookback_addition_sequences(seq_len, max_start, k=1):
    x, y = [], []
    for i in range(0, max_start):
        seq = [i, i+1]
        for j in range(2, seq_len+1):
            seq.append(seq[j-k] + 1)
        x.append(seq[:seq_len])
        y.append(seq[1:])
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)

In [5]:
def build_mod3_sequences(seq_len, max_start):
    x, y = [], []
    for i in range(0, max_start):
        seq = [i]
        for j in range(1, seq_len+1):
            if seq[j-1] % 3 == 0:
                seq.append(seq[j-1]+1)
            else:
                seq.append(seq[j-1]+2)
        x.append(seq[:seq_len])
        y.append(seq[1:])
    return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.int64)

In [56]:
#build_fibmodp_sequences(5, 128, 3)

In [7]:
    """If even, divide by 2, if odd, then multiply by 3 and add 1
    """

    """a, b, x_n = x{x-n} + (1 if x_{n-2}+x_{n-1} mod 3 i== 0 else 2)"""

'a, b, x_n = x{x-n} + (1 if x_{n-2}+x_{n-1} mod 3 i== 0 else 2)'

In [8]:
class NumSequenceDataset(Dataset):
    def __init__(self, x, y):
        """Initialize the dataset
        Args:
            x (list): list of input sequences
            y (list): list of output sequences
        """
        self.x = x
        self.y = y

    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

In [9]:
class NumSequenceTransformer(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.emb = Embedding(args.vocab_size, args.hidden_size)
        self.pos_enc = PositionalEncoding(args.max_seq_len, args.hidden_size)
        self.dropout = Dropout(p=args.dropout)

        decoders = [DecoderBlock(args) for l in range(args.num_layers)]
        self.decoders = nn.Sequential(*decoders)
        
        self.post_norm = LayerNorm(args.hidden_size)

        self.regression = nn.Linear((args.max_seq_len) * args.hidden_size, args.max_seq_len)
        

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass of the model

        Args:
            x (torch.Tensor): Input tensor

        Returns:
            torch.Tensor: Output tensor
        """        
        embedding = self.emb(x.long())
        #print(f"resid shape: {embedding.shape}")
        #combined_emb = self.pos_enc(embedding)
        combined_emb = self.dropout(embedding)
        combined_emb = embedding.to(torch.float32)

        out = self.decoders(combined_emb)
        #print(f"decoder out shape: {out.shape}")
        out = self.post_norm(out)
        
        out = einsum("B S E, V E -> B S V", out, self.emb.weight)

        return out

In [78]:
from typing import Callable

loss_fn = nn.CrossEntropyLoss()

MODEL_FILENAME = "./fibonacci_model.pt"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_transformer(trainloader: DataLoader, args) -> list:
    '''
    Defines a Transformer from our custom modules, and trains it on the Fibonacci dataset.
    '''
    epochs = args.epochs
    
    model = NumSequenceTransformer(args).to(device).train()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    loss_list = []
    accuracy_list = []
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=0.01, steps_per_epoch=len(trainloader), epochs=args.epochs)
    
    progress_bar = tqdm_notebook(range(epochs))
    for epoch in progress_bar:
        
        for (x, y) in trainloader:
            
            x = x.to(device)
            y = y.to(device)

            logits = model(x)
            logits = rearrange(logits, 'B S V -> (B S) V')
            y = rearrange(y, 'B S -> (B S)')

            loss = loss_fn(logits, y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()

            loss_list.append(loss.item())

            progress_bar.set_description(f"Epoch = {epoch}, Loss = {loss.item():.4f}")

    print(f"Saving model to: {MODEL_FILENAME}")
    torch.save(model, MODEL_FILENAME)
    return loss_list, accuracy_list

In [79]:
from typing import Tuple


@dataclass
class AlgoTransformerArgs():
    # model args
    num_layers: int = 4
    num_heads: int = 8
    vocab_size: int = 5000
    hidden_size: int = 64
    max_seq_len: int = 64
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05
    # training args
    batch_size: int = 128
    epochs: int = 100
    lr: float = 0.001
    betas: Tuple[float] = (0.99, 0.999)
    track: bool = False
    cuda: bool = False

In [80]:
args = AlgoTransformerArgs()
#x, y = build_fibonacci_sequences(args.max_seq_len, 256)
x, y = build_fibmodp_sequences(args.max_seq_len, 2048, 10)
#x, y = build_mod3_sequences(args.max_seq_len, 512)
dataset = NumSequenceDataset(x, y)
args.vocab_size = int(torch.max(y).item())+10
print(args.vocab_size)

train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), int(len(dataset) - int(len(dataset)*0.8))])

trainloader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, num_workers=4)
valloader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=4)

19


In [81]:
loss_list, accuracy_list = train_transformer(trainloader, args=args)

fig = px.line(y=loss_list, template="simple_white")
fig.update_layout(title="Cross entropy loss on Fibonacci", yaxis_range=[0, max(loss_list)])
fig.show()

  0%|          | 0/100 [00:00<?, ?it/s]

Saving model to: ./fibonacci_model.pt


In [75]:
dataset.__getitem__(1)[0].to(int)

tensor([1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0, 7, 7, 4, 1, 5, 6, 1, 7, 8,
        5, 3, 8, 1, 9, 0, 9, 9, 8, 7, 5, 2, 7, 9, 6, 5, 1, 6, 7, 3, 0, 3, 3, 6,
        9, 5, 4, 9, 3, 2, 5, 7, 2, 9, 1, 0, 1, 1, 2, 3])

In [76]:
import sample_methods as s

model = torch.load(MODEL_FILENAME, map_location=torch.device('cpu'))
model.eval()

initial_seq = [1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0, 7, 7, 4, 1, 5, 6, 1, 7, 8,
        5, 3, 8, 1, 9, 0, 9, 9, 8, 7, 5, 2, 7, 9, 6, 5, 1, 6, 7, 3, 0, 3, 3, 6,
        9, 5, 4, 9, 3, 2, 5, 7, 2, 9, 1, 0, 1, 1, 2, 3]

text_output = s.sample_tokens_no_detokenization(
    model, initial_seq, max_tokens_generated=100, max_seq_len=args.max_seq_len, 
    temperature=0, top_k=10
)

print(text_output)

[1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0, 7, 7, 4, 1, 5, 6, 1, 7, 8, 5, 3, 8, 1, 9, 0, 9, 9, 8, 7, 5, 2, 7, 9, 6, 5, 1, 6, 7, 3, 0, 3, 3, 6, 9, 5, 4, 9, 3, 2, 5, 7, 2, 9, 1, 0, 1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0, 7, 7, 4, 1, 5, 6, 1, 7, 8, 5, 3, 8, 1, 9, 0, 9, 9, 8, 7, 5, 2, 7, 9, 6, 5, 1, 6, 7, 3, 0, 3, 3, 6, 9, 5, 4, 9, 3, 2, 5, 7, 2, 9, 1, 0, 1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0, 7, 7, 4, 1, 5, 6, 1, 7, 8, 5, 3, 8, 1, 9, 0, 9, 9, 8, 7, 5, 2, 7, 9, 6, 5, 1, 6, 7, 3]


In [89]:
@dataclass
class DataArgs():
    max_seq_len: int
    batch_size: int
    num_workers: int
    dataset_size: int

@dataclass
class TrainingArgs():
    batch_size: int
    epochs: int
    optimizer: torch.optim.Optimizer
    lr: float
    betas: Tuple[float]
    track: bool
    cuda: bool

In [85]:
data_args = DataArgs(
    max_seq_len = 64,
    batch_size = 128,
    num_workers = 4,
    dataset_size = 2048
)

#x, y = build_fibonacci_sequences(args.max_seq_len, 256)
x, y = build_fibmodp_sequences(data_args.max_seq_len, data_args.dataset_size, 10)
#x, y = build_mod3_sequences(args.max_seq_len, 512)
dataset = NumSequenceDataset(x, y)
vocab_size = int(torch.max(y).item())+1
print(vocab_size)

train_set, val_set = torch.utils.data.random_split(dataset, [int(len(dataset)*0.8), int(len(dataset) - int(len(dataset)*0.8))])

trainloader = DataLoader(train_set, batch_size=data_args.batch_size, shuffle=True, num_workers=4)
valloader = DataLoader(val_set, batch_size=data_args.batch_size, shuffle=False, num_workers=4)



10


In [105]:
config = tl.EasyTransformerConfig(
    d_model=64,
    d_head=8,
    n_heads=8,
    d_mlp=256,
    n_layers=4,
    n_ctx=64,
    act_fn="solu_ln",
    d_vocab=vocab_size,
    normalization_type="LN",
    seed=23,
)

args = TrainingArgs(
    batch_size = 128,
    epochs = 12,
    optimizer = torch.optim.Adam,
    lr = 0.001,
    betas = (0.99, 0.999),
    track = False,
    cuda = False
)

In [106]:
from typing import Callable

loss_fn = nn.CrossEntropyLoss()

MODEL_FILENAME = "./fibonacci_model.pt"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def train_transformer(trainloader: DataLoader, args) -> list:
    '''
    Defines a Transformer from our custom modules, and trains it on the algotimic seq dataset.
    '''
    epochs = args.epochs
    
    #model = NumSequenceTransformer(args).to(device).train()
    model = tl.EasyTransformer(config).to(device)
    optimizer = args.optimizer(model.parameters(), lr=args.lr)
    loss_list = []
    accuracy_list = []
    #scheduler = torch.optim.lr_scheduler.OneCycleLR(
    #    optimizer, max_lr=0.01, steps_per_epoch=len(trainloader), epochs=args.epochs)
    
    progress_bar = tqdm_notebook(range(epochs))
    for epoch in progress_bar:
        
        for (x, y) in trainloader:
        #for batch in trainloader:
            
            #x = x.to(device)
            #y = y.to(device)

            #logits = model(x)
            logits = model(x.long(), return_type="logits")
            logits = rearrange(logits, 'B S V -> (B S) V')
            y = rearrange(y, 'B S -> (B S)')

            loss = loss_fn(logits, y)
            
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            #scheduler.step()

            loss_list.append(loss.item())

            progress_bar.set_description(f"Epoch = {epoch}, Loss = {loss.item():.4f}")

    print(f"Saving model to: {MODEL_FILENAME}")
    torch.save(model, MODEL_FILENAME)
    return loss_list, accuracy_list

In [107]:
loss_list, accuracy_list = train_transformer(trainloader, args=args)

fig = px.line(y=loss_list, template="simple_white")
fig.update_layout(title="Cross entropy loss on Fibonacci", yaxis_range=[0, max(loss_list)])
fig.show()

Moving model to device:  cpu


  0%|          | 0/12 [00:00<?, ?it/s]

Saving model to: ./fibonacci_model.pt


In [108]:
import sample_methods as s

model = torch.load(MODEL_FILENAME, map_location=torch.device('cpu'))
model.eval()

initial_seq = [1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0, 7, 7, 4, 1, 5, 6, 1, 7, 8,
        5, 3, 8, 1, 9, 0, 9, 9, 8, 7, 5, 2, 7, 9, 6, 5, 1, 6, 7, 3, 0, 3, 3, 6,
        9, 5, 4, 9, 3, 2, 5, 7, 2, 9, 1, 0, 1, 1, 2, 3]

output = s.sample_tokens_no_detokenization(
    model, initial_seq, max_tokens_generated=100, max_seq_len=data_args.max_seq_len, 
    temperature=0, top_k=10
)

print(output)

[1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0, 7, 7, 4, 1, 5, 6, 1, 7, 8, 5, 3, 8, 1, 9, 0, 9, 9, 8, 7, 5, 2, 7, 9, 6, 5, 1, 6, 7, 3, 0, 3, 3, 6, 9, 5, 4, 9, 3, 2, 5, 7, 2, 9, 1, 0, 1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0, 7, 7, 4, 1, 5, 6, 1, 7, 8, 5, 3, 8, 1, 9, 0, 9, 9, 8, 7, 5, 2, 7, 9, 6, 5, 1, 6, 7, 3, 0, 3, 3, 6, 9, 5, 4, 9, 3, 2, 5, 7, 2, 9, 1, 0, 1, 1, 2, 3, 5, 8, 3, 1, 4, 5, 9, 4, 3, 7, 0, 7, 7, 4, 1, 5, 6, 1, 7, 8, 5, 3, 8, 1, 9, 0, 9, 9, 8, 7, 5, 2, 7, 9, 6, 5, 1, 6, 7, 3]
