Skip to content

Commit

Permalink
[Pallas][Mosaic] Add support for nontrivial semaphore memrefs
Browse files Browse the repository at this point in the history
The previous patch simply changed the type we use to represent semaphores,
but didn't actually add support for any more operations. With this one,
semaphore memrefs can be allocated and (dynamically) indexed.

PiperOrigin-RevId: 597538913
  • Loading branch information
apaszke authored and jax authors committed Jan 11, 2024
1 parent 858fd52 commit ce00e10
Show file tree
Hide file tree
Showing 8 changed files with 345 additions and 143 deletions.
2 changes: 2 additions & 0 deletions jax/_src/pallas/mosaic/__init__.py
Expand Up @@ -19,6 +19,8 @@
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
from jax._src.pallas.mosaic.core import SemaphoreType
from jax._src.pallas.mosaic.core import TPUMemorySpace
from jax._src.pallas.mosaic.core import semaphore
from jax._src.pallas.mosaic.core import dma_semaphore
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
from jax._src.pallas.mosaic.lowering import LoweringException
Expand Down
45 changes: 43 additions & 2 deletions jax/_src/pallas/mosaic/core.py
Expand Up @@ -22,6 +22,7 @@
from typing import Any

from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import state
from jax._src import tree_util
from jax._src import util
Expand Down Expand Up @@ -50,6 +51,7 @@ class TPUMemorySpace(enum.Enum):
VMEM = "vmem"
SMEM = "smem"
CMEM = "cmem"
SEMAPHORE = "semaphore_mem"

def __str__(self) -> str:
return self.value
Expand All @@ -58,14 +60,53 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
# A convenience function for constructing MemoryRef types.
return MemoryRef(shape, dtype, self)

class semaphore_dtype(dtypes.extended): pass
class semaphore(semaphore_dtype): pass
class dma_semaphore(semaphore_dtype): pass
class barrier_semaphore(semaphore_dtype): pass

class AbstractSemaphoreTy(dtypes.ExtendedDType):
name: str

def __repr__(self) -> str:
return self.name

def __eq__(self, other):
return self.__class__ == other.__class__

def __hash__(self) -> int:
return hash((self.__class__))

# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy

class SemaphoreTy(AbstractSemaphoreTy):
type = semaphore
name = "sem"

class DmaSemaphoreTy(AbstractSemaphoreTy):
type = dma_semaphore
name = "dma_sem"

class BarrierSemaphoreTy(AbstractSemaphoreTy):
type = barrier_semaphore
name = "barrier_sem"

class SemaphoreType(enum.Enum):
REGULAR = "regular"
DMA = "dma"
BARRIER = "barrier"

def get_aval(self) -> AbstractSemaphore:
return AbstractSemaphore(self)
def __call__(self, shape: tuple[int, ...]):
if self == SemaphoreType.DMA:
dtype = DmaSemaphoreTy()
elif self == SemaphoreType.BARRIER:
dtype = BarrierSemaphoreTy()
else:
dtype = SemaphoreTy()
return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE)

def get_aval(self) -> "AbstractMemoryRef":
return self(()).get_aval()

class AbstractMemoryRef(state.AbstractRef):
__slots__ = ["inner_aval", "memory_space"]
Expand Down

0 comments on commit ce00e10

Please sign in to comment.