Skip to content

Conversation

@justinjfu
Copy link
Contributor

No description provided.


However, the stated FLOP/s does not tell the whole story as not all operations can be performed at the same speed. The vast majority (often 90%+) of the peak FLOP/s on modern accelerators is restricted to matrix multiplications performed by the TensorCore (GPUs) or MXU (TPUs). Element-wise operations such as point-wise addition or multiplication are performed at a much slower rate by the vector ALU on the chip. Scalar FLOP/s are also much lower, especially compared to CPUs. This means that if your workflow is not dominated by matrix multiplications, you should never expect to utilize more than 10-20% of the FLOP/s of your accelerator.

![TPUv4](https://docs.cloud.google.com/static/tpu/docs/images/tpu-v4-layout.png)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a diagram for v5e? or can we say that the high level design for v5e is similar? We are talking about v5e in this section and then show a diagram for a different generation.

Modern accelerators (such as TPUs and GPUs) typically contain multiple levels of memory just like CPUs, each trading off speed vs. capacity.

- **Registers** live closest to the processor and can usually be accessed within a single processor cycle. Each core of an accelerator typically has a very limited number of registers (typically in the 10-100 range, depending on the type of register).
- **Caches** (SRAM) typically live on same die as the processor, and can typically be accessed with 1-10x the latency of accessing a register. The amount of SRAM usually ranges in the 10s-100s MB range.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it is accurate to call SRAM "caches"? Is this how it's used on TPU?

2. Here, we need to explicitly zero-out `out_vmem_ref` on the first accumulation step
of each output block as Pallas does not guarantee any initialization for buffers.

We construct a Pallas kernel using the `pl.kernel` entry point. Passing in `mesh=pltpu.create_tensorcore_mesh("x")` informs JAX that we wish to use the TPU tensorcore backend (as opposed to GPU or TPU sparsecore). On TPUs with megacore enabled, the axis name we passed in can be queried to obtain the core id via `lax.axis_index("x")`.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: should we capitalize first letters in tensorcore and sparsecore?

2. Here, we need to explicitly zero-out `out_vmem_ref` on the first accumulation step
of each output block as Pallas does not guarantee any initialization for buffers.

We construct a Pallas kernel using the `pl.kernel` entry point. Passing in `mesh=pltpu.create_tensorcore_mesh("x")` informs JAX that we wish to use the TPU tensorcore backend (as opposed to GPU or TPU sparsecore). On TPUs with megacore enabled, the axis name we passed in can be queried to obtain the core id via `lax.axis_index("x")`.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: using something like "core" instead of "x" might be slightly more clear. "x" also suggests that the axis might be use for partitioning the x input, which isn't the case in the kernel.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants