Skip to content

Commit

Permalink
Merge pull request #34 from jcmgray/randomgreedy
Browse files Browse the repository at this point in the history
lightweight random greedy optimization
  • Loading branch information
jcmgray committed Apr 19, 2024
2 parents df29d54 + 00eddeb commit f6250a0
Show file tree
Hide file tree
Showing 7 changed files with 596 additions and 52 deletions.
24 changes: 21 additions & 3 deletions cotengra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Hyper optimized contraction trees for large tensor networks and einsums.
"""
"""Hyper optimized contraction trees for large tensor networks and einsums."""

try:
# -- Distribution mode --
# import from _version.py generated by setuptools_scm during release
Expand All @@ -8,9 +8,10 @@
# -- Source mode --
try:
# use setuptools_scm to get the current version from src using git
from setuptools_scm import get_version as _gv
from pathlib import Path as _Path

from setuptools_scm import get_version as _gv

__version__ = _gv(_Path(__file__).parent.parent)
except ImportError:
# setuptools_scm is not available, use a default version
Expand Down Expand Up @@ -68,6 +69,7 @@
from .pathfinders.path_basic import (
GreedyOptimizer,
OptimalOptimizer,
RandomGreedyOptimizer,
)
from .pathfinders.path_flowcutter import (
FlowCutterOptimizer,
Expand Down Expand Up @@ -183,6 +185,7 @@
"plot_trials",
"QuasiRandOptimizer",
"QuickBBOptimizer",
"RandomGreedyOptimizer",
"register_preset",
"ReusableHyperCompressedOptimizer",
"ReusableHyperOptimizer",
Expand All @@ -200,6 +203,13 @@ def hyper_optimize(inputs, output, size_dict, memory_limit=None, **opts):
return optimizer(inputs, output, size_dict, memory_limit)


def random_greedy_optimize(
inputs, output, size_dict, memory_limit=None, **opts
):
optimizer = RandomGreedyOptimizer(**opts)
return optimizer(inputs, output, size_dict)


try:
register_preset(
"hyper",
Expand Down Expand Up @@ -244,6 +254,14 @@ def hyper_optimize(inputs, output, size_dict, memory_limit=None, **opts):
"hyper-betweenness",
functools.partial(hyper_optimize, methods=["betweenness"]),
)
register_preset(
"random-greedy",
random_greedy_optimize,
)
register_preset(
"random-greedy-128",
functools.partial(random_greedy_optimize, max_repeats=128),
)
register_preset(
"flowcutter-2",
functools.partial(optimize_flowcutter, max_time=2),
Expand Down
82 changes: 71 additions & 11 deletions cotengra/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,43 @@
_DEFAULT_BACKEND = "concurrent.futures"


@functools.lru_cache(None)
def choose_default_num_workers():
import os

if "COTENGRA_NUM_WORKERS" in os.environ:
return int(os.environ["COTENGRA_NUM_WORKERS"])

if "OMP_NUM_THREADS" in os.environ:
return int(os.environ["OMP_NUM_THREADS"])

return os.cpu_count()


def get_pool(n_workers=None, maybe_create=False, backend=None):
"""Get a parallel pool."""

if backend is None:
backend = _DEFAULT_BACKEND

if backend == "dask":
return _get_pool_dask(n_workers=n_workers, maybe_create=maybe_create)

if backend == "ray":
return _get_pool_ray(n_workers=n_workers, maybe_create=maybe_create)

# above backends are distributed, don't specify n_workers
if n_workers is None:
n_workers = choose_default_num_workers()

if backend == "loky":
get_reusable_executor = get_loky_get_reusable_executor()
return get_reusable_executor(max_workers=n_workers)

if backend == "concurrent.futures":
return _get_pool_cf(n_workers=n_workers)

if backend == "dask":
return _get_pool_dask(n_workers=n_workers, maybe_create=maybe_create)
return _get_process_pool_cf(n_workers=n_workers)

if backend == "ray":
return _get_pool_ray(n_workers=n_workers, maybe_create=maybe_create)
if backend == "threads":
return _get_thread_pool_cf(n_workers=n_workers)


@functools.lru_cache(None)
Expand Down Expand Up @@ -118,6 +137,9 @@ def parse_parallel_arg(parallel):
if parallel == "concurrent.futures":
return get_pool(maybe_create=True, backend="concurrent.futures")

if parallel == "threads":
return get_pool(maybe_create=True, backend="threads")

if parallel == "dask":
_AUTO_BACKEND = "dask"
return get_pool(maybe_create=True, backend="dask")
Expand Down Expand Up @@ -219,16 +241,54 @@ def __del__(self):
self.shutdown()


PoolHandler = CachedProcessPoolExecutor()
ProcessPoolHandler = CachedProcessPoolExecutor()


@atexit.register
def _shutdown_cached_process_pool():
PoolHandler.shutdown()
ProcessPoolHandler.shutdown()


def _get_process_pool_cf(n_workers=None):
return ProcessPoolHandler(n_workers)


class CachedThreadPoolExecutor:
def __init__(self):
self._pool = None
self._n_workers = -1

def __call__(self, n_workers=None):
if n_workers != self._n_workers:
from concurrent.futures import ThreadPoolExecutor

self.shutdown()
self._pool = ThreadPoolExecutor(n_workers)
self._n_workers = n_workers
return self._pool

def is_initialized(self):
return self._pool is not None

def shutdown(self):
if self._pool is not None:
self._pool.shutdown()
self._pool = None

def __del__(self):
self.shutdown()


ThreadPoolHandler = CachedThreadPoolExecutor()


@atexit.register
def _shutdown_cached_thread_pool():
ThreadPoolHandler.shutdown()


def _get_pool_cf(n_workers=None):
return PoolHandler(n_workers)
def _get_thread_pool_cf(n_workers=None):
return ThreadPoolHandler(n_workers)


# ---------------------------------- DASK ----------------------------------- #
Expand Down
Loading

0 comments on commit f6250a0

Please sign in to comment.