Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion connect/eaas/runner/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
TRANSFORMATION_TASK_MAX_EXECUTION_TIME = 300
RESULT_SENDER_MAX_RETRIES = 5
RESULT_SENDER_WAIT_GRACE_SECONDS = 90
TRANSFORMATION_TASK_MAX_PARALLEL_LINES = 20
TRANSFORMATION_TASK_MAX_PARALLEL_LINES = 200
DOWNLOAD_CHUNK_SIZE = 1024
UPLOAD_CHUNK_SIZE = 65535
TRANSFORMATION_WRITE_QUEUE_TIMEOUT = 600
Expand Down
83 changes: 52 additions & 31 deletions connect/eaas/runner/managers/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
from connect.eaas.core.enums import (
ResultType,
)
from connect.eaas.core.logging import (
RequestLogger,
)
from connect.eaas.core.proto import (
Task,
TaskOutput,
Expand All @@ -39,6 +42,9 @@
from connect.eaas.runner.managers.base import (
TasksManagerBase,
)
from connect.eaas.runner.managers.utils import (
ResultStore,
)


logger = logging.getLogger(__name__)
Expand All @@ -56,6 +62,7 @@ def get_client(self, task_data):
endpoint=self.config.get_api_url(),
use_specs=False,
default_headers=self.config.get_user_agent(),
logger=RequestLogger(logger),
)

return self.client
Expand Down Expand Up @@ -130,6 +137,7 @@ async def build_response(self, task_data, future):
return result_message

async def process_transformation(self, task_data, tfn_request, method):
semaphore = asyncio.Semaphore(TRANSFORMATION_TASK_MAX_PARALLEL_LINES)
input_file = await asyncio.get_running_loop().run_in_executor(
self.executor,
self.download_excel,
Expand All @@ -141,7 +149,7 @@ async def process_transformation(self, task_data, tfn_request, method):
)

read_queue = asyncio.Queue(TRANSFORMATION_TASK_MAX_PARALLEL_LINES)
write_queue = asyncio.Queue()
result_store = ResultStore()

loop = asyncio.get_event_loop()

Expand All @@ -156,15 +164,16 @@ async def process_transformation(self, task_data, tfn_request, method):
self.executor,
self.write_excel,
output_file.name,
write_queue,
result_store,
tfn_request['stats']['rows']['total'],
tfn_request['transformation']['columns']['output'],
task_data,
loop,
)
processor_task = asyncio.create_task(self.process_rows(
semaphore,
read_queue,
write_queue,
result_store,
method,
tfn_request['stats']['rows']['total'],
))
Expand Down Expand Up @@ -230,28 +239,30 @@ def read_excel(self, filename, queue, loop):

wb.close()

async def process_rows(self, read_queue, write_queue, method, total_rows):
async def process_rows(self, semaphore, read_queue, result_store, method, total_rows):
rows_processed = 0
tasks = []
while rows_processed < total_rows - 1:
while rows_processed < total_rows:
await semaphore.acquire()
row_idx, row = await read_queue.get()

if inspect.iscoroutinefunction(method):
tasks.append(asyncio.create_task(self.async_process_row(
semaphore,
method,
row_idx,
row,
write_queue,
result_store,
)))
else:
loop = asyncio.get_running_loop()
tasks.append(loop.run_in_executor(
self.executor,
self.sync_process_row,
semaphore,
method,
row_idx,
row,
write_queue,
result_store,
loop,
))

Expand All @@ -265,59 +276,69 @@ async def process_rows(self, read_queue, write_queue, method, total_rows):
task.cancel()
raise e

async def async_process_row(self, method, row_idx, row, write_queue):
async def async_process_row(self, semaphore, method, row_idx, row, result_store):
transformed_row = await method(row)
await write_queue.put((row_idx, transformed_row))
await result_store.put(row_idx, transformed_row)
semaphore.release()

def sync_process_row(self, semaphore, method, row_idx, row, result_store, loop):
async def store_results(transformed_row):
await result_store.put(row_idx, transformed_row)
semaphore.release()

def sync_process_row(self, method, row_idx, row, write_queue, loop):
transformed_row = method(row)
asyncio.run_coroutine_threadsafe(
write_queue.put((row_idx, transformed_row)),
loop,
)
asyncio.run_coroutine_threadsafe(store_results(transformed_row), loop)

def write_excel(self, filename, queue, total_rows, output_columns, task_data, loop):
wb = Workbook()
def write_excel(self, filename, result_store, total_rows, output_columns, task_data, loop):
wb = Workbook(write_only=True)

ws_columns = wb.active
ws_columns = wb.create_sheet('Columns')
ws = wb.create_sheet('Data')
ws_columns.title = 'Columns'
ws_columns.append(['Name', 'Type', 'Nullable', 'Description', 'Precision'])
column_keys = ['name', 'type', 'nullable', 'description', 'precision']
lookup_columns = {}
for col_idx, column in enumerate(output_columns, start=1):

column_names = []

for column in output_columns:
row = [column.get(key) for key in column_keys]
ws_columns.append(row)
lookup_columns[column.get('name')] = col_idx
ws.cell(row=1, column=col_idx, value=column.get('name'))
column_names.append(column.get('name'))

ws.append(column_names)

rows_processed = 0
total_rows -= 1
delta = 1 if total_rows <= 10 else round(total_rows / 10)

while rows_processed < total_rows:
for idx in range(2, total_rows + 2):
future = asyncio.run_coroutine_threadsafe(
queue.get(),
result_store.get(idx),
loop,
)
row_idx, row = future.result(
row_data = future.result(
timeout=self.config.env['transformation_write_queue_timeout'],
)
for name, value in row.items():
ws.cell(row=row_idx, column=lookup_columns[name], value=value)
row = [row_data.get(col_name) for col_name in column_names]

ws.append(row)
logger.debug(f'Row {idx} of {total_rows + 1} written!')
rows_processed += 1
if rows_processed % delta == 0 or rows_processed == total_rows:
asyncio.run_coroutine_threadsafe(
self.send_stat_update(task_data, rows_processed),
self.send_stat_update(task_data, rows_processed, total_rows),
loop,
)
logger.debug(
f'{task_data.input.object_id} processed {rows_processed}'
f' of {total_rows} rows',
)

wb.save(filename)

async def send_stat_update(self, task_data, rows_processed):
async def send_stat_update(self, task_data, rows_processed, total_rows):
client = self.get_client(task_data)
await client('billing').requests[task_data.input.object_id].update(
payload={'rows_processed': rows_processed},
payload={'stats': {'rows': {'total': total_rows, 'processed': rows_processed}}},
)

async def send_output_file(self, task_data, batch_id, output_file):
Expand Down
20 changes: 20 additions & 0 deletions connect/eaas/runner/managers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import asyncio


class ResultStore:
def __init__(self):
self.lock = asyncio.Lock()
self.futures = {}

async def put(self, idx, data):
async with self.lock:
future = self.futures.setdefault(idx, asyncio.Future())
future.set_result(data)

async def get(self, idx):
async with self.lock:
future = self.futures.setdefault(idx, asyncio.Future())
data = await future
async with self.lock:
del self.futures[idx]
return data
2 changes: 1 addition & 1 deletion tests/managers/test_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def transform_it(self, row):
assert result == task

requests = httpx_mock.get_requests()
assert len(requests) == 11
assert len(requests) == 12


@pytest.mark.asyncio
Expand Down