In [None]:
!pip install --user tfx tensorflow Pillow tensorflow_datasets matplotlib azure-storage-blob object-detection

In [None]:
import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
import tensorflow_datasets as tfds
import IPython.display as display
from azure.storage.blob import BlobServiceClient, AccountSasPermissions, ResourceTypes
from datetime import datetime, timedelta

In [None]:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


### Downloading images from Azure

In [None]:
connection_string = "DefaultEndpointsProtocol=https;AccountName=datacentricthesis;AccountKey=Z2yIApz/GjhHPu28cAclxOnaqChRERQlrmGkfqfDcpCLlBRo/oPBN8n3Mhg6cRVFR5b/iS0ZxZ/D+ASt378Qfw==;EndpointSuffix=core.windows.net"
account_name = "datacentricthesis"

In [None]:
def download_directory_from_blob_storage(blob_service_client, container_name, destination_directory, prefix=""):
  container_client = blob_service_client.get_container_client(container_name)
  blob_list = container_client.list_blobs(name_starts_with=prefix)
  
  for blob in blob_list:
    blob_path = os.path.relpath(blob.name, prefix)
    local_path = os.path.join(destination_directory, blob_path)

    os.makedirs(os.path.dirname(local_path), exist_ok=True)

    blob_client = container_client.get_blob_client(blob.name)
    with open(local_path, "wb") as file:
        file.write(blob_client.download_blob().readall())

In [None]:
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
container_name = "flowers"
destination_directory = "./azure_blob_storage_test"
prefix = "flower_photos/"

if not (os.path.exists(destination_directory)):
  print("Only create and download if destination doesn't extist so we don't override it \n")
  
  download_directory_from_blob_storage(blob_service_client, container_name, destination_directory, prefix)

### Parse images into byte strings


In [None]:
def create_tf_example(image_path, label_map, label):
    img = PIL.Image.open(image_path)
    width, height = img.size
    img_format = img.format.lower()
    filename = os.path.basename(image_path)

    with tf.io.gfile.GFile(image_path, 'rb') as fid:
        encoded_image = fid.read()

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': _int64_feature(height),
        'image/width': _int64_feature(width),
        'image/filename': _bytes_feature(filename.encode('utf8')),
        'image/source_id': _bytes_feature(filename.encode('utf8')),
        'image/encoded': _bytes_feature(encoded_image),
        'image/format': _bytes_feature(img_format.encode('utf8')),
        'image/class/text': _bytes_feature(label.encode('utf8')),
        'image/class/label': _int64_feature(label_map[label]),
    }))
    return tf_example

## Formats images and creates a tfrecord file

In [None]:
def format_images():
  
  label_map = {
    'daisy': 1,
    'dandelion': 2,
    'roses': 3,
    'sunflowers': 4,
    'tulips': 5
  }  

  image_root = os.path.join(os.getcwd(), 'azure_blob_storage_test')
  if not os.path.exists('data'):
    os.mkdir('data')
  output_path = 'data/flower_images.tfrecord'

  flower_classes = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

  with tf.io.TFRecordWriter(output_path) as writer:
      for flower_class in flower_classes:
          flower_dir = os.path.join(image_root, flower_class)
          for image_name in os.listdir(flower_dir):
              image_path = os.path.join(flower_dir, image_name)
              tf_example = create_tf_example(image_path, label_map, flower_class)
              writer.write(tf_example.SerializeToString())

  print(f"TFRecord file created at {output_path}")  
  
format_images()

### Convert tfrecord back to images

In [None]:
raw_image_dataset = tf.data.TFRecordDataset('flower_images.tfrecords')

# Create a dictionary describing the features.
def _parse_function(example_proto):
    feature_description = {
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image/filename': tf.io.FixedLenFeature([], tf.string),
        'image/source_id': tf.io.FixedLenFeature([], tf.string),
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/format': tf.io.FixedLenFeature([], tf.string),
        'image/class/text': tf.io.FixedLenFeature([], tf.string),
        'image/class/label': tf.io.FixedLenFeature([], tf.int64),
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_description)
    return parsed_features

def display_first_matching_by_flower(tfrecord_file, flower_type):
    raw_image_dataset = tf.data.TFRecordDataset(tfrecord_file)
    image_dataset = raw_image_dataset.map(_parse_function)

    for image_features in image_dataset:
        label = image_features['image/class/text'].numpy().decode('utf-8')
        if label == flower_type:
            encoded_image = image_features['image/encoded'].numpy()
            display.display(display.Image(data=encoded_image))
            break
        
def display_first_matching_by_label(tfrecord_file, label_number):
    raw_image_dataset = tf.data.TFRecordDataset(tfrecord_file)
    image_dataset = raw_image_dataset.map(_parse_function)

    for image_features in image_dataset:
        label = image_features['image/class/label']
        if label == label_number:
            encoded_image = image_features['image/encoded'].numpy()
            display.display(display.Image(data=encoded_image))
            break

tfrecord_file = 'data/flower_images.tfrecord'
display_first_matching_by_flower(tfrecord_file, 'sunflowers')
display_first_matching_by_label(tfrecord_file, 4)

### Reading record file into tfx examplegen component, creates separate python script


In [87]:

import os
from tfx.v1.orchestration import metadata
from tfx.proto import example_gen_pb2
from tfx.v1.dsl import Pipeline 
from tfx.v1.components import ImportExampleGen
from tfx.v1.components import StatisticsGen

def _create_pipeline(pipeline_name, pipeline_root, data_root, metadata_path):

  # Create the ImportExampleGen component
  output_config = example_gen_pb2.Output(
      split_config=example_gen_pb2.SplitConfig(splits=[
          example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=6),
          example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=2),
          example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=2)
      ]))

  example_gen = ImportExampleGen(input_base=data_root, output_config=output_config)

  # Create the StatisticsGen component
  statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

  components = [
      example_gen,
      statistics_gen,
  ]

  return Pipeline(
    pipeline_name=pipeline_name,
    pipeline_root=pipeline_root,
    metadata_connection_config=metadata.sqlite_metadata_connection_config(metadata_path),
    components=components)



In [88]:
PIPELINE_NAME = 'flower_pipeline'
TFX_ROOT = os.path.join(os.getcwd(), 'tfx')
PIPELINE_ROOT = os.path.join(TFX_ROOT, 'pipelines', PIPELINE_NAME)
METADATA_PATH = os.path.join(TFX_ROOT, 'metadata', 'metadata.db')
DATA_ROOT = 'data'

In [89]:
from tfx.v1.orchestration import LocalDagRunner
LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=PIPELINE_NAME,
      pipeline_root=PIPELINE_ROOT,
      data_root=DATA_ROOT,
      #module_file=_trainer_module_file,
      #serving_model_dir=SERVING_MODEL_DIR,
      metadata_path=METADATA_PATH))



In [None]:
num_classes = 5

model = tf.keras.Sequential([
  tf.keras.layers.Rescaling(1./255),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(num_classes)
])