Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 43 additions & 33 deletions src/eda.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging

import requests
import yaml

from src.http_client import create_pool_manager

# configure logging
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -34,21 +35,24 @@ def __init__(self, hostname, username, password, verify):
self.version = None
self.transactions = []

self.http = create_pool_manager(url=self.url, verify=self.verify)

def login(self):
"""
Retrieves an access_token and refresh_token from the EDA API
"""
payload = {"username": self.username, "password": self.password}

response = self.post("auth/login", payload, False).json()
response = self.post("auth/login", payload, False)
response_data = json.loads(response.data.decode("utf-8"))

if "code" in response and response["code"] != 200:
if "code" in response_data and response_data["code"] != 200:
raise Exception(
f"Could not authenticate with EDA, error message: '{response['message']} {response['details']}'"
f"Could not authenticate with EDA, error message: '{response_data['message']} {response_data['details']}'"
)

self.access_token = response["access_token"]
self.refresh_token = response["refresh_token"]
self.access_token = response_data["access_token"]
self.refresh_token = response_data["refresh_token"]

def get_headers(self, requires_auth):
"""
Expand Down Expand Up @@ -88,11 +92,7 @@ def get(self, api_path, requires_auth=True):
url = f"{self.url}/{api_path}"
logger.info(f"Performing GET request to '{url}'")

return requests.get(
url,
verify=self.verify,
headers=self.get_headers(requires_auth),
)
return self.http.request("GET", url, headers=self.get_headers(requires_auth))

def post(self, api_path, payload, requires_auth=True):
"""
Expand All @@ -110,11 +110,11 @@ def post(self, api_path, payload, requires_auth=True):
"""
url = f"{self.url}/{api_path}"
logger.info(f"Performing POST request to '{url}'")
return requests.post(
return self.http.request(
"POST",
url,
verify=self.verify,
json=payload,
headers=self.get_headers(requires_auth),
body=json.dumps(payload).encode("utf-8"),
)

def is_up(self):
Expand All @@ -127,8 +127,9 @@ def is_up(self):
"""
logger.info("Checking whether EDA is up")
health = self.get("core/about/health", requires_auth=False)
logger.debug(health.json())
return health.json()["status"] == "UP"
health_data = json.loads(health.data.decode("utf-8"))
logger.debug(health_data)
return health_data["status"] == "UP"

def get_version(self):
"""
Expand All @@ -140,7 +141,9 @@ def get_version(self):
return self.version

logger.info("Getting EDA version")
version = self.get("core/about/version").json()["eda"]["version"].split("-")[0]
version_response = self.get("core/about/version")
version_data = json.loads(version_response.data.decode("utf-8"))
version = version_data["eda"]["version"].split("-")[0]
logger.info(f"EDA version is {version}")

# storing this to make the tool backwards compatible
Expand Down Expand Up @@ -254,18 +257,20 @@ def is_transaction_item_valid(self, item):
logger.info("Validating transaction item")

response = self.post("core/transaction/v1/validate", item)
if response.status_code == 204:
if response.status == 204:
logger.info("Validation successful")
return True

response = response.json()
response_data = json.loads(
response.data.decode("utf-8")
) # Need to decode response data

if "code" in response:
message = f"{response['message']}"
if "details" in response:
message = f"{message} - {response['details']}"
if "code" in response_data:
message = f"{response_data['message']}"
if "details" in response_data:
message = f"{message} - {response_data['details']}"
logger.warning(
f"While validating a transaction item, the following validation error was returned (code {response['code']}): '{message}'"
f"While validating a transaction item, the following validation error was returned (code {response_data['code']}): '{message}'"
)

return False
Expand Down Expand Up @@ -295,16 +300,21 @@ def commit_transaction(
logger.info(f"Committing transaction with {len(self.transactions)} item(s)")
logger.debug(json.dumps(payload, indent=4))

response = self.post("core/transaction/v1", payload).json()
if "id" not in response:
raise Exception(f"Could not find transaction ID in response {response}")
response = self.post("core/transaction/v1", payload)
response_data = json.loads(response.data.decode("utf-8"))
if "id" not in response_data:
raise Exception(
f"Could not find transaction ID in response {response_data}"
)

transactionId = response["id"]
transactionId = response_data["id"]

logger.info(f"Waiting for transaction with ID {transactionId} to complete")
result = self.get(
f"core/transaction/v1/details/{transactionId}?waitForComplete=true&failOnErrors=true"
).json()
result = json.loads(
self.get(
f"core/transaction/v1/details/{transactionId}?waitForComplete=true&failOnErrors=true"
).data.decode("utf-8")
)

if "code" in result:
message = f"{result['message']}"
Expand Down Expand Up @@ -348,7 +358,7 @@ def revert_transaction(self, transactionId):
).json()

response = self.post(f"core/transaction/v1/revert/{transactionId}", {})
result = response.json()
result = json.loads(response.data.decode("utf-8"))

if "code" in result and result["code"] != 0:
message = f"{result['message']}"
Expand Down Expand Up @@ -392,7 +402,7 @@ def restore_transaction(self, transactionId):
).json()

response = self.post(f"core/transaction/v1/restore/{restore_point}", {})
result = response.json()
result = json.loads(response.data.decode("utf-8"))

if "code" in result and result["code"] != 0:
message = f"{result['message']}"
Expand Down
39 changes: 0 additions & 39 deletions src/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
import tempfile

import requests
from jinja2 import Environment, FileSystemLoader

import src.topology as topology
Expand Down Expand Up @@ -98,44 +97,6 @@ def apply_manifest_via_kubectl(yaml_str: str, namespace: str = "eda-system"):
finally:
os.remove(tmp_path)


def get_artifact_from_github(owner: str, repo: str, version: str, asset_filter=None):
"""
Queries GitHub for a specific release artifact.

Parameters
----------
owner: GitHub repository owner
repo: GitHub repository name
version: Version tag to search for (without 'v' prefix)
asset_filter: Optional function(asset_name) -> bool to filter assets

Returns
-------
Tuple of (filename, download_url) or (None, None) if not found
"""
tag = f"v{version}" # Assume GitHub tags are prefixed with 'v'
url = f"https://api.github.com/repos/{owner}/{repo}/releases/tags/{tag}"

logger.info(f"Querying GitHub release {tag} from {owner}/{repo}")
resp = requests.get(url)

if resp.status_code != 200:
logger.warning(f"Failed to fetch release for {tag}, status={resp.status_code}")
return None, None

data = resp.json()
assets = data.get("assets", [])

for asset in assets:
name = asset.get("name", "")
if asset_filter is None or asset_filter(name):
return name, asset.get("browser_download_url")

# No matching asset found
return None, None


def normalize_name(name: str) -> str:
"""
Returns a Kubernetes-compliant name by:
Expand Down
139 changes: 139 additions & 0 deletions src/http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import logging
import os
import re
import urllib3
from urllib.parse import urlparse

logger = logging.getLogger(__name__)


def get_proxy_settings():
"""
Get proxy settings from environment variables.
Handles both upper and lowercase variants.

Returns
-------
tuple: (http_proxy, https_proxy, no_proxy)
"""
# Check both variants
http_upper = os.environ.get("HTTP_PROXY")
http_lower = os.environ.get("http_proxy")
https_upper = os.environ.get("HTTPS_PROXY")
https_lower = os.environ.get("https_proxy")
no_upper = os.environ.get("NO_PROXY")
no_lower = os.environ.get("no_proxy")

# Log if both variants are set
if http_upper and http_lower and http_upper != http_lower:
logger.warning(
f"Both HTTP_PROXY ({http_upper}) and http_proxy ({http_lower}) are set with different values. Using HTTP_PROXY."
)

if https_upper and https_lower and https_upper != https_lower:
logger.warning(
f"Both HTTPS_PROXY ({https_upper}) and https_proxy ({https_lower}) are set with different values. Using HTTPS_PROXY."
)

if no_upper and no_lower and no_upper != no_lower:
logger.warning(
f"Both NO_PROXY ({no_upper}) and no_proxy ({no_lower}) are set with different values. Using NO_PROXY."
)

# Use uppercase variants if set, otherwise lowercase
http_proxy = http_upper if http_upper is not None else http_lower
https_proxy = https_upper if https_upper is not None else https_lower
no_proxy = no_upper if no_upper is not None else no_lower or ""

return http_proxy, https_proxy, no_proxy


def should_bypass_proxy(url, no_proxy=None):
"""
Check if the given URL should bypass proxy based on NO_PROXY settings.

Parameters
----------
url : str
The URL to check
no_proxy : str, optional
The NO_PROXY string to use. If None, gets from environment.

Returns
-------
bool
True if proxy should be bypassed, False otherwise
"""
if no_proxy is None:
_, _, no_proxy = get_proxy_settings()

if not no_proxy:
return False

parsed_url = urlparse(url if "//" in url else f"http://{url}")
hostname = parsed_url.hostname

if not hostname:
return False

# Split NO_PROXY into parts and clean them
no_proxy_parts = [p.strip() for p in no_proxy.split(",") if p.strip()]

for no_proxy_value in no_proxy_parts:
# Convert .foo.com to foo.com
if no_proxy_value.startswith("."):
no_proxy_value = no_proxy_value[1:]

# Handle IP addresses and CIDR notation
if re.match(r"^(?:\d{1,3}\.){3}\d{1,3}(?:/\d{1,2})?$", no_proxy_value):
# TODO: Implement CIDR matching if needed
if hostname == no_proxy_value:
return True
# Handle domain names with wildcards
else:
pattern = re.escape(no_proxy_value).replace(r"\*", ".*")
if re.match(f"^{pattern}$", hostname, re.IGNORECASE):
return True

return False


def create_pool_manager(url=None, verify=True):
"""
Create a PoolManager or ProxyManager based on environment settings and URL

Parameters
----------
url : str, optional
The URL that will be accessed with this pool manager
If provided, NO_PROXY rules will be checked
verify : bool
Whether to verify SSL certificates

Returns
-------
urllib3.PoolManager or urllib3.ProxyManager
"""
http_proxy, https_proxy, no_proxy = get_proxy_settings()

# Check if this URL should bypass proxy
if url and should_bypass_proxy(url, no_proxy):
logger.debug(f"URL {url} matches NO_PROXY rules, creating direct PoolManager")
return urllib3.PoolManager(
cert_reqs="CERT_REQUIRED" if verify else "CERT_NONE",
retries=urllib3.Retry(3),
)

proxy_url = https_proxy or http_proxy
if proxy_url:
logger.debug(f"Creating ProxyManager with proxy URL: {proxy_url}")
return urllib3.ProxyManager(
proxy_url,
cert_reqs="CERT_REQUIRED" if verify else "CERT_NONE",
retries=urllib3.Retry(3),
)

logger.debug("Creating PoolManager without proxy")
return urllib3.PoolManager(
cert_reqs="CERT_REQUIRED" if verify else "CERT_NONE", retries=urllib3.Retry(3)
)
Loading