<a href="https://colab.research.google.com/github/b-schoen/alignment-playground/blob/main/colab/initial_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup and Imports

In [1]:
import torch
import torch.nn
import torch.optim
import torch.utils.data

import numpy as np

In [2]:
from typing import Callable, Iterable, Any
import math
import dataclasses

In [3]:
!pip install optuna
import optuna



In [4]:
!pip install jaxtyping
import jaxtyping



In [5]:
# note: needed for jaxtyping, using jaxtyping idiom
import typeguard
from typeguard import typechecked as typechecker

In [6]:
# turn on all warnings, since a lot of torch errors are warnings
# import warnings
# warnings.simplefilter("always")

## Setup Device

In [7]:
# setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
device.type

'cuda'

In [31]:
# from https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Main_Demo.ipynb#scrollTo=VGvdkQwSIi9Q
!pip install transformer_lens
!pip install circuitsvis



In [32]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

pio.renderers.default = "colab"

In [33]:
import transformer_lens
import circuitsvis as cv

## Dataset Type

In [9]:
# TODO(bschoen): Generalize dataset type to just take generator function

In [10]:
from jaxtyping import Float, jaxtyped

# type aliases
NumSamples = int
SeqLength = int

# TODO(bschoen): Abstract more than just integer sequence
InputTensor = Float[torch.Tensor, "seq_len *feature_dim"]
OutputTensor = Float[torch.Tensor, "*output_dim"]

InputGeneratorFn = Callable[[], InputTensor]
OutputGeneratorFn = Callable[[InputTensor], OutputTensor]

class SequenceDataset(torch.utils.data.Dataset):
    """Generic sequence dataset taking both an input and output generator."""

    @jaxtyped(typechecker=typechecker)
    def __init__(
        self,
        num_samples: int,
        x_generator: InputGeneratorFn,
        y_computer: OutputGeneratorFn,
    ) -> None:
        self.data    = torch.stack([x_generator() for _ in range(num_samples)])
        self.targets = torch.stack([y_computer(row) for row in self.data])

    def __len__(self) -> int:

        return len(self.data)

    @jaxtyped(typechecker=typechecker)
    def __getitem__(self, idx: int) -> tuple[InputTensor, OutputTensor]:

        return self.data[idx], self.targets[idx]

In [11]:
# Define types for our datasets and dataloaders
SequenceDataSubset = torch.utils.data.Subset[tuple[Float[torch.Tensor, "seq_len"], Float[torch.Tensor, ""]]]
SequenceDataLoader = torch.utils.data.DataLoader[tuple[Float[torch.Tensor, "batch seq_len"], Float[torch.Tensor, "batch"]]]

# note: uses default values for pytorch random_split
def split_dataset(
    dataset: SequenceDataset,
    train_ratio: float = 0.7,
    val_ratio: float = 0.15,
    test_ratio: float = 0.15
) -> tuple[SequenceDataSubset, SequenceDataSubset, SequenceDataSubset]:

    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-5, "Ratios must sum to 1"

    print(f'Creating splits for {len(dataset)} samples {train_ratio=}, {val_ratio=}, {test_ratio=}')

    total_size: int = len(dataset)

    train_size: int = int(train_ratio * total_size)
    val_size: int = int(val_ratio * total_size)

    # Ensure we use all samples
    test_size: int = total_size - train_size - val_size

    # For reproducibility
    random_generator = torch.Generator().manual_seed(42)

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset,
        [train_size, val_size, test_size],
        generator=random_generator,
    )

    print(f'Created splits from {total_size} total samples:')
    print(f' - train ({len(train_dataset)} samples)')
    print(f' - validation ({len(val_dataset)} samples)')
    print(f' - test ({len(test_dataset)} samples)')

    return train_dataset, val_dataset, test_dataset

In [12]:
def create_dataloaders(
    train_dataset: SequenceDataSubset,
    val_dataset: SequenceDataSubset,
    test_dataset: SequenceDataSubset,
    batch_size: int,
    num_workers: int,
) -> tuple[SequenceDataLoader, SequenceDataLoader, SequenceDataLoader]:

    print(f'Creating dataloaders for {batch_size=}, {num_workers=}, pin_memory={torch.cuda.is_available()}')

    def make_dataloader(dataset: torch.utils.data.Subset, shuffle: bool) -> SequenceDataLoader:
      return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
      )

    train_loader = make_dataloader(train_dataset, shuffle=True)
    val_loader   = make_dataloader(val_dataset,   shuffle=False)
    test_loader  = make_dataloader(test_dataset,  shuffle=False)

    return train_loader, val_loader, test_loader

## Model

In [13]:
# note: Why use an encoder at all?
#
#       However, it may be less suitable for tasks that require processing the entire input
#       before generating any output, which is where encoder-decoder models excel.
#
#       We'll likely stick to decoder only problems for now, so we don't have to do
#       interpretability on the input

### Positional Encoding

In [14]:
class PositionalEncoding(torch.nn.Module):

  def __init__(
    self,
    d_model: int,
    max_len: int = 5000,
  ) -> None:
    super(PositionalEncoding, self).__init__()

    # note: d_model must be even for positional embedding
    assert d_model % 2 == 0, f"{d_model} must be even"

    # Create a long enough `pe` matrix
    pe: Float[torch.Tensor, "max_len d_model"] = torch.zeros(max_len, d_model)
    position: Float[torch.Tensor, "max_len 1"] = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
    div_term: Float[torch.Tensor, "d_model//2"] = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

    # Fill the `pe` matrix
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)

    # Add new intermediate variables with type annotations
    pe_unsqueezed: Float[torch.Tensor, "1 max_len d_model"] = pe.unsqueeze(0)
    pe_transposed: Float[torch.Tensor, "max_len 1 d_model"] = pe_unsqueezed.transpose(0, 1)

    self.register_buffer('pe', pe_transposed)

  @jaxtyped(typechecker=typechecker)
  def forward(
      self,
      x: Float[torch.Tensor, "seq_len batch_size d_model"],
  ) -> Float[torch.Tensor, "seq_len batch_size d_model"]:

    # Create a new intermediate variable for the sliced pe
    pe_sliced: Float[torch.Tensor, "seq_len 1 d_model"] = self.pe[:x.size(0), :]

    # print(f'{x.shape=}, {pe_sliced.shape=}')

    return x + pe_sliced

In [15]:
def test_positional_encoding() -> None:
  # test positional encoding
  input_dim = 1
  seq_len = 10
  d_model = 4
  batch_size = 2

  positional_encoding = PositionalEncoding(d_model=d_model)

  # happens before positional encoding
  input_projection = torch.nn.Linear(input_dim, d_model)

  batch_x = torch.randint(0, 5, (batch_size, seq_len, input_dim)).float()

  print(f'Before input projection\n\t{batch_x.shape=}\n\t{batch_x[0]=}')

  batch_x = input_projection.forward(batch_x)

  print(f'After input projection:\n\t{batch_x.shape=}\n\t{batch_x[0]=}')

  batch_x = positional_encoding.forward(batch_x)

  print(f'After adding positional encoding\n\t{batch_x.shape=}\n\t{batch_x[0]=}')

test_positional_encoding()

Before input projection
	batch_x.shape=torch.Size([2, 10, 1])
	batch_x[0]=tensor([[2.],
        [3.],
        [2.],
        [0.],
        [2.],
        [0.],
        [0.],
        [3.],
        [2.],
        [3.]])
After input projection:
	batch_x.shape=torch.Size([2, 10, 4])
	batch_x[0]=tensor([[ 2.2565, -0.9736,  1.9765,  0.6300],
        [ 3.1362, -1.2093,  2.7340,  0.5887],
        [ 2.2565, -0.9736,  1.9765,  0.6300],
        [ 0.4969, -0.5020,  0.4613,  0.7126],
        [ 2.2565, -0.9736,  1.9765,  0.6300],
        [ 0.4969, -0.5020,  0.4613,  0.7126],
        [ 0.4969, -0.5020,  0.4613,  0.7126],
        [ 3.1362, -1.2093,  2.7340,  0.5887],
        [ 2.2565, -0.9736,  1.9765,  0.6300],
        [ 3.1362, -1.2093,  2.7340,  0.5887]], grad_fn=<SelectBackward0>)
After adding positional encoding
	batch_x.shape=torch.Size([2, 10, 4])
	batch_x[0]=tensor([[ 2.2565,  0.0264,  1.9765,  1.6300],
        [ 3.1362, -0.2093,  2.7340,  1.5887],
        [ 2.2565,  0.0264,  1.9765,  1.6300],
  

### Transformer

In [16]:
def generate_square_subsequent_mask(sz: int) -> Float[torch.Tensor, "sz sz"]:
  """
  Generate a square mask for the sequence. The mask ensures that the
  prediction for position i can depend only on known outputs at positions less than i.
  """
  # Intuition: This mask implements causal attention, preventing the model from looking at future tokens
  mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
  mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
  return mask

In [17]:
# Note: The 4 in 4 * d_model:
# This multiplier comes from the original Transformer paper, "Attention Is All You Need" by Vaswani et al. It's a hyperparameter that determines the expansion factor for the hidden layer in the feedforward network.
#
# - Smaller values (e.g., 2) might be used for more efficient, compact models.
# - Larger values might be used when more capacity is needed, at the cost of increased computation.
# - Some recent models (like GPT-3) use even larger expansion factors.
#
DEFAULT_MLP_EXPANSION_FACTOR: int = 4

class DecoderLayer(torch.nn.Module):
  """This is essentially a transformer block."""

  def __init__(
      self,
      d_model: int,
      num_heads: int,
      mlp_expansion_factor: int = DEFAULT_MLP_EXPANSION_FACTOR,
    ) -> None:

      print(f'Creating decoder layer: {d_model=}, {num_heads=}, {mlp_expansion_factor=}')

      super().__init__()

      # Self-attention mechanism
      # Intuition: Allow the model to weigh the importance of different parts of the input for each position
      self.self_attn = torch.nn.MultiheadAttention(d_model, num_heads, batch_first=True)

      # Feedforward neural network
      # Intuition: Apply non-linear transformations to each position independently
      self.feed_forward = torch.nn.Sequential(
          torch.nn.Linear(d_model, mlp_expansion_factor * d_model),
          torch.nn.ReLU(),
          torch.nn.Linear(mlp_expansion_factor * d_model, d_model)
      )

      # Layer normalization
      # Intuition: Stabilize the learning process and allow for deeper networks
      self.norm1 = torch.nn.LayerNorm(d_model)
      self.norm2 = torch.nn.LayerNorm(d_model)

  @jaxtyped(typechecker=typechecker)
  def forward(
      self,
      x: Float[torch.Tensor, "batch seq d_model"],
      mask: Float[torch.Tensor, "seq seq"] | None,
  ) -> Float[torch.Tensor, "batch seq d_model"]:

      # Self-attention block
      # Intuition: Update each position based on its relationship to all previous positions
      attn_output, _ = self.self_attn(x, x, x, attn_mask=mask)

      # Residual connection and normalization
      x = self.norm1(x + attn_output)

      # Feedforward block
      # Intuition: Apply position-wise transformations to integrate information
      ff_output = self.feed_forward(x)

      # Residual connection and normalization
      x = self.norm2(x + ff_output)

      return x

In [18]:
# Note: `input_dim` == `context_size`
class DecoderOnlyTransformer(torch.nn.Module):
    def __init__(
        self,
        input_dim: int,
        d_model: int,
        num_heads: int,
        num_layers: int,
      ) -> None:

        print(f'Creating decoder only transformer: {input_dim=}, {d_model=}, {num_heads=}, {num_layers=}')

        super().__init__()

        self.d_model = d_model

        # Input projection: maps the input floats to the model dimension
        # Intuition: Learn a richer representation of the input numbers
        self.input_projection = torch.nn.Linear(input_dim, d_model)

        # Positional encoding: adds information about token position in the sequence
        # Intuition: Since self-attention has no inherent notion of order, this helps the model understand sequence order
        self.pos_encoding = PositionalEncoding(d_model)

        # Stack of decoder layers: the core of the transformer
        # Intuition: Each layer refines the representation, capturing increasingly complex patterns
        self.layers = torch.nn.ModuleList([
          DecoderLayer(d_model, num_heads) for _ in range(num_layers)
        ])

        # Final layer: projects to vocabulary size for token prediction
        # Intuition: Convert the final representation back to vocabulary space for next-token prediction
        # This is the "unembedding" / output layer
        self.final_layer = torch.nn.Linear(d_model, 1)

    @jaxtyped(typechecker=typechecker)
    def forward(
        self, x: Float[torch.Tensor, "batch seq input_dim"],
    ) -> Float[torch.Tensor, "batch"]:
        # x shape: (batch_size, seq_len)

        # Project input to model dimension
        x = self.input_projection(x)

        # Add positional encoding
        x = self.pos_encoding(x)

        # Generate causal mask
        # Intuition: Ensure the model only uses previous elements in the sequence for each prediction
        # note: no mask needed yet
        # mask = generate_square_subsequent_mask(x.size(1)).to(x.device)

        # Apply decoder layers
        for layer in self.layers:
            x = layer(x, mask=None)

        # Use the last sequence element for prediction
        x = x[:, -1, :]

        # Project to single float output
        return self.final_layer(x).squeeze(-1)

In [19]:
# test transformer
def test_transformer() -> None:

  input_dim = 1
  seq_len = 10
  d_model = 6
  batch_size = 2
  num_heads = 3
  num_layers = 5

  model = DecoderOnlyTransformer(
      input_dim=input_dim,
      d_model=d_model,
      num_heads=num_heads,
      num_layers=num_layers,
  )

  batch_x = torch.randint(0, 5, (batch_size, seq_len, input_dim)).float()

  print(f'Before transformer\n\t{batch_x.shape=}\n\t{batch_x[0]=}')

  batch_x = model.forward(batch_x)

  print(f'After transformer:\n\t{batch_x.shape=}\n\t{batch_x[0]=}')

test_transformer()

Creating decoder only transformer: input_dim=1, d_model=6, num_heads=3, num_layers=5
Creating decoder layer: d_model=6, num_heads=3, mlp_expansion_factor=4
Creating decoder layer: d_model=6, num_heads=3, mlp_expansion_factor=4
Creating decoder layer: d_model=6, num_heads=3, mlp_expansion_factor=4
Creating decoder layer: d_model=6, num_heads=3, mlp_expansion_factor=4
Creating decoder layer: d_model=6, num_heads=3, mlp_expansion_factor=4
Before transformer
	batch_x.shape=torch.Size([2, 10, 1])
	batch_x[0]=tensor([[4.],
        [3.],
        [2.],
        [1.],
        [3.],
        [2.],
        [4.],
        [0.],
        [1.],
        [4.]])
After transformer:
	batch_x.shape=torch.Size([2])
	batch_x[0]=tensor(-0.7798, grad_fn=<SelectBackward0>)


## Training Loop

### Get Device

In [20]:
def get_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Compute Accuracy

In [37]:
# note: threshold of 0.5 since we're dealing with ints
def compute_accuracy(model, dataloader, device, threshold: float = 0.5):
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():

        for inputs, labels in dataloader:

            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)

            # Compute accuracy within threshold
            error = torch.abs(outputs - labels)
            correct += torch.sum(error <= threshold).item()
            total += labels.size(0)

    accuracy = correct / total

    return accuracy

### Train For One Epoch

In [22]:
EpochLoss = float

def train_epoch(
    model: torch.nn.Module,
    train_loader: torch.utils.data.DataLoader,
    criterion: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device
) -> EpochLoss:

    model.train()

    total_loss = 0.0

    for batch_X, batch_y in train_loader:

        batch_X, batch_y = batch_X.to(device), batch_y.to(device)

        optimizer.zero_grad()

        outputs = model(batch_X)

        loss = criterion(outputs, batch_y)

        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(train_loader)

### Evaluate Model Performance

In [23]:
def evaluate_model(
    model: torch.nn.Module,
    data_loader: torch.utils.data.DataLoader,
    criterion: torch.nn.Module,
    device: torch.device
) -> float:

    model.eval()

    total_loss = 0.0

    with torch.no_grad():

        for batch_X, batch_y in data_loader:

            batch_X, batch_y = batch_X.to(device), batch_y.to(device)

            outputs = model(batch_X)

            loss = criterion(outputs, batch_y)

            total_loss += loss.item()

    return total_loss / len(data_loader)

### Main Training Function

In [24]:
import transformers

def train_model(
    model: torch.nn.Module,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    test_loader: torch.utils.data.DataLoader,
    num_epochs: int,
    learning_rate: float,
    epochs_without_improvement_tolerance: int = 5,
    eval_every_n_epochs: int = 10,
) -> tuple[list[float], list[float], float]:

    print(f'Training for {num_epochs=} with {learning_rate=}')

    device = get_device()
    model.to(device)

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = transformers.get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_epochs // 10, # arbitrarily chosen
        num_training_steps=num_epochs,
    )

    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    epochs_without_improvement = 0

    for epoch in range(num_epochs):

        train_loss = train_epoch(
            model,
            train_loader,
            criterion,
            optimizer,
            device,
        )

        train_losses.append(train_loss)

        # update learning rate scheduler
        scheduler.step()

        # print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, ")

        # only evaluate every `eval_every_n_epochs` epochs
        if (epoch % eval_every_n_epochs) != 0:
          continue

        val_loss = evaluate_model(model, val_loader, criterion, device)

        val_losses.append(val_loss)

        val_accuracy = compute_accuracy(model, val_loader, device)

        # show learning rate
        # TODO(bschoen): Why is this a 1 element list?
        last_lr = scheduler.get_last_lr()[0]

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Learning Rate: {last_lr:.8f}, "
              f"Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, "
              f"Val Accuracy: {val_accuracy:.4f}")

        if val_loss < best_val_loss:
          # TODO(bschoen): save new best model
          pass
        else:
          epochs_without_improvement += 1
          if epochs_without_improvement >= epochs_without_improvement_tolerance:
            print(f"Early stopping triggered after {epoch+1} epochs")
            break

    print('Evaluating final model')
    test_loss = evaluate_model(model, test_loader, criterion, device)
    print(f"Test Loss: {test_loss:.4f}")

    return train_losses, val_losses, test_loss

### Optuna Objective Function

In [25]:
def objective(
    trial: optuna.Trial,
    model_class: type,
    input_dim: int,
    num_epochs: int,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    test_loader: torch.utils.data.DataLoader
) -> float:
    # Define hyperparameters to optimize
    # done since `d_model` must be even
    d_model = trial.suggest_int('d_model', 4, 16, step=2)
    num_heads = trial.suggest_int('num_heads', 2, 8)
    num_layers = trial.suggest_int('num_layers', 1, 6)
    learning_rate = trial.suggest_loguniform('learning_rate', 1e-5, 1e-2)

    print(f'Optimizing across: {input_dim=}, {d_model=}, {num_heads=}, {num_layers=}, {learning_rate=}, {num_epochs=}')

    # prune cases with invalid number of heads
    if (d_model % num_heads) != 0:
      raise optuna.exceptions.TrialPruned()

    # Create model with suggested hyperparameters
    model = model_class(
        input_dim=input_dim,
        d_model=d_model,
        num_heads=num_heads,
        num_layers=num_layers,
    )

    # Train the model
    _, _, test_loss = train_model(
        model,
        train_loader,
        val_loader,
        test_loader,
        num_epochs=num_epochs,
        learning_rate=learning_rate
    )

    return test_loss

### Optimize Hyperparameters

In [26]:
def optimize_hyperparameters(
    model_class: type,
    input_dim: int,
    num_epochs: int,
    n_trials_per_param: int,
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    test_loader: torch.utils.data.DataLoader,
) -> tuple[optuna.Study, dict[str, int | float]]:

    print(f'Creating a study to optimize hyperparams across {n_trials_per_param}')

    study = optuna.create_study(direction='minimize')

    study.optimize(
        lambda trial: objective(trial, model_class, input_dim, num_epochs, train_loader, val_loader, test_loader),
        n_trials=n_trials_per_param,
        n_jobs=-1, # note: tells optuna to use all available CPU
    )

    print("Best trial:")
    trial = study.best_trial
    print(f"  Value: {trial.value}")
    print("  Params: ")
    for key, value in trial.params.items():
        print(f"    {key}: {value}")

    return study, trial.params

# Problem Specific

In [27]:
# TODO(bschoen): Generalize this over `TransformerProblem`
num_samples = 1000
batch_size = 32
context_size = 10
input_dim = 1

# define generator for input
def generate_randint_sequence(
    min_value: int,
    max_value: int,
    seq_len: int,
    input_dim: int,
  ) -> Float[torch.Tensor, "seq_len input_dim"]:
    """Generate arbitrary integer sequence."""
    return torch.randint(min_value, max_value + 1, (seq_len, input_dim)).float()

# define generator for output
def compute_min_max_sum(x: InputTensor) -> OutputTensor:
  return torch.min(x) + torch.max(x)

dataset = SequenceDataset(
    num_samples=num_samples,
    x_generator=lambda: generate_randint_sequence(0, 100, context_size, input_dim),
    y_computer=compute_min_max_sum,
)

## Dataset Construction

In [28]:
# Show a sample from the dataset to sanity check
print('Showing example dataset member:')
x, y = dataset[0]
print(f"x: {x.shape}\t{x}")
print(f"y: {y.shape}\t{y}")

# create splits
train_dataset, val_dataset, test_dataset = split_dataset(dataset)

# create dataloaders
train_dataloader, val_dataloader, test_dataloader = create_dataloaders(
    train_dataset,
    val_dataset,
    test_dataset,
    batch_size=batch_size,
    num_workers=4, # arbitrarily chosen to be non-zero
)

Showing example dataset member:
x: torch.Size([10, 1])	tensor([[65.],
        [81.],
        [73.],
        [80.],
        [41.],
        [24.],
        [19.],
        [22.],
        [21.],
        [19.]])
y: torch.Size([])	100.0
Creating splits for 1000 samples train_ratio=0.7, val_ratio=0.15, test_ratio=0.15
Created splits from 1000 total samples:
 - train (700 samples)
 - validation (150 samples)
 - test (150 samples)
Creating dataloaders for batch_size=32, num_workers=4, pin_memory=True


## Trying Run Before Hyperparameter Optimization

In [29]:
# trial run with arbitrary parameters to make sure everything works
def test_evaluate_model() -> None:
  model = DecoderOnlyTransformer(input_dim=1, d_model=20, num_heads=2, num_layers=3)

  device = get_device()
  model.to(device)

  loss = evaluate_model(
      model=model,
      data_loader=val_dataloader,
      criterion=torch.nn.MSELoss(),
      device=get_device(),
  )

  print(f'{loss=:.4f}')

test_evaluate_model()

Creating decoder only transformer: input_dim=1, d_model=20, num_heads=2, num_layers=3
Creating decoder layer: d_model=20, num_heads=2, mlp_expansion_factor=4
Creating decoder layer: d_model=20, num_heads=2, mlp_expansion_factor=4
Creating decoder layer: d_model=20, num_heads=2, mlp_expansion_factor=4


  self.pid = os.fork()


loss=10073.4689


## Determine Best Hyperparameters

In [34]:
import optuna.visualization

is_hyperparameter_study_needed = False

if is_hyperparameter_study_needed:

  num_epochs_for_optimizing_hyperparameter_trials = 100
  n_trials_per_param = 100

  # Assuming DecoderOnlyTransformer is defined elsewhere
  study, best_params = optimize_hyperparameters(
      model_class=DecoderOnlyTransformer,
      input_dim=input_dim,
      num_epochs=num_epochs_for_optimizing_hyperparameter_trials,
      n_trials_per_param=n_trials_per_param,
      train_loader=train_dataloader,
      val_loader=val_dataloader,
      test_loader=test_dataloader,
  )



  optuna.visualization.plot_optimization_history(study)
  optuna.visualization.plot_contour(study)
  optuna.visualization.plot_param_importances(study)
  optuna.visualization.plot_parallel_coordinate(study)
  optuna.visualization.plot_timeline(study)
  optuna.visualization.plot_slice(study)

## Visualize Optimization Study

## Train Best Model

In [40]:
# TODO(bschoen): Save and load study from local db
"""
Best trial:
  Value: 0.10772023797035217
  Params:
    d_model: 12
    nhead: 6
    num_layers: 1
    learning_rate: 0.005134454801759864
"""
best_params = {
    'd_model': 4,
    'num_heads': 2,
    'num_layers': 2,
    'learning_rate': 0.005,
}

In [41]:
# TODO(bschoen): Profile a single run here, because studies are taking forever

# note: can manually scale up parameters from here

# TODO(bschoen): Give models names

# TODO(bschoen): Ablate the MLP, in case that's what's learning it
# TODO(bschoen): Does the residual stream representation still hold with
#                MLP? It has to for the next layer to access it

# note: more epochs than we use in our study
num_epochs = 1000

# Create the best model using the optimized hyperparameters
best_model = DecoderOnlyTransformer(
    input_dim=input_dim,
    d_model=best_params['d_model'],
    num_heads=best_params['num_heads'],
    num_layers=best_params['num_layers'],
)

# Train the best model
train_losses, val_losses, final_test_loss = train_model(
    best_model,
    train_dataloader,
    val_dataloader,
    test_dataloader,
    num_epochs=num_epochs,
    learning_rate=best_params['learning_rate'],
)

print(f"Final test loss with best model: {final_test_loss:.4f}")

Creating decoder only transformer: input_dim=1, d_model=4, num_heads=2, num_layers=2
Creating decoder layer: d_model=4, num_heads=2, mlp_expansion_factor=4
Creating decoder layer: d_model=4, num_heads=2, mlp_expansion_factor=4
Training for num_epochs=1000 with learning_rate=0.005
Epoch [1/1000], Learning Rate: 0.00005000, Train Loss: 10193.9582, Val Loss: 9997.8986, Val Accuracy: 0.0000
Epoch [11/1000], Learning Rate: 0.00055000, Train Loss: 10020.2452, Val Loss: 9816.6037, Val Accuracy: 0.0000
Epoch [21/1000], Learning Rate: 0.00105000, Train Loss: 9690.4900, Val Loss: 9469.0748, Val Accuracy: 0.0000
Epoch [31/1000], Learning Rate: 0.00155000, Train Loss: 8679.9061, Val Loss: 8418.3759, Val Accuracy: 0.0000
Epoch [41/1000], Learning Rate: 0.00205000, Train Loss: 6293.2039, Val Loss: 5988.1242, Val Accuracy: 0.0000
Epoch [51/1000], Learning Rate: 0.00255000, Train Loss: 3053.8848, Val Loss: 2800.7419, Val Accuracy: 0.0000
Epoch [61/1000], Learning Rate: 0.00305000, Train Loss: 797.8245

In [42]:
device = get_device()

for name, dataloader in {
    'train': train_dataloader,
    'val': val_dataloader,
    'test': test_dataloader,
}.items():

  accuracy = compute_accuracy(best_model, dataloader, device)

  print(f'{name} accuracy: {accuracy:.4f}')

train accuracy: 0.9971
val accuracy: 0.9933
test accuracy: 0.9933


In [43]:
# sample some outputs to sanity check
with torch.no_grad():

  for batch_X, batch_y in test_dataloader:

    batch_X, batch_y = batch_X.to(device), batch_y.to(device)

    predicted_y = best_model(batch_X)

    for i in range(5):
      print('---')
      print(f'x_i:\t{[x.item() for x in batch_X[i]]}')
      print(f'y_pred:\t{predicted_y[i]}')
      print(f'y_i:\t{batch_y[i]}')

    break

---
x_i:	[89.0, 43.0, 96.0, 10.0, 21.0, 55.0, 5.0, 3.0, 34.0, 66.0]
y_pred:	99.1990737915039
y_i:	99.0
---
x_i:	[93.0, 93.0, 67.0, 72.0, 91.0, 16.0, 65.0, 18.0, 89.0, 10.0]
y_pred:	103.17760467529297
y_i:	103.0
---
x_i:	[76.0, 71.0, 37.0, 20.0, 21.0, 83.0, 88.0, 44.0, 53.0, 93.0]
y_pred:	113.03359985351562
y_i:	113.0
---
x_i:	[53.0, 94.0, 27.0, 48.0, 59.0, 51.0, 80.0, 98.0, 68.0, 97.0]
y_pred:	125.07417297363281
y_i:	125.0
---
x_i:	[5.0, 95.0, 74.0, 92.0, 78.0, 57.0, 48.0, 96.0, 91.0, 16.0]
y_pred:	100.9544906616211
y_i:	101.0


### Plotting Helper Functions

In [None]:
import plotly.express as px

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):

    px.imshow(
        transformer_lens.utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        labels={"x":xaxis, "y":yaxis},
        **kwargs,
    ).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):

    px.line(
        transformer_lens.utils.to_numpy(tensor),
        labels={"x":xaxis, "y":yaxis},
        **kwargs,
    ).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):

    x = transformer_lens.utils.to_numpy(x)
    y = transformer_lens.utils.to_numpy(y)

    px.scatter(
        y=y,
        x=x,
        labels={"x":xaxis, "y":yaxis, "color":caxis},
        **kwargs,
    ).show(renderer)