By vectorization, our implemetation of the [bigram data model](https://arxiv.org/pdf/2306.00802) is much faster compared to the original implementation.

In [1]:
from dataclasses import dataclass
import itertools
import logging
import random
import math
import numpy as np
import pickle
import time
import sys

from collections import namedtuple
from functools import partial

import jax
from jax import jit, lax
from jax import numpy as jnp
from jax import random as jr
from jax import vmap
from jax.numpy import linalg as jla

from typing import List, Optional, Tuple, Sequence

from markov import *
from config import *
from causal_graph import *
from old_sampler import *
import torch

logging.getLogger().setLevel(logging.INFO)

SyntaxError: invalid syntax (markov.py, line 330)

In [2]:
args = DataArgs(k=2, seq_length=16, show_latents=False)
ds = Dataset(args)
rng = np.random.default_rng(42)
%timeit ds.gen_batch(rng, 16)

1.84 ms ± 10.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [3]:
torch_args = BiettiSamplerConfig(batch_size=16, seq_len=16, show_latents=True, seed=42)
torch_ds = BiettiTask(torch_args)
%timeit torch_ds.generate()

683 μs ± 9.11 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [4]:
torch_args = BiettiSamplerConfig(batch_size=16, seq_len=16, show_mask=True)
torch_ds = BBTask(torch_args)
%timeit torch_ds.generate()

428 μs ± 1.16 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Compared with JAX

In [5]:
def get_stationary(pi):
    mu = jla.svd(pi.T - jnp.eye(pi.shape[0]))[-1][-1]
    return mu / mu.sum()


class InContextTree:
    def __init__(self, vocab_size, dag, alpha=1):
        assert jnp.all(dag < jnp.arange(len(dag)))
        self.vocab_size = vocab_size
        self.dag = dag
        self.alpha = alpha

    def sample(self, key):
        pi_key, seq_key, test_key = jr.split(key, 3)
        prior = self.alpha * jnp.ones(self.vocab_size)
        pi = jr.dirichlet(pi_key, prior, [self.vocab_size])
        mu = get_stationary(pi)
        x = jnp.zeros((len(self.dag) + 1,), dtype=int)

        def step(i, carry):
            x, k = carry
            k, subkey = jr.split(k)
            p = jnp.where(self.dag[i] == -1, mu, pi[x[self.dag[i]]])
            x = x.at[i].set(jr.choice(subkey, pi.shape[0], p=p))
            return x, k

        x, _ = lax.fori_loop(0, len(self.dag), step, (x, seq_key))
        test_token = jr.choice(test_key, self.vocab_size)
        x = x.at[-1].set(test_token)
        y = pi[test_token]
        return x, y

    def bayes(self, seq):
        s, seq = seq[-1], seq[:-1]
        counts = jnp.zeros(self.vocab_size)
        counts = counts.at[seq].add(seq[self.dag] == s)
        counts += self.alpha
        return counts / counts.sum()


class InContextDAG:
    def __init__(self, vocab_size, dag, alpha):
        for i, p in enumerate(dag):
            # print(i, p)
            assert max(p, default=-1) < i
        dag = [jnp.array(p, dtype=int) for p in dag]
        self.vocab_size = vocab_size
        self.dag = dag
        self.alpha = alpha

    def sample(self, key):
        pi_key, seq_key = jr.split(key)
        ks = set(len(p) for p in self.dag)
        pi_keys = jr.split(pi_key, len(ks))
        pi = dict()
        pi[0] = jnp.ones(self.vocab_size) / self.vocab_size
        prior = self.alpha * jnp.ones(self.vocab_size)
        for k, subkey in zip(ks, pi_keys):
            pi[k] = jr.dirichlet(subkey, prior, [self.vocab_size] * k)

        x = jnp.zeros((len(self.dag) - 1,), dtype=int)
        for i in range(len(self.dag)):
            k = len(self.dag[i])
            if k == 0:
                p = pi[0]
            else:
                p = pi[k][tuple(x[self.dag[i]])]

            if i != len(self.dag) - 1:
                seq_key, subkey = jr.split(seq_key)
                new_token = jr.choice(subkey, self.vocab_size, p=p)
                x = x.at[i].set(new_token)
        return x, p

    def bayes(self, seq):
        counts = jnp.zeros(self.vocab_size)
        s = seq[self.dag[-1]]
        for i in range(len(self.dag) - 1):
            if len(self.dag[i]) == len(s):
                counts = counts.at[seq[i]].add(jnp.all(seq[self.dag[i]] == s))
        counts += self.alpha
        return counts / counts.sum()

class RNG:
    def __init__(self, seed=None, key=None):
        if seed is not None:
            self.key = jax.random.PRNGKey(seed)
        elif key is not None:
            self.key = key
        else:
            raise Exception("RNG expects either a seed or random key.")

    def next(self, n_keys=1):
        if n_keys > 1:
            return jax.random.split(self.next(), n_keys)
        else:
            self.key, key = jax.random.split(self.key)
            return key

    def __getattr__(self, name):
        return partial(getattr(jax.random, name), self.next())

JAX is slightly faster than the vectorized PyTorch.

In [3]:
vocab_size = 10
seq_len = 256
batch_size = 64
rng = RNG(0)
dag = jnp.arange(seq_len)-1
BigramJax = InContextTree(vocab_size, dag)
sample_fn = jit(lambda k: vmap(BigramJax.sample)(jr.split(k, batch_size)))
key, subkey = jr.split(rng.next())
batches = sample_fn(subkey)

INFO:2025-01-28 15:03:01,126:jax._src.xla_bridge:927: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-28 15:03:01,128:jax._src.xla_bridge:927: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/hyan/anaconda3/envs/ICL/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)


In [4]:
key, subkey = jr.split(rng.next())
%timeit sample_fn(subkey)

6.18 ms ± 43.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
torch_args = MarkovSamplerConfig(vocab_size=vocab_size, batch_size=batch_size, seq_len=seq_len, order=1)
torch_ds = MarkovSampler(torch_args)
%timeit torch_ds.generate()

7.43 ms ± 147 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [6]:
dag = torch.arange(seq_len)-1
BigramTorch = InContextTreeTorch(vocab_size, dag)
%timeit BigramTorch.sample_batch(batch_size)

7.71 ms ± 266 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Torch with jit

We can also use torch.jit to optimize the torch implementation.

In [7]:
from typing import List
import torch

def get_stationary(pi: torch.Tensor) -> torch.Tensor:
    """
    Cannot be jitted due to dynamic tensor shape
    """
    batch_size, vocab_size, _ = pi.shape
    pi_t = pi.transpose(1, 2)  # Transpose each matrix
    svd_input = pi_t - torch.eye(vocab_size, device=pi.device).unsqueeze(0).expand(batch_size, -1, -1)
    _, _, v = torch.linalg.svd(svd_input)
    mu = torch.abs(v[:, -1, :])  # Last singular vector for each matrix, make sure that mu is positive to eliminate numerical issues
    return mu / mu.sum(dim=1, keepdim=True)


@torch.jit.script
def sample_dirichlet(alpha: torch.Tensor, size: List[int]) -> torch.Tensor:
    """
    Manually sample from a Dirichlet distribution using Gamma sampling due to incompatibility
    between torch.jit and torch.distributions.Dirichlet. 

    Args:
        alpha (torch.Tensor): Concentration parameters of shape (num_states,).
        size (List[int]): Output size, e.g., [batch_size, num_states, num_states].

    Returns:
        torch.Tensor: Samples from the Dirichlet distribution of shape `size`.
    """
    gamma_samples = torch._standard_gamma(alpha.expand(size))
    return gamma_samples / gamma_samples.sum(dim=-1, keepdim=True)



@torch.jit.script
def markov_chain_sample_batch(num_states: int, batch_size: int, dag: torch.Tensor,
                              device: str, alpha: float = 0.1) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Sample a batch of sequences from a Markov chain.

    Args:
        alpha (float): Dirichlet prior concentration parameter.
        num_states (int): Number of states in the Markov chain.
        device (str): Device to perform computations on ('cpu' or 'cuda').
        seq_len (int): Length of the sequence to sample.
        batch_size (int): Number of sequences to sample.

    Returns:
        torch.Tensor: Batch of sampled sequences of shape (batch_size, seq_len).
    """
    # Sample transition matrices from a Dirichlet distribution
    prior = alpha * torch.ones(num_states, device=device)
    pi = sample_dirichlet(prior, size=(batch_size, num_states, num_states))  # Shape: (batch_size, num_states, num_states)
    
    # Compute stationary distribution
    mu = get_stationary(pi)
    
    seq_len = len(dag) + 1

    # Initialize sequences
    samples = torch.zeros((batch_size, seq_len), dtype=torch.long, device=device)

    for t in range(seq_len-1):
        if dag[t] == -1:  # Root node
            p = mu  # Use stationary distribution
        else:  # Child node
            parent_tokens = samples[:, dag[t]]  # Shape: (batch_size,)
            p = pi[torch.arange(batch_size), parent_tokens]  # Transition probabilities for parent tokens

        # Sample tokens for all sequences in the batch
        samples[:, t] = torch.multinomial(p, num_samples=1).squeeze()
    
    # Sample test tokens for the last position
    test_tokens = torch.randint(num_states, (batch_size,), device=device)
    samples[:, -1] = test_tokens
    target_probs = pi[torch.arange(batch_size), test_tokens]  # Probabilities of test tokens
    return samples, target_probs

In [8]:
# Sample a batch of sequences
samples_batch = markov_chain_sample_batch(vocab_size, seq_len, dag, "cpu")

In [9]:
%timeit markov_chain_sample_batch(vocab_size, seq_len, dag, "cpu")

16.7 ms ± 41.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


It is even slower, so we do not bother...

### Test Multiple Parents Implementation

In [7]:
vocab_size = 10
seq_len = 256
alpha = 1
batch_size = 64
dag = [[], []] + [[(i - 1) // 2, i - 1] for i in range(2, seq_len + 1)]

problem = InContextDAG(vocab_size=vocab_size, dag=dag, alpha=alpha)
sample_fn = jit(lambda k: vmap(problem.sample)(jr.split(k, batch_size)))
rng = RNG(0)
batch, p = sample_fn(rng.next())

In [8]:
%timeit sample_fn(rng.next())

25.3 ms ± 109 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
problemTorch = InContextDAGTorch(vocab_size=vocab_size, dag=dag, alpha=alpha)

For multiple parents, our implementation is much faster.

In [11]:
%timeit problemTorch.sample_batch(batch_size)

10 ms ± 27.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%%timeit
batch_size = 2
seq_len = 6
order = 2
batch = torch.randn((batch_size, seq_len))
states = torch.stack([batch[:, t:t + order] for t in range(seq_len - order)], dim=1)

7.24 μs ± 73.3 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [17]:
%%timeit

batch_size = 2
seq_len = 6
order = 2
batch = torch.randn((batch_size, seq_len))

states = torch.as_strided(batch, 
                 size=(batch_size, seq_len - order, order), 
                 stride=(batch.stride(0), batch.stride(1), batch.stride(1)))

1.85 μs ± 37.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
