In [None]:
import sys
print(sys.version)

# The Annotated Hyena

A didactic annotation of the architecture introduced in the paper _Hyena Hierarchy: Towards Larger Convolutional Language Models_ by Michael Poli, Stefano Massaroli, Eric Nguyen, Daniel Y. Fu, Tri Dao, Stephen Baccus, Yoshua Bengio, Stefano Ermon and Christopher Ré.

The Hyena architecture is an exciting development that acts as a drop-in replacement for attention layers in transformer models enabling long-range sequence modeling. It matches and surpasses the language-modeling performance of attention-based transformers with similar parameter counts and can support context lengths over a 100K tokens with a speedup of 100x over the FlashAttention transformer already at a context length of 64K. Here we walk through its construction with annotated code in a style inspired by _[The Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)_. The paper's [reference implementation](https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py) was used to debug the below code.

## Table of Contents

* Preliminaries
* Part 1: The Hyena Operator
  * Operator definition
  * FFT
  * Interpretation as attention
* Part 2: Defining the Hyena filter
* Part 3: A Working Model
  * Dataset
  * Training
  * Evaluation
* Appendix: Comparison with Attention
* Appendix: Hyena as described in the paper

## Preliminaries

The below code imports the necessary packages and defines some utility functions. Understanding it is not important for understanding hyena-based models. If you are using Jupyter, you will want to begin by installing the required packages:

In [None]:
!pip install datasets lightning numpy regex torch

In [None]:
# !pip install wandb

In [None]:
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import dataclasses
import json
import math
import random

import datasets
import lightning
import numpy as np
import regex
import torch
import wandb


def prettify(z: Union[complex, List[complex], np.array]) -> str:
  def fix_float(f: float) -> Union[float, int]:
    if abs(round(f) - f) < 1e-6:
      return round(f)
    return f

  if isinstance(z, complex) or isinstance(z, float) or isinstance(z, int):
    re = fix_float(z.real)
    im = fix_float(z.imag)
    if im == 1:
      im_str = "i"
    elif im == -1:
      im_str = "-i"
    else:
      im_str = f"{im}i" if isinstance(im, int) else f"{im:.3f}i"
  
    re_str = f"{re}" if isinstance(re, int) else f"{re:.3f}"
    
    if not re:
      if not im:
        return "0"
      return im_str
    elif not im:
      return re_str
    else:
      return f"{re_str} + {im_str}"
  elif isinstance(z, list):
    return prettify(np.array(z))
  else:
    if len(z.shape) == 1:
      return "  ".join([prettify(a) for a in z.tolist()])
    elif len(z.shape) == 2:
      return "\n".join([
        "\t".join([prettify(a) for a in row])
        for row in z.tolist()
      ])
    else:
      raise NotImplementedError("3rd and higher order tensors unsupported")

def average_probability(
    model: lightning.LightningModule,
    ds: datasets.Dataset,
    tokens: int = 256
  ) -> float:
  """Average probability of predicting the next token in the dataset"""
  model.eval()
  model = model.to("cpu")
  total_p = 0
  row = -1
  context_length = len(ds[0]["curr_id"])
  for t in range(tokens):
    L = t % context_length
    if L == 0:
      row += 1
      row_ids = ds[row]["curr_id"]
    input_ids = torch.LongTensor([row_ids[:L + 1]])
    logits, _ = model.forward((input_ids, None))
    expected_id = ds[row]["next_id"][L]
    p = torch.softmax(logits[0, L, :], dim=-1)[expected_id].item()
    total_p += p
  return total_p / tokens
  
def generate(
    model: lightning.LightningModule,
    context_length: int,
    prompt: str,
    max_tokens: int = 32,
    method: str = "topk",
    k: int = 3
  ) -> str:
  model.eval()
  model = model.to("cpu")
  result = prompt
  prompt = prompt[:context_length]
  L = len(prompt)
  for _ in range(max_tokens):
    input_ids = torch.LongTensor([[tok2id[ch] for ch in prompt[:L]]])
    logits, _ = model.forward((input_ids, None))
    match method:
      case "sample":
        output_id = torch.multinomial(
          torch.exp(logits[0, L - 1, :]),
          num_samples=1
        ).item()
      case "topk":
        values, indices = torch.topk(torch.exp(logits[0, L - 1, :]), k=k)
        p = torch.sparse_coo_tensor(
          indices.unsqueeze(0),
          values,
          (logits.shape[-1],)
        )
        output_id = torch.multinomial(p.to_dense(), num_samples=1).item()
      case "greedy":
        output_id = torch.argmax(logits[0, L - 1, :]).item()
    prompt += vocabulary[output_id]
    result += vocabulary[output_id]
    L = min(L + 1, context_length)
  return result

@dataclasses.dataclass(kw_only=True)
class Config:
  learning_rate: float
  epochs: int
  betas: Tuple[float, float]
  weight_decay: float
  device_type: str
  precision: str
  batch_size: int
  num_workers: int


@dataclasses.dataclass(kw_only=True)
class HyenaConfig(Config):
  d_model: int
  n_layers: int
  vocab_size: int
  d_embed: int
  d_filter_mlp: int
  n_filter_layers: int
  context_length: int
  short_conv_size: int
  order: int
  pdrop_hyena: float
  pdrop_embed: float
  omega: Optional[int]


@dataclasses.dataclass(kw_only=True)
class AttentionConfig(Config):
  d_model: int
  n_layers: int
  vocab_size: int
  d_embed: int
  n_head: int
  context_length: int
  pdrop_attn: float
  pdrop_embed: float 


torch.set_float32_matmul_precision("medium")

## Part 1: The Hyena Operator

The Hyena Operator alternates between convolutions and Hadamard (elementwise) products. The convolutions are with filters (vectors) which will be defined later. The Hadamard products are with functions of the input. In other words, the Hyena Operator of order N is
$$
H_N(u) = x_N(u) \cdot (h_N \ast (x_{N - 1}(u) \cdot (h_{N - 1} \ast (\cdots (h_1 \ast x_0(u)) \cdots )))) \qquad x_i(u) = w_i \star (A_i u)
$$
where $u$ is the input vector, $h_i$ are vectors called filters, and $x_i$ is a matrix multiplication followed by a padded cross-correlation known in machine learning as a "depthwise convolution". Here $\star$ represents cross-correlation and $\ast$ represents convolution.

If you are already lost, don't worry! The parts of this definition will be explained below.

### Hadamard product

To begin, the Hadamard product is just the elementwise product of two vectors. In other words, if $a$ and $b$ are vectors, then
$$
a \cdot b = (a_1 b_1, \ldots, a_n b_n) \qquad a = (a_1, \ldots, a_n) \quad b = (b_1, \ldots, b_n)
$$
For those who prefer code:

In [None]:
def hadamard(a: List[complex], b: List[complex]) -> List[complex]:
  assert len(a) == len(b)
  return [a[i] * b[i] for i in range(len(a))]

hadamard([0, 1, 2, 3], [0, 2, 4, 6])

### Convolution

A convolution is a way of multiplying two sequences. Let $a$ and $b$ be infinite sequences. Then their convolution is a doubly infinite sequence whose $n$th element is
$$
(a \ast b)_n = \sum_{k=-\infty}^{\infty} a_k b_{n - k} \qquad a = (\ldots, a_{-1}, a_0, a_1, \ldots) \quad b = (\ldots, b_{-1}, b_0, b_1, \ldots)
$$
So, for example, the $0$th element of the convolution is
$$
\sum_{k=-\infty}^{\infty} a_k b_{-k}
$$
We can define the convolution of two finite-dimensional vectors by viewing them as sequences padded by infinitely many zeros to the right and to the left. For example, if $a = (-1, -2, -3)$ and $b = (1, 2, 3)$, their convolution is
$$
\begin{array}{rcl}
a \ast b & = & (-1 * 1, (-1 * 2) + (-2 * 1), (-1 * 3) + (-2 * 2) + (-3 * 1), (-2 * 3) + (-3 * 2), -3 * 3) \\
& = & (-1, -4, -10, -12, -9)
\end{array}
$$
In code:

In [None]:
def convolution(a: List[complex], b: List[complex]) -> List[complex]:
  M = len(a)
  N = len(b)
  return [
    sum([a[k] * b[n - k] for k in range(M) if 0 <= n - k < N])
    for n in range(N + M - 1)
  ]

a = [-1, -2, -3]
b = [1, 2, 3]
convolution(a, b)

#### Discrete Fourier Transform

The convolution theorem says that we can calculate convolutions using the Discrete Fourier Transform (DFT) and the Hadamard product. This is important because using the definition, calculating a convolution is slow for large vectors and has complexity $O(n^2)$. However there is a fast way to calculate the DFT in $O(n \log n)$ time aptly called the Fast Fourier Transform. Together with the Hadamard product which requires $O(n)$ time, this means that we can calculate convolutions in $O(n \log n)$ time. Let's look at the details.

To start, the formula is
$$
a \ast b = IDFT(DFT(\overline{a}) \cdot DFT(\overline{b}))
$$
where IDFT stands for the Inverse Discrete Fourier Transform and $\overline{a}$ and $\overline{b}$ are $a$ and $b$ padded with $N - 1$ zeros on the right, $N$ being the length of the vectors.

The DFT of a column vector $a$ of length $N$ is defined as the multiplication of $a$ on the left by the matrix with entries
$$
(DFT_N)_{jk} = e^{-2\pi i j k / N}
$$
or in other words multiplication by
$$
DFT_N = \left(
\begin{array}{cccccc}
1 & 1 & 1 & 1 & \cdots & 1 \\
1 & \omega_N & \omega_N^2 & \omega_N^3 & \cdots & \omega_N^{N-1} \\
1 & \omega_N^2 & \omega_N^4 & \omega_N^6 & \cdots & \omega_N^{N-2} \\
1 & \omega_N^3 & \omega_N^6 & \omega_N^9 & \cdots & \omega_N^{N-3} \\
\vdots & & & & & \\
1 & \omega_N^{N-1} & \omega_N^{N-2} & \omega_N^{N-3} & \cdots & \omega_N
\end{array}
\right) \qquad \omega_N = e^{-2\pi i / N}
$$
The inverse of this matrix, i.e., the matrix of the IDFT, turns out to be its elementwise complex conjugate divided by $N$. In code:

In [None]:
def omega(N: int) -> complex:
  return pow(math.e, -2 * math.pi * 1j / N)
  
def DFT_matrix(N: int) -> np.array:
  return np.array([[pow(omega(N), j * k) for k in range(N)] for j in range(N)])

def DFT(a: List[complex]) -> np.array:
  return DFT_matrix(len(a)) @ np.array(a)
  
def IDFT_matrix(N: int) -> List[List[complex]]:
  return np.conjugate(DFT_matrix(N)) / N

def IDFT(a: List[complex]) -> np.array:
  return IDFT_matrix(len(a)) @ np.array(a)
  
print("DFT_matrix(4):")
print(prettify(DFT_matrix(4)))
print("IDFT_matrix(4):")
print(prettify(IDFT_matrix(4)))

We can now check the convolution theorem:

In [None]:
a = [-1, -2, -3, 0, 0]
b = [1, 2, 3, 0, 0]
print(prettify(IDFT(hadamard(DFT(a), DFT(b)))))
print(prettify(convolution(a[:3], b[:3])))

Or here's a check with random vectors:

In [None]:
#random.seed(0)

a = [random.random() for _ in range(8)]
b = [random.random() for _ in range(8)]
a_bar = a + [0] * (len(a) - 1)
b_bar = b + [0] * (len(b) - 1)
print(prettify(IDFT(hadamard(DFT(a_bar), DFT(b_bar)))))
print(prettify(convolution(a, b)))

#### Fast Fourier Transform

The idea for how to calculate the Discrete Fourier Transform in $O(n \log n)$ time is to use the divide and conquer technique. This means noticing that
$$
DFT_N(a_s)_k = DFT_{N/2}(a_{2s})_k + e^{-2\pi i k/N} DFT_{N/2}(a_{2s+1})_k, \qquad k = 0, \ldots, N / 2 - 1
$$
and
$$
DFT_N(a_s)_k = DFT_{N/2}(a_{2s})_{k-N/2} + e^{-2\pi i k/N} DFT_{N/2}(a_{2s+1})_{k-N/2}, \qquad k = N/2, \ldots, N - 1
$$
where, for example, $DFT_N(a_{2s})_k$ means the $k$th coordinate of the Discrete Fourier Transform of the even entries (using 0-based indexing) of the vector $a$. The code might be easier to read than the formula: 

In [None]:
def roots_of_unity(n: int) -> List[complex]:
  return [pow(math.e, -2 * math.pi * 1j * k / n) for k in range(n)]

def _FFT(x: List[complex], W: List[complex]) -> Tuple[List[complex], int]:
  k = len(x)
  if k == 1:
    return x, 0
  else:
    n = len(W)
    
    X_even, ops_even = _FFT([x[i] for i in range(0, k, 2)], W)
    X_odd, ops_odd = _FFT([x[i] for i in range(1, k, 2)], W)
    W_k = [W[i] for i in range(0, n, n // k)]

    X_left = [X_even[i] + W_k[i] * X_odd[i] for i in range(k // 2)]
    X_right = [X_even[i] + W_k[k // 2 + i] * X_odd[i] for i in range(k // 2)]
    ops = 2 * k + ops_even + ops_odd
    return X_left + X_right, ops
    
def FFT(x: List[complex], verbose: bool = False) -> np.array:
  n = len(x)
  assert n == n & ~(n - 1), "only vectors of length 2**n are supported"

  W = roots_of_unity(n)

  X, ops = _FFT(x, W)
  if verbose:
    print("n:", n, "ops:", ops)
  return np.array(X)

for k in range(2, 6):
  a = list(range(2**k))
  FFT(a, verbose=True)

Notice that the algorithm uses $2n \log_2 n$ operations not including the calculation of the roots of unity so we've achieved our goal of subquadratic complexity. Let's check that it gives the same answer as the DFT:

In [None]:
#random.seed(0)

a = [random.random() for _ in range(8)]
print(prettify(DFT(a)))
print(prettify(FFT(a)))

### Padded cross-correlation
The last concept we need to understand the Hyena operator is the padded cross-correlation, a.k.a. the depthwise convolution. Even though it is commonly called a convolution in machine learning, it is not a convolution in the usual mathematical sense that we defined above. If $x$ and $f$ are vectors, then their cross-correlation is the vector with $n$th coordinate
$$
(f \star x)_n = \sum_{k=1}^d f_k x_{n + k} \qquad x = (x_0, \ldots, x_{N - 1}) \quad f = (f_0, \ldots, f_{d - 1})
$$
The cross-correlation has length $N - d + 1$. For those who prefer code:

In [None]:
def crosscorrelation1d(x: List[float], filter: List[float]) -> List[float]:
  N = len(x)
  d = len(filter)
  return [
    sum(x[n + k] * filter[k] for k in range(d))
    for n in range(N - d + 1)
  ]

x = [1, 2, 3, 4, 5, 6, 7]
f = [-1, 2, 1]
print(prettify(crosscorrelation1d(x, f)))

The Hyena operator will be used to create a model that takes in text and attempts to predict the next token, so we need it to be _causal_. In other words, we do not want a value at position $n$ in the sequence receiving information from values in future positions, e.g., at $n + 1$. Otherwise the model will use the information about the future tokens to predict the next token.

The cross-correlation will not be causal because, for example, given a length three vector $f$, the $0$th coordinate of the convolution is $(f \star x)_0 = f_0 x_0 + f_1 x_1 + f_2 x_2$. So the $0$th value will have information about $x_1$ and $x_2$ which are values in future positions relative to the $0$th value. To make the cross-correlation causal, we just need to pad the vector on the left with zeros. If $f$ has length $d$ then we need $d - 1$ zeros. In our example, that would mean that $(f \star x)_0 = f_0 0 + f_1 0 + f_2 x_0$.

### Adding dimensions

So far we have considered the Hyena operator as an operator on vectors. However in practice we will want it to operate on tensors that have a batch axis and an embedding axis, in other words on tensors with shape $(b, d, L)$ where $b$ is the number of samples in a batch and $d$ is the embedding dimension. We just apply the operations we have described in parallel across the batch and embedding dimensions.

### Putting it all together

Using PyTorch it is straightforward to define the matrix multiplication followed by padded cross-correlation. We will call it the `Projection` module because it projects the input embeddings to $x_1, \ldots, x_N$. As a reminder, it is implementing
$$
x_i(u) = w_i \star A_i u
$$
Instead of implementing it in a loop, we do the operation for all $i$ at once and then split the result into separate $x_i$.
As explained in the previous section, the cross-correlation (`Conv1d`) is executed in parallel across the embedding dimension which is accomplished by setting `groups=d_model * (N + 1)`, i.e., one group per dimension.

In [None]:
class Projection(torch.nn.Module):
  def __init__(self, d_model: int, N: int, conv_len: int):
    super().__init__()
    self.d_model = d_model
    self.N = N
    self.linear = torch.nn.Linear(d_model, d_model * (N + 1))
    self.conv = torch.nn.Conv1d(
      in_channels=d_model * (N + 1),
      out_channels=d_model * (N + 1),
      kernel_size=conv_len,
      groups=d_model * (N + 1),  # Depthwise convolution
      padding=conv_len - 1,
    )
    
  def forward(self, u: torch.Tensor) -> List[torch.Tensor]:
    z = self.linear(u)
    z = z.transpose(1, 2)  # Channels (embedding dim) needs to come first
    
    L = z.shape[2]
    z = self.conv(z)[..., :L]
    
    x = torch.split(z, self.d_model, dim=1)
    return x

Next we define the convolution using FFT and the Hadamard product as discussed above. Because our vectors are real-valued, we can use a special version of FFT optimized for real numbers called `torch.fft.rfft`. It turns out that the FFT of a real-valued vector has the property that every conjugate of a coordinate is also a coordinate, and `rfft` drops the conjugates because they are superfluous. See below for an example.

In [None]:
a = torch.Tensor([random.random() for _ in range(8)])
b = torch.Tensor([random.random() for _ in range(8)])
a_rf = torch.fft.rfft(a, norm="forward")
b_rf = torch.fft.rfft(b, norm="forward")
c_rf = a_rf * b_rf
cr = torch.fft.irfft(c_rf, norm="forward")

a_f = torch.fft.fft(a, norm="forward")
b_f = torch.fft.fft(b, norm="forward")
c_f = a_f * b_f
c = torch.fft.irfft(c_f, norm="forward")

print("a_rf:", prettify(a_rf.numpy()))
print("a_f:", prettify(a_f.numpy()))
print("convolution using rfft:", prettify(cr.numpy()))
print("convolution using fft:", prettify(cr.numpy()))

We also add in a skip connection, i.e., we actually compute
$$
h_i \ast x_i + B_i x_i
$$
for some matrix $B_i$. This improves gradient flow.

In [None]:
class FFTConv(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def forward(
      self,
      h: torch.Tensor,
      x: torch.Tensor,
      B: torch.Tensor
    ) -> torch.Tensor:
    L = h.shape[-1]
    h_f = torch.fft.rfft(h, n=2 * L, norm="forward")
    x_f = torch.fft.rfft(x.to(dtype=h.dtype), n=2 * L)
    y = torch.fft.irfft(h_f * x_f, n=2 * L, norm="forward")[..., :L]
    y = y + x * B
    y = y.to(dtype=h.dtype)  # y is ComplexFloat but we need it to be float
    return y

Now we are ready to define the Hyena block which is an $N$th order Hyena Operator followed by a linear output layer. It will be a drop-in replacement for an attention block in a transformer model.

We make two changes which are different from the paper to address vanishing and exploding gradients:

1. We add three skip connections which are marked by comments below. By a skip connection we just mean that we add the input tensor to the output of an operator. See the [Resnet paper](https://arxiv.org/pdf/1512.03385.pdf) for more details.
2. We normalize $x_i$ across the embedding dimension.

Note that it is important that the softmax is taken across the embedding dimension and not across the sequence dimension because taking it across the sequence dimension would mean the operator is no longer causal.

In [None]:
class HyenaBlock(torch.nn.Module):
  def __init__(self, config: HyenaConfig):
    super().__init__()
    self.proj_input = Projection(config.d_model, config.order, config.short_conv_size)
    self.proj_output = torch.nn.Linear(config.d_model, config.d_model)
    self.filter = HyenaFilter(
      config.d_model,
      config.d_filter_mlp,
      config.d_embed,
      config.order,
      config.n_filter_layers,
      config.context_length,
      config.omega,
    )
    self.dropout = torch.nn.Dropout(config.pdrop_hyena)
    self.fft_conv = FFTConv()
    self.B = torch.nn.Parameter(torch.randn((config.order, 1, config.d_model, 1)))

  def forward(self, u: torch.Tensor) -> torch.Tensor:
    L = u.shape[1]
    
    *x, v = self.proj_input(u)
    v = v + u.transpose(1, 2)  # skip connection
    
    h = self.filter(L)
    
    for i, x_i in enumerate(x):
      h_i = h[i].unsqueeze(0)
      v = v + torch.nn.functional.normalize(x_i, dim=1) * self.fft_conv(h_i, v, self.B[i])  # skip connection
    
    v = v.transpose(1, 2)
    y = v + self.proj_output(v)  # skip connection

    return y

## Part 2: Hyena Filter

It remains to define the vectors $h_i$. We will call these vectors "filters" because we are convolving the input with them.

The filters are parameters of the model, but we force them to have a special form that decays exponentially along the sequence so that the model will pay more attention to close context than far away context.
$$
h_i = \mathrm{norm}(g_i \cdot (e^{\alpha \cdot t} + b))
$$
where $\mathrm{norm}(x) = x / |x|$ along the sequence axis. The vector $t$ is fixed to be equal to $(0, 1/(L-1), 2/(L-1), \ldots, 1)$ for sequence length $L$. The vectors $\alpha$ and $b$ are parameters of the model that vary along the embedding axis (they are constant along the sequence axis). The addition of the $b$ vector allows the model to prevent the exponential decay from approaching zero. The vector $g_i$ is also a parameter of the model. The dot again denotes the Hadamard product.

The paper has a more complex definition of $h_i$ that is described below in an appendix, but in our limited testing, the additional complexity did not win any performance improvement.

Putting this into code gives:

In [None]:
class Window(torch.nn.Module):
  def __init__(
      self,
      d_model: int,
      max_seq_len: int,
      fast_decay_pct: float = 0.3,
      slow_decay_pct: float = 1.5,
      target: float = 1e-2,
    ):
    super().__init__()
    self.b = torch.nn.Parameter(torch.zeros((1, d_model, 1)))
    min_decay = math.log(target) / slow_decay_pct
    max_decay = math.log(target) / fast_decay_pct
    self.alphas = torch.nn.Parameter(
      torch.linspace(
        start=min_decay,
        end=max_decay,
        steps=d_model
      )[None, :, None]
    )
    self.t = torch.nn.Parameter(
      torch.linspace(
        start=0,
        end=1,
        steps=max_seq_len
      )[None, None, :], requires_grad=False
    )
      
  def forward(self, x):
    L = x.shape[2]
    c = torch.exp(self.alphas * self.t)[:, :, :L]
    x = x * (c + self.b)
    return x
      
class HyenaFilter(torch.nn.Module):
  def __init__(
      self,
      d_model: int,
      d_mlp: int,
      d_embed: int,
      N: int,
      n_layers: int = 4,
      max_seq_len: int = 128,
      omega: int = 8,
    ):
    assert n_layers >= 2, "n_layers must be at least 2"
    super().__init__()

    self.N = N
    self.d_model = d_model

    # Making this a parameter, even though it is not trained, ensures
    # it will be moved to the gpu with the rest of the model
    self.h = torch.nn.Parameter(torch.randn((N, d_model, max_seq_len)))
      
    self.window = Window(d_model, max_seq_len)
  
  def forward(self, L: int) -> torch.Tensor:
    h = self.h[:, :, :L]
    h = self.window(h)
    
    h = h / torch.norm(h, dim=-2, p=1, keepdim=True)
        
    return h

## Part 3: A Working Model

The `HyenaBlock` defined above is a drop-in replacement for a self-attention block, so we can just use the standard GPT-2 architecture to make a full model. Recall that GPT-2 has randomly initialized token and position embeddings with the input token embedding weights tied (the same variables as) to the logit output weights. We apply dropout to the embeddings.

We train using the cross-entropy loss.

In [None]:
class GPModel(lightning.LightningModule):
  def __init__(self, config: Config, block_cls: torch.nn.Module):
    super().__init__()
    self.config = config
    self.tok_emb = torch.nn.Embedding(config.vocab_size, config.d_model)
    self.pos_emb = torch.nn.Parameter(
      torch.randn(1, config.context_length, config.d_model)
    )
    self.drop = torch.nn.Dropout(config.pdrop_embed)
    self.layers = torch.nn.Sequential(*[
      block_cls(config) for _ in range(config.n_layers)
    ])
    self.ln = torch.nn.LayerNorm(config.d_model)
    self.head = torch.nn.Linear(
      config.d_model,
      config.vocab_size,
      bias=False
    )
    # input embedding and logit output weights are tied
    self.head.weight = self.tok_emb.weight

  def forward(
      self,
      batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> torch.Tensor:
    x, y = batch

    token_embeddings = self.tok_emb(x)
    position_embeddings = self.pos_emb[:, :token_embeddings.shape[1], :]

    x = self.drop(token_embeddings + position_embeddings)
    x = self.layers(x)
    logits = self.head(self.ln(x))

    return logits, y
    
  def calculate_loss(
      self,
      logits: torch.Tensor,
      targets: torch.Tensor
    ) -> float:
    loss = torch.nn.functional.cross_entropy(
      logits.transpose(1, 2), targets
    )
     
    return loss

  def training_step(
      self,
      batch: Tuple[torch.Tensor, torch.Tensor],
      batch_idx: int
    ) -> float:
    #with torch.autograd.detect_anomaly():
    logits, targets = self.forward(batch)
    loss = self.calculate_loss(logits, targets)
    self.log(
      "train_loss", loss, prog_bar=True, on_step=True, on_epoch=False
    )
    return loss

  def validation_step(
      self,
      batch: Tuple[torch.Tensor, torch.Tensor],
      batch_idx: int
    ) -> float:
    logits, targets = self.forward(batch)
    loss = self.calculate_loss(logits, targets)
    self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
    return loss
    
  def configure_optimizers(self):
    return torch.optim.AdamW(
      self.parameters(),
      lr=self.config.learning_rate,
      betas=self.config.betas,
      weight_decay=self.config.weight_decay,
    )

### Dataset

Now we just need a dataset. We choose a dataset of Karpathy composed of text from Shakespeare. The below code downloads it from HuggingFace. It is a single very long row of data.

Our tokens will be the individual characters of the text.

In [None]:
ds = datasets.load_dataset("tiny_shakespeare", split="train")
ds = ds.map(
  lambda x: {"char": regex.findall(r"\X", x["text"])},
  remove_columns=["text"]
)

vocabulary = sorted(set(ds[0]["char"]))  # Entire dataset is a single row
print(vocabulary)

tok2id = {ch: i for i, ch in enumerate(vocabulary)}
print(tok2id)

We process the dataset to tokenize it and then split it into lines of a size specified by the config. Then we split the lines into training and validation sets. Finally we save it to disk so next time we will not need to download it.

Crucially we need to choose a context length that will be the maximum number of tokens our models can process.

In [None]:
CONTEXT_LENGTH = 128

With that decided, we can process the dataset:

In [None]:
# Number of rows to use for the validation set
val_size = 3200
  
def process_dataset(
    ds: datasets.Dataset
  ) -> Tuple[datasets.Dataset, datasets.Dataset]:
  
  def chunk(lst: List[Any], n: int) -> List[List[Any]]:
    """Break `lst` into length n chunks, dropping final chunk"""
    return [
      lst[i:i + n]
      for i in range(0, len(lst) - n, n)
    ]
  
  def create_batches(x: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
    return {
      "curr_id": chunk(x["curr_id"][0], CONTEXT_LENGTH),
      "next_id": chunk(x["next_id"][0], CONTEXT_LENGTH),
    }
  
  def tokenize(x: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
    return {
      "id": [tok2id[ch] for ch in x["char"]]
    }
  
  ds = ds.map(tokenize, remove_columns=["char"])
  ds = ds.map(
    lambda x: {"curr_id": x["id"][:-1], "next_id": x["id"][1:]},
    remove_columns=["id"]
  )
  ds = ds.map(create_batches, batched=True, batch_size=1)
  ds = ds.shuffle(seed=0)
  
  val_ds = ds.select(range(0, val_size))
  train_ds = ds.select(range(val_size, len(ds)))
  
  return train_ds, val_ds

train_ds, val_ds = process_dataset(ds)
train_ds.save_to_disk(f"data/train-{CONTEXT_LENGTH}")
val_ds.save_to_disk(f"data/val-{CONTEXT_LENGTH}")

with open("data/vocabulary.json", "w") as f:
  json.dump(vocabulary, f)

If you have previously downloaded the data and just want to load it from disk, then use this code:

In [None]:
train_ds = datasets.Dataset.load_from_disk(f"data/train-{CONTEXT_LENGTH}")
val_ds = datasets.Dataset.load_from_disk(f"data/val-{CONTEXT_LENGTH}")

with open("data/vocabulary.json", "r") as f:
  vocabulary = json.load(f)
  print(vocabulary)

tok2id = {ch: i for i, ch in enumerate(vocabulary)}
print(tok2id)

### Training

The training loop can be defined using PyTorch Lightning. If you plan to experiment with the architecture, the tool from Weights and Biases might be useful. It requires a free account and uncommenting the Weights and Biases (`wandb`) code. It allows you to view charts of the gradients to diagnose problems like vanishing gradients as well as charts of the loss that will allow you to compare different runs.

If you do use Weights and Biases, the first step is to log in (assuming you have installed the `wandb` package above).

In [None]:
wandb.login()

In any case, Pytorch Lightning makes the training loop easy to write:

In [None]:
import lightning.pytorch.loggers

def collate(
    data: List[Dict[str, List[int]]]
) -> Tuple[torch.Tensor, torch.Tensor]:
    curr_ids = torch.LongTensor([d["curr_id"] for d in data])
    next_ids = torch.LongTensor([d["next_id"] for d in data])
    return curr_ids, next_ids


def train(model: lightning.LightningModule, config: Config) -> None:
    wandb_logger = lightning.pytorch.loggers.WandbLogger(
      project="hyena-gpt-shakespeare"
    )
    wandb_logger.experiment.config.update(dataclasses.asdict(config))
    wandb_logger.watch(model, log="all", log_freq=1)
  
    trainer = lightning.Trainer(
        accelerator=config.device_type,
        precision=config.precision,
        max_epochs=config.epochs,
        gradient_clip_val=0.2,
        logger=wandb_logger,
    )

    train_dl = torch.utils.data.DataLoader(
        train_ds,
        collate_fn=collate,
        shuffle=True,
        pin_memory=True,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
    )

    val_dl = torch.utils.data.DataLoader(
        val_ds,
        collate_fn=collate,
        shuffle=False,
        pin_memory=True,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
    )

    trainer.fit(model, train_dl, val_dl)

    wandb.finish()

    return model

All that's left to do is set the hyperparameters and train. If you are using a cpu, then change the device type to "cpu" and the precision to "32". You might also want to decrease the size of the model to speed up training and decrease the number of epochs. If you are using a gpu, it may not support bf16 precision, in which case you can change it to "16" or "32".

In [None]:
hyena_config = HyenaConfig(
  d_model=386,
  n_layers=6,
  vocab_size=len(vocabulary),
  d_embed=33,
  d_filter_mlp=64,
  n_filter_layers=4,
  context_length=CONTEXT_LENGTH,
  short_conv_size=3,
  order=2,
  pdrop_hyena=0.0,
  pdrop_embed=0.2,
  omega=12,
  epochs=40,
  learning_rate=6e-4,
  betas=(0.9, 0.98),
  weight_decay=0.4,
  device_type="gpu",  # cpu, gpu
  precision="bf16",  # 32, 16, 16-mixed, bf16
  batch_size=64,
  num_workers=4,
)

hyena_model = GPModel(hyena_config, HyenaBlock)
hyena_model = train(hyena_model, hyena_config)

### Evaluation

Once the model is trained we should check that it can predict the next token reasonably well and that its generated text shows some resemblance to its training data. If there is a bug that, for example, violates causility, the loss could have decreased but the model will not do prediction or generation well.

In [None]:
print(
  "hyena average probability of next token:", 
  average_probability(hyena_model, val_ds, tokens=256)
)

In [None]:
print(
  generate(
    hyena_model,
    hyena_config.context_length,
    "Wherefore art thou ",
    method="topk",
    max_tokens=100,
    k=2
  )
)

## Appendix: Comparison with Attention

We can define a traditional causal self-attention layer as follows (Cf. [The Annotated Transformer](https://nlp.seas.harvard.edu/annotated-transformer/)).

In [None]:
class GELU(torch.nn.Module):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return torch.nn.functional.gelu(input)

class CausalSelfAttention(torch.nn.Module):
  def __init__(self, d_embed: int, n_head: int, pdrop_attn: float):
    super().__init__()
    assert d_embed % n_head == 0
    self.d_embed = d_embed
    self.n_head = n_head

    self.mask = torch.zeros((1, 1), dtype=torch.bool)
    self.attn = torch.nn.MultiheadAttention(
      embed_dim=d_embed, num_heads=n_head, dropout=pdrop_attn, batch_first=True
    )

  def forward(self, x: torch.Tensor, padding: Optional[torch.Tensor] = None) -> torch.Tensor:
    seq_len = x.shape[1]
    if self.mask.shape[0] != seq_len:
      self.mask = torch.tril(
        torch.ones((seq_len, seq_len), dtype=torch.bool, device=x.device), diagonal=-1
      ).T
    return self.attn(
      x, x, x, key_padding_mask=padding, need_weights=False, attn_mask=self.mask
    )[0]

class SelfAttentionBlock(torch.nn.Module):
  def __init__(self, config: AttentionConfig):
    super().__init__()
    self.ln1 = torch.nn.LayerNorm(config.d_embed)
    self.ln3 = torch.nn.LayerNorm(config.d_embed)
    self.self_attn = CausalSelfAttention(
      config.d_embed,
      config.n_head,
      config.pdrop_attn,
    )
    self.mlp = torch.nn.Sequential(
      torch.nn.Linear(config.d_embed, 4 * config.d_embed),
      GELU(),
      torch.nn.Linear(4 * config.d_embed, config.d_embed),
      torch.nn.Dropout(config.pdrop_attn),
    )

  def forward(self, x: torch.Tensor, padding: Optional[torch.Tensor] = None) -> torch.Tensor:
    """
    Do layer normalization before attention/MLP according to
    https://arxiv.org/pdf/2002.04745.pdf
    """
    # padding is for x (the key)
    z = x + self.self_attn(self.ln1(x), padding)
    z = z + self.mlp(self.ln3(z))
    return z

Then we can simply pass the layer to the model constructor which will use it to build the model.

In [None]:
attn_config = AttentionConfig(
  d_model=384,
  n_layers=6,
  vocab_size=len(vocabulary),
  d_embed=384,
  n_head=6,
  context_length=CONTEXT_LENGTH,
  pdrop_attn=0.2,
  pdrop_embed=0.2,
  learning_rate=3e-4,
  epochs=40,
  betas=(0.9, 0.98),
  weight_decay=0.1,
  device_type="gpu",
  precision="bf16",
  batch_size=64,
  num_workers=4,
)

attn_model = GPModel(attn_config, SelfAttentionBlock)
attn_model = train(attn_model, attn_config)

In our experiments, we found that the attention-based model attains a loss of about 1.6 compared to less than 1.5 for the hyena-based model despite the attention-based model having more parameters (10M vs 8M). However in this short (128) context regime it does execute more quickly than the hyena-based model.

Finally we evaluate the prediction and generation of the newly trained model.

In [None]:
print(
  "attention average probability of next token:",
  average_probability(attn_model, val_ds, tokens=256)
)

In [None]:
print(
  generate(
    attn_model,
    attn_config.context_length,
    "Wherefore art thou ",
    method="topk",
    max_tokens=80,
    k=2
  )
)

## Appendix: Hyena as described in the paper

__The below modules are labeled as "authentic" because they correspond to the description in the paper. The authors may have omitted details that greatly improve the performance, so don't take the below performance as an indicator of the authors' work.__

The paper's Hyena filter is more complex than the simplified version presented above. It works in three steps:

1. There is a trainable positional embedding initialized with `sin` and `cos` values
2. The positional embeding is passed through a few linear layers with sinusoidal activation functions
3. The output is multiplied elementwise by trainable exponentially decaying vectors ("windows")

The positional embedding is initialized to a matrix whose $t$th row is
$$
[t / L, \cos(2\pi*0*t/L), \ldots, \cos(2\pi*(K-1)*t/L), \sin(2\pi*0*t/L), \ldots \sin(2\pi*(K-1)*t/L]
$$
for a hyperparameter $K$ which determines the size of the embedding.

The frequency of the `sin` activation function of the linear layers is another hyperparameter of the model.

The window function is the same as described above except, at least in the [reference implementation](https://github.com/HazyResearch/safari/blob/main/src/models/sequence/hyena.py), the shift $b$ is an untrainable scalar.
$$
h = h \cdot (e^{\alpha \cdot t} + b)
$$

In [None]:
class PositionalEmbedding(torch.nn.Module):
  def __init__(self, d_embed: int, max_seq_len: int):
    assert d_embed % 2 == 1, "only odd dimensional positional embeddings are supported"
    assert d_embed > 1, "positional embedding must be at least 3"
    super().__init__()
    
    t = torch.linspace(start=0, end=1, steps=max_seq_len)[:, None]

    tp = torch.linspace(start=0, end=max_seq_len - 1, steps=max_seq_len)[:, None]
    K = (d_embed - 1) // 2   
    k = torch.linspace(start=0, end=K - 1, steps=K)[None, :]
    z = torch.exp(2 * math.pi * 1j * k * tp / max_seq_len)
    self.time_emb = torch.nn.Parameter(t.transpose(0, 1).unsqueeze(0), requires_grad=False)
    self.pos_emb = torch.nn.Parameter(torch.cat([t, z.real, z.imag], dim=-1), requires_grad=True)

  def forward(self, L: int) -> Tuple[torch.Tensor, torch.Tensor]: 
    return self.time_emb[:, :, :L], self.pos_emb[:L]    


class Sin(torch.nn.Module):
  def __init__(self, d_model: int, omega: int = 8, trainable: bool = False):
    super().__init__()
    self.freq = torch.nn.Parameter(omega * torch.ones(1, d_model), requires_grad=trainable)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return torch.sin(self.freq * x)


class AuthenticWindow(torch.nn.Module):
  def __init__(
      self,
      d_model: int,
      fast_decay_pct: float = 0.3,  # Defaults from the official implementation
      slow_decay_pct: float = 1.5,
      target: float = 1e-2,
      shift: float = 0.0,
    ):
    super().__init__()
    self.shift = shift
    min_decay = math.log(target) / slow_decay_pct
    max_decay = math.log(target) / fast_decay_pct
    self.alphas = torch.nn.Parameter(
      torch.linspace(
        start=min_decay,
        end=max_decay,
        steps=d_model
      )[None, :, None], requires_grad=True)
        
  def forward(self, t, x):
    L = x.shape[2]
    c = torch.exp(self.alphas * t)[:, :, :L]
    x = x * (c + self.shift)
    return x
      
class AuthenticHyenaFilter(torch.nn.Module):
  def __init__(
      self,
      d_model: int,
      d_mlp: int,
      d_embed: int,
      N: int,
      n_layers: int = 4,
      max_seq_len: int = 128,
      omega: int = 8,
    ):
    assert n_layers >= 2, "n_layers must be at least 2"
    super().__init__()

    self.N = N
    self.d_model = d_model
      
    self.pos_emb = PositionalEmbedding(d_embed, max_seq_len)
    
    self.mlp = torch.nn.Sequential(
      torch.nn.Linear(d_embed, d_mlp),
      Sin(d_mlp, omega),
    )
    for _ in range(n_layers - 2):
      self.mlp.append(torch.nn.Linear(d_mlp, d_mlp))
      self.mlp.append(Sin(d_mlp, omega))
    self.mlp.append(torch.nn.Linear(d_mlp, N * d_model, bias=False))

    self.t = torch.nn.Parameter(
      torch.linspace(
        start=0,
        end=1,
        steps=max_seq_len
      )[None, None, :], requires_grad=False)
    self.h = torch.nn.Parameter(torch.randn((N, d_model, max_seq_len)))
      
    self.window = AuthenticWindow(d_model)

  def forward(self, L: int) -> torch.Tensor:
    t, z = self.pos_emb(L)
    h = self.mlp(z)

    h = h.transpose(0, 1)
    h = h.reshape(self.N, self.d_model, L)
    
    h = self.h[:, :, :L]
    h = self.window(self.t, h)
    
    return h

The operator definition from the paper does not have the three skip connections present in the version above. Otherwise it is the same as described above.

In [None]:
class AuthenticHyenaBlock(torch.nn.Module):
  def __init__(self, config: HyenaConfig):
    super().__init__()
    self.proj_input = Projection(config.d_model, config.order, config.short_conv_size)
    self.proj_output = torch.nn.Linear(config.d_model, config.d_model)
    self.filter = AuthenticHyenaFilter(
      config.d_model,
      config.d_filter_mlp,
      config.d_embed,
      config.order,
      config.n_filter_layers,
      config.context_length,
      config.omega,
    )
    self.dropout = torch.nn.Dropout(config.pdrop_hyena)
    self.fft_conv = FFTConv()
    self.B = torch.nn.Parameter(torch.randn((config.order, 1, config.d_model, 1)))

  def forward(self, u: torch.Tensor) -> torch.Tensor:
    L = u.shape[1]
    
    *x, v = self.proj_input(u)
    
    h = self.filter(L)

    # The reference code for the paper does the product with x_i first
    # but we follow the paper eq (6) here in putting it after the convolution
    for i, x_i in enumerate(x):
      h_i = h[i].unsqueeze(0)
      v = x_i * self.fft_conv(h_i, v, self.B[i])
    
    v = v.transpose(1, 2)
    y = self.proj_output(v)

    return y

Now we can train the hyena-based model. We reduce the number of layers due to vanishing/exploding gradients and the epochs due to overfitting.

In [None]:
hyena_config = HyenaConfig(
  d_model=386,
  n_layers=2,
  vocab_size=len(vocabulary),
  d_embed=33,
  d_filter_mlp=64,
  n_filter_layers=4,
  context_length=CONTEXT_LENGTH,
  short_conv_size=3,
  order=2,
  pdrop_hyena=0.0,
  pdrop_embed=0.2,
  omega=12,
  epochs=10,
  learning_rate=6e-4,
  betas=(0.9, 0.98),
  weight_decay=1,
  device_type="gpu",  # cpu, gpu
  precision="bf16",  # 32, 16, 16-mixed, bf16
  batch_size=64,
  num_workers=4,
)

authentic_hyena_model = GPModel(hyena_config, AuthenticHyenaBlock)
authentic_hyena_model = train(authentic_hyena_model, hyena_config)

In [None]:
print(
  "hyena average probability of next token:",
  average_probability(authentic_hyena_model, val_ds, tokens=256)
)

In [None]:
print(
  generate(
    authentic_hyena_model,
    hyena_config.context_length,
    "Wherefore art thou ",
    method="topk",
    max_tokens=100,
    k=2
  )
)