In [22]:
from tensorflow_data_validation import GenerateStatistics, TFExampleDecoder
from tensorflow_metadata.proto.v0 import statistics_pb2, schema_pb2
import logging
from tensorflow_data_validation.utils.schema_util import get_domain, get_feature
import apache_beam as beam
import tensorflow_data_validation as tfdv
from apache_beam.options.pipeline_options import (PipelineOptions, GoogleCloudOptions,
StandardOptions, WorkerOptions, DebugOptions, SetupOptions)
from tensorflow_data_validation import statistics
from spotify_tensorflow.tfx.tfdv import TfDataValidator
from spotify_tensorflow.tfx.utils import create_setup_file
from tensorflow_data_validation import StatsOptions
import six

from spotify_tensorflow.tf_schema_utils import schema_txt_file_to_feature_spec
from spotify_tensorflow.tf_schema_utils import parse_schema_txt_file

from google.protobuf import text_format
from tensorflow.python.lib.io import file_io

In [3]:
schema = parse_schema_txt_file('../user-protection-pipeline/tf-supervised/src/main/python/trainers/schemas/email_open.pbtxt')
schema_spec = schema_txt_file_to_feature_spec('../user-protection-pipeline/tf-supervised/src/main/python/trainers/schemas/email_open.pbtxt')

In [4]:
def set_domain(schema,
               feature_name,
               domain):
  """Sets the domain for the input feature in the schema.
  If the input feature already has a domain, it is overwritten with the newly
  provided input domain. This method cannot be used to add a new global domain.
  Args:
    schema: A Schema protocol buffer.
    feature_name: The name of the feature whose domain needs to be set.
    domain: A domain protocol buffer (one of IntDomain, FloatDomain,
        StringDomain or BoolDomain) or the name of a global string domain
        present in the input schema.
  Example:
  ```python
    >>> from tensorflow_metadata.proto.v0 import schema_pb2
    >>> import tensorflow_data_validation as tfdv
    >>> schema = schema_pb2.Schema()
    >>> schema.feature.add(name='feature')
    # Setting a int domain.
    >>> int_domain = schema_pb2.IntDomain(min=3, max=5)
    >>> tfdv.set_domain(schema, "feature", int_domain)
    # Setting a string domain.
    >>> str_domain = schema_pb2.StringDomain(value=['one', 'two', 'three'])
    >>> tfdv.set_domain(schema, "feature", str_domain)
  ```
  Raises:
    TypeError: If the input schema or the domain is not of the expected type.
    ValueError: If an invalid global string domain is provided as input.
  """
  if not isinstance(schema, schema_pb2.Schema):
    raise TypeError('schema is of type %s, should be a Schema proto.' %
                    type(schema).__name__)

  if not isinstance(domain, (schema_pb2.IntDomain, schema_pb2.FloatDomain,
                             schema_pb2.StringDomain, schema_pb2.BoolDomain,
                             six.string_types)):
    raise TypeError('domain is of type %s, should be one of IntDomain, '
                    'FloatDomain, StringDomain, BoolDomain proto or a string '
                    'denoting the name of a global domain in the schema.' %
                    type(domain).__name__)

  feature = get_feature(schema, feature_name)

  if feature.WhichOneof('domain_info') is not None:
    logging.warning('Replacing existing domain of feature "%s".', feature_name)

  if isinstance(domain, schema_pb2.IntDomain):
    feature.int_domain.CopyFrom(domain)
  elif isinstance(domain, schema_pb2.FloatDomain):
    feature.float_domain.CopyFrom(domain)
  elif isinstance(domain, schema_pb2.StringDomain):
    feature.string_domain.CopyFrom(domain)
  elif isinstance(domain, schema_pb2.BoolDomain):
    feature.bool_domain.CopyFrom(domain)
  else:
    # If we have a domain name provided as input, check if we have a valid
    # global string domain with the specified name.
    found_domain = False
    for global_domain in schema.string_domain:
      if global_domain.name == domain:
        found_domain = True
        break
    if not found_domain:
      raise ValueError('Invalid global string domain "{}".'.format(domain))
    feature.domain = domain

In [5]:
statsFilePath ="gs://slayton_test/email_open/tf/examples/email_open.BaseInputDataV1/2019-04-01/" + \
               "20190411T155410.476299-3e9102f6a419/training/_stats.pb"
stats = tfdv.load_statistics(input_path=statsFilePath)

In [15]:
for field in stats.ListFields():
    for feat in field[1]:
        lst = []
        for x in feat.features:
            ftype = schema_spec[x.name].dtype.name
            fmin = x.num_stats.min
            fmax = x.num_stats.max
            print(ftype)
            if ftype == 'int64':
                domain = schema_pb2.IntDomain(min=int(fmin), max=int(fmax))
            elif ftype == 'float32':
                domain = schema_pb2.FloatDomain(min=float(fmin), max=float(fmax))
            else:
                continue
            print("Set domain for %s" % x.name)
            set_domain(schema, x.name, domain)
            

            



int64
Set domain for userAggTable.days_since_last_dau
int64
Set domain for emailLabelTable.hrs_to_open
int64
Set domain for userAggTable.num_push_click_week
float32
Set domain for userAggTable.email_click_open_rate_week
string
int64
Set domain for userAggTable.dsr
int64
Set domain for userAggTable.num_in_app_month
string
float32
Set domain for userAggTable.in_app_click_rate_yesterday
int64
Set domain for userAggTable.num_in_app_click_yesterday
int64
Set domain for userAggTable.num_email_click_yesterday
int64
Set domain for userAggTable.num_email_open_yesterday
int64
Set domain for userAggTable.is_mau
string
int64
Set domain for userAggTable.num_in_app_click_month
float32
Set domain for userAggTable.email_open_rate_month
string
int64
Set domain for userAggTable.num_email_week
string
string
int64
Set domain for emailLabelTable.clicked
int64
Set domain for userAggTable.num_email_yesterday
string
int64
Set domain for userAggTable.num_push_month
int64
Set domain for userAggTable.num_streams

In [18]:
get_feature(schema=schema, feature_name="emailLabelTable.clicked")
#schema_spec["userAggTable.user_id"].dtype.name

name: "emailLabelTable.clicked"
value_count {
  min: 1
  max: 1
}
type: INT
int_domain {
  min: 0
  max: 1
}
presence {
  min_fraction: 1.0
  min_count: 1
}

In [19]:
schema

feature {
  name: "userAggTable.days_since_last_dau"
  value_count {
    min: 1
    max: 1
  }
  type: INT
  int_domain {
    min: 0
    max: 1866
  }
  presence {
    min_count: 1
  }
}
feature {
  name: "emailLabelTable.hrs_to_open"
  value_count {
    min: 1
    max: 1
  }
  type: INT
  int_domain {
    min: 0
    max: 167
  }
  presence {
    min_count: 1
  }
}
feature {
  name: "userAggTable.num_push_click_week"
  value_count {
    min: 1
    max: 1
  }
  type: INT
  int_domain {
    min: 0
    max: 4
  }
  presence {
    min_count: 1
  }
}
feature {
  name: "userAggTable.email_click_open_rate_week"
  value_count {
    min: 1
    max: 1
  }
  type: FLOAT
  float_domain {
    min: 0.0
    max: 2.0
  }
  presence {
    min_count: 1
  }
}
feature {
  name: "userAggTable.primary_platform"
  value_count {
    min: 1
    max: 1
  }
  type: BYTES
  domain: "userAggTable.primary_platform"
  presence {
    min_count: 1
  }
}
feature {
  name: "userAggTable.dsr"
  value_count {
    min: 1
 

In [20]:
def write_schema_text(schema, output_path):
  """Writes input schema to a file in text format.
  Args:
    schema: A Schema protocol buffer.
    output_path: File path to write the input schema.
  Raises:
    TypeError: If the input schema is not of the expected type.
  """
  if not isinstance(schema, schema_pb2.Schema):
    raise TypeError('schema is of type %s, should be a Schema proto.' %
                    type(schema).__name__)

  schema_text = text_format.MessageToString(schema)
  file_io.write_string_to_file(output_path, schema_text)

In [23]:
write_schema_text(schema, '../user-protection-pipeline/tf-supervised/src/main/python/trainers/schemas/email_open.pbtxt')

