Skip to content

Commit

Permalink
Feature / Extra runtime validation (#273)
Browse files Browse the repository at this point in the history
* Add validation for model input/output schema fields

* Use using_data model for model import test (put inputs/outputs through validation)

* Remove integer categorical fields from example models

* Do not allow integer categorical variables in platform schema validator

* Fix integer categorical field in end-to-end test

* Disable RT tests for git repos (main branch does not have required changes
  • Loading branch information
martin-traverse committed Jan 12, 2023
1 parent 9ae6c01 commit c4528e3
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/models/python/src/tutorial/model_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def define_inputs(self) -> tp.Dict[str, trac.ModelInputSchema]:
trac.F("loan_amount", trac.DECIMAL, label="Principal loan amount"),
trac.F("total_pymnt", trac.DECIMAL, label="Total amount repaid"),
trac.F("region", trac.STRING, label="Customer home region", categorical=True),
trac.F("loan_condition_cat", trac.INTEGER, label="Loan condition category", categorical=True))
trac.F("loan_condition_cat", trac.INTEGER, label="Loan condition category"))

currency_data = trac.define_input_table(
trac.F("ccy_code", trac.STRING, label="Currency code", categorical=True),
Expand Down
2 changes: 1 addition & 1 deletion examples/models/python/src/tutorial/using_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def define_inputs(self) -> tp.Dict[str, trac.ModelInputSchema]:
trac.F("loan_amount", trac.DECIMAL, label="Principal loan amount"),
trac.F("total_pymnt", trac.DECIMAL, label="Total amount repaid"),
trac.F("region", trac.STRING, label="Customer home region", categorical=True),
trac.F("loan_condition_cat", trac.INTEGER, label="Loan condition category", categorical=True))
trac.F("loan_condition_cat", trac.INTEGER, label="Loan condition category"))

return {"customer_loans": customer_loans}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ public class SchemaValidator {
private static final List<BasicType> ALLOWED_BUSINESS_KEY_TYPES = List.of(
BasicType.STRING, BasicType.INTEGER, BasicType.DATE);

private static final List<BasicType> ALLOWED_CATEGORICAL_TYPES = List.of(
BasicType.STRING, BasicType.INTEGER);
private static final List<BasicType> ALLOWED_CATEGORICAL_TYPES = List.of(BasicType.STRING);

private static final Descriptors.Descriptor SCHEMA_DEFINITION;
private static final Descriptors.FieldDescriptor SD_SCHEMA_TYPE;
Expand Down
39 changes: 39 additions & 0 deletions tracdap-runtime/python/src/tracdap/rt/_impl/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,21 @@ class _StaticValidator:
__identifier_pattern = re.compile("\\A[a-zA-Z_]\\w+\\Z", re.ASCII)
__reserved_identifier_pattern = re.compile("\\A(_|trac_)", re.ASCII)

__PRIMITIVE_TYPES = [
meta.BasicType.BOOLEAN,
meta.BasicType.INTEGER,
meta.BasicType.FLOAT,
meta.BasicType.DECIMAL,
meta.BasicType.STRING,
meta.BasicType.DATE,
meta.BasicType.DATETIME,
]

__BUSINESS_KEY_TYPES = [
meta.BasicType.STRING,
meta.BasicType.INTEGER,
meta.BasicType.DATE]

_log: logging.Logger = util.logger_for_namespace(__name__)

@classmethod
Expand Down Expand Up @@ -289,13 +304,37 @@ def _check_table_fields(cls, inputs_or_outputs):

for input_name, input_schema in inputs_or_outputs.items():

cls._log.info(f"Checking {input_name}")

fields = input_schema.schema.table.fields
field_names = list(map(lambda f: f.fieldName, fields))
property_type = f"field in [{input_name}]"

cls._valid_identifiers(field_names, property_type)
cls._case_insensitive_duplicates(field_names, property_type)

for field in fields:
cls._check_single_field(field, property_type)

@classmethod
def _check_single_field(cls, field: meta.FieldSchema, property_type):

# Valid identifier and not trac reserved checked separately

cls._log.info(field.fieldName)

if field.fieldOrder < 0:
cls._fail(f"Invalid {property_type}: [{field.fieldName}] fieldOrder < 0")

if field.fieldType not in cls.__PRIMITIVE_TYPES:
cls._fail(f"Invalid {property_type}: [{field.fieldName}] fieldType is not a primitive type")

if field.businessKey and field.fieldType not in cls.__BUSINESS_KEY_TYPES:
cls._fail(f"Invalid {property_type}: [{field.fieldName}] fieldType {field.fieldType} used as business key")

if field.categorical and field.fieldType != meta.BasicType.STRING:
cls._fail(f"Invalid {property_type}: [{field.fieldName}] fieldType {field.fieldType} used as categorical")

@classmethod
def _valid_identifiers(cls, keys, property_type):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_import_model_job(self):
language="python",
repository="unit_test_repo",
path="examples/models/python/src",
entryPoint="tutorial.hello_world.HelloWorldModel",
entryPoint="tutorial.using_data.UsingDataModel",
version=self.commit_hash))

job_config = cfg.JobConfig(job_id, job_def)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.finos.tracdap.test.helpers.GitHelpers;
import org.finos.tracdap.test.helpers.PlatformTest;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.junit.jupiter.api.extension.RegisterExtension;
Expand All @@ -49,6 +50,7 @@ public static class LocalRepoTest extends RunFlowTest {
protected String useTracRepo() { return "TRAC_LOCAL_REPO"; }
}

@Disabled("Models on main not in sync with latest changes")
@EnabledIfEnvironmentVariable(named = "GITHUB_ACTIONS", matches = "true", disabledReason = "Only run in CI")
public static class GitRepoTest extends RunFlowTest {
protected String useTracRepo() { return "TRAC_GIT_REPO"; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ public static class LocalRepoTest extends RunFlowTest {
protected String useTracRepo() { return "TRAC_LOCAL_REPO"; }
}

@Disabled("Models on main not in sync with latest changes")
@EnabledIfEnvironmentVariable(named = "GITHUB_ACTIONS", matches = "true", disabledReason = "Only run in CI")
public static class GitRepoTest extends RunFlowTest {
protected String useTracRepo() { return "TRAC_GIT_REPO"; }
Expand Down Expand Up @@ -139,8 +140,7 @@ void loadInputData() throws Exception {
.setFieldType(BasicType.DECIMAL))
.addFields(FieldSchema.newBuilder()
.setFieldName("loan_condition_cat")
.setFieldType(BasicType.INTEGER)
.setCategorical(true))
.setFieldType(BasicType.INTEGER))
.addFields(FieldSchema.newBuilder()
.setFieldName("total_pymnt")
.setFieldType(BasicType.DECIMAL))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ public static class LocalRepoTest extends RunFlowTest {
protected String useTracRepo() { return "TRAC_LOCAL_REPO"; }
}

@Disabled("Models on main not in sync with latest changes")
@EnabledIfEnvironmentVariable(named = "GITHUB_ACTIONS", matches = "true", disabledReason = "Only run in CI")
public static class GitRepoTest extends RunFlowTest {
protected String useTracRepo() { return "TRAC_GIT_REPO"; }
Expand Down

0 comments on commit c4528e3

Please sign in to comment.