diff --git a/juju/charmstore.py b/juju/charmstore.py new file mode 100644 index 000000000..ba792ebd5 --- /dev/null +++ b/juju/charmstore.py @@ -0,0 +1,37 @@ +from functools import partial + +import asyncio +import theblues.charmstore +import theblues.errors + + +class CharmStore: + """ + Async wrapper around theblues.charmstore.CharmStore + """ + def __init__(self, loop, cs_timeout=20): + self.loop = loop + self._cs = theblues.charmstore.CharmStore(timeout=cs_timeout) + + def __getattr__(self, name): + """ + Wrap method calls in coroutines that use run_in_executor to make them + async. + """ + attr = getattr(self._cs, name) + if not callable(attr): + wrapper = partial(getattr, self._cs, name) + setattr(self, name, wrapper) + else: + async def coro(*args, **kwargs): + method = partial(attr, *args, **kwargs) + for attempt in range(1, 4): + try: + return await self.loop.run_in_executor(None, method) + except theblues.errors.ServerError: + if attempt == 3: + raise + await asyncio.sleep(1, loop=self.loop) + setattr(self, name, coro) + wrapper = coro + return wrapper diff --git a/juju/model.py b/juju/model.py index 896e18f10..e3835d2d6 100644 --- a/juju/model.py +++ b/juju/model.py @@ -15,15 +15,13 @@ from pathlib import Path import yaml - -import theblues.charmstore -import theblues.errors import websockets from . import provisioner, tag, utils from .annotationhelper import _get_annotations, _set_annotations from .bundle import BundleHandler, get_charm_series from .charmhub import CharmHub +from .charmstore import CharmStore from .client import client, connector from .client.client import ConfigValue, Value from .client.overrides import Caveat, Macaroon @@ -461,6 +459,8 @@ def __init__( self._watch_stopped = asyncio.Event(loop=self._connector.loop) self._watch_received = asyncio.Event(loop=self._connector.loop) self._watch_stopped.set() + + self._charmhub = CharmHub(self) self._charmstore = CharmStore(self._connector.loop) def is_connected(self): @@ -783,7 +783,11 @@ def charmhub(self): the charm-hub-url model config. """ - return CharmHub(self) + return self._charmhub + + @property + def charmstore(self): + return self._charmstore async def get_info(self): """Return a client.ModelInfo object for this Model. @@ -1968,10 +1972,6 @@ def upload_backup(self, archive_path): """ raise NotImplementedError() - @property - def charmstore(self): - return self._charmstore - async def get_metrics(self, *tags): """Retrieve metrics. @@ -2186,38 +2186,6 @@ def _create_consume_args(offer, macaroon, controller_info): return arg -class CharmStore: - """ - Async wrapper around theblues.charmstore.CharmStore - """ - def __init__(self, loop, cs_timeout=20): - self.loop = loop - self._cs = theblues.charmstore.CharmStore(timeout=cs_timeout) - - def __getattr__(self, name): - """ - Wrap method calls in coroutines that use run_in_executor to make them - async. - """ - attr = getattr(self._cs, name) - if not callable(attr): - wrapper = partial(getattr, self._cs, name) - setattr(self, name, wrapper) - else: - async def coro(*args, **kwargs): - method = partial(attr, *args, **kwargs) - for attempt in range(1, 4): - try: - return await self.loop.run_in_executor(None, method) - except theblues.errors.ServerError: - if attempt == 3: - raise - await asyncio.sleep(1, loop=self.loop) - setattr(self, name, coro) - wrapper = coro - return wrapper - - class CharmArchiveGenerator: """ Create a Zip archive of a local charm directory for upload to a controller. diff --git a/juju/origin.py b/juju/origin.py new file mode 100644 index 000000000..083e16221 --- /dev/null +++ b/juju/origin.py @@ -0,0 +1,191 @@ +from enum import Enum +from .errors import JujuError + + +class Source(Enum): + """Source defines a origin source. Providing a hint to the controller about + what the charm identity is from the URL and origin source. + + """ + LOCAL = "local" + CHARM_STORE = "charm-store" + CHARM_HUB = "charm-hub" + + def __str__(self): + return self.value + + +class Origin: + def __init__(self, source, channel, platform): + self.source = source + self.channel = channel + self.platform = platform + + def __str__(self): + return "origin using source {} for channel {} and platform {}".format(str(self.source), self.channel, self.platform) + + +class Risk(Enum): + STABLE = "stable" + CANDIDATE = "candidate" + BETA = "beta" + EDGE = "edge" + + def __str__(self): + return self.value + + @staticmethod + def valid(potential): + for risk in [Risk.STABLE, Risk.CANDIDATE, Risk.BETA, Risk.EDGE]: + if str(risk) == potential: + return True + return False + + +class Channel: + """Channel identifies and describes completely a store channel. + + A channel consists of, and is subdivided by, tracks, risk-levels and + - Tracks enable snap developers to publish multiple supported releases of + their application under the same snap name. + - Risk-levels represent a progressive potential trade-off between stability + and new features. + + The complete channel name can be structured as three distinct parts separated + by slashes: + + / + + """ + def __init__(self, track=None, risk=None): + if not Risk.valid(risk): + raise JujuError("unexpected risk {}".format(risk)) + + self.track = track or "" + self.risk = risk + + @staticmethod + def parse(s): + """parse a channel from a given string. + Parse does not take into account branches. + + """ + if not s: + raise JujuError("channel cannot be empty") + + p = s.split("/") + + risk = None + track = None + if len(p) == 1: + if Risk.valid(p[0]): + risk = p[0] + else: + track = p[0] + risk = str(Risk.STABLE) + elif len(p) == 2: + track = p[0] + risk = p[1] + else: + raise JujuError("channel is malformed and has too many components {}".format(s)) + + if risk is not None and not Risk.valid(risk): + raise JujuError("risk in channel {} is not valid".format(s)) + if track is not None and track == "": + raise JujuError("track in channel {} is not valid".format(s)) + + return Channel(track, risk) + + def normalize(self): + track = self.track if self.track != "latest" else "" + risk = self.risk if self.risk != "" else "" + return Channel(track, risk) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.track == other.track and self.risk == other.risk + return False + + def __str__(self): + path = self.risk + if self.track != "": + path = "{}/{}".format(self.track, path) + return path + + +class Platform: + """ParsePlatform parses a string representing a store platform. + Serialized version of platform can be expected to conform to the following: + + 1. Architecture is mandatory. + 2. OS is optional and can be dropped. Series is mandatory if OS wants + to be displayed. + 3. Series is also optional. + + To indicate something is missing `unknown` can be used in place. + + Examples: + + 1. `//` + 2. `` + 3. `/` + 4. `/unknown/` + + """ + def __init__(self, arch, series=None, os=None): + self.arch = arch + self.series = series + self.os = os + + @staticmethod + def parse(s): + if not s: + raise JujuError("platform cannot be empty") + + p = s.split("/") + + arch = None + os = None + series = None + if len(p) == 1: + arch = p[0] + elif len(p) == 2: + arch = p[0] + series = p[1] + elif len(p) == 3: + arch = p[0] + os = p[1] + series = p[2] + else: + raise JujuError("platform is malformed and has too many components {}".format(s)) + + if not arch: + raise JujuError("architecture in platform {} is not valid".format(s)) + if os is not None and os == "": + raise JujuError("os in platform {} is not valid".format(s)) + if series is not None and series == "": + raise JujuError("series in platform {} is not valid".format(s)) + + return Platform(arch, series, os) + + def normalize(self): + os = self.os if self.os is not None or self.os != "unknown" else None + series = self.series + if series is None or series == "unknown": + os = None + series = None + + return Platform(self.arch, series, os) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.arch == other.arch and self.os == other.os and self.series == other.series + return False + + def __str__(self): + path = self.arch + if self.os is not None and self.os != "": + path = "{}/{}".format(path, self.os) + if self.series is not None and self.series != "": + path = "{}/{}".format(path, self.series) + return path diff --git a/tests/unit/test_origin.py b/tests/unit/test_origin.py new file mode 100644 index 000000000..b2b3fa5eb --- /dev/null +++ b/tests/unit/test_origin.py @@ -0,0 +1,96 @@ +import unittest + +from juju.origin import Channel, Origin, Platform, Risk, Source + + +class TestRisk(unittest.TestCase): + def test_valid_risk(self): + self.assertTrue(Risk.valid("stable")) + + def test_invalid_risk(self): + self.assertFalse(Risk.valid("maybe")) + + +class TestChannel(unittest.TestCase): + def test_parse_risk_only(self): + ch = Channel.parse("stable") + self.assertEqual(ch, Channel(None, "stable")) + + def test_parse_track_only(self): + ch = Channel.parse("2.0.1") + self.assertEqual(ch, Channel("2.0.1", "stable")) + + def test_parse(self): + ch = Channel.parse("latest/stable") + self.assertEqual(ch, Channel("latest", "stable")) + + def test_parse_numeric(self): + ch = Channel.parse("2.0.7/stable") + self.assertEqual(ch, Channel("2.0.7", "stable")) + + def test_parse_then_normalize(self): + ch = Channel.parse("latest/stable").normalize() + self.assertEqual(ch, Channel(None, "stable")) + + def test_str_risk_only(self): + ch = Channel.parse("stable") + self.assertEqual(str(ch), "stable") + + def test_str_track_only(self): + ch = Channel.parse("2.0.1") + self.assertEqual(str(ch), "2.0.1/stable") + + def test_str(self): + ch = Channel.parse("latest/stable") + self.assertEqual(str(ch), "latest/stable") + + def test_str_numeric(self): + ch = Channel.parse("2.0.7/stable") + self.assertEqual(str(ch), "2.0.7/stable") + + def test_str_then_normalize(self): + ch = Channel.parse("latest/stable").normalize() + self.assertEqual(str(ch), "stable") + + +class TestPlatform(unittest.TestCase): + def test_parse_arch_only(self): + p = Platform.parse("architecture") + self.assertEqual(p, Platform("architecture")) + + def test_parse_arch_and_series(self): + p = Platform.parse("architecture/series") + self.assertEqual(p, Platform("architecture", "series")) + + def test_parse(self): + p = Platform.parse("architecture/os/series") + self.assertEqual(p, Platform("architecture", "series", "os")) + + def test_parse_with_unknowns(self): + p = Platform.parse("architecture/unknown/unknown") + self.assertEqual(p, Platform("architecture", "unknown", "unknown")) + + def test_parse_with_unknowns_after_normalize(self): + p = Platform.parse("architecture/unknown/unknown").normalize() + self.assertEqual(p, Platform("architecture")) + + def test_str_arch_only(self): + p = Platform.parse("architecture") + self.assertEqual(str(p), "architecture") + + def test_str_arch_and_series(self): + p = Platform.parse("architecture/series") + self.assertEqual(str(p), "architecture/series") + + def test_str(self): + p = Platform.parse("architecture/os/series") + self.assertEqual(str(p), "architecture/os/series") + + +class TestOrigin(unittest.TestCase): + def test_origin(self): + ch = Channel.parse("latest/stable") + p = Platform.parse("amd64/ubuntu/focal") + + o = Origin(Source.CHARM_HUB, ch, p) + self.assertEqual(str(o), "origin using source charm-hub for channel latest/stable and platform amd64/ubuntu/focal")