Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ markers =
rocm_only: marks tests that should only run on hosts with ROCm GPUs
darwin_only: marks tests that should only run on macOS
xpu_only: marks tests that should only run on hosts with Intel XPUs
token: enable tests that require a write token
22 changes: 12 additions & 10 deletions src/kernels/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def __hash__(self):
return hash((self.layer_name, self._repo_id, self._revision, self._version))

def __str__(self) -> str:
return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`"
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"


class LocalLayerRepository:
Expand Down Expand Up @@ -372,7 +372,7 @@ def __hash__(self):
return hash((self.layer_name, self._repo_path, self._package_name))

def __str__(self) -> str:
return f"`{self._repo_path}` (package: {self._package_name}) for layer `{self.layer_name}`"
return f"`{self._repo_path}` (package: {self._package_name}), layer `{self.layer_name}`"


class LockedLayerRepository:
Expand Down Expand Up @@ -427,7 +427,7 @@ def __hash__(self):
return hash((self.layer_name, self._repo_id))

def __str__(self) -> str:
return f"`{self._repo_id}` (revision: {self._resolve_revision()}) for layer `{self.layer_name}`"
return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`"


_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
Expand Down Expand Up @@ -1020,7 +1020,7 @@ def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]:
return layer


def _validate_layer(*, check_cls, cls):
def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
import torch.nn as nn

# The layer must have at least have the following properties: (1) it
Expand All @@ -1029,34 +1029,36 @@ def _validate_layer(*, check_cls, cls):
# methods.

if not issubclass(cls, nn.Module):
raise TypeError(f"Layer `{cls}` is not a Torch layer.")
raise TypeError(f"Layer `{cls.__name__}` is not a Torch layer.")

# We verify statelessness by checking that the does not have its own
# constructor (since the constructor could add member variables)...
if cls.__init__ is not nn.Module.__init__:
raise TypeError("Layer must not override nn.Module constructor.")
raise TypeError(f"{repo} must not override nn.Module constructor.")

# ... or predefined member variables.
torch_module_members = {name for name, _ in inspect.getmembers(nn.Module)}
cls_members = {name for name, _ in inspect.getmembers(cls)}
difference = cls_members - torch_module_members
# verify if : difference ⊄ {"can_torch_compile", "has_backward"}
if not difference <= {"can_torch_compile", "has_backward"}:
raise TypeError("Layer must not contain additional members.")
raise TypeError(
f"{repo} must not contain additional members compared to `{check_cls.__name__}`."
)

# Check whether the forward signatures are similar.
params = inspect.signature(cls.forward).parameters
ref_params = inspect.signature(check_cls.forward).parameters

if len(params) != len(ref_params):
raise TypeError(
"Forward signature does not match: different number of arguments."
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different number of arguments."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Superb! This should be more than enough!

)

for param, ref_param in zip(params.values(), ref_params.values()):
if param.kind != ref_param.kind:
raise TypeError(
f"Forward signature does not match: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
f"Forward signature of {repo} does not match `{check_cls.__name__}`: different kind of arguments ({param} ({param.kind}) and {ref_param} ({ref_param.kind})"
)


Expand Down Expand Up @@ -1173,7 +1175,7 @@ def _get_layer_memoize(
return layer

layer = _get_kernel_layer(repo)
_validate_layer(check_cls=module_class, cls=layer)
_validate_layer(check_cls=module_class, cls=layer, repo=repo)
_CACHED_LAYER[repo] = layer

return layer
10 changes: 10 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
)


def pytest_addoption(parser):
parser.addoption(
"--token",
action="store_true",
help="run tests that require a token with write permissions",
)


def pytest_runtest_setup(item):
if "cuda_only" in item.keywords and not has_cuda:
pytest.skip("skipping CUDA-only test on host without CUDA")
Expand All @@ -29,3 +37,5 @@ def pytest_runtest_setup(item):
pytest.skip("skipping macOS-only test on non-macOS platform")
if "xpu_only" in item.keywords and not has_xpu:
pytest.skip("skipping XPU-only test on host without XPU")
if "token" in item.keywords and not item.config.getoption("--token"):
pytest.skip("need --token option to run this test")
6 changes: 1 addition & 5 deletions tests/test_kernel_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,7 @@ def get_filenames_from_a_repo(repo_id: str) -> List[str]:
logging.error(f"Error connecting to the Hub: {e}.")


@pytest.mark.xfail(
condition=os.environ.get("GITHUB_ACTIONS") == "true",
reason="There is something weird when writing to the Hub from a GitHub CI.",
strict=True,
)
@pytest.mark.token
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use skip as discussed? 👀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was nicer if it's possible to run the test locally with --token,

def test_kernel_upload_deletes_as_expected():
repo_filenames = get_filenames_from_a_repo(REPO_ID)
filename_to_change = get_filename_to_change(repo_filenames)
Expand Down
33 changes: 25 additions & 8 deletions tests/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,26 +480,43 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.foo = 42

with pytest.raises(TypeError, match="not override"):
_validate_layer(cls=BadLayer, check_cls=SiluAndMul)
def stub_repo(layer):
return LayerRepository(
repo_id="kernels-test/nonexisting", layer_name=layer.__name__
)

with pytest.raises(
TypeError,
match="`kernels-test/nonexisting`.*layer `BadLayer` must not override",
):
_validate_layer(cls=BadLayer, check_cls=SiluAndMul, repo=stub_repo(BadLayer))

class BadLayer2(nn.Module):
foo: int = 42

with pytest.raises(TypeError, match="not contain additional members"):
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul)
with pytest.raises(
TypeError,
match="`kernels-test/nonexisting`.*layer `BadLayer2` must not contain.*SiluAndMul",
):
_validate_layer(cls=BadLayer2, check_cls=SiluAndMul, repo=stub_repo(BadLayer2))

class BadLayer3(nn.Module):
def forward(self, x: torch.Tensor, foo: int) -> torch.Tensor: ...

with pytest.raises(TypeError, match="different number of arguments"):
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul)
with pytest.raises(
TypeError,
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer3` does not match `SiluAndMul`: different number of arguments",
):
_validate_layer(cls=BadLayer3, check_cls=SiluAndMul, repo=stub_repo(BadLayer3))

class BadLayer4(nn.Module):
def forward(self, *, x: torch.Tensor) -> torch.Tensor: ...

with pytest.raises(TypeError, match="different kind of arguments"):
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul)
with pytest.raises(
TypeError,
match="Forward.*`kernels-test/nonexisting`.*layer `BadLayer4` does not match `SiluAndMul`: different kind of arguments",
):
_validate_layer(cls=BadLayer4, check_cls=SiluAndMul, repo=stub_repo(BadLayer4))


@pytest.mark.cuda_only
Expand Down
Loading