In [1]:
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import jax

In [2]:
import jax_smi
jax_smi.initialise_tracking()

In [3]:
input_size = 2**20
hidden_size = 2**14
w_chunk = 2**12
w_size = 2**20
k = 64
# n_dp, n_mp = 1, 4
n_dp, n_mp = 4, 1
input_size = input_size * n_dp

In [4]:
one_v4_chip_flops = 275 * 1e12  # https://cloud.google.com/tpu/docs/v4
v4_8_flops = one_v4_chip_flops * 4
chunk_flops = input_size * hidden_size * w_chunk * n_mp
task_flops = input_size * hidden_size * w_size
print("For chunk:", chunk_flops / v4_8_flops)
print("For task:", task_flops / v4_8_flops)

For chunk: 0.25588634246423275
For task: 65.50690367084358


In [5]:
devices = np.asarray(jax.local_devices()).reshape(n_dp, n_mp)
mesh = jax.sharding.Mesh(devices, ("dp", "mp"))

cpu = jax.devices("cpu")[0]
to_cpu = lambda x: jax.device_put(x, cpu)

In [6]:
weights_enc = []
weights_dec = []
with jax.default_device(cpu):
    key = jax.random.key(0)
    n_chunks = w_size // w_chunk // n_mp
    try:
        for _ in trange(n_chunks, postfix="Generating encoder weights..."):
            key, subkey = jax.random.split(key)
            scale = hidden_size ** -0.5 / 128
            weights_enc.append(jax.random.uniform(subkey, (hidden_size, w_chunk * n_mp), dtype=jnp.bfloat16, minval=-scale, maxval=scale))
            for _ in range(n_mp):
                key, subkey = jax.random.split(key)
                weights_dec.append(jax.random.uniform(subkey, (w_chunk, hidden_size), dtype=jnp.bfloat16, minval=-scale, maxval=scale))
    except KeyboardInterrupt:
        pass

  0%|          | 0/256 [00:00<?, ?it/s, Generating encoder weights...]

In [7]:
weight_dec = jax.jit(lambda x: jnp.concatenate(x, axis=0), backend="cpu")(weights_dec)

In [8]:
import gc


# key_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("dp"))
key_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("dp", None))
weight_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "mp"))
topk_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("dp", "mp"))
# single_mesh = jax.sharding.Mesh(devices[:1, :], ("dp", "mp"))
# single_sharding = jax.sharding.NamedSharding(single_mesh, jax.sharding.PartitionSpec(None, "mp"))

def split(key):
    a, b = jax.random.split(key)
    return a, b

def gen_inputs(input_key):
    # input_key, subkey = split(input_key)
    input_key, subkey = input_key + 1, input_key
    
    # inputs = jax.vmap(lambda x: jax.random.randint(x, (input_size // mesh.shape["dp"], hidden_size), -100, 100, jnp.int8))(subkey).reshape(input_size, hidden_size)
    # inputs = jax.random.randint(subkey, (input_size, hidden_size), -100, 100, jnp.int8)
    # inputs = jax.random.bits(subkey, shape=(input_size, hidden_size), dtype=jnp.uint8)
    inputs = jax.lax.broadcasted_iota(jnp.int4, (input_size, hidden_size), 0).astype(jnp.int8) + jax.lax.broadcasted_iota(jnp.int4, (input_size, hidden_size), 1).astype(jnp.int8) + subkey
    # inputs = jnp.ones((input_size, hidden_size), jnp.int8)
    return inputs, input_key
gen_inputs_jit = jax.jit(gen_inputs, in_shardings=(key_sharding,), out_shardings=(data_sharding, key_sharding), donate_argnums=(0,))

def matmul_topk(inputs, weight_chunk, old_weights, old_indices, offset):
    output_chunk = inputs @ weight_chunk
    # weights, indices = jax.lax.top_k(output_chunk, k=k)
    # weights, indices = jax.lax.approx_max_k(output_chunk, k=k)
    weights, indices = jax.lax.approx_max_k(output_chunk.astype(jnp.bfloat16), k=k, recall_target=0.8)
    indices += offset
    # weights, indices = jax.lax.approx_max_k(output_chunk, k=k, recall_target=0.5)
    # weights, indices = jax.lax.approx_max_k(output_chunk, k=k, recall_target=0.25)
    
    # _, indices = jax.lax.approx_max_k(output_chunk[:, :k * 4], k=k)
    # _, indices = jax.lax.approx_max_k(output_chunk[:, :k], k=k)
    # weights = output_chunk.reshape(output_chunk.shape[0], k, -1).sum(-1)
    # weights, indices = jax.lax.top_k(output_chunk, k=k)
    
    if old_weights is None or old_indices is None:
        return weights, indices
    else:
        # replace_mask = weights > old_weights
        # return jnp.where(replace_mask, weights, old_weights), jnp.where(replace_mask, indices, old_indices)

        new_full_weights = jnp.concatenate((weights, old_weights), axis=1)
        new_full_indices = jnp.concatenate((indices, old_indices), axis=1)
        
        # over_indices = jnp.argsort(new_full_weights)[:, -k:]
        # return jnp.take_along_axis(new_full_weights, over_indices, axis=1), jnp.take_along_axis(new_full_indices, over_indices, axis=1)
        
        _, over_indices = jax.lax.top_k(new_full_weights, k=k)
        return jnp.take_along_axis(new_full_weights, over_indices, axis=1), jnp.take_along_axis(new_full_indices, over_indices, axis=1)

matmul_topk_jit = jax.jit(matmul_topk, in_shardings=(data_sharding, weight_sharding, topk_sharding, topk_sharding, None), out_shardings=(topk_sharding, topk_sharding), donate_argnums=(2, 3))

# weights_0_put = jax.device_put(weights_enc[0], weight_sharding)
# send_weights = lambda _: weights_0_put
def send_weights(weights):
    # weights = jax.device_put(weights, single_sharding)
    return jax.device_put(weights, weight_sharding)

In [9]:
# input_key = jax.device_put(jax.random.split(jax.random.key(0), mesh.shape["dp"]), key_sharding)
# %time inputs, input_key = jax.block_until_ready(gen_inputs_jit(input_key))


# for idx in trange(3):
#     %time current_chunk = jax.block_until_ready(send_weights(weights_enc[idx]))
#     %time inputs.block_until_ready()
#     %time current_chunk.block_until_ready()
#     %time weights, indices = jax.block_until_ready(matmul_topk_jit(inputs, current_chunk))
# compiled = matmul_topk_jit.lower(inputs, current_chunk).compile()
# estimated_chunk_flops = compiled.cost_analysis()[0]['flops']
# print("Our flops estimate:", estimated_chunk_flops / v4_8_flops)
# print("Estimate from compiler:", compiled.cost_analysis()[0]["optimal_seconds"])
# for name in ("current_chunk", "next_chunk", "weights", "indices", "inputs"):
#     try:
#         del globals()[name]
#     except KeyError:
#         pass
#     gc.collect()

In [10]:
import time


def matmul_trial(inputs):
    current_chunk = send_weights(weights_enc[0])
    weights, indices = None, None
    bar = tqdm(weights_enc[1:60], postfix=f"Encoder forward pass")
    offset = 0
    for cpu_enc_chunk in bar:
        next_chunk = send_weights(cpu_enc_chunk)
        weights, indices = matmul_topk_jit(inputs, current_chunk, weights, indices, offset)
        offset += cpu_enc_chunk.shape[1]
        current_chunk = next_chunk
        gc.collect()
    before_compute = time.time()
    weights, indices = weights.block_until_ready(), indices.block_until_ready()
    after_compute = time.time()
    print("Compute time:", (after_compute - before_compute) + (bar.last_print_t - bar.start_t))
    weights, indices = to_cpu(weights), to_cpu(indices)
    after_to_cpu = time.time()
    print("To CPU time:", after_to_cpu - after_compute)
    return weights, indices


def trial():
    global encback
    gc.collect()
    # input_key = jax.device_put(jax.random.split(jax.random.key(0), mesh.shape["dp"]), key_sharding)
    # input_key = jax.device_put(jax.random.key(0), key_sharding)
    input_key = 0
    for _ in trange(100, postfix="Measuring forward speed..."):
        inputs, input_key = gen_inputs_jit(input_key)
        weights, indices = matmul_trial(inputs)
        encback = weights, indices
        del inputs
        gc.collect()
        break

trial()

  0%|          | 0/100 [00:00<?, ?it/s, Measuring forward speed...]

  0%|          | 0/59 [00:00<?, ?it/s, Encoder forward pass]

Compute time: 106.80990624427795
To CPU time: 41.33571815490723


In [11]:
# index_array_np = np.asarray(encback[1].ravel())
# weight_array_np = np.asarray(encback[0].ravel().view(jnp.float16))

In [12]:
# from tqdm import trange
# masks = []
# chunked_weights = []
# chunked_indices = []
# for i in trange(0, w_size, w_chunk):
#     mask = (i < index_array_np) & (index_array_np < (i + w_chunk))
#     chunked_indices.append(index_array_np[mask] - i)
#     chunked_weights.append(weight_array_np[mask])
#     masks.append(np.nonzero(mask)[0])

In [13]:
# from matplotlib import pyplot as plt
# plt.hist([x.shape[0] for x in chunked_indices])
# plt.show()

In [28]:
def sparse_matmul_inner(weights, indices, decoder_weight):
    out = weights @ decoder_weight[indices].astype(jnp.float32)
    return out


def sparse_matmul(weights, indices, decoder_weight):
    out = jax.vmap(sparse_matmul_inner, in_axes=(0, 0, None), out_axes=0)(weights.astype(jnp.float32), indices, decoder_weight)
    return out


sparse_matmul_jit = jax.jit(sparse_matmul, backend="cpu")

In [29]:
b_chunk = 2**12
for idx in trange(0, w_size, b_chunk):
    sparse_matmul_out = sparse_matmul_jit(*(u[idx:idx + b_chunk] for u in encback), weight_dec)
    del sparse_matmul_out
    gc.collect()
# %time sparse_matmul_out = sparse_matmul_jit(*encback, weight_dec)

  0%|          | 0/256 [00:00<?, ?it/s]

In [16]:
from ml_dtypes import bfloat16
to_bf = lambda x: np.asarray(x, dtype=bfloat16)
%time matmul_outputs = sparse_matmul(to_bf(encback[1]), encback[0], [to_bf(w) for w in weights_dec])

TypeError: Only integer scalar arrays can be converted to a scalar index.

In [17]:
encback[0].shape, encback[1].shape

((4194304, 64), (4194304, 64))

In [18]:
weights_dec[0].shape

(4096, 16384)