Skip to content
7 changes: 7 additions & 0 deletions src/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ def _on_peer_relation_joined(self, event: RelationJoinedEvent) -> None:
event.defer()
return

# Safeguard against event deferall
if self._mysql.is_instance_in_cluster(event_unit_label):
logger.debug(
f"Unit {event_unit_label} is already part of the cluster, don't try to add it again."
)
return

# Add the instance to the cluster. This operation uses locks to ensure that
# only one instance is added to the cluster at a time
# (so only one instance is involved in a state transfer at a time)
Expand Down
126 changes: 124 additions & 2 deletions src/relations/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
MySQLGrantPrivilegesToUserError,
MySQLUpgradeUserForMySQLRouterError,
)
from ops.charm import RelationBrokenEvent
from ops.charm import RelationBrokenEvent, RelationDepartedEvent, RelationJoinedEvent
from ops.framework import Object
from ops.model import BlockedStatus

from constants import DB_RELATION_NAME, PASSWORD_LENGTH
from constants import DB_RELATION_NAME, PASSWORD_LENGTH, PEER
from utils import generate_random_password

logger = logging.getLogger(__name__)
Expand All @@ -43,6 +43,128 @@ def __init__(self, charm):
self.framework.observe(
self.charm.on[DB_RELATION_NAME].relation_broken, self._on_database_broken
)
self.framework.observe(self.charm.on[PEER].relation_joined, self._on_relation_joined)
self.framework.observe(self.charm.on[PEER].relation_departed, self._on_relation_departed)

self.framework.observe(self.charm.on.leader_elected, self._on_leader_elected)

def _on_leader_elected(self, _):
"""Handle on leader elected event for the database relation."""
if not self.charm.unit.is_leader():
return
# get all relations involving the database relation
relations = list(self.model.relations[DB_RELATION_NAME])
# check if there are relations in place
if len(relations) == 0:
return

if not self.charm.cluster_initialized:
logger.debug("Waiting cluster to be initialized")
return

relation_data = self.database.fetch_relation_data()
# for all relations update the read-only-endpoints
for relation in relations:
# check if the on_database_requested has been executed
if relation.id not in relation_data:
logger.debug("On database requested not happened yet! Nothing to do in this case")
continue
self._update_endpoints(relation.id, self.charm.app.name)

def _on_relation_departed(self, event: RelationDepartedEvent):
"""Handle the peer relation departed event for the database relation."""
if not self.charm.unit.is_leader():
return
# get all relations involving the database relation
relations = list(self.model.relations[DB_RELATION_NAME])
if len(relations) == 0:
return

if not self.charm.cluster_initialized:
logger.debug("Waiting cluster to be initialized")
return

# check if the leader is departing
if self.charm.unit.name == event.departing_unit.name:
return

# get unit name that departed
dep_unit_name = event.departing_unit.name.replace("/", "-")

# defer if the added unit is still in the cluster
if self.charm._mysql.is_instance_in_cluster(dep_unit_name):
logger.debug(f"Departing unit {dep_unit_name} is still in the cluster!")
event.defer()
return

relation_data = self.database.fetch_relation_data()
# for all relations update the read-only-endpoints
for relation in relations:
# check if the on_database_requested has been executed
if relation.id not in relation_data:
logger.debug("On database requested not happened yet! Nothing to do in this case")
continue
# update the endpoints
self._update_endpoints(relation.id, event.app.name)

def _on_relation_joined(self, event: RelationJoinedEvent):
"""Handle the peer relation joined event for the database relation."""
if not self.charm.unit.is_leader():
return
# get all relations involving the database relation
relations = list(self.model.relations[DB_RELATION_NAME])

if len(relations) == 0:
return

if not self.charm.cluster_initialized:
logger.debug("Waiting cluster to be initialized")
return

# get unit name that joined
event_unit_label = event.unit.name.replace("/", "-")

# defer if the added unit is not in the cluster
if not self.charm._mysql.is_instance_in_cluster(event_unit_label):
event.defer()
return
relation_data = self.database.fetch_relation_data()
# for all relations update the read-only-endpoints
for relation in relations:
# check if the on_database_requested has been executed
if relation.id not in relation_data:
logger.debug("On database requested not happened yet! Nothing to do in this case")
continue
# update the endpoints
self._update_endpoints(relation.id, event.app.name)

def _update_endpoints(self, relation_id: int, remote_app: str):
"""Updates the read-only-endpoints.

Args:
relation_id (int): The id of the relation
remote_app (str): The name of the remote application
"""
try:

primary_endpoint = self.charm._mysql.get_cluster_primary_address()
self.database.set_endpoints(relation_id, primary_endpoint)
# get read only endpoints by removing primary from all members
read_only_endpoints = sorted(
self.charm._mysql.get_cluster_members_addresses()
- {
primary_endpoint,
}
)
self.database.set_read_only_endpoints(relation_id, ",".join(read_only_endpoints))
logger.debug(f"Updated endpoints for {remote_app}")

except MySQLGetClusterMembersAddressesError as e:
logger.exception("Failed to get cluster members", exc_info=e)
self.charm.unit.status = BlockedStatus("Failed to get cluster members")
except MySQLClientError as e:
logger.exception("Failed to get primary", exc_info=e)
self.charm.unit.status = BlockedStatus("Failed to get primary")

def _get_or_set_password(self, relation) -> str:
"""Retrieve password from cache or generate a new one.
Expand Down
160 changes: 159 additions & 1 deletion tests/integration/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
import secrets
import string
import subprocess
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Set

import yaml
from connector import MysqlConnector
from juju.unit import Unit
from mysql.connector.errors import InterfaceError, OperationalError, ProgrammingError
Expand Down Expand Up @@ -372,3 +373,160 @@ def cluster_name(unit: Unit, model_name: str) -> str:
output = json.loads(output.decode("utf-8"))

return output[unit.name]["relation-info"][0]["application-data"]["cluster-name"]


async def get_relation_data(
ops_test: OpsTest,
application_name: str,
relation_name: str,
) -> list:
"""Returns a list that contains the relation-data.

Args:
ops_test: The ops test framework instance
application_name: The name of the application
relation_name: name of the relation to get connection data from
Returns:
a list that contains the relation-data
"""
# get available unit id for the desidered application
units_ids = [
app_unit.name.split("/")[1]
for app_unit in ops_test.model.applications[application_name].units
]
assert len(units_ids) > 0
unit_name = f"{application_name}/{units_ids[0]}"
raw_data = (await ops_test.juju("show-unit", unit_name))[1]
if not raw_data:
raise ValueError(f"no unit info could be grabbed for {unit_name}")
data = yaml.safe_load(raw_data)
# Filter the data based on the relation name.
relation_data = [v for v in data[unit_name]["relation-info"] if v["endpoint"] == relation_name]
if len(relation_data) == 0:
raise ValueError(
f"no relation data could be grabbed on relation with endpoint {relation_name}"
)

return relation_data


def get_read_only_endpoints(relation_data: list) -> Set[str]:
"""Returns the read-only-endpoints from the relation data.

Args:
relation_data: The dictionary that contains the info
Returns:
a set that contains the read-only-endpoints
"""
related_units = relation_data[0]["related-units"]
read_only_endpoints = set()
for _, relation_data in related_units.items():
assert "data" in relation_data
data = relation_data["data"]["data"]

try:
j_data = json.loads(data)
if "read-only-endpoints" in j_data:
read_only_endpoint_field = j_data["read-only-endpoints"]
if read_only_endpoint_field.strip() == "":
continue
for ep in read_only_endpoint_field.split(","):
read_only_endpoints.add(ep)
except json.JSONDecodeError:
raise ValueError("Relation data are not valid JSON.")

return read_only_endpoints


def get_read_only_endpoint_hostnames(relation_data: list) -> List[str]:
"""Returns the read-only-endpoint hostnames from the relation data.

Args:
relation_data: The dictionary that contains the info
Returns:
a set that contains the read-only-endpoint hostnames
"""
read_only_endpoints = get_read_only_endpoints(relation_data)
read_only_endpoint_hostnames = []
for read_only_endpoint in read_only_endpoints:
if ":" in read_only_endpoint:
read_only_endpoint_hostnames.append(read_only_endpoint.split(":")[0])
else:
raise ValueError("Malformed endpoint")
return read_only_endpoint_hostnames


async def remove_leader_unit(ops_test: OpsTest, application_name: str):
"""Removes the leader unit of a specified application.

Args:
ops_test: The ops test framework instance
application_name: The name of the application
"""
leader_unit = None
for app_unit in ops_test.model.applications[application_name].units:
is_leader = await app_unit.is_leader_from_status()
if is_leader:
leader_unit = app_unit.name

await ops_test.model.destroy_units(leader_unit)

count = len(ops_test.model.applications[application_name].units)

application = ops_test.model.applications[application_name]
await ops_test.model.block_until(lambda: len(application.units) == count)

if count > 0:
await ops_test.model.wait_for_idle(
apps=[application_name],
status="active",
raise_on_blocked=True,
timeout=1000,
)


async def get_unit_hostname(ops_test: OpsTest, app_name: str) -> List[str]:
"""Retrieves hostnames of given application units.

Args:
ops_test: The ops test framework instance
app_name: The name of the application
Returns:
a list that contains the hostnames of a given application
"""
units = [app_unit.name for app_unit in ops_test.model.applications[app_name].units]
status = await ops_test.model.get_status()
machine_hostname = {}

for machine_id, v in status["machines"].items():
machine_hostname[machine_id] = v["hostname"]

unit_machine = {}
for unit in units:
unit_machine[unit] = status["applications"][app_name]["units"][f"{unit}"]["machine"]
hostnames = []
for unit, machine in unit_machine.items():
if machine in machine_hostname:
hostnames.append(machine_hostname[machine])
return hostnames


async def check_read_only_endpoints(ops_test: OpsTest, app_name: str, relation_name: str):
"""Checks that read-only-endpoints are correctly set.

Args:
ops_test: The ops test framework instance
app_name: The name of the application
relation_name: The name of the relation
"""
# check update for read-only-endpoints
relation_data = await get_relation_data(
ops_test=ops_test, application_name=app_name, relation_name=relation_name
)
read_only_endpoints = get_read_only_endpoint_hostnames(relation_data)
# check that the number of read-only-endpoints is correct
assert len(ops_test.model.applications[app_name].units) - 1 == len(read_only_endpoints)
app_hostnames = await get_unit_hostname(ops_test=ops_test, app_name=app_name)
# check that endpoints are the one of the application
for r_endpoint in read_only_endpoints:
assert r_endpoint in app_hostnames
Loading