In [None]:
import io
from io import BytesIO
import os
from PIL import Image
import PIL.Image as pil_image
from google3.file.base.python import shards
from google3.pyglib import gfile
from google3.third_party.array_record.python import array_record_module
from google3.third_party.tensorflow.core.example import example_pb2

In [None]:
def image_to_bytes(image: pil_image.Image, image_format: str = 'PNG') -> bytes:
  """Converts a PIL image to bytes."""
  data = io.BytesIO()
  image.save(data, format=image_format)
  return data.getvalue()

In [None]:
# @title Write ArrayRecords with TF Example

OUTPUT_DIR = '/cns/dz-d/home/xliucs/sensorlm'
NUMBER_OF_SHARDS = 5
NUMBER_OF_IMAGES_PER_SHARD = 10
gfile.MakeDirs(os.path.join(OUTPUT_DIR))
filenames = shards.GenerateShardedFilenames(
    os.path.join(
        OUTPUT_DIR, f'dummy_{NUMBER_OF_SHARDS}.arrayrecord@{NUMBER_OF_SHARDS}'
    )
)
for shard_id in range(NUMBER_OF_SHARDS):
  path = filenames[shard_id]
  writer = array_record_module.ArrayRecordWriter(path)
  try:
    for _ in range(NUMBER_OF_IMAGES_PER_SHARD):
      value = example_pb2.Example()
      dummy_image = pil_image.new('RGB', (100, 100), color='red')
      value.features.feature['input_images'].bytes_list.value.append(
          image_to_bytes(dummy_image)
      )
      value.features.feature['input_texts'].bytes_list.value.append(
          'What is this image?'.encode()
      )
      value.features.feature['output_texts'].bytes_list.value.append(
          'All ones.'.encode()
      )
      writer.write(value.SerializeToString())
  finally:
    writer.close()

In [None]:
filenames

In [None]:
# @title Read ArrayRecords with TF Example

reader = array_record_module.ArrayRecordReader(filenames[0])
vqa_content = iter(example_pb2.Example.FromString(r) for r in reader.read_all())

def get_vqa_sample_data_point():
  """Returns tuple (question str, image bytes, answer str)."""
  vqa_example = next(vqa_content)
  return (
      vqa_example.features.feature["input_texts"]
      .bytes_list.value[0]
      .decode("utf-8"),
      vqa_example.features.feature["input_images"].bytes_list.value[0],
      vqa_example.features.feature["output_texts"]
      .bytes_list.value[0]
      .decode("utf-8"),
  )

In [None]:
input_text, input_image, output_text = get_vqa_sample_data_point()

print(f"Question: {input_text}")
display(Image.open(BytesIO(input_image)))
print(f"Response: {output_text}")