diff --git a/bindings/python/flashlight/lib/sequence/criterion.py b/bindings/python/flashlight/lib/sequence/criterion.py index dcf4e4f6..f9fc434f 100644 --- a/bindings/python/flashlight/lib/sequence/criterion.py +++ b/bindings/python/flashlight/lib/sequence/criterion.py @@ -6,13 +6,33 @@ LICENSE file in the root directory of this source tree. """ +from .flashlight_lib_sequence_criterion import ( + CpuForceAlignmentCriterion, + CpuFullConnectionCriterion, + CpuViterbiPath, + CriterionScaleMode, +) + have_torch = False try: + import torch have_torch = True except ImportError: pass if have_torch: - pass + from flashlight.lib.sequence.criterion_torch import ( + ASGLoss, + check_tensor, + create_workspace, + FACFunction, + FCCFunction, + get_cuda_stream_as_bytes, + get_data_ptr_as_bytes, + run_backward, + run_direction, + run_forward, + run_get_workspace_size, + ) diff --git a/bindings/python/test/test_import.py b/bindings/python/test/test_import.py index 41601731..319342d2 100644 --- a/bindings/python/test/test_import.py +++ b/bindings/python/test/test_import.py @@ -11,6 +11,13 @@ class ImportTestCase(unittest.TestCase): def test_import_lib_sequence(self) -> None: + from flashlight.lib.sequence import criterion + from flashlight.lib.sequence.criterion import ( + CpuForceAlignmentCriterion, + CpuFullConnectionCriterion, + CpuViterbiPath, + CriterionScaleMode, + ) if os.getenv("USE_CUDA", "OFF").upper() not in [ "OFF", @@ -19,7 +26,11 @@ def test_import_lib_sequence(self) -> None: "FALSE", "N", ]: - pass + from flashlight.lib.sequence.flashlight_lib_sequence_criterion import ( + CudaForceAlignmentCriterion, + CudaFullConnectionCriterion, + CudaViterbiPath, + ) else: logging.info("Flashlight Sequence bindings built without CUDA")