In [None]:
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import jax

In [None]:
import jax_smi
jax_smi.initialise_tracking()

In [None]:
# input_size = 2**19 + 2**18
input_size = 2**18
hidden_size = 2**14
w_chunk = 2**12
w_size = 2**20
k = 64
# n_dp, n_mp = 1, 4
n_dp, n_mp = 4, 1
input_size = input_size * n_dp

In [None]:
one_v4_chip_flops = 275 * 1e12  # https://cloud.google.com/tpu/docs/v4
v4_8_flops = one_v4_chip_flops * 4
chunk_flops = input_size * hidden_size * w_chunk * n_mp
task_flops = input_size * hidden_size * w_size
print("For chunk:", chunk_flops / v4_8_flops)
print("For task:", task_flops / v4_8_flops)

In [None]:
devices = np.asarray(jax.local_devices()).reshape(n_dp, n_mp)
mesh = jax.sharding.Mesh(devices, ("dp", "mp"))

cpu = jax.devices("cpu")[0]
to_cpu = lambda x: jax.device_put(x, cpu)

In [None]:
from functools import partial


with jax.default_device(cpu):
    weight_dec = jnp.asarray(np.empty((w_size, hidden_size), dtype=np.uint16)).view(jnp.bfloat16)
    @partial(jax.jit, device=cpu, donate_argnums=(0, 1, 2))
    def generate_weight(weight, key, counter, scale):
        key, rng = jax.random.split(key)
        weight = jax.lax.dynamic_update_slice(weight, jax.random.uniform(rng, (w_chunk, hidden_size), dtype=jnp.bfloat16, minval=-scale, maxval=scale), (counter, 0))
        return weight, key, counter + w_chunk
    key = jax.random.key(0)
    counter = 0
    n_chunks = w_size // w_chunk // n_mp
    try:
        for _ in trange(n_chunks, postfix="Generating weights..."):
            key, subkey = jax.random.split(key)
            scale = hidden_size ** -0.5
            for _ in range(n_mp):
                weight_dec, key, counter = generate_weight(weight_dec, key, counter, scale)
                # key, subkey = jax.random.split(key)
                # weights_dec.append(jax.random.uniform(subkey, (w_chunk, hidden_size), dtype=jnp.bfloat16, minval=-scale, maxval=scale))
    except KeyboardInterrupt:
        pass
    weight_enc = weight_dec.copy()

In [None]:
import gc


# key_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("dp"))
key_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("dp", None))
weight_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "mp"))
topk_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("dp", "mp"))
# single_mesh = jax.sharding.Mesh(devices[:1, :], ("dp", "mp"))
# single_sharding = jax.sharding.NamedSharding(single_mesh, jax.sharding.PartitionSpec(None, "mp"))

def split(key):
    a, b = jax.random.split(key)
    return a, b

def gen_inputs(input_key):
    # input_key, subkey = split(input_key)
    input_key, subkey = input_key + 1, input_key
    
    # inputs = jax.vmap(lambda x: jax.random.randint(x, (input_size // mesh.shape["dp"], hidden_size), -100, 100, jnp.int8))(subkey).reshape(input_size, hidden_size)
    # inputs = jax.random.randint(subkey, (input_size, hidden_size), -100, 100, jnp.int8)
    # inputs = jax.random.bits(subkey, shape=(input_size, hidden_size), dtype=jnp.uint8)
    inputs = jax.lax.broadcasted_iota(jnp.int4, (input_size, hidden_size), 0).astype(jnp.int8) + jax.lax.broadcasted_iota(jnp.int4, (input_size, hidden_size), 1).astype(jnp.int8) + subkey
    # inputs = jnp.ones((input_size, hidden_size), jnp.int8)
    return inputs, input_key
gen_inputs_jit = jax.jit(gen_inputs, in_shardings=(key_sharding,), out_shardings=(data_sharding, key_sharding), donate_argnums=(0,))

def matmul_topk(inputs, weight_chunk, old_weights, old_indices, offset):
    output_chunk = inputs @ weight_chunk
    # weights, indices = jax.lax.top_k(output_chunk, k=k)
    # weights, indices = jax.lax.approx_max_k(output_chunk, k=k)
    weights, indices = jax.lax.approx_max_k(output_chunk.astype(jnp.bfloat16), k=k, recall_target=0.8)
    weights = weights / 128
    indices += offset
    # weights, indices = jax.lax.approx_max_k(output_chunk, k=k, recall_target=0.5)
    # weights, indices = jax.lax.approx_max_k(output_chunk, k=k, recall_target=0.25)
    
    # _, indices = jax.lax.approx_max_k(output_chunk[:, :k * 4], k=k)
    # _, indices = jax.lax.approx_max_k(output_chunk[:, :k], k=k)
    # weights = output_chunk.reshape(output_chunk.shape[0], k, -1).sum(-1)
    # weights, indices = jax.lax.top_k(output_chunk, k=k)
    
    if old_weights is None or old_indices is None:
        return weights, indices
    else:
        # replace_mask = weights > old_weights
        # return jnp.where(replace_mask, weights, old_weights), jnp.where(replace_mask, indices, old_indices)

        new_full_weights = jnp.concatenate((weights, old_weights), axis=1)
        new_full_indices = jnp.concatenate((indices, old_indices), axis=1)
        
        # over_indices = jnp.argsort(new_full_weights)[:, -k:]
        # return jnp.take_along_axis(new_full_weights, over_indices, axis=1), jnp.take_along_axis(new_full_indices, over_indices, axis=1)
        
        _, over_indices = jax.lax.top_k(new_full_weights, k=k)
        return jnp.take_along_axis(new_full_weights, over_indices, axis=1), jnp.take_along_axis(new_full_indices, over_indices, axis=1)

matmul_topk_jit = jax.jit(matmul_topk, in_shardings=(data_sharding, weight_sharding, topk_sharding, topk_sharding, None), out_shardings=(topk_sharding, topk_sharding), donate_argnums=(2, 3))

# weights_0_put = jax.device_put(weights_enc[0], weight_sharding)
# send_weights = lambda _: weights_0_put
def send_weights(weights):
    # weights = jax.device_put(weights, single_sharding)
    return jax.device_put(weights, weight_sharding)

In [None]:
# input_key = jax.device_put(jax.random.split(jax.random.key(0), mesh.shape["dp"]), key_sharding)
# %time inputs, input_key = jax.block_until_ready(gen_inputs_jit(input_key))


# for idx in trange(3):
#     %time current_chunk = jax.block_until_ready(send_weights(weights_enc[idx]))
#     %time inputs.block_until_ready()
#     %time current_chunk.block_until_ready()
#     %time weights, indices = jax.block_until_ready(matmul_topk_jit(inputs, current_chunk))
# compiled = matmul_topk_jit.lower(inputs, current_chunk).compile()
# estimated_chunk_flops = compiled.cost_analysis()[0]['flops']
# print("Our flops estimate:", estimated_chunk_flops / v4_8_flops)
# print("Estimate from compiler:", compiled.cost_analysis()[0]["optimal_seconds"])
# for name in ("current_chunk", "next_chunk", "weights", "indices", "inputs"):
#     try:
#         del globals()[name]
#     except KeyError:
#         pass
#     gc.collect()

In [None]:
import splatmul


def sparse_matmul(weights, indices, decoder_weight):
    with jax.default_device(cpu):
        weights = np.asarray(weights.view(jnp.uint16))
        indices = np.asarray(indices.astype(jnp.uint32))
        decoder_weight = np.asarray(decoder_weight.view(jnp.uint16))
        result = splatmul.splatmul(weights, indices, decoder_weight)
        return jnp.asarray(result).view(jnp.bfloat16)


In [None]:
# %time result = sparse_matmul_jit(saved_encodings[0], saved_encodings[1], weight_dec)
# %time result.block_until_ready();

In [None]:
from threading import Thread
from queue import Queue

def sparse_matmul_async(weights, indices):
    out_queue = Queue()
    def worker():
        out_queue.put(sparse_matmul(weights, indices, weight_dec))
    thread = Thread(target=worker)
    thread.start()
    return out_queue

In [None]:
import threading
import queue
import time


def matmul_trial(inputs):
    current_chunk = send_weights(weights_enc[0])
    weights, indices = None, None
    bar = trange(0, w_size, w_chunk * n_mp, postfix=f"Encoder forward pass")
    offset = 0
    for chunk_start in bar:
        next_chunk = send_weights(weight_enc[chunk_start:chunk_start + w_chunk * n_mp].T)
        weights, indices = matmul_topk_jit(inputs, current_chunk, weights, indices, offset)
        offset += w_chunk * n_mp
        current_chunk = next_chunk
        gc.collect()
    return weights, indices


def trial(save_encodings=False):
    global saved_encodings
    gc.collect()
    input_key = 0  # jax.random.split(jax.random.key(0), mesh.shape["dp"])
    encodings, past_inputs, decoder_outputs = None, None, None
    for _ in trange(100, postfix="Measuring forward speed..."):
        gc.collect()
        if encodings is not None:
            decoder_outputs = sparse_matmul_async(encodings[0], encodings[1])
            past_inputs = inputs
        inputs, input_key = gen_inputs_jit(input_key)
        weights, indices = matmul_trial(inputs)
        if decoder_outputs is not None:
            print("Computing decoder...")
            decoder_start = time.time()
            decoder_outputs = decoder_outputs.get().block_until_ready()
            decoder_end = time.time()
            print("Decoder time:", decoder_end - decoder_start)
            with jax.default_device(cpu):
                print(decoder_outputs.shape, decoder_outputs.mean(), decoder_outputs.std(), past_inputs.shape)
            print("Sending past input data to CPU...")
            past_input_start = time.time()
            past_inputs = jax.device_put(past_inputs, cpu).block_until_ready()
            print("Past input time:", time.time() - past_input_start)
            past_inputs, decoder_outputs = None, None
            gc.collect()
        print("Waiting for encoder...")
        before_compute = time.time()
        weights, indices = weights.block_until_ready(), indices.block_until_ready()
        # wi = jax.block_until_ready((weights, indices))
        after_compute = time.time()
        print("Encoder time:", (after_compute - before_compute))
        weights, indices = to_cpu(weights), to_cpu(indices)
        # weights, indices = jax.block_until_ready(jax.device_put((wi), cpu))
        after_to_cpu = time.time()
        print("Encoder to CPU time:", after_to_cpu - after_compute)
        encodings = weights, indices
        if save_encodings:
            saved_encodings = encodings
            break

trial()

In [None]:
# 1/0

In [None]:
weights, indices = saved_encodings

In [None]:
# # sort on CPU is slow
# import jax.experimental.sparse


# def sort_encodings(weights, indices):
#     index_array = indices.ravel().astype(jnp.uint32)
#     weight_array = weights.ravel().astype(jnp.bfloat16)
#     indices_to_sort = jnp.argsort(index_array)
#     sorted_indices = index_array[indices_to_sort]
#     sorted_weights = weight_array[indices_to_sort]
#     inverse_indices = jnp.argsort(indices_to_sort)
#     bcoo = jax.experimental.sparse.BCOO((sorted_weights, jnp.stack((inverse_indices, sorted_indices), axis=-1).reshape(-1, 2)), shape=weights.shape)
#     return bcoo

# sort_encodings_jit = jax.jit(sort_encodings, backend="cpu")
# %time bcoo = sort_encodings_jit(weights, indices).block_until_ready();

In [None]:
# # Numba doesn't support bfloat16
# import ml_dtypes
# import numba as nb

# index_array_np = np.asarray(indices.ravel().astype(np.uint32))
# weight_array_np = np.asarray(weights.ravel().view(jnp.float16))
# %time indices_to_sort = np.argsort(index_array_np);
# %time inverse_indices = np.argsort(indices_to_sort);
# %time sorted_indices = index_array_np[indices_to_sort];
# %time sorted_weights = weight_array_np[indices_to_sort];

# @nb.njit
# def coo_matmul(inverse_indices, sorted_indices, sorted_weights, weights_dec):
#     N = inverse_indices.shape[0] // k
#     D = weights_dec.shape[1]
#     out = np.zeros((N, D), dtype=ml_dtypes.bfloat16)
#     for i in range(len(inverse_indices)):
#         out[inverse_indices[i] // k] += (sorted_weights[i].view(ml_dtypes.bfloat16).astype(np.float32) * weights_dec[sorted_indices[i]].view(ml_dtypes.bfloat16).astype(np.float32))
#     return out

# %time coo_matmul(inverse_indices, sorted_indices, np.asarray(sorted_weights.view(np.float16)), np.asarray(weight_dec.view(np.float16)));

In [None]:
%time index_array_np = np.asarray(indices.ravel().astype(np.uint32))
%time weight_array_np = np.asarray(weights.ravel().view(jnp.float16))
%time indices_to_sort = np.argsort(index_array_np);
%time inverse_indices = np.argsort(indices_to_sort);
%time sorted_indices = index_array_np[indices_to_sort];
%time sorted_weights = weight_array_np[indices_to_sort];

In [None]:
# # Too slow
# import numba as nb

# @nb.njit
# def coo_matmul(inverse_indices, sorted_indices, sorted_weights, weights_dec):
#     N = inverse_indices.shape[0] // k
#     D = weights_dec.shape[1]
#     out = np.zeros((N, D), dtype=np.float32)
#     for i in range(len(inverse_indices)):
#         out[inverse_indices[i] // k] += (sorted_weights[i] * weights_dec[sorted_indices[i]])
#     return out

# %time result = coo_matmul(inverse_indices, sorted_indices, np.asarray(sorted_weights.view(jnp.bfloat16).astype(np.float32)), np.asarray(weight_dec.view(jnp.bfloat16).astype(np.float32)));

In [None]:
import numba as nb

@nb.njit
def to_bf16(x):
    x = x.view(np.uint32)
    conv_lowbit = x >> 16 + (x & (1 << 15))
    return conv_lowbit.astype(np.uint16)

@nb.njit
def from_bf16(x):
    x = x.astype(np.uint32) << 16
    return np.ascontiguousarray(x.view(np.float32)[..., ::2])
    # return x.view(np.float32)
    # return (x.astype(np.uint32) << 16).view(np.float32)

print(to_bf16(np.array([200], dtype=np.float32)))
print(from_bf16(np.array([[12, 13, 14], [15, 16, 17]], dtype=np.uint16)))

In [None]:
# # Too slow
# @nb.njit
# def coo_matmul(inverse_indices, sorted_indices, sorted_weights, weights_dec):
#     N = inverse_indices.shape[0] // k
#     D = weights_dec.shape[1]
#     out = np.zeros((N, D), dtype=np.uint16)
#     for i in range(len(inverse_indices)):
#         out[inverse_indices[i] // k] = to_bf16(from_bf16(out[inverse_indices[i] // k]) + (sorted_weights[i] * from_bf16(weights_dec[sorted_indices[i]])))
#     return out

# %time result = coo_matmul(inverse_indices, sorted_indices, np.asarray(sorted_weights.view(jnp.uint16)), np.asarray(weight_dec.view(jnp.uint16)));

In [None]:
# from numba_progress import ProgressBar
%env OPENBLAS_NUM_THREADS=64

# @nb.njit(parallel=True)
# def sparse_matmul(indices, weights, weights_dec):
#     out = np.zeros((indices.shape[0], weights_dec.shape[1]), dtype=np.uint16)
#     for i in nb.prange(indices.shape[0]):
#         out[i] = to_bf16(from_bf16(weights[i]) @ from_bf16(weights_dec[indices[i]]))
#     return out


@nb.njit(nogil=True, parallel=True)
# def sparse_matmul(indices, weights, weights_dec, bar):
def sparse_matmul(indices, weights, weights_dec):
    out = np.zeros((indices.shape[0], weights_dec.shape[1]), dtype=np.uint16)
    for i in nb.prange(indices.shape[0]):
        out[i] = to_bf16(from_bf16(weights[i]) @ from_bf16(weights_dec[indices[i]]))
        # bar.update(1)
    # chunk = 64
    # for i_ in range(0, indices.shape[0], chunk):
        # for i in nb.prange(i_, min(i_ + chunk, indices.shape[0])):
        #     out[i] = to_bf16(from_bf16(weights[i]) @ from_bf16(weights_dec[indices[i]]))
        # bar.update(chunk)
    # return out

%time indices_np = np.asarray(indices.astype(np.uint32))
%time weights_np = np.asarray(weights.view(np.uint16))
%time weight_dec_np = np.asarray(weight_dec.view(np.uint16))
# with ProgressBar(total=indices_np.shape[0]) as bar:
#     result = sparse_matmul(indices_np, weights_np, weight_dec_np, bar)
%time result = sparse_matmul(indices_np, weights_np, weight_dec_np);

In [None]:
# %time bcoo = jax.experimental.sparse.BCOO((jnp.asarray(sorted_weights).view(jnp.bfloat16), jnp.stack((jnp.asarray(inverse_indices // k), jnp.asarray(sorted_indices)), axis=-1).reshape(-1, 2)), shape=(weights.shape[0], weight_dec.shape[0])).block_until_ready();

In [None]:
# import jax.experimental.sparse


# chunk_size = 2**12
# def sparse_matmul_cpu(row_indices, col_indices, weights, weight_dec):
#     def sparse_matmul_cpu_chunk(carry, chunk):
#         row_ind, col_ind, w = chunk
#         bcoo = jax.experimental.sparse.BCOO((w, jnp.stack((row_ind, col_ind), axis=-1).reshape(-1, 2)), shape=(weights.shape[0], weight_dec.shape[0]))
#         return carry + bcoo @ weight_dec, None
#     out = jnp.zeros((weights.shape[0], weight_dec.shape[1]), jnp.bfloat16)
#     return jax.lax.scan(sparse_matmul_cpu_chunk, out, tuple(x.reshape(-1, chunk_size) for x in (row_indices, col_indices, weights)))[0]

# sparse_matmul_cpu_jit = jax.jit(sparse_matmul_cpu, device=cpu)
# with jax.default_device(cpu):
#     %time result = sparse_matmul_cpu_jit(jnp.asarray(inverse_indices // k), jnp.asarray(sorted_indices), jnp.asarray(sorted_weights).view(jnp.bfloat16), weight_dec).block_until_ready();

In [None]:
# import jax.experimental.sparse


# chunk_size = 2**12
# def sparse_matmul_cpu(row_ind, col_ind, w, weight_dec):
#     bcoo = jax.experimental.sparse.BCOO((w, jnp.stack((row_ind, col_ind), axis=-1).reshape(-1, 2)), shape=(weights.shape[0], weight_dec.shape[0]))
#     return bcoo @ weight_dec

# sparse_matmul_cpu_jit = jax.jit(sparse_matmul_cpu, device=cpu)
# with jax.default_device(cpu):
#     out = 0
#     for i in trange(0, len(sorted_indices), chunk_size):
#         out += sparse_matmul_cpu_jit(jnp.asarray(inverse_indices[i:i+chunk_size] // k), jnp.asarray(sorted_indices[i:i+chunk_size]), jnp.asarray(sorted_weights[i:i+chunk_size]).view(jnp.bfloat16), weight_dec)

In [None]:
# from tqdm import trange
# masks = []
# chunked_weights = []
# chunked_indices = []
# for i in trange(0, w_size, w_chunk):
#     mask = (i < index_array_np) & (index_array_np < (i + w_chunk))
#     chunked_indices.append(index_array_np[mask] - i)
#     chunked_weights.append(weight_array_np[mask])
#     masks.append(np.nonzero(mask)[0])

In [None]:
# from matplotlib import pyplot as plt
# plt.hist([x.shape[0] for x in chunked_indices])
# plt.show()

In [None]:
b_chunk = 2**12
for idx in trange(0, w_size, b_chunk):
    sparse_matmul_out = sparse_matmul_jit(*(u[idx:idx + b_chunk] for u in encback), weight_dec)
    del sparse_matmul_out
    gc.collect()
# %time sparse_matmul_out = sparse_matmul_jit(*encback, weight_dec)

In [None]:
encback[0].shape, encback[1].shape

In [None]:
weights_dec[0].shape