Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 32 additions & 38 deletions connect/eaas/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,58 +48,52 @@ class TaskPayload:
task_category: str
task_type: str
object_id: str
result: str = None
result: Optional[str] = None
data: Any = None
countdown: int = 0
output: str = None
correlation_id: str = None
reply_to: str = None

def to_json(self):
return dataclasses.asdict(self)
output: Optional[str] = None
correlation_id: Optional[str] = None
reply_to: Optional[str] = None


@dataclasses.dataclass
class ConfigurationPayload:
configuration: dict = None
logging_api_key: str = None
environment_type: str = None
log_level: str = None
runner_log_level: str = None

def to_json(self):
return dataclasses.asdict(self)
configuration: Optional[dict] = None
logging_api_key: Optional[str] = None
environment_type: Optional[str] = None
account_id: Optional[str] = None
account_name: Optional[str] = None
log_level: Optional[str] = None
runner_log_level: Optional[str] = None


@dataclasses.dataclass
class CapabilitiesPayload:
capabilities: dict
readme_url: str = None
changelog_url: str = None
readme_url: Optional[str] = None
changelog_url: Optional[str] = None

def to_json(self):
return dataclasses.asdict(self)


@dataclasses.dataclass(init=False)
@dataclasses.dataclass
class Message:
message_type: str
data: Optional[Union[CapabilitiesPayload, ConfigurationPayload, TaskPayload]] = None

def __init__(self, message_type=None, data=None):
self.message_type = message_type
if isinstance(data, dict):
if self.message_type == MessageType.CONFIGURATION:
self.data = ConfigurationPayload(**data)
elif self.message_type == MessageType.TASK:
self.data = TaskPayload(**data)
elif self.message_type == MessageType.CAPABILITIES:
self.data = CapabilitiesPayload(**data)
else:
self.data = data

def to_json(self):
payload = {'message_type': self.message_type}
if self.data:
payload['data'] = dataclasses.asdict(self.data)
return payload

def from_dict(cls, data):
field_names = set(f.name for f in dataclasses.fields(cls))
return cls(**{k: v for k, v in data.items() if k in field_names})


def parse_message(payload):
message_type = payload['message_type']
if message_type == MessageType.CONFIGURATION:
data = from_dict(ConfigurationPayload, payload.get('data'))
elif message_type == MessageType.TASK:
data = from_dict(TaskPayload, payload.get('data'))
elif message_type == MessageType.CAPABILITIES:
data = from_dict(CapabilitiesPayload, payload.get('data'))
else:
data = payload.get('data')

return Message(message_type=message_type, data=data)
13 changes: 7 additions & 6 deletions connect/eaas/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Copyright (c) 2021 Ingram Micro. All Rights Reserved.
#
import asyncio
import dataclasses
import inspect
import logging
import traceback
Expand Down Expand Up @@ -90,7 +91,7 @@ async def submit_task(self, data):
object_id = data.object_id
task_type = data.task_type
method_name = TASK_TYPE_EXT_METHOD_MAP[task_type]
extension = self.worker.get_extension()
extension = self.worker.get_extension(data.task_id)
method = getattr(extension, method_name)
logger.debug(f'invoke {method_name}')
self.running_tasks += 1
Expand Down Expand Up @@ -175,7 +176,7 @@ async def result_sender(self): # noqa: CCR001
message_type=MessageType.TASK,
data=result,
)
await self.worker.send(message.to_json())
await self.worker.send(dataclasses.asdict(message))
logger.info(f'Result for task {result.task_id} has been sent.')
break
except Exception:
Expand Down Expand Up @@ -222,13 +223,13 @@ async def build_bg_response(self, task_data, future):
"""
Wait for a background task to be completed and than uild the task result message.
"""
result_message = TaskPayload(**task_data.to_json())
result_message = TaskPayload(**dataclasses.asdict(task_data))
result = None
try:
result = await asyncio.wait_for(future, timeout=BACKGROUND_TASK_MAX_EXECUTION_TIME)
except Exception as e:
logger.warning(f'Got exception during execution of task {task_data.task_id}: {e}')
self.worker.get_extension().logger.exception(
self.worker.get_extension(task_data.task_id).logger.exception(
f'Unhandled exception during execution of task {task_data.task_id}',
)
result_message.result = ResultType.RETRY
Expand All @@ -249,12 +250,12 @@ async def build_interactive_response(self, task_data, future):
Wait for an interactive task to be completed and than uild the task result message.
"""
result = None
result_message = TaskPayload(**task_data.to_json())
result_message = TaskPayload(**dataclasses.asdict(task_data))
try:
result = await asyncio.wait_for(future, timeout=INTERACTIVE_TASK_MAX_EXECUTION_TIME)
except Exception as e:
logger.warning(f'Got exception during execution of task {task_data.task_id}: {e}')
self.worker.get_extension().logger.exception(
self.worker.get_extension(task_data.task_id).logger.exception(
f'Unhandled exception during execution of task {task_data.task_id}',
)
result_message.result = ResultType.FAIL
Expand Down
42 changes: 33 additions & 9 deletions connect/eaas/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Copyright (c) 2021 Ingram Micro. All Rights Reserved.
#
import asyncio
import dataclasses
import json
import logging
from asyncio.exceptions import TimeoutError
Expand All @@ -17,13 +18,21 @@
)

from connect.client import AsyncConnectClient, ConnectClient
from connect.eaas.dataclasses import CapabilitiesPayload, Message, MessageType
from connect.eaas.dataclasses import (
CapabilitiesPayload,
Message,
MessageType,
parse_message,
)
from connect.eaas.helpers import (
get_environment,
get_extension_class,
get_extension_type,
)
from connect.eaas.logging import ExtensionLogHandler, RequestLogger
from connect.eaas.logging import (
ExtensionLogHandler,
RequestLogger,
)
from connect.eaas.manager import TasksManager


Expand Down Expand Up @@ -63,6 +72,8 @@ def __init__(self, secure=True):
self.paused = False
self.logging_handler = None
self.environment_type = None
self.account_id = None
self.account_name = None

async def ensure_connection(self):
"""
Expand Down Expand Up @@ -93,7 +104,7 @@ async def receive(self):
except TimeoutError:
pass

def get_client(self):
def get_client(self, task_id):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like here you didn't change comment for the function

"""
Get an instance of the Connect Openapi Client. If the extension is asyncrhonous
it returns an instance of the AsyncConnectClient otherwise the ConnectClient.
Expand All @@ -104,7 +115,10 @@ def get_client(self):
endpoint=f'https://{self.api_address}/public/v1',
use_specs=False,
logger=RequestLogger(
self.get_extension_logger(self.logging_api_key),
logging.LoggerAdapter(
self.get_extension_logger(self.logging_api_key),
{'task_id': task_id},
),
),
)

Expand All @@ -122,6 +136,8 @@ def get_extension_logger(self, token):
'environment_id': self.environment_id,
'instance_id': self.instance_id,
'environment_type': self.environment_type,
'account_id': self.account_id,
'account_name': self.account_name,
'api_address': self.api_address,
},
)
Expand Down Expand Up @@ -151,10 +167,13 @@ def get_url(self):
url = f'{self.base_ws_url}/{self.environment_id}/{self.instance_id}'
return f'{url}?running_tasks={running_tasks}'

def get_extension(self):
def get_extension(self, task_id):
return self.extension_class(
self.get_client(),
self.get_extension_logger(self.logging_api_key),
self.get_client(task_id),
logging.LoggerAdapter(
self.get_extension_logger(self.logging_api_key),
{'task_id': task_id},
),
self.extension_config,
)

Expand All @@ -176,7 +195,7 @@ async def run(self): # noqa: CCR001
self.changelog_url,
),
)
await self.send(message.to_json())
await self.send(dataclasses.asdict(message))
while self.run_event.is_set():
await self.ensure_connection()
self.ensure_tasks_manager_running()
Expand Down Expand Up @@ -208,7 +227,7 @@ async def process_message(self, data):
"""
Process a message received from the websocket server.
"""
message = Message(**data)
message = parse_message(data)
if message.message_type == MessageType.CONFIGURATION:
await self.configuration(message.data)
elif message.message_type == MessageType.TASK:
Expand All @@ -232,6 +251,11 @@ async def configuration(self, data):
self.logging_api_key = data.logging_api_key
if data.environment_type:
self.environment_type = data.environment_type
if data.account_id:
self.account_id = data.account_id
if data.account_name:
self.account_name = data.account_name

if data.log_level:
logger.info(f'Change extesion logger level to {data.log_level}')
logging.getLogger('eaas.extension').setLevel(
Expand Down
113 changes: 76 additions & 37 deletions tests/test_dataclasses.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,93 @@
import dataclasses

from connect.eaas.dataclasses import (
CapabilitiesPayload,
ConfigurationPayload,
from_dict,
Message,
MessageType,

parse_message,
TaskPayload,
)


def test_capabilities_payload():
assert CapabilitiesPayload(
{'cap1': 'val1'},
'https://example.com/readme',
'https://example.com/changelog',
).to_json() == {
'capabilities': {'cap1': 'val1'},
'readme_url': 'https://example.com/readme',
'changelog_url': 'https://example.com/changelog',
def test_from_dict():
data = {
'capabilities': {'test': 'data'},
'readme_url': 'https://read.me',
'changelog_url': 'https://change.log',
'extra': 'data',
}
capabilities = from_dict(CapabilitiesPayload, data)

assert capabilities.capabilities == data['capabilities']
assert capabilities.changelog_url == data['changelog_url']
assert capabilities.readme_url == data['readme_url']

def test_configuration_payload():
assert ConfigurationPayload(
{'conf1': 'val1'},
'logging-token',
'environ-type',
'log-level',
'runner-log-level',
).to_json() == {
'configuration': {'conf1': 'val1'},
'logging_api_key': 'logging-token',
'environment_type': 'environ-type',
'log_level': 'log-level',
'runner_log_level': 'runner-log-level',

def test_parse_task_message():
msg_data = {
'message_type': 'task',
'data': {
'task_id': 'task_id',
'task_category': 'task_category',
'task_type': 'task_type',
'object_id': 'object_id',
'result': 'result',
'data': {'data': 'value'},
'countdown': 10,
'output': 'output',
'correlation_id': 'correlation_id',
'reply_to': 'reply_to',
},
}

message = parse_message(msg_data)

assert isinstance(message, Message)
assert message.message_type == MessageType.TASK
assert isinstance(message.data, TaskPayload)

def test_message_capabilities():
cap = CapabilitiesPayload(
{'cap1': 'val1'},
'https://example.com/readme',
'https://example.com/changelog',
)
assert dataclasses.asdict(message) == msg_data

msg = Message(
MessageType.CAPABILITIES,
cap.to_json(),
)
assert msg.data == cap

assert msg.to_json() == {
'message_type': MessageType.CAPABILITIES,
'data': cap.to_json(),
def test_parse_capabilities_message():
msg_data = {
'message_type': 'capabilities',
'data': {
'capabilities': {'test': 'data'},
'readme_url': 'https://read.me',
'changelog_url': 'https://change.log',
},
}

message = parse_message(msg_data)

assert isinstance(message, Message)
assert message.message_type == MessageType.CAPABILITIES
assert isinstance(message.data, CapabilitiesPayload)

assert dataclasses.asdict(message) == msg_data


def test_parse_configuration_message():
msg_data = {
'message_type': 'configuration',
'data': {
'configuration': {'conf1': 'val1'},
'logging_api_key': 'logging-token',
'environment_type': 'environ-type',
'log_level': 'log-level',
'runner_log_level': 'runner-log-level',
'account_id': 'account_id',
'account_name': 'account_name',
},
}

message = parse_message(msg_data)

assert isinstance(message, Message)
assert message.message_type == MessageType.CONFIGURATION
assert isinstance(message.data, ConfigurationPayload)

assert dataclasses.asdict(message) == msg_data
Loading