In [None]:
!pip install --user tfx tensorflow Pillow tensorflow_datasets matplotlib azure-storage-blob object-detection

In [None]:
import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
import tensorflow_datasets as tfds
import IPython.display as display
from azure.storage.blob import BlobServiceClient, AccountSasPermissions, ResourceTypes
from datetime import datetime, timedelta

In [None]:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


### Downloading images from Azure

In [None]:
connection_string = "DefaultEndpointsProtocol=https;AccountName=datacentricthesis;AccountKey=Z2yIApz/GjhHPu28cAclxOnaqChRERQlrmGkfqfDcpCLlBRo/oPBN8n3Mhg6cRVFR5b/iS0ZxZ/D+ASt378Qfw==;EndpointSuffix=core.windows.net"
account_name = "datacentricthesis"

In [None]:
def download_directory_from_blob_storage(blob_service_client, container_name, destination_directory, prefix=""):
  container_client = blob_service_client.get_container_client(container_name)
  blob_list = container_client.list_blobs(name_starts_with=prefix)
  
  for blob in blob_list:
    blob_path = os.path.relpath(blob.name, prefix)
    local_path = os.path.join(destination_directory, blob_path)

    os.makedirs(os.path.dirname(local_path), exist_ok=True)

    blob_client = container_client.get_blob_client(blob.name)
    with open(local_path, "wb") as file:
        file.write(blob_client.download_blob().readall())

In [None]:
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
container_name = "flowers"
destination_directory = "./azure_blob_storage_test"
prefix = "flower_photos/"

if not (os.path.exists(destination_directory)):
  print("Only create and download if destination doesn't extist so we don't override it \n")
  
  download_directory_from_blob_storage(blob_service_client, container_name, destination_directory, prefix)

### Parse images into byte strings


In [None]:
#Not sure but we might need to standardize the images before formatting. 
# https://www.tensorflow.org/tutorials/images/classification#standardize_the_data

def create_tf_example(image_path, label_map, label):
    img = PIL.Image.open(image_path)
    width, height = img.size
    img_format = img.format.lower()
    filename = os.path.basename(image_path)

    with tf.io.gfile.GFile(image_path, 'rb') as fid:
        encoded_image = fid.read()

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': _int64_feature(height),
        'image/width': _int64_feature(width),
        'image/filename': _bytes_feature(filename.encode('utf8')),
        'image/source_id': _bytes_feature(filename.encode('utf8')),
        'image/encoded': _bytes_feature(encoded_image),
        'image/format': _bytes_feature(img_format.encode('utf8')),
        'image/class/text': _bytes_feature(label.encode('utf8')),
        'image/class/label': _int64_feature(label_map[label]),
    }))
    return tf_example

## Formats images and creates a tfrecord file

In [None]:
def format_images():
  
  label_map = {
    'daisy': 1,
    'dandelion': 2,
    'roses': 3,
    'sunflowers': 4,
    'tulips': 5
  }  

  image_root = os.path.join(os.getcwd(), 'azure_blob_storage_test')
  if not os.path.exists('data'):
    os.mkdir('data')
  output_path = 'data/flower_images.tfrecord'

  flower_classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

  with tf.io.TFRecordWriter(output_path) as writer:
      for flower_class in flower_classes:
          flower_dir = os.path.join(image_root, flower_class)
          for image_name in os.listdir(flower_dir):
              image_path = os.path.join(flower_dir, image_name)
              tf_example = create_tf_example(image_path, label_map, flower_class)
              writer.write(tf_example.SerializeToString())

  print(f"TFRecord file created at {output_path}")  
  
format_images()

### Convert tfrecord back to images

In [None]:
# Create a dictionary describing the features.
def _parse_function(example_proto):
    feature_description = {
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image/filename': tf.io.FixedLenFeature([], tf.string),
        'image/source_id': tf.io.FixedLenFeature([], tf.string),
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/format': tf.io.FixedLenFeature([], tf.string),
        'image/class/text': tf.io.FixedLenFeature([], tf.string),
        'image/class/label': tf.io.FixedLenFeature([], tf.int64),
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_description)
    return parsed_features

def display_first_matching_by_flower(tfrecord_file, flower_type):
    raw_image_dataset = tf.data.TFRecordDataset(tfrecord_file)
    image_dataset = raw_image_dataset.map(_parse_function)

    for image_features in image_dataset:
        label = image_features['image/class/text'].numpy().decode('utf-8')
        if label == flower_type:
            encoded_image = image_features['image/encoded'].numpy()
            display.display(display.Image(data=encoded_image))
            break
        
def display_first_matching_by_label(tfrecord_file, label_number):
    raw_image_dataset = tf.data.TFRecordDataset(tfrecord_file)
    image_dataset = raw_image_dataset.map(_parse_function)

    for image_features in image_dataset:
        label = image_features['image/class/label']
        if label == label_number:
            encoded_image = image_features['image/encoded'].numpy()
            display.display(display.Image(data=encoded_image))
            break

tfrecord_file = 'data/flower_images.tfrecord'
display_first_matching_by_flower(tfrecord_file, 'sunflowers')
display_first_matching_by_label(tfrecord_file, 4)

### Create schema pipeline

In [None]:
import os

# We will create two pipelines. One for schema generation and one for training.
SCHEMA_PIPELINE_NAME = "flower-tfdv-schema"
PIPELINE_NAME = 'flower_pipeline'
TFX_ROOT = os.path.join(os.getcwd(), 'tfx')
# Output directory to store artifacts generated from the pipeline.
SCHEMA_PIPELINE_ROOT = os.path.join(TFX_ROOT, 'pipelines', SCHEMA_PIPELINE_NAME)
PIPELINE_ROOT = os.path.join(TFX_ROOT, 'pipelines', PIPELINE_NAME)
# Path to a SQLite DB file to use as an MLMD storage.
SCHEMA_METADATA_PATH = os.path.join(TFX_ROOT, 'metadata', SCHEMA_PIPELINE_NAME,
                                    'metadata.db')
METADATA_PATH = os.path.join(TFX_ROOT, 'metadata', PIPELINE_NAME, 'metadata.db')
DATA_ROOT = 'data'

# Output directory where created models from the pipeline will be exported.
SERVING_MODEL_DIR = os.path.join(TFX_ROOT,'serving_model', PIPELINE_NAME)
from absl import logging
logging.set_verbosity(logging.INFO)  # Set default logging level.


In [None]:

from tfx.v1.orchestration import metadata
from tfx.proto import example_gen_pb2
from tfx.v1.dsl import Pipeline 
from tfx.v1.components import ImportExampleGen, StatisticsGen, SchemaGen

def _create_schema_pipeline(pipeline_name: str,
                            pipeline_root: str,
                            data_root: str,
                            metadata_path: str) -> Pipeline:

  # Create the ImportExampleGen component
  output_config = example_gen_pb2.Output(
      split_config=example_gen_pb2.SplitConfig(splits=[
          example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=6),
          example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=2),
          example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=2)
      ]))

  example_gen = ImportExampleGen(input_base=data_root, output_config=output_config)

  # Create the StatisticsGen component
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  # Generates schema based on the generated statistics.
  schema_gen = SchemaGen(
      statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)

  components = [
      example_gen,
      statistics_gen,
      schema_gen
  ]

  return Pipeline(
    pipeline_name=pipeline_name,
    pipeline_root=pipeline_root,
    metadata_connection_config=metadata
    .sqlite_metadata_connection_config(metadata_path),
    components=components)



In [None]:
from tfx.v1.orchestration import LocalDagRunner
LocalDagRunner().run(
  _create_schema_pipeline(
      pipeline_name=SCHEMA_PIPELINE_NAME,
      pipeline_root=SCHEMA_PIPELINE_ROOT,
      data_root=DATA_ROOT,
      metadata_path=SCHEMA_METADATA_PATH))

In [None]:
from ml_metadata.proto import metadata_store_pb2
# Non-public APIs, just for showcase.
from tfx.orchestration.portable.mlmd import execution_lib

# TODO(b/171447278): Move these functions into the TFX library.

def get_latest_artifacts(metadata, pipeline_name, component_id):
  """Output artifacts of the latest run of the component."""
  context = metadata.store.get_context_by_type_and_name(
      'node', f'{pipeline_name}.{component_id}')
  executions = metadata.store.get_executions_by_context(context.id)
  latest_execution = max(executions,
                         key=lambda e:e.last_update_time_since_epoch)
  return execution_lib.get_output_artifacts(metadata, latest_execution.id)

# Non-public APIs, just for showcase.
from tfx.orchestration.experimental.interactive import visualizations

def visualize_artifacts(artifacts):
  """Visualizes artifacts using standard visualization modules."""
  for artifact in artifacts:
    visualization = visualizations.get_registry().get_visualization(
        artifact.type_name)
    if visualization:
      visualization.display(artifact)

from tfx.orchestration.experimental.interactive import standard_visualizations
standard_visualizations.register_standard_visualizations()


In [None]:
import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))


In [None]:
# Non-public APIs, just for showcase.
from tfx.orchestration.metadata import Metadata
from tfx.types import standard_component_specs

metadata_connection_config = tfx.orchestration.metadata.sqlite_metadata_connection_config(
    SCHEMA_METADATA_PATH)

with Metadata(metadata_connection_config) as metadata_handler:
  # Find output artifacts from MLMD.
  stat_gen_output = get_latest_artifacts(metadata_handler, SCHEMA_PIPELINE_NAME,
                                         'StatisticsGen')
  stats_artifacts = stat_gen_output[standard_component_specs.STATISTICS_KEY]

  schema_gen_output = get_latest_artifacts(metadata_handler,
                                           SCHEMA_PIPELINE_NAME, 'SchemaGen')
  schema_artifacts = schema_gen_output[standard_component_specs.SCHEMA_KEY]


In [None]:
visualize_artifacts(stats_artifacts)


In [None]:
visualize_artifacts(schema_artifacts)


### Exporting schema

In [None]:
import shutil

_schema_filename = 'schema.pbtxt'
SCHEMA_PATH = os.path.join(TFX_ROOT,'schema')

os.makedirs(SCHEMA_PATH, exist_ok=True)
_generated_path = os.path.join(schema_artifacts[0].uri, _schema_filename)

# Copy the 'schema.pbtxt' file from the artifact uri to a predefined path.
shutil.copy(_generated_path, SCHEMA_PATH)


### Create pipeline to validate inputs and train model

In [None]:
_trainer_module_file = 'flower_trainer.py'


In [None]:
%%writefile {_trainer_module_file}


### Creating Pipeline for training

In [63]:
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     schema_path: str, module_file: str, serving_model_dir: str,
                     metadata_path: str) -> tfx.dsl.Pipeline:
  """Creates a pipeline using predefined schema with TFX."""
  # Brings data into the pipeline.
    # Create the ImportExampleGen component
  output_config = example_gen_pb2.Output(
      split_config=example_gen_pb2.SplitConfig(splits=[
          example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=6),
          example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=2),
          example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=2)
      ]))

  example_gen = ImportExampleGen(input_base=data_root, output_config=output_config)

  # Computes statistics over data for visualization and example validation.
  statistics_gen = tfx.components.StatisticsGen(
      examples=example_gen.outputs['examples'])

  # NEW: Import the schema.
  schema_importer = tfx.dsl.Importer(
      source_uri=schema_path,
      artifact_type=tfx.types.standard_artifacts.Schema).with_id(
          'schema_importer')

  # NEW: Performs anomaly detection based on statistics and data schema.
  example_validator = tfx.components.ExampleValidator(
      statistics=statistics_gen.outputs['statistics'],
      schema=schema_importer.outputs['result'])

  # Uses user-provided Python function that trains a model.
  trainer = tfx.components.Trainer(
      module_file=module_file,
      examples=example_gen.outputs['examples'],
      schema=schema_importer.outputs['result'],  # Pass the imported schema.
      train_args=tfx.proto.TrainArgs(num_steps=100),
      eval_args=tfx.proto.EvalArgs(num_steps=5))

  # Pushes the model to a filesystem destination.
  pusher = tfx.components.Pusher(
      model=trainer.outputs['model'],
      push_destination=tfx.proto.PushDestination(
          filesystem=tfx.proto.PushDestination.Filesystem(
              base_directory=serving_model_dir)))

  components = [
      example_gen,

      # NEW: Following three components were added to the pipeline.
      statistics_gen,
      schema_importer,
      example_validator,
  ]

  return Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=tfx.orchestration.metadata
      .sqlite_metadata_connection_config(metadata_path),
      components=components)


In [62]:
from tfx.v1.orchestration import LocalDagRunner

LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=PIPELINE_NAME,
      pipeline_root=PIPELINE_ROOT,
      data_root=DATA_ROOT,
      schema_path=SCHEMA_PATH,
      module_file=_trainer_module_file,
      serving_model_dir=SERVING_MODEL_DIR,
      metadata_path=METADATA_PATH))


INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Excluding no splits because exclude_splits is not set.
INFO:absl:Using deployment config:
 executor_specs {
  key: "ExampleValidator"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_validator.executor.Executor"
    }
  }
}
executor_specs {
  key: "ImportExampleGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.example_gen.import_example_gen.executor.Executor"
      }
    }
  }
}
executor_specs {
  key: "StatisticsGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: "tfx.components.statistics_gen.executor.Executor"
      }
    }
  }
}
custom_driver_specs {
  key: "ImportExampleGen"
  value {
    python_class_executable_spec {
      class_path: "tfx.components.example_gen.driver.FileBasedDriver"
    }
  }
}
metadata_connection_config {
  database_connection_config {
    sqlite {
    