<a href="https://colab.research.google.com/github/margaretmz/CartoonGAN-e2e-tflite-tutorial/blob/master/ml/metadata/Add_Metadata.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Reference: https://github.com/margaretmz/selfie2anime-with-tflite/blob/master/ml/add-meta-data-Colab/Add%20metadata%20to%20selfie2anime.ipynb. 

TensorFlow Lite meatdata: https://www.tensorflow.org/lite/convert/metadata.

In [None]:
!pip install tflite-support

In [2]:
import os
import tensorflow as tf
from absl import flags

In [3]:
from tflite_support import flatbuffers
from tflite_support import metadata as _metadata
from tflite_support import metadata_schema_py_generated as _metadata_fb

In [4]:
!mkdir model_without_metadata
!mkdir model_with_metadata

In [None]:
!wget https://storage.googleapis.com/cartoon_gan/whitebox_cartoon_gan_dr.tflite
!wget https://storage.googleapis.com/cartoon_gan/whitebox_cartoon_gan_fp16.tflite
!wget https://storage.googleapis.com/cartoon_gan/whitebox_cartoon_gan_int8.tflite
!wget https://storage.googleapis.com/cartoon_gan/whitebox_cartoon_gan_full_int8.tflite

!mv *.tflite model_without_metadata/

In [6]:
# This is where we will export a new .tflite model file with metadata, and a .json file with metadata info
EXPORT_DIR = "model_with_metadata"

In [11]:
class MetadataPopulatorForGANModel(object):
  """Populates the metadata for the CartoonGAN model."""

  def __init__(self, model_file, model_type="other"):
    self.model_file = model_file
    self.metadata_buf = None
    self.model_type = model_type

  def populate(self):
    """Creates metadata and then populates it for a style transfer model."""
    self._create_metadata()
    self._populate_metadata()
  
  def _create_metadata(self):
    """Creates the metadata for the CartoonGAN model."""

    # Creates model info.
    model_meta = _metadata_fb.ModelMetadataT()
    model_meta.name = "CartoonGAN" 
    model_meta.description = ("Cartoonizes an image. Reference: https://bit.ly/cartoon-gan.")
    model_meta.version = "v1"
    model_meta.author = "TensorFlow"
    model_meta.license = ("Apache License. Version 2.0 "
                          "http://www.apache.org/licenses/LICENSE-2.0.")

    # Creates info for the input, normal image.
    input_image_meta = _metadata_fb.TensorMetadataT()
    input_image_meta.name = "source_image"
    if self.model_type=="other":
        input_image_meta.description = (
            "The expected image can be of any shape but with three channels "
            "(red, blue, and green) per pixel. Each value in the tensor is between"
            " -1 and 1.")
    elif self.model_type=="fp16":
        input_image_meta.description = (
            "The expected image is 224 x 224, with three channels "
            "(red, blue, and green) per pixel. Each value in the tensor is between"
            " -1 and 1.")
    input_image_meta.content = _metadata_fb.ContentT()
    input_image_meta.content.contentProperties = (
        _metadata_fb.ImagePropertiesT())
    input_image_meta.content.contentProperties.colorSpace = (
        _metadata_fb.ColorSpaceType.RGB)
    input_image_meta.content.contentPropertiesType = (
        _metadata_fb.ContentProperties.ImageProperties)
    input_image_normalization = _metadata_fb.ProcessUnitT()
    input_image_normalization.optionsType = (
        _metadata_fb.ProcessUnitOptions.NormalizationOptions)
    input_image_normalization.options = _metadata_fb.NormalizationOptionsT()
    input_image_normalization.options.mean = [127.5]
    input_image_normalization.options.std = [127.5]
    input_image_meta.processUnits = [input_image_normalization]
    input_image_stats = _metadata_fb.StatsT()
    input_image_stats.max = [1.0]
    input_image_stats.min = [-1.0]
    input_image_meta.stats = input_image_stats


    # Creates output info, cartoonized image
    output_image_meta = _metadata_fb.TensorMetadataT()
    output_image_meta.name = "cartoonized_image"
    output_image_meta.description = "Image cartoonized."
    output_image_meta.content = _metadata_fb.ContentT()
    output_image_meta.content.contentProperties = _metadata_fb.ImagePropertiesT()
    output_image_meta.content.contentProperties.colorSpace = (
        _metadata_fb.ColorSpaceType.RGB)
    output_image_meta.content.contentPropertiesType = (
        _metadata_fb.ContentProperties.ImageProperties)
    output_image_normalization = _metadata_fb.ProcessUnitT()
    output_image_normalization.optionsType = (
        _metadata_fb.ProcessUnitOptions.NormalizationOptions)
    output_image_normalization.options = _metadata_fb.NormalizationOptionsT()
    output_image_normalization.options.mean = [0.0]
    output_image_normalization.options.std = [1.0]
    output_image_meta.processUnits = [output_image_normalization]
    output_image_stats = _metadata_fb.StatsT()
    output_image_stats.max = [255.0]
    output_image_stats.min = [0.0]
    output_image_meta.stats = output_image_stats

    # Creates subgraph info.
    subgraph = _metadata_fb.SubGraphMetadataT()
    subgraph.inputTensorMetadata = [input_image_meta] 
    subgraph.outputTensorMetadata = [output_image_meta] 
    model_meta.subgraphMetadata = [subgraph]

    b = flatbuffers.Builder(0)
    b.Finish(
        model_meta.Pack(b),
        _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER)
    self.metadata_buf = b.Output()

  def _populate_metadata(self):
    """Populates metadata to the model file."""
    populator = _metadata.MetadataPopulator.with_model_file(self.model_file)
    populator.load_metadata_buffer(self.metadata_buf)
    populator.populate()

In [12]:
def populate_metadata(model_file, model_type="other"):
  """Populates the metadata using the populator specified.
  Args:
      model_file: valid path to the model file.
      model_type: a type defined in StyleTransferModelType .
  """

  # Populates metadata for the model.
  model_file_basename = os.path.basename(model_file)
  export_path = os.path.join(EXPORT_DIR, model_file_basename)
  tf.io.gfile.copy(model_file, export_path, overwrite=True)

  populator = MetadataPopulatorForGANModel(export_path, model_type) 
  populator.populate()

  # Displays the metadata that was just populated into the tflite model.
  displayer = _metadata.MetadataDisplayer.with_model_file(export_path)
  export_json_file = os.path.join(
      EXPORT_DIR,
      os.path.splitext(model_file_basename)[0] + ".json")
  json_file = displayer.get_metadata_json()
  with open(export_json_file, "w") as f:
    f.write(json_file)
  print("Finished populating metadata and associated file to the model:")
  print(export_path)
  print("The metadata json file has been saved to:")
  print(os.path.join(EXPORT_DIR,
                   os.path.splitext(model_file_basename)[0] + ".json"))

In [16]:
quantization = "other" #@param ["other", "fp16"]
tflite_model_path = "whitebox_cartoon_gan_dr.tflite" #@param ["whitebox_cartoon_gan_full_int8.tflite", "whitebox_cartoon_gan_int8.tflite", "whitebox_cartoon_gan_dr.tflite", "whitebox_cartoon_gan_fp16.tflite"]
MODEL_FILE = "/content/model_without_metadata/{}".format(tflite_model_path)
populate_metadata(MODEL_FILE, model_type=quantization)

Finished populating metadata and associated file to the model:
model_with_metadata/whitebox_cartoon_gan_dr.tflite
The metadata json file has been saved to:
model_with_metadata/whitebox_cartoon_gan_dr.json


In [17]:
from google.colab import auth as google_auth
google_auth.authenticate_user()

In [18]:
!gsutil -m cp -r model_with_metadata/* gs://cartoon_gan/model_with_metadata/

Copying file://model_with_metadata/whitebox_cartoon_gan_dr.json [Content-Type=application/json]...
/ [0/8 files][    0.0 B/  7.3 MiB]   0% Done                                    Copying file://model_with_metadata/whitebox_cartoon_gan_full_int8.tflite [Content-Type=application/octet-stream]...
/ [0/8 files][    0.0 B/  7.3 MiB]   0% Done                                    Copying file://model_with_metadata/whitebox_cartoon_gan_dr.tflite [Content-Type=application/octet-stream]...
/ [0/8 files][    0.0 B/  7.3 MiB]   0% Done                                    Copying file://model_with_metadata/whitebox_cartoon_gan_fp16.json [Content-Type=application/json]...
/ [0/8 files][    0.0 B/  7.3 MiB]   0% Done                                    Copying file://model_with_metadata/whitebox_cartoon_gan_fp16.tflite [Content-Type=application/octet-stream]...
/ [0/8 files][    0.0 B/  7.3 MiB]   0% Done                                    Copying file://model_with_metadata/whitebox_cartoon_gan_int