Skip to content

Commit

Permalink
Add type annotations, some small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nsoranzo committed Jul 7, 2022
1 parent a587ac0 commit 563a417
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 69 deletions.
33 changes: 19 additions & 14 deletions bioblend/_tests/TestGalaxyObjects.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,18 @@ def test_kwargs(self):
w.parent = 0


class TestWorkflow(unittest.TestCase):
@test_util.skip_unless_galaxy()
class GalaxyObjectsTestBase(unittest.TestCase):
gi: galaxy_instance.GalaxyInstance

@classmethod
def setUpClass(cls):
galaxy_key = os.environ["BIOBLEND_GALAXY_API_KEY"]
galaxy_url = os.environ["BIOBLEND_GALAXY_URL"]
cls.gi = galaxy_instance.GalaxyInstance(galaxy_url, galaxy_key)


class TestWorkflow(GalaxyObjectsTestBase):
def setUp(self):
self.wf = wrappers.Workflow(SAMPLE_WF_DICT)

Expand Down Expand Up @@ -236,8 +247,10 @@ def test_taint(self):
self.assertTrue(self.wf.is_modified)

def test_input_map(self):
hda = wrappers.HistoryDatasetAssociation({"id": "hda_id"}, container="mock_history")
ldda = wrappers.LibraryDatasetDatasetAssociation({"id": "ldda_id"}, container="mock_library")
history = wrappers.History({}, gi=self.gi)
library = wrappers.Library({}, gi=self.gi)
hda = wrappers.HistoryDatasetAssociation({"id": "hda_id"}, container=history, gi=self.gi)
ldda = wrappers.LibraryDatasetDatasetAssociation({"id": "ldda_id"}, container=library, gi=self.gi)
input_map = self.wf._convert_input_map({"0": hda, "1": ldda, "2": {"id": "hda2_id", "src": "hda"}})
self.assertEqual(
input_map,
Expand All @@ -249,20 +262,12 @@ def test_input_map(self):
)


@test_util.skip_unless_galaxy()
class GalaxyObjectsTestBase(unittest.TestCase):
def setUp(self):
galaxy_key = os.environ["BIOBLEND_GALAXY_API_KEY"]
galaxy_url = os.environ["BIOBLEND_GALAXY_URL"]
self.gi = galaxy_instance.GalaxyInstance(galaxy_url, galaxy_key)


@test_util.skip_unless_galaxy("release_19.09")
class TestInvocation(GalaxyObjectsTestBase):
@classmethod
def setUpClass(cls):
super().setUp(cls)
cls.inv = wrappers.Invocation(SAMPLE_INV_DICT)
super().setUpClass()
cls.inv = wrappers.Invocation(SAMPLE_INV_DICT, gi=cls.gi)
with open(SAMPLE_FN) as f:
cls.workflow = cls.gi.workflows.import_new(f.read())
path_pause = test_util.get_abspath(os.path.join("data", "test_workflow_pause.ga"))
Expand Down Expand Up @@ -394,7 +399,7 @@ def _obj_invoke_workflow(self):
class TestObjInvocationClient(GalaxyObjectsTestBase):
@classmethod
def setUpClass(cls):
super().setUp(cls)
super().setUpClass()
with open(SAMPLE_FN) as f:
cls.workflow = cls.gi.workflows.import_new(f.read())
cls.history = cls.gi.histories.create(name="TestGalaxyObjInvocationClient")
Expand Down
7 changes: 6 additions & 1 deletion bioblend/galaxy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import time
import typing
from typing import Optional

import requests
Expand All @@ -16,10 +17,14 @@
# ConnectionError class was originally defined here
from bioblend import ConnectionError # noqa: I202

if typing.TYPE_CHECKING:
from bioblend.galaxy import GalaxyInstance


class Client:
# The `module` attribute needs to be defined in subclasses
module: str
gi: "GalaxyInstance"

# Class variables that configure GET request retries. Note that since these
# are class variables their values are shared by all Client instances --
Expand Down Expand Up @@ -69,7 +74,7 @@ def set_get_retry_delay(cls, value):
cls._get_retry_delay = value
return cls

def __init__(self, galaxy_instance):
def __init__(self, galaxy_instance: "GalaxyInstance"):
"""
A generic Client interface defining the common fields.
Expand Down
29 changes: 20 additions & 9 deletions bioblend/galaxy/objects/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import abc
import json
import typing
from collections.abc import (
Mapping,
Sequence,
Expand All @@ -15,9 +16,12 @@
import bioblend
from . import wrappers

if typing.TYPE_CHECKING:
from .galaxy_instance import GalaxyInstance


class ObjClient(abc.ABC):
def __init__(self, obj_gi):
def __init__(self, obj_gi: "GalaxyInstance"):
self.obj_gi = obj_gi
self.gi = self.obj_gi.gi
self.log = bioblend.log
Expand Down Expand Up @@ -100,6 +104,13 @@ def _get_container(self, id_, ctype):
c_infos = [ctype.CONTENT_INFO_TYPE(_) for _ in c_infos]
return ctype(cdict, content_infos=c_infos, gi=self.obj_gi)

@abc.abstractmethod
def get(self, id_: str) -> wrappers.DatasetContainer:
"""
Retrieve the dataset corresponding to the given id.
"""
pass


class ObjLibraryClient(ObjDatasetContainerClient):
"""
Expand All @@ -120,7 +131,7 @@ def create(self, name, description=None, synopsis=None):
lib_info = self._get_dict("create_library", res)
return self.get(lib_info["id"])

def get(self, id_):
def get(self, id_) -> wrappers.Library:
"""
Retrieve the data library corresponding to the given id.
Expand Down Expand Up @@ -188,7 +199,7 @@ def create(self, name=None):
hist_info = self._get_dict("create_history", res)
return self.get(hist_info["id"])

def get(self, id_):
def get(self, id_) -> wrappers.History:
"""
Retrieve the history corresponding to the given id.
Expand Down Expand Up @@ -279,7 +290,7 @@ def import_shared(self, id_):
wf_info = self.gi.workflows.import_shared_workflow(id_)
return self.get(wf_info["id"])

def get(self, id_):
def get(self, id_) -> wrappers.Workflow:
"""
Retrieve the workflow corresponding to the given id.
Expand Down Expand Up @@ -349,7 +360,7 @@ def get_previews(self) -> List[wrappers.InvocationPreview]:
:param: previews of invocations
"""
inv_list = self.gi.invocations.get_invocations()
return [wrappers.InvocationPreview(inv_dict, self.obj_gi) for inv_dict in inv_list]
return [wrappers.InvocationPreview(inv_dict, gi=self.obj_gi) for inv_dict in inv_list]

def list(self, workflow=None, history=None, include_terminal=True, limit=None) -> List[wrappers.Invocation]:
"""
Expand Down Expand Up @@ -390,7 +401,7 @@ class ObjToolClient(ObjClient):
Interacts with Galaxy tools.
"""

def get(self, id_, io_details=False, link_details=False):
def get(self, id_, io_details=False, link_details=False) -> wrappers.Tool:
"""
Retrieve the tool corresponding to the given id.
Expand Down Expand Up @@ -454,7 +465,7 @@ class ObjJobClient(ObjClient):
Interacts with Galaxy jobs.
"""

def get(self, id_, full_details=False):
def get(self, id_, full_details=False) -> wrappers.Job:
"""
Retrieve the job corresponding to the given id.
Expand Down Expand Up @@ -488,7 +499,7 @@ class ObjDatasetClient(ObjClient):
Interacts with Galaxy datasets.
"""

def get(self, id_: str, hda_ldda: str = "hda"):
def get(self, id_: str, hda_ldda: str = "hda") -> wrappers.Dataset:
"""
Retrieve the dataset corresponding to the given id.
Expand Down Expand Up @@ -522,7 +533,7 @@ class ObjDatasetCollectionClient(ObjClient):
Interacts with Galaxy dataset collections.
"""

def get(self, id_: str):
def get(self, id_: str) -> wrappers.HistoryDatasetCollectionAssociation:
"""
Retrieve the dataset collection corresponding to the given id.
Expand Down
Loading

0 comments on commit 563a417

Please sign in to comment.