# Developing dataset schema

In [23]:
import os
import sys
import logging

import tensorflow as tf
import tensorflow_data_validation as tfdv

from tensorflow_metadata.proto.v0 import schema_pb2, statistics_pb2, anomalies_pb2
from google.cloud import bigquery

In [24]:
PROJECT = 'jk-mlops-dev'
STAGING_BUCKET = 'gs://jk-vertex-workshop-bucket'
REGION = 'us-central1'

BQ_DATASET_NAME = 'chicago_taxi_tips_dataset' # Change to your BQ datasent name.
BQ_TRAIN_SPLIT_NAME = 'chicago_taxi_tips_train'
BQ_VALID_SPLIT_NAME = 'valid_split'
BQ_TEST_SPLIT_NAME = 'test_split'
BQ_LOCATION = 'US'

## Generate Raw Data Schema

### Load a sample of the training split

In [25]:
client = bigquery.Client()

sql_script = f'''
SELECT * 
FROM {PROJECT}.{BQ_DATASET_NAME}.{BQ_TRAIN_SPLIT_NAME} 
'''
df = client.query(sql_script).result().to_dataframe()

In [26]:
df.head().T

Unnamed: 0,0,1,2,3,4
trip_month,2,2,2,2,2
trip_day,1,1,1,1,1
trip_day_of_week,7,7,7,7,7
trip_hour,8,8,19,5,7
trip_seconds,219,418,420,240,289
trip_miles,0.46,1.53,1.2,1.1,0.88
payment_type,Cash,Cash,Cash,Credit Card,Cash
pickup_grid,POINT(-87.6 41.9),POINT(-87.7 41.9),POINT(-87.6 41.9),POINT(-87.7 41.9),POINT(-87.6 41.9)
dropoff_grid,POINT(-87.6 41.9),POINT(-87.7 41.9),POINT(-87.6 41.9),POINT(-87.7 41.9),POINT(-87.6 41.9)
euclidean,0.0,0.0,0.0,0.0,0.0


### Generate statistics

In [27]:
stats = tfdv.generate_statistics_from_dataframe(
    dataframe=df,
    stats_options=tfdv.StatsOptions(
        label_feature='tip_bin',
        weight_feature=None,
        sample_rate=1,
        num_top_values=50
    )
)

In [28]:
tfdv.visualize_statistics(stats)

### Generate schema

In [29]:
schema = tfdv.infer_schema(statistics=stats)
tfdv.display_schema(schema=schema)

Unnamed: 0_level_0,Type,Presence,Valency,Domain
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
'trip_month',INT,required,,-
'trip_day',INT,required,,-
'trip_day_of_week',INT,required,,-
'trip_hour',INT,required,,-
'trip_seconds',INT,required,,-
'trip_miles',FLOAT,required,,-
'payment_type',STRING,required,,'payment_type'
'pickup_grid',STRING,required,,'pickup_grid'
'dropoff_grid',STRING,required,,'dropoff_grid'
'euclidean',FLOAT,required,,-


### Update the schema

In [30]:
tfdv.set_domain(schema, 'trip_month', schema_pb2.IntDomain(name='trip_month', min=1, max=12, is_categorical=True))
tfdv.set_domain(schema, 'trip_day', schema_pb2.IntDomain(name='trip_day', min=1, max=31, is_categorical=True))
tfdv.set_domain(schema, 'trip_day_of_week', schema_pb2.IntDomain(name='trip_day_of_week', min=1, max=7, is_categorical=True))
tfdv.set_domain(schema, 'trip_hour', schema_pb2.IntDomain(name='trip_hour', min=0, max=23, is_categorical=True))

2
4
2
4
2
4
2
4


In [31]:
tfdv.display_schema(schema=schema)

Unnamed: 0_level_0,Type,Presence,Valency,Domain
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
'trip_month',INT,required,,"[1,12]"
'trip_day',INT,required,,"[1,31]"
'trip_day_of_week',INT,required,,"[1,7]"
'trip_hour',INT,required,,"[0,23]"
'trip_seconds',INT,required,,-
'trip_miles',FLOAT,required,,-
'payment_type',STRING,required,,'payment_type'
'pickup_grid',STRING,required,,'pickup_grid'
'dropoff_grid',STRING,required,,'dropoff_grid'
'euclidean',FLOAT,required,,-


### Save the updated schema

In [35]:
schema_dir = os.path.join(STAGING_BUCKET, 'schema')
tf.io.gfile.makedirs(schema_dir)
schema_file = os.path.join(schema_dir, 'schema.pbtxt')

tfdv.write_schema_text(schema, schema_file)

In [36]:
schema_file

'gs://jk-vertex-workshop-bucket/schema/schema.pbtxt'