In [None]:
from tqdm.auto import tqdm, trange
import jax.numpy as jnp
import numpy as np
import jax

In [None]:
import jax_smi
jax_smi.initialise_tracking()

In [None]:
input_size = 2**20
hidden_size = 2**14
w_chunk = 2**12
w_size = 2**20
k = 64
# n_dp, n_mp = 1, 4
n_dp, n_mp = 4, 1
input_size = input_size * n_dp

In [None]:
one_v4_chip_flops = 275 * 1e12  # https://cloud.google.com/tpu/docs/v4
v4_8_flops = one_v4_chip_flops * 4
chunk_flops = input_size * hidden_size * w_chunk * n_mp
task_flops = input_size * hidden_size * w_size
print("For chunk:", chunk_flops / v4_8_flops)
print("For task:", task_flops / v4_8_flops)

In [None]:
devices = np.asarray(jax.local_devices()).reshape(n_dp, n_mp)
mesh = jax.sharding.Mesh(devices, ("dp", "mp"))

cpu = jax.devices("cpu")[0]
to_cpu = lambda x: jax.device_put(x, cpu)

In [None]:
weights_enc = []
with jax.default_device(cpu):
    key = jax.random.key(0)
    n_chunks = w_size // w_chunk // n_mp
    try:
        for _ in trange(n_chunks, postfix="Generating encoder weights..."):
            key, subkey = jax.random.split(key)
            weights_enc.append(jax.random.randint(subkey, (hidden_size, w_chunk * n_mp), 0, 100, dtype=jnp.int8))
    except KeyboardInterrupt:
        pass

In [None]:
import gc


key_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("dp"))
data_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("dp", None))
weight_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec(None, "mp"))
topk_sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("dp", "mp"))
# single_mesh = jax.sharding.Mesh(devices[:1, :], ("dp", "mp"))
# single_sharding = jax.sharding.NamedSharding(single_mesh, jax.sharding.PartitionSpec(None, "mp"))

def gen_inputs(input_key):
    input_key, subkey = jax.vmap(jax.random.split, in_axes=0, out_axes=1)(input_key)
    # inputs = jax.random.uniform(subkey, (input_size, hidden_size), jnp.bfloat16)
    inputs = jnp.ones((input_size, hidden_size), jnp.int8)
    return inputs, input_key
gen_inputs_jit = jax.jit(gen_inputs, in_shardings=(key_sharding,), out_shardings=(data_sharding, key_sharding), donate_argnums=(0,))

def matmul_topk(inputs, weight_chunk, old_weights, old_indices):
    output_chunk = inputs @ weight_chunk
    # weights, indices = jax.lax.top_k(output_chunk, k=k)
    # weights, indices = jax.lax.approx_max_k(output_chunk, k=k)
    weights, indices = jax.lax.approx_max_k(output_chunk.astype(jnp.bfloat16), k=k, recall_target=0.8)
    # weights, indices = jax.lax.approx_max_k(output_chunk, k=k, recall_target=0.5)
    # weights, indices = jax.lax.approx_max_k(output_chunk, k=k, recall_target=0.25)
    
    # _, indices = jax.lax.approx_max_k(output_chunk[:, :k * 4], k=k)
    # _, indices = jax.lax.approx_max_k(output_chunk[:, :k], k=k)
    # weights = output_chunk.reshape(output_chunk.shape[0], k, -1).sum(-1)
    # weights, indices = jax.lax.top_k(output_chunk, k=k)
    
    if old_weights is None or old_indices is None:
        return weights, indices
    else:
        # replace_mask = weights > old_weights
        # return jnp.where(replace_mask, weights, old_weights), jnp.where(replace_mask, indices, old_indices)

        new_full_weights = jnp.concatenate((weights, old_weights), axis=1)
        new_full_indices = jnp.concatenate((indices, old_indices), axis=1)
        
        # over_indices = jnp.argsort(new_full_weights)[:, -k:]
        # return jnp.take_along_axis(new_full_weights, over_indices, axis=1), jnp.take_along_axis(new_full_indices, over_indices, axis=1)
        
        _, over_indices = jax.lax.top_k(new_full_weights, k=k)
        return jnp.take_along_axis(new_full_weights, over_indices, axis=1), jnp.take_along_axis(new_full_indices, over_indices, axis=1)

matmul_topk_jit = jax.jit(matmul_topk, in_shardings=(data_sharding, weight_sharding, topk_sharding, topk_sharding), out_shardings=(topk_sharding, topk_sharding), donate_argnums=(2, 3))

# weights_0_put = jax.device_put(weights_enc[0], weight_sharding)
# send_weights = lambda _: weights_0_put
def send_weights(weights):
    # weights = jax.device_put(weights, single_sharding)
    return jax.device_put(weights, weight_sharding)

In [None]:
# input_key = jax.device_put(jax.random.split(jax.random.key(0), mesh.shape["dp"]), key_sharding)
# %time inputs, input_key = jax.block_until_ready(gen_inputs_jit(input_key))


# for idx in trange(3):
#     %time current_chunk = jax.block_until_ready(send_weights(weights_enc[idx]))
#     %time inputs.block_until_ready()
#     %time current_chunk.block_until_ready()
#     %time weights, indices = jax.block_until_ready(matmul_topk_jit(inputs, current_chunk))
# compiled = matmul_topk_jit.lower(inputs, current_chunk).compile()
# estimated_chunk_flops = compiled.cost_analysis()[0]['flops']
# print("Our flops estimate:", estimated_chunk_flops / v4_8_flops)
# print("Estimate from compiler:", compiled.cost_analysis()[0]["optimal_seconds"])
# for name in ("current_chunk", "next_chunk", "weights", "indices", "inputs"):
#     try:
#         del globals()[name]
#     except KeyError:
#         pass
#     gc.collect()

In [None]:
def matmul_trial(inputs):
    current_chunk = send_weights(weights_enc[0])
    weights, indices = None, None
    bar = tqdm(weights_enc[1:], postfix=f"Trial {trial}, encoder forward pass")
    for cpu_enc_chunk in bar:
        next_chunk = send_weights(cpu_enc_chunk)

        weights, indices = matmul_topk_jit(inputs, current_chunk, weights, indices)
        
        current_chunk = next_chunk
        gc.collect()


gc.collect()
if "inputs" in globals():
    del inputs
    gc.collect()
input_key = jax.device_put(jax.random.split(jax.random.key(0), mesh.shape["dp"]), key_sharding)
for trial in trange(100, postfix="Measuring forward speed..."):
    inputs, input_key = gen_inputs_jit(input_key)
    matmul_trial(inputs)
    del inputs
    gc.collect()

In [None]:
for name in ("current_chunk", "next_chunk", "weights", "indices", "inputs"):
    try:
        del globals()[name]
    except KeyError:
        pass
    gc.collect()