Skip to content

Commit

Permalink
Merge pull request #3 from mberr/ad-hoc-hasher
Browse files Browse the repository at this point in the history
Add simplified key hasher
  • Loading branch information
mberr committed May 6, 2022
2 parents 7604483 + a86487b commit 4879f7d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
34 changes: 32 additions & 2 deletions src/torch_max_mem/api.py
Expand Up @@ -7,7 +7,7 @@
import inspect
import itertools
import logging
from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, TypeVar
from typing import Any, Callable, Collection, Mapping, MutableMapping, Optional, Tuple, TypeVar

import torch

Expand Down Expand Up @@ -168,6 +168,31 @@ def wrapper_maximize_memory_utilization(*args, **kwargs) -> Tuple[R, int]:
return decorator_maximize_memory_utilization


class KeyHasher:
"""A hasher based on (a subset of) keys."""

def __init__(self, keys: Optional[Collection[str]]) -> None:
"""
Initialize the hasher.
:param keys:
the keys whose associated values should be used for hashing
"""
self.keys = keys or []

def __call__(self, kwargs: Mapping[str, Any]) -> int:
"""
Calculate the hash based on the values associated with the selected keys.
:param kwargs:
the key-value dictionary
:return:
the hash of the tuple of values associated with the stored keys.
"""
return hash(tuple(*(kwargs.get(key, None) for key in self.keys)))


class MemoryUtilizationMaximizer:
"""Stateful memory utilization maximizer."""

Expand All @@ -177,6 +202,7 @@ def __init__(
q: int = 32,
cpu_warning: bool = True,
hasher: Optional[Callable[[Mapping[str, Any]], int]] = None,
keys: Optional[str] = None,
) -> None:
"""
Initialize the stateful maximizer.
Expand All @@ -189,11 +215,15 @@ def __init__(
Whether to check the input for CPU tensors and warn about potential CPU OOM problems.
:param hasher:
a hashing function for separate parameter values depending on hash value; if None, use the same for all
:param keys:
the keys to use for creating a hasher. Only used if hasher is None.
"""
self.parameter_name = parameter_name
self.q = q
self.cpu_warning = cpu_warning
self.parameter_value: MutableMapping[int, int] = dict()
if hasher is None:
hasher = KeyHasher(keys=keys)
self.hasher = hasher

def __call__(self, func: Callable[..., R]) -> Callable[..., R]:
Expand All @@ -207,7 +237,7 @@ def __call__(self, func: Callable[..., R]) -> Callable[..., R]:
@functools.wraps(wrapped)
def inner(*args, **kwargs):
"""Evaluate function with the stored parameter size."""
h = 0 if self.hasher is None else self.hasher(kwargs)
h = self.hasher(kwargs)
kwargs[self.parameter_name] = self.parameter_value.get(h) or kwargs[self.parameter_name]
result, self.parameter_value[h] = wrapped(*args, **kwargs)
return result
Expand Down
2 changes: 1 addition & 1 deletion src/torch_max_mem/version.py
Expand Up @@ -38,4 +38,4 @@ def get_version(with_git_hash: bool = False):


if __name__ == "__main__":
print(get_version(with_git_hash=True)) # noqa:T001
print(get_version(with_git_hash=True)) # noqa:T201

0 comments on commit 4879f7d

Please sign in to comment.