In [None]:
import os
# Must be set before importing ray and labtech.runners.ray
os.environ['RAY_DEDUP_LOGS'] = '0'

In [None]:
from enum import Enum
from time import sleep

import mlflow
import ray
from s3fs import S3FileSystem

import labtech
from labtech.storage import FsspecStorage
from labtech.runners.ray import RayRunnerBackend

In [None]:
def worker_setup():
    # Initialise mlflow on each worker
    mlflow.set_tracking_uri('examples/storage/mlruns')
    mlflow.set_experiment('example_ray_experiment')

# Will start a ray cluster, or connect to one started with: make ray-up
ray.init(
    runtime_env={
        'worker_process_setup_hook': worker_setup,
    },
)

In [None]:
# Use a localstack-emulated S3 bucket to serve as distributed storage for results.
# Make sure localstack is running: make localstack
class S3fsStorage(FsspecStorage):

    def fs_constructor(self):
        return S3FileSystem(
            # Use localstack endpoint:
            endpoint_url='http://localhost:4566',
            key='anything',
            secret='anything',
        )

In [None]:
class Multiplier(Enum):
    ONE = 1
    TWO = 2
    THREE = 3


@labtech.task(mlflow_run=True)
class Experiment:
    seed: int
    multiplier: Multiplier

    def run(self):
        labtech.logger.info(f'Running with seed {self.seed} and multiplier {self.multiplier}')
        sleep(3)
        return self.seed * self.multiplier.value


experiments = [
    Experiment(
        seed=seed,
        multiplier=multiplier,
    )
    for seed in range(10)
    for multiplier in Multiplier
]

lab = labtech.Lab(
    storage=S3fsStorage('labtech-dev-bucket'),
    runner_backend=RayRunnerBackend(),
)

cached_experiments = lab.cached_tasks([Experiment])
print(f'Clearing {len(cached_experiments)} cached experiments.')
lab.uncache_tasks(cached_experiments)

results = lab.run_tasks(experiments, bust_cache=True)
print(results)