-
Notifications
You must be signed in to change notification settings - Fork 1
Intro to Pallas blog post #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
5061c6d to
f9be0d0
Compare
f9be0d0 to
d2dcdfc
Compare
|
|
||
| However, the stated FLOP/s does not tell the whole story as not all operations can be performed at the same speed. The vast majority (often 90%+) of the peak FLOP/s on modern accelerators is restricted to matrix multiplications performed by the TensorCore (GPUs) or MXU (TPUs). Element-wise operations such as point-wise addition or multiplication are performed at a much slower rate by the vector ALU on the chip. Scalar FLOP/s are also much lower, especially compared to CPUs. This means that if your workflow is not dominated by matrix multiplications, you should never expect to utilize more than 10-20% of the FLOP/s of your accelerator. | ||
|
|
||
|  |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a diagram for v5e? or can we say that the high level design for v5e is similar? We are talking about v5e in this section and then show a diagram for a different generation.
| Modern accelerators (such as TPUs and GPUs) typically contain multiple levels of memory just like CPUs, each trading off speed vs. capacity. | ||
|
|
||
| - **Registers** live closest to the processor and can usually be accessed within a single processor cycle. Each core of an accelerator typically has a very limited number of registers (typically in the 10-100 range, depending on the type of register). | ||
| - **Caches** (SRAM) typically live on same die as the processor, and can typically be accessed with 1-10x the latency of accessing a register. The amount of SRAM usually ranges in the 10s-100s MB range. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if it is accurate to call SRAM "caches"? Is this how it's used on TPU?
| 2. Here, we need to explicitly zero-out `out_vmem_ref` on the first accumulation step | ||
| of each output block as Pallas does not guarantee any initialization for buffers. | ||
|
|
||
| We construct a Pallas kernel using the `pl.kernel` entry point. Passing in `mesh=pltpu.create_tensorcore_mesh("x")` informs JAX that we wish to use the TPU tensorcore backend (as opposed to GPU or TPU sparsecore). On TPUs with megacore enabled, the axis name we passed in can be queried to obtain the core id via `lax.axis_index("x")`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: should we capitalize first letters in tensorcore and sparsecore?
| 2. Here, we need to explicitly zero-out `out_vmem_ref` on the first accumulation step | ||
| of each output block as Pallas does not guarantee any initialization for buffers. | ||
|
|
||
| We construct a Pallas kernel using the `pl.kernel` entry point. Passing in `mesh=pltpu.create_tensorcore_mesh("x")` informs JAX that we wish to use the TPU tensorcore backend (as opposed to GPU or TPU sparsecore). On TPUs with megacore enabled, the axis name we passed in can be queried to obtain the core id via `lax.axis_index("x")`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: using something like "core" instead of "x" might be slightly more clear. "x" also suggests that the axis might be use for partitioning the x input, which isn't the case in the kernel.
No description provided.