Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions torchx/container/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
torchx provides a standard container spec and entry point at
`torchx/container/main.py`. This allows for executing torchx components by
fully qualified class name.

Usage
-----------------

The container entry point lives at `torchx/container/main.py`.

The first argument is the fully qualified class name. The entry point will automatically import that path and load the component config, inputs and output definitions from the command line args.

Ex:

.. code:: bash

$ docker run -it --name torchx --rm pytorch/torchx:latest python3 torchx/container/main.py torchx.components.io.copy.Copy --input_path 'file:///etc/os-release' --output_path 'file:///tmp/bar'


Configuration
-----------------

The entry point automatically loads a configuration file located at
`/etc/torchx.yaml` or from the path specified by `TORCHX_CONFIG`.

The config looks like this:

.. code:: yaml

storage_providers:
- torchx.aws.s3

Configuration options:

- storage_providers: this is a list of python packages that should be loaded at runtime to register any third party storage_providers.


Extending
-----------------

You can extend the prebuilt docker container to add extra dependencies,
components or storage providers.

.. code:: Dockerfile
FROM pytorch/torchx:latest

RUN pip install <your package>
COPY torchx.yaml /etc/torchx.yaml

This container can then be used instead of the default by specifying the
`TORCHX_CONTAINER` environment variable with the kfp adapter.
"""
27 changes: 27 additions & 0 deletions torchx/container/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
from typing import Type, List, Set, Callable, Optional

import torchx
import yaml
from torchx.sdk.component import is_optional, Component

TORCHX_CONFIG_ENV = "TORCHX_CONFIG"
TORCHX_CONFIG_PATH = os.getenv(
TORCHX_CONFIG_ENV,
"/etc/torchx.yaml",
)

# pyre-fixme[24]: Generic type `Component` expects 3 type parameters.
def get_component_class(path: str) -> Type[Component]:
dot = path.rindex(".")
Expand All @@ -28,11 +35,29 @@ def _get_parser(field: Type[object]) -> Callable[[str], object]:
return json.loads


def load_and_process_config(path: str) -> None:
if not os.path.exists(path):
return

with open(path, "r") as f:
config = yaml.safe_load(f)

if providers := config.get("storage_providers"):
assert isinstance(
providers, list
), f"storage_providers must be a list: {providers}"
for provider in providers:
print(f"loading storage provider: {provider}")
importlib.import_module(provider)


def main(args: List[str]) -> None:
print(f"torchx version: {torchx.__version__}")
print(f"process args: {args}")
component_name = args[1]
print(f"component_name: {component_name}")
print(f"config path: {TORCHX_CONFIG_PATH}")
load_and_process_config(TORCHX_CONFIG_PATH)

parser = argparse.ArgumentParser(prog="torchx-main")
cls = get_component_class(component_name)
Expand Down Expand Up @@ -70,6 +95,8 @@ def main(args: List[str]) -> None:
component = cls(**inputs)
component.run(component.inputs, component.outputs)

print("done")


if __name__ == "__main__":
main(sys.argv)
Empty file.
47 changes: 46 additions & 1 deletion torchx/container/test/main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@

import json
import os.path
import sys
import tempfile
import unittest
from typing import TypedDict, Optional
from unittest.mock import patch

import yaml
from torchx.container.main import main
from torchx.sdk.component import Component
from torchx.sdk.storage import temppath, upload_file, download_file
Expand Down Expand Up @@ -56,11 +59,30 @@ def run(self, inputs: Inputs, outputs: Outputs) -> None:
TestComponent.ran = True


class NoopConfig(TypedDict):
pass


class NoopInputs(TypedDict):
pass


class NoopOutputs(TypedDict):
pass


class NoopComponent(Component[NoopConfig, NoopInputs, NoopOutputs]):
Version: str = "0.1"

def run(self, inputs: NoopInputs, outputs: NoopOutputs) -> None:
pass


class ContainerTest(unittest.TestCase):
def test_main(self) -> None:
main(
[
"main.par",
"main.py",
"torchx.container.test.main_test.TestComponent",
"--input_path",
"somepath",
Expand Down Expand Up @@ -99,3 +121,26 @@ def test_output_path(self) -> None:
self.assertEqual(out, data)
with open(out_path_file, "rt") as f:
self.assertEqual(f.read(), output_path)

def test_config_storage_providers(self) -> None:
"""
Tests that storage providers from the specified config are loaded.
"""

module = "torchx.container.test.dummy_module"
config = {
"storage_providers": [module],
}
with tempfile.TemporaryDirectory() as tmpdir:
config_path = os.path.join(tmpdir, "torchx.yaml")
with open(config_path, "w") as f:
yaml.dump(config, f)
with patch("torchx.container.main.TORCHX_CONFIG_PATH", config_path):
self.assertNotIn(module, sys.modules)
main(
[
"main.py",
"torchx.container.test.main_test.NoopComponent",
]
)
self.assertIn(module, sys.modules)