Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 197 additions & 23 deletions ads/telemetry/telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,117 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import os
import re
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Dict, Optional
from functools import wraps
from typing import Any, Callable, Dict, Optional

import ads.config
from ads import __version__
from ads.common import logger

TELEMETRY_ARGUMENT_NAME = "telemetry"


LIBRARY = "Oracle-ads"
EXTRA_USER_AGENT_INFO = "EXTRA_USER_AGENT_INFO"
USER_AGENT_KEY = "additional_user_agent"
UNKNOWN = "UNKNOWN"
DELIMITER = "&"


def update_oci_client_config(config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""
Adds user agent information to the signer config if it is not setup yet.

Parameters
----------
config: Dict
The signer configuration.

Returns
-------
Dict
The updated configuration.
"""

try:
config = config or {}
if not config.get(USER_AGENT_KEY):
config.update(
{
USER_AGENT_KEY: (
f"{LIBRARY}/version={__version__}#"
f"surface={Surface.surface().name}#"
f"api={os.environ.get(EXTRA_USER_AGENT_INFO,UNKNOWN) or UNKNOWN}"
)
}
)
except Exception as ex:
logger.debug(ex)

return config


def telemetry(
entry_point: str = "",
name: str = "",
environ_variable: str = EXTRA_USER_AGENT_INFO,
) -> Callable:
"""
The telemetry decorator.
Injects the Telemetry object into the `kwargs` arguments of the decorated function.
This is essential for adding additional information to the telemetry from within the
decorated function. Eventually this information will be merged into the `additional_user_agent`.

Important Note: The telemetry decorator exclusively updates the specified environment
variable and does not perform any additional actions.
"

Parameters
----------
entry_point: str
The entry point of the telemetry.
Example: "plugin=project&action=run"
name: str
The name of the telemetry.
environ_variable: (str, optional). Defaults to `EXTRA_USER_AGENT_INFO`.
The name of the environment variable to capture the telemetry sequence.

Examples
--------
>>> @telemetry(entry_point="plugin=project&action=run", name="ads")
... def test_function(**kwargs)
... telemetry = kwargs.get("telemetry")
... telemetry.add("param=hello_world")
... print(telemetry)

>>> test_function()
... "ads&plugin=project&action=run&param=hello_world"
"""

def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
telemetry = Telemetry(name=name, environ_variable=environ_variable).begin(
entry_point
)
try:
return func(*args, **{**kwargs, **{TELEMETRY_ARGUMENT_NAME: telemetry}})
except:
raise
finally:
telemetry.restore()

return wrapper

return decorator


class Surface(Enum):
"""
An Enum class for labeling the surface where ADS is being used.
An Enum class used to label the surface where ADS is being utilized.
"""

WORKSTATION = auto()
Expand Down Expand Up @@ -53,28 +148,107 @@ def surface(cls):
return surface


def update_oci_client_config(config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Adds user agent information to the config if it is not setup yet.
@dataclass
class Telemetry:
"""
This class is designed to capture a telemetry sequence and store it in the specified
environment variable. By default the `EXTRA_USER_AGENT_INFO` environment variable is used.

Returns
-------
Dict
The updated configuration.
Attributes
----------
name: (str, optional). Default to empty string.
The name of the telemetry. The very beginning of the telemetry string.
environ_variable: (str, optional). Defaults to `EXTRA_USER_AGENT_INFO`.
The name of the environment variable to capture the telemetry sequence.
"""

try:
config = config or {}
if not config.get(USER_AGENT_KEY):
config.update(
{
USER_AGENT_KEY: (
f"{LIBRARY}/version={__version__}#"
f"surface={Surface.surface().name}#"
f"api={os.environ.get(EXTRA_USER_AGENT_INFO,UNKNOWN) or UNKNOWN}"
)
}
)
except Exception as ex:
logger.debug(ex)
name: str = ""
environ_variable: str = EXTRA_USER_AGENT_INFO

return config
def __post_init__(self):
self.name = self._prepare(self.name)
self._original_value = os.environ.get(self.environ_variable)
os.environ[self.environ_variable] = ""

def restore(self) -> "Telemetry":
"""Restores the original value of the environment variable.

Returns
-------
self: Telemetry
An instance of the Telemetry.
"""
os.environ[self.environ_variable] = self._original_value
return self

def clean(self) -> "Telemetry":
"""Cleans the associated environment variable.

Returns
-------
self: Telemetry
An instance of the Telemetry.
"""
os.environ[self.environ_variable] = ""
return self

def _begin(self):
self.clean()
os.environ[self.environ_variable] = self.name

def begin(self, value: str = "") -> "Telemetry":
"""
This method should be invoked at the start of telemetry sequence capture.
It resets the value of the associated environment variable.

Parameters
----------
value: (str, optional). Defaults to empty string.
The value that need to be added to the telemetry.

Returns
-------
self: Telemetry
An instance of the Telemetry.
"""
return self.clean().add(self.name).add(value)

def add(self, value: str) -> "Telemetry":
"""Appends the new value to the telemetry data.

Parameters
----------
value: str
The value that need to be added to the telemetry.

Returns
-------
self: Telemetry
An instance of the Telemetry.
"""
if not os.environ.get(self.environ_variable):
self._begin()

if value:
current_value = os.environ.get(self.environ_variable, "")
new_value = self._prepare(value)

if new_value not in current_value:
os.environ[self.environ_variable] = (
f"{current_value}{DELIMITER}{new_value}"
if current_value
else new_value
)
return self

def print(self) -> None:
"""Prints the telemetry sequence from environment variable."""
print(f"{self.environ_variable} = {os.environ.get(self.environ_variable)}")

def _prepare(self, value: str):
"""Replaces the special characters with the `_` in the input string."""
return (
re.sub("[^a-zA-Z0-9\.\-\_\&\=]", "_", re.sub(r"\s+", " ", value))
if value
else ""
)
88 changes: 88 additions & 0 deletions tests/unitary/default_setup/telemetry/test_telemetry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#!/usr/bin/env python
# -*- coding: utf-8; -*-

# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import os
from unittest.mock import patch

import pytest

from ads.telemetry import Telemetry


class TestTelemetry:
"""Tests the Telemetry.
Class to capture telemetry sequence into the environment variable.
"""

def setup_method(self):
self.telemetry = Telemetry(name="test.api")

@patch.dict(os.environ, {}, clear=True)
def test_init(self):
"""Ensures initializing Telemetry passes."""
self.telemetry = Telemetry("test.api")
assert self.telemetry.name == "test.api"
assert self.telemetry.environ_variable in os.environ
assert os.environ[self.telemetry.environ_variable] == ""

@patch.dict(os.environ, {}, clear=True)
def test_add(self):
"""Tests adding the new value to the telemetry."""
self.telemetry.begin()
self.telemetry.add("key=value").add("new_key=new_value")
assert (
os.environ[self.telemetry.environ_variable]
== "test.api&key=value&new_key=new_value"
)

@patch.dict(os.environ, {}, clear=True)
def test_begin(self):
"""Tests cleaning the value of the associated environment variable."""
self.telemetry.begin("key=value")
assert os.environ[self.telemetry.environ_variable] == "test.api&key=value"

@patch.dict(os.environ, {}, clear=True)
def test_clean(self):
"""Ensures that telemetry associated environment variable can be cleaned."""
self.telemetry.begin()
self.telemetry.add("key=value").add("new_key=new_value")
assert (
os.environ[self.telemetry.environ_variable]
== "test.api&key=value&new_key=new_value"
)
self.telemetry.clean()
assert os.environ[self.telemetry.environ_variable] == ""

@patch.dict(os.environ, {"EXTRA_USER_AGENT_INFO": "some_existing_value"}, clear=True)
def test_restore(self):
"""Ensures that telemetry associated environment variable can be restored to the original value."""
telemetry = Telemetry(name="test.api")
telemetry.begin()
telemetry.add("key=value").add("new_key=new_value")
assert (
os.environ[telemetry.environ_variable]
== "test.api&key=value&new_key=new_value"
)
telemetry.restore()
assert os.environ[telemetry.environ_variable] == "some_existing_value"

@pytest.mark.parametrize(
"NAME,INPUT_DATA,EXPECTED_RESULT",
[
("test.api", "key=va~!@#$%^*()_+lue", "key=va____________lue"),
("test.api", "key=va lue", "key=va_lue"),
("", "key=va123***lue", "key=va123___lue"),
("", "", ""),
],
)
@patch.dict(os.environ, {}, clear=True)
def test__prepare(self, NAME, INPUT_DATA, EXPECTED_RESULT):
"""Tests replacing special characters in the telemetry input value."""
telemetry = Telemetry(name=NAME)
telemetry.begin(INPUT_DATA)
expected_result = f"{NAME}&{EXPECTED_RESULT}" if NAME else EXPECTED_RESULT
assert os.environ[telemetry.environ_variable] == expected_result