Skip to content

Commit

Permalink
Added initial client for FAH resources, init for compute service
Browse files Browse the repository at this point in the history
Closes #1.
  • Loading branch information
dotsdl committed Sep 27, 2023
1 parent 7264fbf commit c4bc367
Show file tree
Hide file tree
Showing 3 changed files with 378 additions and 2 deletions.
215 changes: 215 additions & 0 deletions alchemiscale_fah/compute/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
#!/usr/bin/env python3

import os
import requests
from typing import Optional
from urllib.parse import urljoin

from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.backends import default_backend
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes


# project_data = dict(
# core_id=0x22,
# gens=25000,
# atoms=288449,
# credit=56,
# timeout=0.002,
# deadline=0.005,
# )


class FahAdaptiveSamplingClient:
def __init__(
self,
as_api_url: str,
ws_api_url: str,
ws_ip_addr: str,
certificate_file: os.PathLike = "api-certificate.pem",
key_file: os.PathLike = "api-private.pem",
verify: bool = True,
):
self.as_api_url = as_api_url
self.ws_api_url = ws_api_url
self.ws_ip_addr = ws_ip_addr

self.certificate = self.read_certificate(certificate_file)

if key_file is None:
self.key = self.create_key()
else:
self.key = self.read_key(key_file)

self.verify = verify

@staticmethod
def read_key(key_file):
with open(key_file, "rb") as f:
pem = f.read()

return serialization.load_pem_private_key(pem, None, default_backend())

@staticmethod
def read_certificate(certificate_file):
with open(certificate_file, "rb") as f:
pem = f.read()

return x509.load_pem_x509_certificate(pem, default_backend())

@classmethod
def create_key():
return rsa.generate_private_key(
backend=default_backend(), public_exponent=65537, key_size=4096
)

@classmethod
def write_key(key, key_file):
pem = key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.TraditionalOpenSSL,
encryption_algorithm=serialization.NoEncryption(),
)

with open(key_file, "wb") as f:
f.write(pem)

@classmethod
def generate_csr(key, csr_file):
"""Generate certificate signing request (CSR) using private key.
It is necessary to create a CSR and present this to an AS in order to
receive a valid certificate. The CSR will be written in PEM format.
"""
cn = x509.NameAttribute(NameOID.COMMON_NAME, "joe@example.com")
csr = x509.CertificateSigningRequestBuilder()
csr = csr.subject_name(x509.Name([cn]))
csr = csr.sign(key, hashes.SHA256())

with open(csr_file, "wb") as f:
f.write(csr.public_bytes(serialization.Encoding.PEM))

def _check_status(self, r):
if r.status_code != 200:
raise Exception("Request failed with %d: %s" % (r.status_code, r.text))

def _get(self, api_url, endpoint, **params):
url = urljoin(api_url, endpoint)
r = requests.get(url, cert=self.cert, params=params, verify=self.verify)
self._check_status(r)
return r.json()

def _put(self, api_url, endpoint, **data):
url = urljoin(api_url, endpoint)
r = requests.put(url, json=data, cert=self.cert, verify=self.verify)
self._check_status(r)

def _delete(self, api_url, endpoint):
url = urljoin(api_url, endpoint)
r = requests.delete(url, cert=self.cert, verify=self.verify)
self._check_status(r)

def _upload(self, api_url, endpoint, filename):
url = urljoin(api_url, endpoint)
with open(filename, "rb") as f:
r = requests.put(url, data=f, cert=self.cert, verify=self.verify)
self._check_status(r)

def _download(self, api_url, endpoint, filename):
url = urljoin(api_url, endpoint)
r = requests.get(url, cert=self.cert, verify=self.verify, stream=True)
self._check_status(r)

os.makedirs(os.path.dirname(filename), exist_ok=True)

with open(filename, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)

r.close()

def as_get_ws(self):
"""Get work server attributes from assignment server."""
return self._get(self.as_api_url, f"/ws/{self.ws_ip_addr}")

def as_set_ws(self, as_workserver_data):
"""Set work server attributes on assignment server."""
return self._put(
self.as_api_url, f"/ws/{self.ws_ip_addr}", **as_workserver_data
)

def as_get_project(self, project_id):
"""Set project attributes on the assignment server."""
self._gut(
self.as_api_url,
f"/ws/{self.ws_ip_addr}/projects/{project_id}",
)

def as_set_project(self, project_id, weight, constraints):
"""Set project attributes on the assignment server."""
self._put(
self.as_api_url,
f"/ws/{self.ws_ip_addr}/projects/{project_id}",
weight=weight,
constraints=constraints,
)

def as_reset_project(self, project_id):
"""Set project attributes to default on the assignment server.
Sets project weight to 0, drops all constraints.
"""
self._put(
self.as_api_url,
f"/ws/{self.ws_ip_addr}/projects/{project_id}",
weight=0,
constraints="",
)

def create_project(self, project_id, project_data):
self._put(self.ws_api_url, f"/projects/{project_id}", **project_data)

def delete_project(self, project_id):
self._delete(self.ws_api_url, f"/projects/{project_id}")

def start_run(self, project_id, run_id, clones=0):
"""Start a new run."""
self._put(
self.ws_api_url,
f"/projects/{project_id}/runs/{run_id}/create",
clones=clones,
)

def upload_project_files(self, project_id):
files = "core.xml integrator.xml.bz2 state.xml.bz2 system.xml.bz2".split()

for name in files:
self._upload(self.ws_api_url, f"/projects/{project_id}/files/{name}", name)

def get_project(self, project_id):
return self._get(self.ws_api_url, f"/projects/{project_id}")

def get_job_files(self, project_id, run_id, clone_id):
return self._get(
self.ws_api_url,
f"/projects/{project_id}/runs/{run_id}/clones/{clone_id}/files",
)

def get_xtcs(self, project_id, run_id, clone_id):
data = self._get(
self.ws_api_url,
f"/projects/{project_id}/runs/{run_id}/clones/{clone_id}/files",
)

for info in data:
if info["path"].endswith(".xtc"):
self._download(
self.ws_api_url,
f"/projects/{project_id}/runs/{run_id}/clones/{clone_id}/files/{info['path']}",
info["path"],
)
85 changes: 83 additions & 2 deletions alchemiscale_fah/compute/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,88 @@
===============================================================================================
"""
import os
from typing import Union, Optional, List, Dict, Tuple
from pathlib import Path
from uuid import uuid4
import time
import logging

from alchemiscale.models import Scope, ScopedKey
from alchemiscale.storage.models import Task, TaskHub, ComputeServiceID
from alchemiscale.compute.client import AlchemiscaleComputeClient
from alchemiscale.compute.service import SynchronousComputeService, InterruptableSleep

class FAHComputeService:
...
from .settings import FAHSynchronousComputeServiceSettings
from .client import FahWorkServerClient


class FahSynchronousComputeService(SynchronousComputeService):
"""Fully synchronous compute service for utilizing a Folding@Home work server.
This service is intended for use as a reference implementation, and for
testing/debugging protocols.
"""

def __init__(self, settings: FAHSynchronousComputeServiceSettings):
"""Create a `FAHSynchronousComputeService` instance."""
self.settings = settings

self.api_url = self.settings.api_url
self.name = self.settings.name
self.sleep_interval = self.settings.sleep_interval
self.heartbeat_interval = self.settings.heartbeat_interval
self.claim_limit = self.settings.claim_limit

self.client = AlchemiscaleComputeClient(
self.settings.api_url,
self.settings.identifier,
self.settings.key,
max_retries=self.settings.client_max_retries,
retry_base_seconds=self.settings.client_retry_base_seconds,
retry_max_seconds=self.settings.client_retry_max_seconds,
verify=self.settings.client_verify,
)

self.fah_client = FahWorkServerClient(...)

if self.settings.scopes is None:
self.scopes = [Scope()]
else:
self.scopes = self.settings.scopes

self.shared_basedir = Path(self.settings.shared_basedir).absolute()
self.shared_basedir.mkdir(exist_ok=True)
self.keep_shared = self.settings.keep_shared

self.scratch_basedir = Path(self.settings.scratch_basedir).absolute()
self.scratch_basedir.mkdir(exist_ok=True)
self.keep_scratch = self.settings.keep_scratch

self.compute_service_id = ComputeServiceID(f"{self.name}-{uuid4()}")

self.int_sleep = InterruptableSleep()

self._stop = False

# logging
extra = {"compute_service_id": str(self.compute_service_id)}
logger = logging.getLogger("AlchemiscaleSynchronousComputeService")
logger.setLevel(self.settings.loglevel)

formatter = logging.Formatter(
"[%(asctime)s] [%(compute_service_id)s] [%(levelname)s] %(message)s"
)
formatter.converter = time.gmtime # use utc time for logging timestamps

sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)

if self.settings.logfile is not None:
fh = logging.FileHandler(self.settings.logfile)
fh.setFormatter(formatter)
logger.addHandler(fh)

self.logger = logging.LoggerAdapter(logger, extra)
80 changes: 80 additions & 0 deletions alchemiscale_fah/compute/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
from typing import Union, Optional, List, Dict, Tuple
from pydantic import BaseModel

from alchemiscale.models import Scope, ScopedKey


class FAHSynchronousComputeServiceSettings(BaseModel):
api_url: str
identifier: str
key: str
name: str
shared_basedir: os.PathLike
scratch_basedir: os.PathLike
keep_shared: bool = False
keep_scratch: bool = False
sleep_interval: int = 30
heartbeat_interval: int = 300
scopes: Optional[List[Scope]] = None
claim_limit: int = 1
loglevel = "WARN"
logfile: Optional[os.PathLike] = None
client_max_retries = (5,)
client_retry_base_seconds = 2.0
client_retry_max_seconds = 60.0
client_verify = True


"""
Parameters
----------
api_url
URL of the compute API to execute Tasks for.
identifier
Identifier for the compute identity used for authentication.
key
Credential for the compute identity used for authentication.
name
The name to give this compute service; used for Task provenance, so
typically set to a distinct value to distinguish different compute
resources, e.g. different hosts or HPC clusters.
shared_basedir
Filesystem path to use for `ProtocolDAG` `shared` space.
scratch_basedir
Filesystem path to use for `ProtocolUnit` `scratch` space.
keep_shared
If True, don't remove shared directories for `ProtocolDAG`s after
completion.
keep_scratch
If True, don't remove scratch directories for `ProtocolUnit`s after
completion.
sleep_interval
Time in seconds to sleep if no Tasks claimed from compute API.
heartbeat_interval
Frequency at which to send heartbeats to compute API.
scopes
Scopes to limit Task claiming to; defaults to all Scopes accessible
by compute identity.
claim_limit
Maximum number of Tasks to claim at a time from a TaskHub.
loglevel
The loglevel at which to report; see the :mod:`logging` docs for
available levels.
logfile
Path to file for logging output; if not set, logging will only go
to STDOUT.
client_max_retries
Maximum number of times to retry a request. In the case the API
service is unresponsive an expoenential backoff is applied with
retries until this number is reached. If set to -1, retries will
continue indefinitely until success.
client_retry_base_seconds
The base number of seconds to use for exponential backoff.
Must be greater than 1.0.
client_retry_max_seconds
Maximum number of seconds to sleep between retries; avoids runaway
exponential backoff while allowing for many retries.
client_verify
Whether to verify SSL certificate presented by the API server.
"""

0 comments on commit c4bc367

Please sign in to comment.