In [41]:
import torch
import torch.nn.functional as F
import random
import numpy as np
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset, Dataset

We are giving you a 2 layer transformer model in induction task. Check the data generation mechanism for the induction head and adjust it accordingly to your needs. For AR, you need to code one from scratch. 

In [64]:
class InductionAR(Dataset):
    # In induction head we have ngram = 1. But the code provided is for general ngram setting. While using this, initialize ngram = 1. 
    """ Naive associative recall dataset """
    def __init__(self, num_examples, tokenizer, n_gram=1, n_ctx = 1024, seed = 0, train_split=0.8):
        self.num_examples = num_examples
        self.tokenizer = tokenizer
        self.n_ctx = n_ctx
        self.seed = seed
        self.n_gram = n_gram
        x, y = self.data_gen()
        if train_split:
            self.train_x, self.train_y, self.test_x, self.test_y = self.split(x, y, train_split)
            self.train = self.numpy_to_tensor_dataset(self.train_x, self.train_y)
            self.test = self.numpy_to_tensor_dataset(self.test_x, self.test_y)
        else:
            self.train_x, self.train_y, self.test_x, self.test_y = x, y, None, None
            self.train = self.numpy_to_tensor_dataset(self.train_x, self.train_y)
            self.test = None
    def get_str_dataset(self, split="train"):
        if split == "train":
            x_str = [self.tokenizer.decode(xi) for xi in self.train_x]
            y_str = [self.tokenizer.decode([yi]) for yi in self.train_y]
        elif split == "test":
            x_str = [self.tokenizer.decode(xi) for xi in self.test_x]
            y_str = [self.tokenizer.decode([yi]) for yi in self.test_y]
        else:
            raise ValueError("split should be either 'train' or 'test'")
        return x_str, y_str
    def numpy_to_tensor_dataset(self, x, y):
        x = torch.tensor(x, dtype=torch.long)
        y = torch.tensor(y, dtype=torch.long)
        return TensorDataset(x, y)
    def gen_single_example(self):
        # get the vocab size
        def count(str_x, str_n_gram_head):
            counts = sum([
                str_x.startswith(str_n_gram_head, i) for i in range(len(str_x))
            ])
            return counts
        def gen_x():
            gen_x_success = False
            while not gen_x_success:
                x = np.random.choice(vocab, self.n_ctx-self.n_gram*2, replace=True).tolist()
                # remove the case where the n_gram_head is repeated in the sequence
                for _ in range(10):
                    pos = [i for i in range(len(x)-len(n_gram_head)+1) if x[i:i+len(n_gram_head)] == n_gram_head]
                    if len(pos) == 0:
                        gen_x_success = True
                        break
                    else:
                        # remove the n_gram_head from x
                        # get all positions of the n_gram_head in x
                        for p in reversed(pos):
                            # remove len(n_gram_head) elements from x starting from p
                            x = x[:p] + x[p+len(n_gram_head):]
                        # fill the rest of the sequence with random elements
                        x.extend(np.random.choice(vocab, self.n_ctx-self.n_gram*2-len(x), replace=True).tolist())
                x_test = " ".join([str(xi) for xi in x])
                if count(x_test, str_n_gram_head) == 0:
                    gen_x_success = True

            x_test = x + n_gram_head
            # check if there's only one n_gram_head in the sequence
            # to avoid the case where the n_gram_head has 
            # repeated structure such as x= [1, 2, 3, 1] , n_gram_head = [1, 1]
            str_x_test = " "+" ".join([str(xi) for xi in x_test])+ " "
            if count(str_x_test, str_n_gram_head) > 1:
                print("Error in gen_x")
                print(f"str_x_test: {str_x_test}", f"str_n_gram_head: {str_n_gram_head}", 
                      "count: ", count(str_x_test, str_n_gram_head))
            if count(str_x_test, str_n_gram_head) == 1:
                return x
            else:
                return None
        def insert_n_gram_head(x):
            pos = random.randint(0, len(x)-1)
            y = x[pos]
            x_new = x[:pos] + n_gram_head + x[pos:] + n_gram_head
            str_x_new = " "+" ".join([str(xi) for xi in x_new])+" "

            if count(str_x_new, str_n_gram_head) == 2:
                return x_new, y
            else:
                return None, None
        vocab_size = len(self.tokenizer)
        vocab = list(range(vocab_size))
        # set a deterministic n_gram_head
        n_gram_head = list(range(self.n_gram))
       
        str_n_gram_head = " "+" ".join([str(xi) for xi in n_gram_head])+" "
        assert self.n_gram*2 < self.n_ctx, "n_gram*2 should be less than n_ctx"
        success = False
        while not success:
            x = gen_x()
            if x is not None:
                for _ in range(10):
                    x_new, y = insert_n_gram_head(x)
                    if x_new is not None:
                        success = True
                        break
        return x_new, y
            
    def data_gen(self):
        x = []
        y = []
        # get previous random status and recover after generating the dataset
        random_status = random.getstate()
        random.seed(self.seed)
        for i in range(self.num_examples):
            if i % 5000 == 0:
                print(f"Generating example {i}")
            xi, yi = self.gen_single_example()
            x.append(xi)
            y.append(yi)
        x = np.array(x)
        y = np.array(y)
        random.setstate(random_status)
        return x, y
    def split(self, x, y, train_ratio = 0.8):
        num_train = int(len(x)*train_ratio)
        train_x = x[:num_train]
        train_y = y[:num_train]
        test_x = x[num_train:]
        test_y = y[num_train:]
        return train_x, train_y, test_x, test_y


In [43]:
class Random_tokenizer:
    def __init__(self, vocab=None, vocab_size = None) -> None:
        """ The init function of the tokenizer class.
         one of vocab or vocab_size should be provided.
         If vocab is provided, vocab_size will be ignored.
         If vocab is not provided, vocab_size should be provided. we will generate a random vocab of vocab_size."""
        if vocab is not None:
            self.vocab = vocab
            self.vocab_size = len(vocab)
        elif vocab_size is not None:
            self.vocab_size = vocab_size
            self.vocab = [str(i) for i in range(vocab_size)]
        else:
            raise ValueError("one of vocab or vocab_size should be provided.")
        self.vocab_dict = {v: i for i, v in enumerate(self.vocab)}
        self.vocab_dict_inv = {i: v for i, v in enumerate(self.vocab)}
    def encode(self, x):
        """ Encode a string into a list of integers """
        return [self.vocab_dict[i] for i in x]
    def decode(self, x):
        """ Decode a list of integers into a string """
        return ' '.join([self.vocab_dict_inv[i] for i in x])
    def __len__(self):
        return self.vocab_size
    def __getitem__(self, i):
        return self.vocab[i]
    def __iter__(self):
        return iter(self.vocab)
    def __contains__(self, x):
        return x in self.vocab
    def __repr__(self):
        return f"Random_tokenizer(vocab_size={self.vocab_size})"
    def __str__(self):
        return f"Random_tokenizer(vocab_size={self.vocab_size})"
    def __call__(self, x):
        return self.encode(x)
        

In [44]:
# self attention block
class Block(nn.Module):
    def __init__(self, embed_dim, max_len=11):
        super(Block, self).__init__()
        self.embed_dim = embed_dim
        self.max_len = max_len
        self.c_attn = nn.Linear(embed_dim, embed_dim*3)
        self.register_buffer('mask', torch.tril(torch.ones(max_len, max_len)))
    def forward(self, x):
        T = x.size(1)
        q, k, v = self.c_attn(x).chunk(3, dim=-1)
        y = torch.nn.functional.scaled_dot_product_attention(q, k, v, is_causal=True)
        return y

### Mamba Block
Below, we copy the implementation of the Mamba block with parallel scan from this repository: https://github.com/alxndrTL/mamba.py/

We create a wrapper for this block as MambaBlockWrapper, holding all the other necessary configurations regarding the internals of the Mamba block. The other configurations (conv dimension, hidden state dimension, etc.) are taken directly from the linked repo.

In [45]:
import math
from dataclasses import dataclass
from typing import Union

"""
Taken from https://github.com/alxndrTL/mamba.py/blob/main/pscan.py

An implementation of the parallel scan operation in PyTorch (Blelloch version).
Please see docs/pscan.ipynb for a detailed explanation of what happens here.

"""

def npo2(len):
    """
    Returns the next power of 2 above len
    """

    return 2 ** math.ceil(math.log2(len))

def pad_npo2(X):
    """
    Pads input length dim to the next power of 2

    Args:
        X : (B, L, D, N)

    Returns:
        Y : (B, npo2(L), D, N)
    """

    len_npo2 = npo2(X.size(1))
    pad_tuple = (0, 0, 0, 0, 0, len_npo2 - X.size(1))
    return F.pad(X, pad_tuple, "constant", 0)

class PScan(torch.autograd.Function):
    @staticmethod
    def pscan(A, X):
        # A : (B, D, L, N)
        # X : (B, D, L, N)

        # modifies X in place by doing a parallel scan.
        # more formally, X will be populated by these values :
        # H[t] = A[t] * H[t-1] + X[t] with H[0] = 0
        # which are computed in parallel (2*log2(T) sequential steps (ideally), instead of T sequential steps)

        # only supports L that is a power of two (mainly for a clearer code)
        
        B, D, L, _ = A.size()
        num_steps = int(math.log2(L))

        # up sweep (last 2 steps unfolded)
        Aa = A
        Xa = X
        for _ in range(num_steps-2):
            T = Xa.size(2)
            Aa = Aa.view(B, D, T//2, 2, -1)
            Xa = Xa.view(B, D, T//2, 2, -1)
            
            Xa[:, :, :, 1].add_(Aa[:, :, :, 1].mul(Xa[:, :, :, 0]))
            Aa[:, :, :, 1].mul_(Aa[:, :, :, 0])

            Aa = Aa[:, :, :, 1]
            Xa = Xa[:, :, :, 1]

        # we have only 4, 2 or 1 nodes left
        if Xa.size(2) == 4:
            Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
            Aa[:, :, 1].mul_(Aa[:, :, 0])

            Xa[:, :, 3].add_(Aa[:, :, 3].mul(Xa[:, :, 2] + Aa[:, :, 2].mul(Xa[:, :, 1])))
        elif Xa.size(2) == 2:
            Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 0]))
            return
        else:
            return

        # down sweep (first 2 steps unfolded)
        Aa = A[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
        Xa = X[:, :, 2**(num_steps-2)-1:L:2**(num_steps-2)]
        Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 1]))
        Aa[:, :, 2].mul_(Aa[:, :, 1])

        for k in range(num_steps-3, -1, -1):
            Aa = A[:, :, 2**k-1:L:2**k]
            Xa = X[:, :, 2**k-1:L:2**k]

            T = Xa.size(2)
            Aa = Aa.view(B, D, T//2, 2, -1)
            Xa = Xa.view(B, D, T//2, 2, -1)

            Xa[:, :, 1:, 0].add_(Aa[:, :, 1:, 0].mul(Xa[:, :, :-1, 1]))
            Aa[:, :, 1:, 0].mul_(Aa[:, :, :-1, 1])

    @staticmethod
    def pscan_rev(A, X):
        # A : (B, D, L, N)
        # X : (B, D, L, N)

        # the same function as above, but in reverse
        # (if you flip the input, call pscan, then flip the output, you get what this function outputs)
        # it is used in the backward pass

        # only supports L that is a power of two (mainly for a clearer code)

        B, D, L, _ = A.size()
        num_steps = int(math.log2(L))

        # up sweep (last 2 steps unfolded)
        Aa = A
        Xa = X
        for _ in range(num_steps-2):
            T = Xa.size(2)
            Aa = Aa.view(B, D, T//2, 2, -1)
            Xa = Xa.view(B, D, T//2, 2, -1)
                    
            Xa[:, :, :, 0].add_(Aa[:, :, :, 0].mul(Xa[:, :, :, 1]))
            Aa[:, :, :, 0].mul_(Aa[:, :, :, 1])

            Aa = Aa[:, :, :, 0]
            Xa = Xa[:, :, :, 0]

        # we have only 4, 2 or 1 nodes left
        if Xa.size(2) == 4:
            Xa[:, :, 2].add_(Aa[:, :, 2].mul(Xa[:, :, 3]))
            Aa[:, :, 2].mul_(Aa[:, :, 3])

            Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1].add(Aa[:, :, 1].mul(Xa[:, :, 2]))))
        elif Xa.size(2) == 2:
            Xa[:, :, 0].add_(Aa[:, :, 0].mul(Xa[:, :, 1]))
            return
        else:
            return

        # down sweep (first 2 steps unfolded)
        Aa = A[:, :, 0:L:2**(num_steps-2)]
        Xa = X[:, :, 0:L:2**(num_steps-2)]
        Xa[:, :, 1].add_(Aa[:, :, 1].mul(Xa[:, :, 2]))
        Aa[:, :, 1].mul_(Aa[:, :, 2])

        for k in range(num_steps-3, -1, -1):
            Aa = A[:, :, 0:L:2**k]
            Xa = X[:, :, 0:L:2**k]

            T = Xa.size(2)
            Aa = Aa.view(B, D, T//2, 2, -1)
            Xa = Xa.view(B, D, T//2, 2, -1)

            Xa[:, :, :-1, 1].add_(Aa[:, :, :-1, 1].mul(Xa[:, :, 1:, 0]))
            Aa[:, :, :-1, 1].mul_(Aa[:, :, 1:, 0])

    @staticmethod
    def forward(ctx, A_in, X_in):
        """
        Applies the parallel scan operation, as defined above. Returns a new tensor.
        If you can, privilege sequence lengths that are powers of two.

        Args:
            A_in : (B, L, D, N)
            X_in : (B, L, D, N)

        Returns:
            H : (B, L, D, N)
        """

        L = X_in.size(1)

        # cloning is requiered because of the in-place ops
        if L == npo2(L):
            A = A_in.clone()
            X = X_in.clone()
        else:
            # pad tensors (and clone btw)
            A = pad_npo2(A_in) # (B, npo2(L), D, N)
            X = pad_npo2(X_in) # (B, npo2(L), D, N)
        
        # prepare tensors
        A = A.transpose(2, 1) # (B, D, npo2(L), N)
        X = X.transpose(2, 1) # (B, D, npo2(L), N)

        # parallel scan (modifies X in-place)
        PScan.pscan(A, X)

        ctx.save_for_backward(A_in, X)
        
        # slice [:, :L] (cut if there was padding)
        return X.transpose(2, 1)[:, :L]
    
    @staticmethod
    def backward(ctx, grad_output_in):
        """
        Flows the gradient from the output to the input. Returns two new tensors.

        Args:
            ctx : A_in : (B, L, D, N), X : (B, D, L, N)
            grad_output_in : (B, L, D, N)

        Returns:
            gradA : (B, L, D, N), gradX : (B, L, D, N)
        """

        A_in, X = ctx.saved_tensors

        L = grad_output_in.size(1)

        # cloning is requiered because of the in-place ops
        if L == npo2(L):
            grad_output = grad_output_in.clone()
            # the next padding will clone A_in
        else:
            grad_output = pad_npo2(grad_output_in) # (B, npo2(L), D, N)
            A_in = pad_npo2(A_in) # (B, npo2(L), D, N)

        # prepare tensors
        grad_output = grad_output.transpose(2, 1)
        A_in = A_in.transpose(2, 1) # (B, D, npo2(L), N)
        A = torch.nn.functional.pad(A_in[:, :, 1:], (0, 0, 0, 1)) # (B, D, npo2(L), N) shift 1 to the left (see hand derivation)

        # reverse parallel scan (modifies grad_output in-place)
        PScan.pscan_rev(A, grad_output)

        Q = torch.zeros_like(X)
        Q[:, :, 1:].add_(X[:, :, :-1] * grad_output[:, :, 1:])

        return Q.transpose(2, 1)[:, :L], grad_output.transpose(2, 1)[:, :L]
    
pscan = PScan.apply

In [46]:
# taken from https://github.com/alxndrTL/mamba.py/blob/main/mamba.py
@dataclass
class MambaConfig:
    d_model: int # D
#     n_layers: int
    dt_rank: Union[int, str] = 'auto'
    d_state: int = 16 # N in paper/comments
    expand_factor: int = 2 # E in paper/comments
    d_conv: int = 4

    dt_min: float = 0.001
    dt_max: float = 0.1
    dt_init: str = "random" # "random" or "constant"
    dt_scale: float = 1.0
    dt_init_floor = 1e-4

    bias: bool = False
    conv_bias: bool = True

    pscan: bool = True # use parallel scan mode or sequential mode when training

    def __post_init__(self):
        self.d_inner = self.expand_factor * self.d_model # E*D = ED in comments

        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)

class MambaBlock(nn.Module):
    def __init__(self, config: MambaConfig):
        super().__init__()

        self.config = config

        # projects block input from D to 2*ED (two branches)
        self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)

        self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner, 
                              kernel_size=config.d_conv, bias=config.conv_bias, 
                              groups=config.d_inner,
                              padding=config.d_conv - 1)
        
        # projects x to input-dependent Δ, B, C
        self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)

        # projects Δ from dt_rank to d_inner
        self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)

        # dt initialization
        # dt weights
        dt_init_std = config.dt_rank**-0.5 * config.dt_scale
        if config.dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif config.dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError
        
        # dt bias
        dt = torch.exp(
            torch.rand(config.d_inner) * (math.log(config.dt_max) - math.log(config.dt_min)) + math.log(config.dt_min)
        ).clamp(min=config.dt_init_floor)
        inv_dt = dt + torch.log(-torch.expm1(-dt)) # inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        #self.dt_proj.bias._no_reinit = True # initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        # todo : explain why removed

        # S4D real initialization
        A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
        self.D = nn.Parameter(torch.ones(config.d_inner))

        # projects block output from ED back to D
        self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)

    def forward(self, x):
        # x : (B, L, D)
        
        # y : (B, L, D)

        _, L, _ = x.shape

        xz = self.in_proj(x) # (B, L, 2*ED)
        x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)

        # x branch
        x = x.transpose(1, 2) # (B, ED, L)
        x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
        x = x.transpose(1, 2) # (B, L, ED)

        x = F.silu(x)
        y = self.ssm(x)

        # z branch
        z = F.silu(z)

        output = y * z
        output = self.out_proj(output) # (B, L, D)

        return output
    
    def ssm(self, x):
        # x : (B, L, ED)

        # y : (B, L, ED)

        A = -torch.exp(self.A_log.float()) # (ED, N)
        D = self.D.float()
        # TODO remove .float()

        deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)

        delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
        delta = F.softplus(self.dt_proj(delta)) # (B, L, ED)

        if self.config.pscan:
            y = self.selective_scan(x, delta, A, B, C, D)
        else:
            y = self.selective_scan_seq(x, delta, A, B, C, D)

        return y
    
    def selective_scan(self, x, delta, A, B, C, D):
        # x : (B, L, ED)
        # Δ : (B, L, ED)
        # A : (ED, N)
        # B : (B, L, N)
        # C : (B, L, N)
        # D : (ED)

        # y : (B, L, ED)

        deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)

        BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
        
        hs = pscan(deltaA, BX)

        y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x

        return y
    
    def selective_scan_seq(self, x, delta, A, B, C, D):
        # x : (B, L, ED)
        # Δ : (B, L, ED)
        # A : (ED, N)
        # B : (B, L, N)
        # C : (B, L, N)
        # D : (ED)

        # y : (B, L, ED)

        _, L, _ = x.shape

        deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)

        BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)

        h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
        hs = []

        for t in range(0, L):
            h = deltaA[:, t] * h + BX[:, t]
            hs.append(h)
            
        hs = torch.stack(hs, dim=1) # (B, L, ED, N)

        y = (hs @ C.unsqueeze(-1)).squeeze(3) # (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x

        return y
    
    # -------------------------- inference -------------------------- #
    """
    Concerning auto-regressive inference

    The cool part of using Mamba : inference is constant wrt to sequence length
    We just have to keep in cache, for each layer, two things :
    - the hidden state h (which is (B, ED, N)), as you typically would when doing inference with a RNN
    - the last d_conv-1 inputs of the layer, to be able to compute the 1D conv which is a convolution over the time dimension
      (d_conv is fixed so this doesn't incur a growing cache as we progress on generating the sequence)
      (and d_conv is usually very small, like 4, so we just have to "remember" the last 3 inputs)

    Concretely, these two quantities are put inside a cache tuple, and are named h and inputs respectively.
    h is (B, ED, N), and inputs is (B, ED, d_conv-1)
    The MambaBlock.step() receives this cache, and, along with outputing the output, alos outputs the updated cache for the next call.

    The cache object is initialized as follows : (None, torch.zeros()).
    When h is None, the selective scan function detects it and start with h=0.
    The torch.zeros() isn't a problem (it's same as just feeding the input, because the conv1d is padded)

    As we need one such cache variable per layer, we store a caches object, which is simply a list of cache object. (See mamba_lm.py)
    """
    
    def step(self, x, cache):
        # x : (B, D)
        # cache : (h, inputs)
                # h : (B, ED, N)
                # inputs : (B, ED, d_conv-1)
        
        # y : (B, D)
        # cache : (h, inputs)
        
        h, inputs = cache
        
        xz = self.in_proj(x) # (B, 2*ED)
        x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)

        # x branch
        x_cache = x.unsqueeze(2)
        x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)

        x = F.silu(x)
        y, h = self.ssm_step(x, h)

        # z branch
        z = F.silu(z)

        output = y * z
        output = self.out_proj(output) # (B, D)

        # prepare cache for next call
        inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
        cache = (h, inputs)
        
        return output, cache

    def ssm_step(self, x, h):
        # x : (B, ED)
        # h : (B, ED, N)

        # y : (B, ED)
        # h : (B, ED, N)

        A = -torch.exp(self.A_log.float()) # (ED, N) # todo : ne pas le faire tout le temps, puisque c'est indépendant de la timestep
        D = self.D.float()
        # TODO remove .float()

        deltaBC = self.x_proj(x) # (B, dt_rank+2*N)

        delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, dt_rank), (B, N), (B, N)
        delta = F.softplus(self.dt_proj(delta)) # (B, ED)

        deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(1) # (B, ED, N)

        BX = deltaB * (x.unsqueeze(-1)) # (B, ED, N)

        if h is None:
            h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)

        h = deltaA * h + BX # (B, ED, N)

        y = (h @ C.unsqueeze(-1)).squeeze(2) # (B, ED, N) @ (B, N, 1) -> (B, ED, 1)

        y = y + D * x

        # todo : pq h.squeeze(1) ??
        return y, h.squeeze(1)

In [53]:
class MambaBlockWrapper(nn.Module):
    def __init__(self, embed_dim, max_len):  # mamba doesnt have a maxlen but we just add this to work with BaseNet
        super(MambaBlockWrapper, self).__init__()
        
        self.config = MambaConfig(d_model=embed_dim)
        self.block = MambaBlock(self.config)
        
    def forward(self, x):
        return self.block(x)

MambaConfig(d_model=10)

MambaConfig(d_model=10, dt_rank=1, d_state=16, expand_factor=2, d_conv=4, dt_min=0.001, dt_max=0.1, dt_init='random', dt_scale=1.0, bias=False, conv_bias=True, pscan=True)

In [54]:
class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
    
class BaseNet(nn.Module):
    def __init__(self, vocab_size, embed_dim, 
                 is_pe = False, max_len=11, 
                 attn_layers=2, block=None,
                 **kwargs):
        super(BaseNet, self).__init__()
        if block is None:
            raise ValueError("block type should be provided.")
        self.vocab_size = vocab_size
        self.max_len = max_len
        self.is_pe = is_pe
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pe = nn.Embedding(max_len, embed_dim) if is_pe else None
        self.att = nn.ModuleList([block(embed_dim, max_len, **kwargs) for _ in range(attn_layers)])
        self.ln = nn.ModuleList([LayerNorm(embed_dim, True) for _ in range(attn_layers)])
        self.head = nn.Linear(embed_dim, vocab_size)
    
        print(f"BaseNet with {attn_layers} layers of {block} blocks")
        print(f"Embedding dimension: {embed_dim}")
        print(f"Positional Encoding: {is_pe}")
        print(f"Vocabulary size: {vocab_size}")
        print(f"Context length: {max_len}")
        
    def forward(self, x):
        b, t = x.size()
        x = self.embed(x)
        if self.is_pe:
            pos = torch.arange(0, t, dtype=torch.long, device=x.device)
            pe_emb = self.pe(pos) if self.is_pe else 0
            x = x + pe_emb
        for layer, ln in zip(self.att, self.ln):
            x = ln(layer(x))
        x = self.head(x)
        return x

In [56]:
# small scale test -- ignore

dataset = InductionAR(10, tokenizer, 1, n_ctx=n_ctx, seed=seed, train_split=0.8)
train_loader = DataLoader(dataset.train, batch_size=batch_size, shuffle=True)
# attn_model = BaseNet(len(tokenizer), embed_dim, is_pe,  max_len=n_ctx*4, attn_layers=attn_layers, block=Block).to(device)
mamba_model = BaseNet(len(tokenizer), embed_dim, is_pe,  max_len=n_ctx*4, attn_layers=attn_layers, block=MambaBlockWrapper).to(device)

x, y = next(iter(train_loader))
print("input x shape", x.shape, "y shape", y.shape)
mamba_model(x).shape

Generating example 0
BaseNet with 2 layers of <class '__main__.MambaBlockWrapper'> blocks
Embedding dimension: 8
Positional Encoding: True
Vocabulary size: 16
Context length: 256
input x shape torch.Size([8, 64]) y shape torch.Size([8])


torch.Size([8, 64, 16])

In [60]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using {device} device")
seed = 0
n_ctx = 64 # training sequence length, L
# num_examples = 100000 # generate 100000 examples
num_examples = 10
batch_size = 64 # batch size
vocab_size = 16 # vocabulary size, K
num_epochs = 250    # number of epochs
attn_layers = 2 # number of attention layers
embed_dim = 8 # embedding dimension, d
is_pe = True  # the default positional embedding we are using is the learned positional embedding

tokenizer = Random_tokenizer(vocab_size=vocab_size)
dataset = InductionAR(num_examples, tokenizer, 1, n_ctx=n_ctx, seed=seed, train_split=0.8)
train_loader = DataLoader(dataset.train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset.test, batch_size=batch_size, shuffle=True)

attn_learned_PE = BaseNet(len(tokenizer), embed_dim, is_pe,  max_len=n_ctx*4, attn_layers=attn_layers, block=Block).to(device)
mamba = BaseNet(len(tokenizer), embed_dim, is_pe,  max_len=n_ctx*4, attn_layers=attn_layers, block=MambaBlockWrapper).to(device)
    
# models = [mamba, attn_learned_PE] # add more models here when you have more models
models = {
    "Transformer Learned PE": attn_learned_PE,
    "Mamba": mamba
}

for model in models.values():
    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    for epoch in range(num_epochs):
        total_loss = 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(x)[:,-1]
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"epoch {epoch} loss: {total_loss/len(train_loader)}")
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)[:,-1]
            y_pred = F.softmax(y_pred, dim=-1)
            y_pred = torch.argmax(y_pred, dim=-1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
    print(f"Test accuracy: {correct/total}")



Using cpu device
Generating example 0
BaseNet with 2 layers of <class '__main__.Block'> blocks
Embedding dimension: 8
Positional Encoding: True
Vocabulary size: 16
Context length: 256
BaseNet with 2 layers of <class '__main__.MambaBlockWrapper'> blocks
Embedding dimension: 8
Positional Encoding: True
Vocabulary size: 16
Context length: 256
epoch 0 loss: 3.1832780838012695
epoch 1 loss: 3.180947780609131
epoch 2 loss: 3.178617000579834
epoch 3 loss: 3.176285743713379
epoch 4 loss: 3.173954486846924
epoch 5 loss: 3.1716222763061523
epoch 6 loss: 3.169290065765381
epoch 7 loss: 3.1669576168060303
epoch 8 loss: 3.1646246910095215
epoch 9 loss: 3.1622915267944336
epoch 10 loss: 3.1599583625793457
epoch 11 loss: 3.157625198364258
epoch 12 loss: 3.155291795730591
epoch 13 loss: 3.1529579162597656
epoch 14 loss: 3.1506242752075195
epoch 15 loss: 3.1482906341552734
epoch 16 loss: 3.145956516265869
epoch 17 loss: 3.143622875213623
epoch 18 loss: 3.141289234161377
epoch 19 loss: 3.138955593109131

epoch 232 loss: 2.6616134643554688
epoch 233 loss: 2.6592698097229004
epoch 234 loss: 2.656921148300171
epoch 235 loss: 2.6545662879943848
epoch 236 loss: 2.6522057056427
epoch 237 loss: 2.649838924407959
epoch 238 loss: 2.6474661827087402
epoch 239 loss: 2.645087242126465
epoch 240 loss: 2.6427016258239746
epoch 241 loss: 2.6403098106384277
epoch 242 loss: 2.637911796569824
epoch 243 loss: 2.6355061531066895
epoch 244 loss: 2.6330947875976562
epoch 245 loss: 2.630675792694092
epoch 246 loss: 2.6282501220703125
epoch 247 loss: 2.625817060470581
epoch 248 loss: 2.6233768463134766
epoch 249 loss: 2.620929479598999
Test accuracy: 0.5
epoch 0 loss: 3.1876120567321777
epoch 1 loss: 3.180656909942627
epoch 2 loss: 3.17370867729187
epoch 3 loss: 3.1667699813842773
epoch 4 loss: 3.159843683242798
epoch 5 loss: 3.1529316902160645
epoch 6 loss: 3.1460342407226562
epoch 7 loss: 3.1391539573669434
epoch 8 loss: 3.132291793823242
epoch 9 loss: 3.12544846534729
epoch 10 loss: 3.118624687194824
epoch

epoch 224 loss: 1.937292218208313
epoch 225 loss: 1.934779167175293
epoch 226 loss: 1.9322792291641235
epoch 227 loss: 1.9297912120819092
epoch 228 loss: 1.9273152351379395
epoch 229 loss: 1.9248512983322144
epoch 230 loss: 1.9223989248275757
epoch 231 loss: 1.9199581146240234
epoch 232 loss: 1.9175286293029785
epoch 233 loss: 1.91511070728302
epoch 234 loss: 1.9127033948898315
epoch 235 loss: 1.9103071689605713
epoch 236 loss: 1.9079217910766602
epoch 237 loss: 1.9055466651916504
epoch 238 loss: 1.9031821489334106
epoch 239 loss: 1.9008277654647827
epoch 240 loss: 1.8984836339950562
epoch 241 loss: 1.8961491584777832
epoch 242 loss: 1.893824815750122
epoch 243 loss: 1.8915101289749146
epoch 244 loss: 1.889204978942871
epoch 245 loss: 1.8869092464447021
epoch 246 loss: 1.8846229314804077
epoch 247 loss: 1.8823456764221191
epoch 248 loss: 1.8800777196884155
epoch 249 loss: 1.8778185844421387
Test accuracy: 0.0


In [None]:
import collections

# generate test data with length 32
# test the model with the test data
num_examples = 20000
n_ctx_list = [16, 32, 128, 256]
results = collections.defaultdict(list)

for n_ctx in n_ctx_list:
    dataset = InductionAR(num_examples, tokenizer, 1, n_ctx=n_ctx, seed=seed, train_split=0.01)
    train_loader = DataLoader(dataset.train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(dataset.test, batch_size=batch_size, shuffle=True)

    for model_name, model in models.items():
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in test_loader:
                x,y = x.to(device), y.to(device)
                y_pred = model(x)[:,-1]
                y_pred = F.softmax(y_pred, dim=-1)
                y_pred = torch.argmax(y_pred, dim=-1)
                correct += (y_pred == y).sum().item()
                total += y.size(0)
            print(f"{model_name}, Context length: {n_ctx}, Test accuracy: {correct/total}")
            results[model_name].append(correct/total)

            
import matplotlib.pyplot as plt

plt.figure()
for model_name, accuracies in results.items():
    plt.plot(n_ctx_list, accuracies, label=model_name)
plt.title("Accuracy on Induction Heads vs. Prompt Context Length")
plt.legend()
plt.savefig("1a.png")
plt.show()

# # generate test data with length 16
# # test the model with the test data
# num_examples = 20000
# n_ctx = 16

# dataset = InductionAR(num_examples, tokenizer, 1, n_ctx=n_ctx, seed=seed, train_split=0.01)
# train_loader = DataLoader(dataset.train, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(dataset.test, batch_size=batch_size, shuffle=True)

# attn_model.eval()
# correct = 0
# total = 0
# with torch.no_grad():
#     for x, y in test_loader:
#         x,y = x.to(device), y.to(device)
#         y_pred = attn_model(x)[:,-1]
#         y_pred = F.softmax(y_pred, dim=-1)
#         y_pred = torch.argmax(y_pred, dim=-1)
#         correct += (y_pred == y).sum().item()
#         total += y.size(0)
#     print(f"Test accuracy: {correct/total}")

# # generate test data with length 128
# # test the model with the test data
# num_examples = 20000
# n_ctx = 128

# dataset = InductionAR(num_examples, tokenizer, 1, n_ctx=n_ctx, seed=seed, train_split=0.01)
# train_loader = DataLoader(dataset.train, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(dataset.test, batch_size=batch_size, shuffle=True)

# attn_model.eval()
# correct = 0
# total = 0
# with torch.no_grad():
#     for x, y in test_loader:
#         x,y = x.to(device), y.to(device)
#         y_pred = attn_model(x)[:,-1]
#         y_pred = F.softmax(y_pred, dim=-1)
#         y_pred = torch.argmax(y_pred, dim=-1)
#         correct += (y_pred == y).sum().item()
#         total += y.size(0)
#     print(f"Test accuracy: {correct/total}")


# # generate test data with length 256
# # test the model with the test data
# num_examples = 20000
# n_ctx = 256

# dataset = InductionAR(num_examples, tokenizer, 1, n_ctx=n_ctx, seed=seed, train_split=0.01)
# train_loader = DataLoader(dataset.train, batch_size=batch_size, shuffle=True)
# test_loader = DataLoader(dataset.test, batch_size=batch_size, shuffle=True)

# attn_model.eval()
# correct = 0
# total = 0
# with torch.no_grad():
#     for x, y in test_loader:
#         x,y = x.to(device), y.to(device)
#         y_pred = attn_model(x)[:,-1]
#         y_pred = F.softmax(y_pred, dim=-1)
#         y_pred = torch.argmax(y_pred, dim=-1)
#         correct += (y_pred == y).sum().item()
#         total += y.size(0)
#     print(f"Test accuracy: {correct/total}")



Generating example 0
Generating example 5000
Generating example 10000
Generating example 15000
Transformer Learned PE, Context length: 16, Test accuracy: 0.06712121212121212
Mamba, Context length: 16, Test accuracy: 0.07075757575757576
Generating example 0
Generating example 5000
Generating example 10000
Generating example 15000
Transformer Learned PE, Context length: 32, Test accuracy: 0.07171717171717172
Mamba, Context length: 32, Test accuracy: 0.06676767676767677
Generating example 0
Generating example 5000
Generating example 10000
Generating example 15000
Transformer Learned PE, Context length: 128, Test accuracy: 0.06808080808080808
Mamba, Context length: 128, Test accuracy: 0.06671717171717172
Generating example 0
Generating example 5000
Generating example 10000
Generating example 15000
