In [1]:
import timeit
import numpy as np

In [2]:
# Define two matrices
m1 = np.array([[1, 2], [3, 4]])
m2 = np.array([[5, 6], [7, 8]])

def f_dot():
  result = np.dot(m1, m2)

def f_matmul():
  result = np.matmul(m1, m2)

def f_einsum():
  result = np.einsum('ij,jk->ik', m1, m2)

n = 100000

print(f'matmul: {timeit.timeit(f_matmul, number = n)}')
print(f'dot: {timeit.timeit(f_dot, number = n)}')
print(f'einsum: {timeit.timeit(f_einsum, number = n)}')


matmul: 0.07955529098398983
dot: 0.05391174997203052
einsum: 0.14371120801661164
matmul: 0.07955529098398983
dot: 0.05391174997203052
einsum: 0.14371120801661164


In [3]:
# Define two matrices
m1 = np.array([[1, 2, 3], [3, 4, 5]])
m2 = np.array([[5, 6, 7], [7, 8, 9], [9, 10, 11]])
def f_dot():
  result = np.dot(m1, m2)
def f_matmul():
  result = np.matmul(m1, m2)
def f_einsum():
  result = np.einsum('ij,jk -> ik', m1, m2)
def f_at():
  result = m1 @ m2

n = 1000000

print(f'matmul: {timeit.timeit(f_matmul, number = n)}')
print(f'dot: {timeit.timeit(f_dot, number = n)}')
print(f'einsum: {timeit.timeit(f_einsum, number = n)}')
print(f'at: {timeit.timeit(f_at, number = n)}')

# Check if the outputs are identical
output_dot = f_dot()
output_matmul = f_matmul()
output_einsum = f_einsum()

print(f"Are the outputs identical? {np.array_equal(output_dot, output_matmul) and np.array_equal(output_dot, output_einsum)}")


matmul: 0.6809974589850754
matmul: 0.6809974589850754


dot: 0.5357569169718772
dot: 0.5357569169718772


einsum: 1.4327294999966398
einsum: 1.4327294999966398


at: 0.638580167084001
Are the outputs identical? True
at: 0.638580167084001
Are the outputs identical? True


In [4]:
import time
import random

def demo_perf_counter():
    print("Demo of time.perf_counter()")
    print("============================")

    # Basic usage
    print("\n1. Basic timing of a simple operation:")
    start = time.perf_counter()
    # Simulate some work
    sum([i**2 for i in range(10000)])
    end = time.perf_counter()
    print(f"Time taken: {end - start:.6f} seconds")

    # Timing a function
    print("\n2. Timing a function:")
    def slow_function():
        time.sleep(0.1)  # Simulate some work
    
    start = time.perf_counter()
    slow_function()
    end = time.perf_counter()
    print(f"slow_function() took {end - start:.6f} seconds")

    # Timing multiple iterations
    print("\n3. Timing multiple iterations:")
    iterations = 5
    total_time = 0
    for _ in range(iterations):
        start = time.perf_counter()
        slow_function()
        end = time.perf_counter()
        total_time += (end - start)
    print(f"Average time over {iterations} iterations: {total_time/iterations:.6f} seconds")

    # Measuring time between checkpoints
    print("\n4. Measuring time between checkpoints:")
    start = time.perf_counter()
    # Checkpoint 1
    time.sleep(random.uniform(0.1, 0.3))
    checkpoint1 = time.perf_counter()
    # Checkpoint 2
    time.sleep(random.uniform(0.2, 0.4))
    checkpoint2 = time.perf_counter()
    # End
    time.sleep(random.uniform(0.1, 0.2))
    end = time.perf_counter()

    print(f"Time to checkpoint 1: {checkpoint1 - start:.6f} seconds")
    print(f"Time between checkpoints 1 and 2: {checkpoint2 - checkpoint1:.6f} seconds")
    print(f"Time from checkpoint 2 to end: {end - checkpoint2:.6f} seconds")
    print(f"Total time: {end - start:.6f} seconds")

    # Demonstrating monotonicity
    print("\n5. Demonstrating monotonicity:")
    times = [time.perf_counter() for _ in range(5)]
    for i in range(1, len(times)):
        print(f"Difference between consecutive calls: {times[i] - times[i-1]:.9f} seconds")

demo_perf_counter()


Demo of time.perf_counter()

1. Basic timing of a simple operation:
Time taken: 0.000663 seconds

2. Timing a function:
slow_function() took 0.105041 seconds

3. Timing multiple iterations:
Demo of time.perf_counter()

1. Basic timing of a simple operation:
Time taken: 0.000663 seconds

2. Timing a function:
slow_function() took 0.105041 seconds

3. Timing multiple iterations:


Average time over 5 iterations: 0.102664 seconds

4. Measuring time between checkpoints:
Average time over 5 iterations: 0.102664 seconds

4. Measuring time between checkpoints:


Time to checkpoint 1: 0.104817 seconds
Time between checkpoints 1 and 2: 0.299447 seconds
Time from checkpoint 2 to end: 0.139653 seconds
Total time: 0.543917 seconds

5. Demonstrating monotonicity:
Difference between consecutive calls: 0.000000833 seconds
Difference between consecutive calls: 0.000000125 seconds
Difference between consecutive calls: 0.000000167 seconds
Difference between consecutive calls: 0.000000167 seconds
Time to checkpoint 1: 0.104817 seconds
Time between checkpoints 1 and 2: 0.299447 seconds
Time from checkpoint 2 to end: 0.139653 seconds
Total time: 0.543917 seconds

5. Demonstrating monotonicity:
Difference between consecutive calls: 0.000000833 seconds
Difference between consecutive calls: 0.000000125 seconds
Difference between consecutive calls: 0.000000167 seconds
Difference between consecutive calls: 0.000000167 seconds


In [5]:
from typing import List, NamedTuple

import jax
import jax.numpy as jnp

class LayerWeights(NamedTuple):
  attn_norm: jax.Array
  ffn_norm: jax.Array
  w_q_dhk: jax.Array
  w_k_dhk: jax.Array
  w_v_dhk: jax.Array
  w_o_hkd: jax.Array
  w1: jax.Array
  w2: jax.Array
  w3: jax.Array

class XfmrWeights(NamedTuple):
  tok_embeddings: jax.Array
  layer_weights: List[LayerWeights]
  norm: jax.Array
  output: jax.Array

def norm(x, w, eps: float = 1e-6):
  return w * (x * jax.lax.rsqrt(jax.lax.pow(x, 2).mean(-1, keepdims=True) + eps))

def attention(input_bld, params):
    """
    B: batch size
    L: sequence length
    M: memory length 
    D: model dimension
    H: number of attention heads in a layer
    K: size of each attention key or value
    """
    normalized_bld = norm(input_bld, params.attn_norm)
    query_blhk = jnp.einsum('bld,dhk->blhk', normalized_bld, params.w_q_dhk)
    key_blhk = jnp.einsum('bld,dhk->blhk', normalized_bld, params.w_k_dhk)
    value_blhk = jnp.einsum('bld,dhk->blhk', normalized_bld, params.w_v_dhk)
    logits_bhlm = jnp.einsum('blhk,bmhk->bhlm', query_blhk, key_blhk)
    _, l, h, k = query_blhk.shape
    logits_bhlm = logits_bhlm / jnp.sqrt(k)
    mask = jnp.triu(jnp.ones((l, l)), k=1).astype(input_bld.dtype)
    logits_bhlm = logits_bhlm - jnp.inf * mask[None, None, :, :]
    weights_bhlm = jax.nn.softmax(logits_bhlm, axis=-1)
    wtd_values_blhk = jnp.einsum('blhk,bhlm->blhk', value_blhk, weights_bhlm)
    out_bld = jnp.einsum('blhk,hkd->bld', wtd_values_blhk, params.w_o_hkd)
    return out_bld

def ffn(x: jax.Array, w1: jax.Array, w2: jax.Array, w3: jax.Array) -> jax.Array:
  return jnp.dot(jax.nn.silu(jnp.dot(x, w1)) * jnp.dot(x, w3), w2)

def transformer(tokens: jax.Array, params: jax.Array) -> jax.Array:
  x = params.tok_embeddings[tokens]
  def scan_fn(h, layer_weights):
    h += attention(h, layer_weights)
    h += ffn(norm(h, layer_weights.ffn_norm), layer_weights.w1, layer_weights.w2, layer_weights.w3)
    return h, None
  h, _ = jax.lax.scan(scan_fn, x, params.layer_weights)
  h = norm(h, params.norm)
  logits = jnp.dot(h, params.output.T)
  return logits

vocab_size = 32000
dim = 4096
hidden_dim = 14336
n_layers = 1
n_heads = 32
head_dim = dim // n_heads

layer_weights = LayerWeights(
  attn_norm=jnp.ones((n_layers, dim,)),
  ffn_norm=jnp.ones((n_layers, dim,)),
  w_q_dhk=jnp.zeros((n_layers, dim, n_heads, head_dim)),
  w_k_dhk=jnp.zeros((n_layers, dim, n_heads, head_dim)),
  w_v_dhk=jnp.zeros((n_layers, dim, n_heads, head_dim)),
  w_o_hkd=jnp.zeros((n_layers, n_heads, head_dim, dim)),
  w1=jnp.zeros((n_layers, dim, hidden_dim)),
  w2=jnp.zeros((n_layers, hidden_dim, dim)),
  w3=jnp.zeros((n_layers, dim, hidden_dim))
)
params = XfmrWeights(tok_embeddings = jnp.ones((vocab_size, dim)),
                     layer_weights=layer_weights,
                     norm=jnp.ones((dim,)),
                     output=jnp.ones((vocab_size, dim)))
tokens = jnp.array([[123,234,234,345,446]])
out = transformer(tokens, params)
print(f'{out.shape=}')


out.shape=(1, 5, 32000)
out.shape=(1, 5, 32000)
