Skip to content

Commit

Permalink
Feature / Run flow (preview) (#139)
Browse files Browse the repository at this point in the history
* Stub example files for model chaining example

* Fix duplicate license headers in Python runtime

* Simplify execution graph for model import

* Simplify graph nodes

* Update graph builder for simplified nodes, add first draft of build_flow

* Experimental work - use typed wrapper classes for node results

* Graph fixes to allow using_data example to run

* Engine fixes following graph updates

* Engine fixes following graph updates

* Fix CI for the oldest supported version of PySpark, following a breaking change in PyPanDoc

* Remove NodeResult wrapper classes

* Engine updates

* Rename DataSpec class

* Update handling of section dependencies in graph builder

* Move config quoting into config_parser module

* Use YAML instead of JSON for example flow definition in the chaining exmaple

* Run the chaining example in the example tests

* Autowire flows as part of dev mode translation

* Generate flow parameters in dev mode

* Build flow fixes in dev_mode and graph_builder

* Update model chaining example

* Run flow work

* Run flow work

* Run flow work

* Run flow work

* Chaining example working in the runtime

* Dependency graph updates

* Fix Python 3.7 support in runtime engine type match check

* Fix Python 3.7 support in runtime engine type match check

* Make get_origin and get_args util functions in the runtime

* Minor fixes in using_data example

* Add E2E test for run flow

* Rename RunModelOrFlow as base class for model and flow job logic

* Working E2E RUN_FLOW job for the common case (minimal error handling)

* Provide a structured interface class for NodeContext in the runtime functions module

* Remove unneeded engine-level context class

* Complete type checking for node results and lookups

* Handle nodes with bundle results (type checks and logging still needed in NodeProcessor)

* Set missing node types in exec.graph

* Check node result types for bundles and other generic types

* Update runtime node logging

* Fix one todo in NodeLogger

* Bump netty version for compliance checks

* Move GCP SDK to major version 2 on latest stable

* Set version of Netty dependencies in -lib-test

* Compliance version bumps for protobuf and grpc

* Compliance - force versions of Google HTTP client in GCP SDK

* Compliance - false positive for gson

* Put the CDDL exclusion back in the gradle file for the GCP config plugin
  • Loading branch information
martin-traverse committed May 25, 2022
1 parent 27bea0e commit c4a4c8c
Show file tree
Hide file tree
Showing 38 changed files with 2,671 additions and 1,261 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,12 @@ jobs:

# PyPanDoc dependency is not managed correctly by PySpark 2.4 package
# It needs to be installed explicitly first
# Also, PyPanDoc >=1.8 breaks with PySpark 2.4
# This workaround could be dropped if TRAC set baseline support for Spark at Spark 3.0...
- name: Install pre-req dependencies
if: ${{ matrix.enviroment.PYPANDOC }}
run: |
pip install pypandoc
pip install "pypandoc <1.8"
- name: Install dependencies
run: |
Expand Down
10 changes: 10 additions & 0 deletions dev/compliance/owasp-false-positives.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,14 @@
<cpe>cpe:/a:mariadb:mariadb</cpe>
</suppress>

<!-- google-http-client-gson is detected as google:gson -->
<!-- In fact the version of the core gson library pulled in is 2.9, which is the patched version -->
<suppress>
<notes><![CDATA[
file name: google-http-client-gson-1.41.8.jar
]]></notes>
<packageUrl regex="true">^pkg:maven/com\.google\.http\-client/google\-http\-client\-gson@.*$</packageUrl>
<cpe>cpe:/a:google:gson</cpe>
</suppress>

</suppressions>
21 changes: 21 additions & 0 deletions examples/models/python/chaining/chaining.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

job:
runFlow:

flow: ./chaining_flow.yaml

parameters:
param_1: 42
param_2: "2015-01-01"
param_3: 1.5

inputs:
customer_loans: "inputs/loan_final313_100.csv"
currency_data: "inputs/currency_data_sample.csv"

outputs:
profit_by_region: "outputs/hello_pandas/profit_by_region.csv"

models:
model_1: model_1.FirstModel
model_2: model_2.SecondModel
23 changes: 23 additions & 0 deletions examples/models/python/chaining/chaining_flow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@

nodes:

customer_loans:
nodeType: "INPUT_NODE"

currency_data:
nodeType: "INPUT_NODE"

model_1:
nodeType: "MODEL_NODE"
modelStub:
inputs: [customer_loans, currency_data]
outputs: [preprocessed_data]

model_2:
nodeType: "MODEL_NODE"
modelStub:
inputs: [preprocessed_data]
outputs: [profit_by_region]

profit_by_region:
nodeType: "OUTPUT_NODE"
62 changes: 62 additions & 0 deletions examples/models/python/chaining/model_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Copyright 2022 Accenture Global Solutions Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime as dt
import typing as tp

import tracdap.rt.api as trac


class FirstModel(trac.TracModel):

def define_parameters(self) -> tp.Dict[str, trac.ModelParameter]:

return trac.declare_parameters(
trac.P("param_1", trac.INTEGER, "First parameter"),
trac.P("param_2", trac.DATE, "Second parameter", default_value=dt.date(2001, 1, 1)))

def define_inputs(self) -> tp.Dict[str, trac.ModelInputSchema]:

customer_loans = trac.declare_input_table(
trac.F("id", trac.BasicType.STRING, label="Customer account ID", business_key=True),
trac.F("loan_amount", trac.BasicType.DECIMAL, label="Principal loan amount", format_code="CCY:EUR"),
trac.F("total_pymnt", trac.BasicType.DECIMAL, label="Total amount repaid", format_code="CCY:EUR"),
trac.F("region", trac.BasicType.STRING, label="Customer home region", categorical=True),
trac.F("loan_condition_cat", trac.BasicType.INTEGER, label="Loan condition category", categorical=True))

currency_data = trac.declare_input_table(
trac.F("ccy_code", trac.BasicType.STRING, label="Currency code", categorical=True),
trac.F("spot_date", trac.BasicType.DATE, label="Spot date for FX rate"),
trac.F("dollar_rate", trac.BasicType.DECIMAL, label="Dollar FX rate", format_code="CCY:USD"))

return {"customer_loans": customer_loans, "currency_data": currency_data}

def define_outputs(self) -> tp.Dict[str, trac.ModelOutputSchema]:

preprocessed = trac.declare_output_table(
trac.F("id", trac.BasicType.STRING, label="Customer account ID", business_key=True),
trac.F("some_quantity_x", trac.BasicType.DECIMAL, label="Some quantity X", format_code="CCY:EUR"))

return {"preprocessed_data": preprocessed}

def run_model(self, ctx: trac.TracContext):

loans = ctx.get_pandas_table("customer_loans")
currencies = ctx.get_pandas_table("currency_data")

loans["some_quantity_x"] = loans["loan_amount"] - loans["total_pymnt"]

preproc = loans[["id", "some_quantity_x"]]

ctx.put_pandas_table("preprocessed_data", preproc)
56 changes: 56 additions & 0 deletions examples/models/python/chaining/model_2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2022 Accenture Global Solutions Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime as dt
import decimal
import typing as tp

import pandas as pd

import tracdap.rt.api as trac


class SecondModel(trac.TracModel):

def define_parameters(self) -> tp.Dict[str, trac.ModelParameter]:

return trac.declare_parameters(
trac.P("param_2", trac.DATE, "A data parameter", default_value=dt.date(2000, 1, 1)),
trac.P("param_3", trac.FLOAT, "A float parameter"))

def define_inputs(self) -> tp.Dict[str, trac.ModelInputSchema]:

preprocessed = trac.declare_input_table(
trac.F("id", trac.BasicType.STRING, label="Customer account ID", business_key=True),
trac.F("some_quantity_x", trac.BasicType.DECIMAL, label="Some quantity X", format_code="CCY:EUR"))

return {"preprocessed_data": preprocessed}

def define_outputs(self) -> tp.Dict[str, trac.ModelOutputSchema]:

profit_by_region = trac.declare_output_table(
trac.F("region", trac.BasicType.STRING, label="Customer home region", categorical=True),
trac.F("gross_profit", trac.BasicType.DECIMAL, label="Total gross profit", format_code="CCY:USD"))

return {"profit_by_region": profit_by_region}

def run_model(self, ctx: trac.TracContext):

preproc = ctx.get_pandas_table("preprocessed_data")

profit_by_region = pd.DataFrame(data={
"region": ["uk", "us"],
"gross_profit": [decimal.Decimal(24000000), decimal.Decimal(13000000)]})

ctx.put_pandas_table("profit_by_region", profit_by_region)
3 changes: 3 additions & 0 deletions examples/models/python/data/inputs/currency_data_sample.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ccy_code,spot_date,dollar_rate
EUR,2022-05-17,1.05
GBP,2022-05-17,1.25
6 changes: 3 additions & 3 deletions examples/models/python/using_data/using_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,17 @@ def run_model(self, ctx: trac.TracContext):
if filter_defaults:
customer_loans = customer_loans[customer_loans["loan_condition_cat"] == 0]

customer_loans.loc[:, "gross_profit_unweighted"] = \
customer_loans["gross_profit_unweighted"] = \
customer_loans["total_pymnt"] - \
customer_loans["loan_amount"]

condition_weighting = customer_loans["loan_condition_cat"] \
.apply(lambda c: decimal.Decimal(default_weighting) if c > 0 else decimal.Decimal(1))

customer_loans.loc[:, "gross_profit_weighted"] = \
customer_loans["gross_profit_weighted"] = \
customer_loans["gross_profit_unweighted"] * condition_weighting

customer_loans.loc[:, "gross_profit"] = \
customer_loans["gross_profit"] = \
customer_loans["gross_profit_weighted"] \
.apply(lambda x: x * decimal.Decimal.from_float(eur_usd_rate))

Expand Down
9 changes: 5 additions & 4 deletions gradle/versions.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ ext {
java_min_version = JavaVersion.VERSION_11

// Core platform technologies
netty_version = '4.1.74.Final' // gRPC does not support changes in Netty 4.1.75 yet (mar 2022)
netty_version = '4.1.77.Final' // gRPC does not support changes in Netty 4.1.75 yet (mar 2022)
guava_version = '31.1-jre'
proto_version = '3.19.4'
grpc_version = '1.45.0'
proto_version = '3.20.1'
grpc_version = '1.46.0'
gapi_version = '2.7.4'


Expand Down Expand Up @@ -54,7 +54,8 @@ ext {

// Plugins
aws_sdk_version = '1.12.178'
gcp_sdk_version = '1.113.14-sp.4'
gcp_sdk_version = '2.6.1'
gcp_sdk_http_client_version = '1.41.8' // Force Google HTTP client version needed for compliance


// Test dependencies
Expand Down
4 changes: 4 additions & 0 deletions tracdap-libs/tracdap-lib-test/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ dependencies {
implementation group: 'org.junit.jupiter', name: 'junit-jupiter-api', version: "$junit_version"
implementation group: 'org.junit.jupiter', name: 'junit-jupiter-params', version: "$junit_version"

implementation group: 'io.netty', name: 'netty-common', version: "$netty_version"
implementation group: 'io.netty', name: 'netty-codec-http', version: "$netty_version"
implementation group: 'io.netty', name: 'netty-codec-http2', version: "$netty_version"
implementation group: 'io.netty', name: 'netty-handler-proxy', version: "$netty_version"
implementation group: 'io.grpc', name: 'grpc-netty', version: "$grpc_version"

// Jackson uses runtime class resolution
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ public class JobValidator {
private static final Descriptors.FieldDescriptor RMJ_OUTPUTS;
private static final Descriptors.FieldDescriptor RMJ_PRIOR_OUTPUTS;

private static final Descriptors.Descriptor RUN_FLOW_JOB;
private static final Descriptors.FieldDescriptor RFJ_FLOW;
private static final Descriptors.FieldDescriptor RFJ_MODELS;
private static final Descriptors.FieldDescriptor RFJ_PARAMETERS;
private static final Descriptors.FieldDescriptor RFJ_INPUTS;
private static final Descriptors.FieldDescriptor RFJ_OUTPUTS;
private static final Descriptors.FieldDescriptor RFJ_PRIOR_OUTPUTS;

static {

JOB_DEFINITION = JobDefinition.getDescriptor();
Expand All @@ -73,6 +81,14 @@ public class JobValidator {
RMJ_INPUTS = field(RUN_MODEL_JOB, RunModelJob.INPUTS_FIELD_NUMBER);
RMJ_OUTPUTS = field(RUN_MODEL_JOB, RunModelJob.OUTPUTS_FIELD_NUMBER);
RMJ_PRIOR_OUTPUTS = field(RUN_MODEL_JOB, RunModelJob.PRIOROUTPUTS_FIELD_NUMBER);

RUN_FLOW_JOB = RunFlowJob.getDescriptor();
RFJ_FLOW = field(RUN_FLOW_JOB, RunFlowJob.FLOW_FIELD_NUMBER);
RFJ_MODELS = field(RUN_FLOW_JOB, RunFlowJob.MODELS_FIELD_NUMBER);
RFJ_PARAMETERS = field(RUN_FLOW_JOB, RunFlowJob.PARAMETERS_FIELD_NUMBER);
RFJ_INPUTS = field(RUN_FLOW_JOB, RunFlowJob.INPUTS_FIELD_NUMBER);
RFJ_OUTPUTS = field(RUN_FLOW_JOB, RunFlowJob.OUTPUTS_FIELD_NUMBER);
RFJ_PRIOR_OUTPUTS = field(RUN_FLOW_JOB, RunFlowJob.PRIOROUTPUTS_FIELD_NUMBER);
}


Expand Down Expand Up @@ -108,29 +124,59 @@ public static ValidationContext runModelJob(RunModelJob msg, ValidationContext c
.apply(ObjectIdValidator::selectorType, TagSelector.class, ObjectType.MODEL)
.pop();

ctx = ctx.pushMap(RMJ_PARAMETERS)
return runModelOrFlow(ctx, RMJ_PARAMETERS, RMJ_INPUTS, RMJ_OUTPUTS, RMJ_PRIOR_OUTPUTS);
}

@Validator
public static ValidationContext runFlowJob(RunFlowJob msg, ValidationContext ctx) {

ctx = ctx.push(RFJ_FLOW)
.apply(CommonValidators::required)
.apply(ObjectIdValidator::tagSelector, TagSelector.class)
.apply(ObjectIdValidator::selectorType, TagSelector.class, ObjectType.FLOW)
.pop();

ctx = ctx.pushMap(RFJ_MODELS)
.applyMapKeys(CommonValidators::identifier)
.applyMapKeys(CommonValidators::notTracReserved)
.applyMapValues(ObjectIdValidator::tagSelector, TagSelector.class)
.applyMapValues(ObjectIdValidator::selectorType, TagSelector.class, ObjectType.MODEL)
.applyMapValues(ObjectIdValidator::fixedObjectVersion, TagSelector.class)
.pop();

return runModelOrFlow(ctx, RFJ_PARAMETERS, RFJ_INPUTS, RFJ_OUTPUTS, RFJ_PRIOR_OUTPUTS);
}

public static ValidationContext runModelOrFlow(
ValidationContext ctx,
Descriptors.FieldDescriptor parameters,
Descriptors.FieldDescriptor inputs,
Descriptors.FieldDescriptor outputs,
Descriptors.FieldDescriptor priorOutputs) {

ctx = ctx.pushMap(parameters)
.applyMapKeys(CommonValidators::identifier)
.applyMapKeys(CommonValidators::notTracReserved)
.applyMapValues(TypeSystemValidator::value, Value.class)
.pop();

ctx = ctx.pushMap(RMJ_INPUTS)
ctx = ctx.pushMap(inputs)
.applyMapKeys(CommonValidators::identifier)
.applyMapKeys(CommonValidators::notTracReserved)
.applyMapValues(ObjectIdValidator::tagSelector, TagSelector.class)
.applyMapValues(ObjectIdValidator::selectorType, TagSelector.class, ObjectType.DATA)
.applyMapValues(ObjectIdValidator::fixedObjectVersion, TagSelector.class)
.pop();

ctx = ctx.pushMap(RMJ_OUTPUTS)
ctx = ctx.pushMap(outputs)
.applyMapKeys(CommonValidators::identifier)
.applyMapKeys(CommonValidators::notTracReserved)
.applyMapValues(ObjectIdValidator::tagSelector, TagSelector.class)
.applyMapValues(ObjectIdValidator::selectorType, TagSelector.class, ObjectType.DATA)
.applyMapValues(ObjectIdValidator::fixedObjectVersion, TagSelector.class)
.pop();

ctx = ctx.pushMap(RMJ_PRIOR_OUTPUTS)
ctx = ctx.pushMap(priorOutputs)
.applyMapKeys(CommonValidators::identifier)
.applyMapKeys(CommonValidators::notTracReserved)
.applyMapValues(ObjectIdValidator::tagSelector, TagSelector.class)
Expand All @@ -141,12 +187,6 @@ public static ValidationContext runModelJob(RunModelJob msg, ValidationContext c
return ctx;
}

@Validator
public static ValidationContext runFlowJob(RunFlowJob msg, ValidationContext ctx) {

return ctx.error("Run flow not implemented yet");
}

private static ValidationContext jobMatchesType(ValidationContext ctx) {

var job = (JobDefinition) ctx.parentMsg();
Expand Down
8 changes: 8 additions & 0 deletions tracdap-plugins/gcp-config/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ dependencies {
exclude group: 'javax.annotation', module: 'javax.annotation-api'
}

// Google HTTP client components are failing compliance on the latest version
// Force these versions until the main GCP SDK updates this dependency
implementation group: 'com.google.http-client', name: 'google-http-client', version: "$gcp_sdk_http_client_version"
implementation group: 'com.google.http-client', name: 'google-http-client-apache-v2', version: "$gcp_sdk_http_client_version"
implementation group: 'com.google.http-client', name: 'google-http-client-appengine', version: "$gcp_sdk_http_client_version"
implementation group: 'com.google.http-client', name: 'google-http-client-gson', version: "$gcp_sdk_http_client_version"
implementation group: 'com.google.http-client', name: 'google-http-client-jackson2', version: "$gcp_sdk_http_client_version"

// Force the version of Jackson (only latest stable will ever pass compliance)
implementation group: 'com.fasterxml.jackson.core', name: 'jackson-core', version: "$jackson_version"
implementation group: 'com.fasterxml.jackson.core', name: 'jackson-databind', version: "$jackson_databind_version"
Expand Down
12 changes: 0 additions & 12 deletions tracdap-runtime/python/src/tracdap/rt/api/core_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

Expand Down
Loading

0 comments on commit c4a4c8c

Please sign in to comment.