Skip to content

Commit

Permalink
Merge pull request #20 from demml/add-hw-table
Browse files Browse the repository at this point in the history
Add hw table
  • Loading branch information
thorrester committed Jun 11, 2024
2 parents cded397 + 5fd246b commit ff888d4
Show file tree
Hide file tree
Showing 25 changed files with 1,072 additions and 825 deletions.
62 changes: 61 additions & 1 deletion opsml/app/routes/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@

from fastapi import APIRouter, HTTPException, Request, status

from opsml.app.routes.pydantic_models import GetMetricRequest, Metrics, Success
from opsml.app.routes.pydantic_models import (
GetMetricRequest,
HardwareMetricscPut,
HardwareMetricsResponse,
Metrics,
Success,
)
from opsml.helpers.logging import ArtifactLogger
from opsml.registry.sql.base.server import ServerRunCardRegistry

Expand Down Expand Up @@ -44,6 +50,34 @@ def insert_metric(request: Request, payload: Metrics) -> Success:
) from error


@router.put("/metrics/hardware", name="hw_metric_put", response_model=Success)
def insert_hw_metrics(
request: Request, payload: HardwareMetricscPut
) -> Success: ## should match hardware metrics schema run_id, timestamp, JSON dict... pydantic_models
"""Inserts metrics into metric table
Args:
request:
FastAPI request object
payload:
MetricsModel
Returns:
200
"""

run_reg: ServerRunCardRegistry = request.app.state.registries.run._registry

try:
run_reg.insert_hw_metrics(payload.model_dump()["metrics"])
return Success()
except Exception as error:
logger.error(f"Failed to insert metrics: {error}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to insert metrics"
) from error


# GET would be used, but we are using POST to allow for a request body so that we can pass in a list of metrics to retrieve
@router.post("/metrics", response_model=Metrics, name="metric_get")
def get_metric(request: Request, payload: GetMetricRequest) -> Metrics:
Expand All @@ -69,3 +103,29 @@ def get_metric(request: Request, payload: GetMetricRequest) -> Metrics:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get metrics"
) from error


@router.get("/metrics/hardware", response_model=HardwareMetricsResponse, name="hw_metric_get")
def get_hw_metric(request: Request, run_uid: str) -> HardwareMetricsResponse:
"""Get metrics from hw metric table
Args:
request:
FastAPI request object
run_uid:
Run UID
Returns:
`HardwareMetricsResponse`
"""

run_reg: ServerRunCardRegistry = request.app.state.registries.run._registry
try:
metrics = run_reg.get_hw_metric(run_uid=run_uid)
return HardwareMetricsResponse(metrics=[] if metrics is None else metrics)

except Exception as error:
logger.error(f"Failed to get metrics: {error}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to get metrics"
) from error
26 changes: 24 additions & 2 deletions opsml/app/routes/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) Shipt, Inc.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import datetime
from typing import Any, Dict, List, Optional, Tuple, Union

from fastapi import File, Form, UploadFile
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, field_serializer, model_validator

from opsml.cards.audit import AuditSections
from opsml.model.challenger import BattleReport
from opsml.registry.semver import CardVersion, VersionType
from opsml.types import Comment
from opsml.types import Comment, HardwareMetrics


class Success(BaseModel):
Expand Down Expand Up @@ -266,6 +267,27 @@ class Metrics(BaseModel):
metric: Union[Optional[List[Metric]], Optional[List[str]]]


class HardwareMetricRecord(BaseModel):
run_uid: str
created_at: Optional[datetime.datetime] = None
metrics: HardwareMetrics

# serialize datetime
@field_serializer("created_at")
def serialize_created_at(self, value: Optional[datetime.datetime] = None) -> Optional[str]:
if value is not None:
return value.isoformat()
return value


class HardwareMetricscPut(BaseModel):
metrics: List[HardwareMetricRecord]


class HardwareMetricsResponse(BaseModel):
metrics: List[HardwareMetricRecord] = []


class GetMetricRequest(BaseModel):
run_uid: str
name: Optional[List[str]] = None
Expand Down
9 changes: 9 additions & 0 deletions opsml/cards/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,15 @@ def load_metrics(self) -> None:
self.metrics[_metric.name].append(_metric)
return None

def get_hardware_metrics(self) -> Optional[List[Dict[str, Any]]]:
"""Returns hardware metrics recorded during run.
Returns:
List of dictionaries containing hardware metrics
"""
assert self.uid is not None, "RunCard must be registered to get hardware metrics"
return self._registry.get_hw_metric(run_uid=self.uid)

def get_parameter(self, name: str) -> Union[List[Param], Param]:
"""
Gets a parameter by name
Expand Down
7 changes: 5 additions & 2 deletions opsml/helpers/gcp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _get_creds(self) -> Tuple[Optional[Union[ComputeEngineCredentials, Credentia
return self.get_default_creds()

def get_default_creds(self) -> Tuple[Optional[ComputeEngineCredentials], Optional[str], bool]:
credentials, project_id = google.auth.default()
credentials, project_id = google.auth.default() # type: ignore

return credentials, project_id, True

Expand All @@ -72,7 +72,10 @@ def create_gcp_creds_from_base64(self, service_base64_creds: str) -> Tuple[Crede
"""
scopes = {"scopes": ["https://www.googleapis.com/auth/devstorage.full_control"]} # needed for gcsfs
key = self.decode_base64(service_base64_creds=service_base64_creds)
service_creds: Credentials = service_account.Credentials.from_service_account_info(info=key, **scopes) # noqa
service_creds: Credentials = service_account.Credentials.from_service_account_info( # type: ignore # noqa
info=key,
**scopes,
)
project_name = cast(str, service_creds.project_id)

return service_creds, project_name, False
40 changes: 2 additions & 38 deletions opsml/projects/_hw_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,54 +6,18 @@
import abc
import os
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

import psutil
from pydantic import BaseModel

from opsml.helpers.logging import ArtifactLogger
from opsml.types import CPUMetrics, HardwareMetrics, MemoryMetrics, NetworkRates

logger = ArtifactLogger.get_logger()

_UTILIZATION_MEASURE_INTERVAL = 0.3


class CPUMetrics(BaseModel):
"""CPU metrics data model."""

cpu_percent_avg: float = 0.0
cpu_percent_per_core: Optional[List[float]] = None
compute_overall: Optional[float] = None
compute_utilized: Optional[float] = None
load_avg: float


class MemoryMetrics(BaseModel):
"""Memory metrics data model."""

sys_ram_total: int = 0
sys_ram_used: int = 0
sys_ram_available: int = 0
sys_ram_percent_used: float = 0.0
sys_swap_total: Optional[int] = None
sys_swap_used: Optional[int] = None
sys_swap_free: Optional[int] = None
sys_swap_percent: Optional[float] = None


class NetworkRates(BaseModel):
"""Network rates data model."""

bytes_recv: float = 0.0
bytes_sent: float = 0.0


class HardwareMetrics(BaseModel):
cpu: CPUMetrics
memory: MemoryMetrics
network: NetworkRates


class BaseMetricsLogger(abc.ABC):
"""The base class for all system metrics data loggers. It implements common scheduling and error handling."""

Expand Down
36 changes: 23 additions & 13 deletions opsml/projects/_run_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
import concurrent
import time
import uuid
from datetime import datetime
from queue import Empty, Queue
from typing import Dict, Optional, Union, cast
from typing import Any, Dict, Optional, Union, cast

from opsml.cards import RunCard
from opsml.helpers.logging import ArtifactLogger
from opsml.projects._hw_metrics import HardwareMetrics, HardwareMetricsLogger
from opsml.projects._hw_metrics import HardwareMetricsLogger
from opsml.projects.active_run import ActiveRun, RunInfo
from opsml.projects.types import _DEFAULT_INTERVAL, ProjectInfo, Tags
from opsml.registry import CardRegistries
Expand All @@ -20,23 +21,33 @@
logger = ArtifactLogger.get_logger()


def put_hw_metrics(interval: int, run: "ActiveRun", queue: Queue[HardwareMetrics]) -> bool:
def put_hw_metrics(
interval: int,
run: "ActiveRun",
queue: Queue[Dict[str, Union[str, datetime, Dict[str, Any]]]],
) -> bool:
hw_logger = HardwareMetricsLogger(interval=interval)

while run.active: # producer function for hw output
metrics = hw_logger.get_metrics()
metrics: Dict[str, Union[str, datetime, Dict[str, Any]]] = {
"metrics": hw_logger.get_metrics().model_dump(),
"run_uid": run.run_id,
}

# add to the queue
queue.put(metrics, block=False)
logger.info("Metrics in queue: {}", metrics)
time.sleep(interval)

logger.info("Hardware logger stopped")

return False


def get_hw_metrics(interval: int, run: "ActiveRun", queue: Queue[HardwareMetrics]) -> None:
def get_hw_metrics(
interval: int,
run: "ActiveRun",
queue: Queue[Dict[str, Union[str, datetime, Dict[str, Any]]]],
) -> None:
"""Pull hardware metrics from the queue and log them.
Args:
Expand All @@ -49,14 +60,13 @@ def get_hw_metrics(interval: int, run: "ActiveRun", queue: Queue[HardwareMetrics
"""
while run.active: # consumer function for hw output
try:
metrics_unit = queue.get(timeout=1)
# report
logger.info("Got metrics: {}", metrics_unit.model_dump())
metrics = queue.get(timeout=1)
run.runcard._registry.insert_hw_metrics([metrics])

except Empty:
pass

time.sleep(interval + 0.5)
time.sleep(interval / 2)


class ActiveRunException(Exception): ...
Expand Down Expand Up @@ -171,10 +181,10 @@ def _log_hardware_metrics(self, interval: int) -> None:
assert self.active_run is not None, "active_run should not be None"

# run hardware logger in background thread
queue: Queue[HardwareMetrics] = Queue()
queue: Queue[Dict[str, Union[str, datetime, Dict[str, Any]]]] = Queue()
executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
executor.submit(put_hw_metrics, interval, self.active_run, queue)
executor.submit(get_hw_metrics, interval, self.active_run, queue)
executor.submit(put_hw_metrics, interval, self.active_run, queue)
self.thread_executor = executor

def start_run(
Expand Down Expand Up @@ -237,5 +247,5 @@ def end_run(self) -> None:

# check if thread executor is still running
if self.thread_executor is not None:
self.thread_executor.shutdown(wait=True, cancel_futures=True)
self.thread_executor.shutdown(wait=False, cancel_futures=True)
self._thread_executor = None
1 change: 0 additions & 1 deletion opsml/registry/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class ModelRegistryRecord(SaveRecord):
@model_validator(mode="before")
@classmethod
def set_metadata(cls, values: Dict[str, Any]) -> Dict[str, Any]:
print(values["datacard_uid"])
metadata: Dict[str, Any] = values["metadata"]
values["sample_data_type"] = metadata["data_schema"]["data_type"]
values["model_type"] = values["interface"]["model_type"]
Expand Down
34 changes: 34 additions & 0 deletions opsml/registry/sql/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,20 @@ def insert_metric(self, metric: List[Dict[str, Any]]) -> None:
json={"metric": metric},
)

def insert_hw_metrics(self, metrics: List[Dict[str, Any]]) -> None:
"""Inserts metrics into the run registry
Args:
metrics:
List of hw metric(s) to insert
"""

self._session.request(
route=api_routes.HW_METRICS,
request_type=RequestType.PUT,
json={"metrics": metrics},
)

def get_metric(
self,
run_uid: str,
Expand Down Expand Up @@ -366,6 +380,26 @@ def get_metric(
metric = data.get("metric")
return cast(Optional[List[Dict[str, Any]]], metric)

def get_hw_metric(self, run_uid: str) -> Optional[List[Dict[str, Any]]]:
"""Gets run hardware metrics
Args:
run_uid:
Run uid
Returns:
List of run metrics
"""

data = self._session.request(
route=api_routes.HW_METRICS,
request_type=RequestType.GET,
params={"run_uid": run_uid},
)

metric = data.get("metrics")
return cast(Optional[List[Dict[str, Any]]], metric)

@staticmethod
def validate(registry_name: str) -> bool:
return registry_name.lower() == RegistryType.RUN.value
Expand Down
Loading

0 comments on commit ff888d4

Please sign in to comment.