In [2]:
import tensorflow as tf
import json

def load_features_json(json_filepath):
    """
    Loads the JSON file containing the feature normalization parameters.

    The JSON file (brk_features.json) should have a structure like:
    {
      "action": {"mean": [...], "std": [...], "max": [...], "min": [...]},
      "proprio": {"mean": [...], "std": [...], "max": [...], "min": [...]}
    }
    """
    with open(json_filepath, 'r') as f:
        features = json.load(f)
    return features

def parse_tfrecord(example_proto):
    """
    Parses a single TFRecord example.

    This function assumes that each example should contain two fields:
    - "action": a fixed-length float array of 7 elements.
    - "proprio": a fixed-length float array of 7 elements.

    If a feature is missing, a default value of zeros is provided.
    """
    feature_description = {
        'action': tf.io.FixedLenFeature([7], tf.float32, default_value=[0.0]*7),
        'proprio': tf.io.FixedLenFeature([7], tf.float32, default_value=[0.0]*7)
    }
    parsed_features = tf.io.parse_single_example(example_proto, feature_description)
    return parsed_features

def normalize_features(features, norm_params):
    """
    Normalizes each feature in the example using the provided normalization parameters.

    For each key ("action" and "proprio"), the function subtracts the mean and divides
    by the standard deviation from the JSON file.
    """
    normalized = {}
    for key in ['action', 'proprio']:
        if key in features:
            mean = tf.constant(norm_params[key]["mean"], dtype=tf.float32)
            std = tf.constant(norm_params[key]["std"], dtype=tf.float32)
            normalized[key] = (features[key] - mean) / std
        else:
            tf.print(f"Warning: key '{key}' not found in the features.")
    return normalized

def main():
    # Update these file paths as necessary.
    tfrecord_filepath = '/content/bridge_dataset-train.tfrecord-00006-of-01024'
    json_filepath = '/content/brk_features.json'

    # Load normalization parameters from the JSON file.
    norm_params = load_features_json(json_filepath)

    # Create a TFRecordDataset and parse each example.
    raw_dataset = tf.data.TFRecordDataset(tfrecord_filepath)
    parsed_dataset = raw_dataset.map(parse_tfrecord)

    # Normalize each example using the loaded normalization parameters.
    normalized_dataset = parsed_dataset.map(lambda features: normalize_features(features, norm_params))

    # For demonstration, print the first 5 normalized examples.
    for i, example in enumerate(normalized_dataset.take(60)):
        print(f"Example {i}:")
        print("action:", example['action'].numpy())
        print("proprio:", example['proprio'].numpy())
        print('---')

if __name__ == '__main__':
    main()


Example 0:
action: [-0.02244844 -0.01509992  0.01618756  0.00212304  0.0050381  -0.00697939
 -1.1817025 ]
proprio: [-5.2183948  -0.31278124 -1.2672548  -0.02097114  0.4379158  -0.2251477
 -1.9972415 ]
---
Example 1:
action: [-0.02244844 -0.01509992  0.01618756  0.00212304  0.0050381  -0.00697939
 -1.1817025 ]
proprio: [-5.2183948  -0.31278124 -1.2672548  -0.02097114  0.4379158  -0.2251477
 -1.9972415 ]
---
Example 2:
action: [-0.02244844 -0.01509992  0.01618756  0.00212304  0.0050381  -0.00697939
 -1.1817025 ]
proprio: [-5.2183948  -0.31278124 -1.2672548  -0.02097114  0.4379158  -0.2251477
 -1.9972415 ]
---
Example 3:
action: [-0.02244844 -0.01509992  0.01618756  0.00212304  0.0050381  -0.00697939
 -1.1817025 ]
proprio: [-5.2183948  -0.31278124 -1.2672548  -0.02097114  0.4379158  -0.2251477
 -1.9972415 ]
---
Example 4:
action: [-0.02244844 -0.01509992  0.01618756  0.00212304  0.0050381  -0.00697939
 -1.1817025 ]
proprio: [-5.2183948  -0.31278124 -1.2672548  -0.02097114  0.4379158  -0.2

In [3]:
import os
import json
import pprint
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import tensorflow as tf

def example_to_dict(example_proto):
    """
    Converts a tf.train.Example proto into a Python dictionary.

    For each feature, if it's a bytes_list, we try to decode it as a UTF-8 string.
    If that string is valid JSON, we parse it into a Python object.
    Otherwise, for float_list or int64_list, we return a list of values.
    """
    result = {}
    for key, feature in example_proto.features.feature.items():
        # Handle bytes_list features.
        if feature.bytes_list.value:
            value = feature.bytes_list.value[0]
            try:
                value_decoded = value.decode('utf-8')
                try:
                    value_parsed = json.loads(value_decoded)
                    result[key] = value_parsed
                except Exception:
                    result[key] = value_decoded
            except Exception:
                result[key] = value
        elif feature.float_list.value:
            result[key] = list(feature.float_list.value)
        elif feature.int64_list.value:
            result[key] = list(feature.int64_list.value)
        else:
            result[key] = None
    return result

def load_array_record(file_path):
    """
    Loads the array-record file (TFRecord format) and converts each record into a dictionary.
    """
    dataset = tf.data.TFRecordDataset(filenames=[file_path])
    records = []
    for raw_record in dataset:
        example_proto = tf.train.Example.FromString(raw_record.numpy())
        record = example_to_dict(example_proto)
        records.append(record)
    return records

def process_flattened_images(record, image_key="steps/observation/image", save_dir="extracted_images"):
    """
    Processes image data from a flattened record structure.

    If the record contains a key (by default "steps/observation/image")
    whose value is a list/array of images (one per time step), then this function
    will iterate over the images, convert each to a PIL Image, and save them.

    Parameters:
      record (dict): The dictionary representation of a single record.
      image_key (str): The key to search for image data.
      save_dir (str): Directory where images will be saved.

    Returns:
      list: A list of PIL Image objects for the successfully processed images.
    """
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    images = []
    if image_key not in record:
        print(f"Error: '{image_key}' key not found in the record.")
        return images

    image_list = record[image_key]
    print(f"Found {len(image_list)} images under key '{image_key}'.")

    for idx, image_data in enumerate(image_list):
        # Convert image_data to a numpy array if it's not already one.
        if not isinstance(image_data, np.ndarray):
            try:
                image_data = np.array(image_data)
            except Exception as e:
                print(f"Error converting image data at index {idx}: {e}")
                continue

        # Debug: print shape and dtype.
        print(f"Image {idx}: shape {image_data.shape}, dtype {image_data.dtype}")

        try:
            # Create a PIL image.
            if image_data.ndim == 2:
                pil_img = Image.fromarray(image_data.astype(np.uint8), mode='L')
            elif image_data.ndim == 3:
                if image_data.shape[2] == 1:
                    pil_img = Image.fromarray(image_data.squeeze().astype(np.uint8), mode='L')
                elif image_data.shape[2] == 3:
                    pil_img = Image.fromarray(image_data.astype(np.uint8), mode='RGB')
                else:
                    raise ValueError(f"Unsupported channel count: {image_data.shape[2]}")
            else:
                raise ValueError(f"Unsupported image dimensions: {image_data.shape}")

            images.append(pil_img)
            img_filename = os.path.join(save_dir, f"image_{idx}.png")
            pil_img.save(img_filename)
            print(f"Saved image {idx} to {img_filename}")

        except Exception as e:
            print(f"Error processing image at index {idx}: {e}")
            if isinstance(image_data, np.ndarray):
                plt.imshow(image_data)
                plt.title(f"Image {idx} (debug)")
                plt.show()

    return images

def main():
    # Path to your array-record file (update as needed)
    file_path = "/content/bridge_dataset-train.tfrecord-00006-of-01024"

    # Load records from the file.
    records = load_array_record(file_path)
    if not records:
        print("No records found in the file.")
        return

    # For inspection, print keys and content of the first record.
    print("Keys in the first record:")
    pprint.pprint(records[0].keys())
    print("\nFirst record content (truncated):")
    # Truncate output for brevity.
    for key, value in records[0].items():
        print(f"{key}: {str(value)[:100]}{'...' if len(str(value))>100 else ''}")

    # Process images using the flattened key structure.
    # Adjust the key if your data uses a different naming convention.
    images = process_flattened_images(records[0], image_key="steps/observation/image")

    if not images:
        print("No images were extracted.")
    else:
        print(f"Extracted {len(images)} images.")

if __name__ == '__main__':
    main()


Keys in the first record:
dict_keys(['steps/language_embedding', 'steps/is_last', 'episode_metadata/has_image_3', 'steps/observation/image_2', 'episode_metadata/file_path', 'episode_metadata/has_image_2', 'steps/observation/image_1', 'steps/observation/image_3', 'steps/is_first', 'steps/action', 'steps/discount', 'steps/language_instruction', 'episode_metadata/has_language', 'episode_metadata/episode_id', 'steps/reward', 'steps/observation/image_0', 'episode_metadata/has_image_0', 'steps/observation/state', 'episode_metadata/has_image_1', 'steps/is_terminal'])

First record content (truncated):
steps/language_embedding: [0.012435083277523518, -0.02298649400472641, 0.03941458836197853, 0.04642587900161743, 0.02560792863...
steps/is_last: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
episode_metadata/has_image_3: [0]
steps/observation/image_2: b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x01,\x01,\x00\x00\xff\xdb\x00C\x00\x02\x01\x01\x01\x0...
episode_metadata/f

In [7]:
import os
import io
import tensorflow as tf
from PIL import Image

def save_images_from_record(record, output_dir="extracted_images"):
    """
    Searches for keys starting with "steps/observation/image_" in the record,
    decodes the JPEG image data from the byte strings, and saves each image to disk.

    Parameters:
      record (tf.train.Example): A parsed TFRecord example.
      output_dir (str): Directory to save the extracted images.
    """
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # Convert the tf.train.Example to a dictionary of features.
    features = record.features.feature
    image_keys = [key for key in features.keys() if key.startswith("steps/observation/image_")]

    for key in image_keys:
        # Each image is stored in a bytes_list. Get the first element.
        image_bytes = features[key].bytes_list.value[0]
        try:
            # Use a BytesIO buffer to load the image with PIL.
            image_buffer = io.BytesIO(image_bytes)
            img = Image.open(image_buffer)
            # Optionally, convert image to RGB (if not already)
            # img = img.convert('RGB')
            # Create a safe filename by replacing '/' with '_'
            safe_key = key.replace("/", "_")
            filename = os.path.join(output_dir, f"{safe_key}.png")
            img.save(filename)
            print(f"Saved {filename}")
        except Exception as e:
            print(f"Error processing {key}: {e}")

def main():
    # Update this to your TFRecord file path.
    tfrecord_path = "/content/bridge_dataset-train.tfrecord-00006-of-01024"

    dataset = tf.data.TFRecordDataset(filenames=[tfrecord_path])

    # Process the first record for demonstration.
    for idx, raw_record in enumerate(dataset.take(40)):
        try:
            record = tf.train.Example.FromString(raw_record.numpy())
            print(f"Processing record {idx}")
            save_images_from_record(record, output_dir="extracted_images")
        except Exception as e:
            print(f"Error parsing record {idx}: {e}")

if __name__ == '__main__':
    main()


Processing record 0
Saved extracted_images/steps_observation_image_2.png
Saved extracted_images/steps_observation_image_1.png
Saved extracted_images/steps_observation_image_3.png
Saved extracted_images/steps_observation_image_0.png
Processing record 1
Saved extracted_images/steps_observation_image_2.png
Saved extracted_images/steps_observation_image_1.png
Saved extracted_images/steps_observation_image_3.png
Saved extracted_images/steps_observation_image_0.png
Processing record 2
Saved extracted_images/steps_observation_image_2.png
Saved extracted_images/steps_observation_image_1.png
Saved extracted_images/steps_observation_image_3.png
Saved extracted_images/steps_observation_image_0.png
Processing record 3
Saved extracted_images/steps_observation_image_2.png
Saved extracted_images/steps_observation_image_1.png
Saved extracted_images/steps_observation_image_3.png
Saved extracted_images/steps_observation_image_0.png
Processing record 4
Saved extracted_images/steps_observation_image_2.png

In [None]:
import os
import io
import tensorflow as tf
from PIL import Image

def extract_image_from_record(record, key="steps/observation/image_0"):
    """
    Given a parsed tf.train.Example record, extract and decode the image
    under the specified key (default "steps/observation/image_0").

    Returns a PIL Image in RGB mode if successful, or None if not.
    """
    features = record.features.feature
    if key in features:
        image_bytes = features[key].bytes_list.value[0]
        try:
            image_buffer = io.BytesIO(image_bytes)
            img = Image.open(image_buffer)
            # Ensure consistent RGB mode for GIF frames.
            return img.convert("RGB")
        except Exception as e:
            print(f"Error processing image for key {key}: {e}")
            return None
    else:
        print(f"Key {key} not found in record.")
        return None

def main():
    # Update this path to your TFRecord file.
    tfrecord_path = "/content/bridge_dataset-train.tfrecord-00000-of-01024"
    dataset = tf.data.TFRecordDataset(filenames=[tfrecord_path])

    frames = []
    record_count = 0
    # Iterate over all records in the dataset.
    for raw_record in dataset:
        try:
            record = tf.train.Example.FromString(raw_record.numpy())
            img = extract_image_from_record(record, key="steps/observation/image_0")
            if img is not None:
                frames.append(img)
                record_count += 1
                print(f"Extracted image from record {record_count}")
        except Exception as e:
            print(f"Error parsing record: {e}")

    if frames:
        gif_path = "output.gif"
        # Save frames as an animated GIF.
        # duration=500 sets the time (ms) per frame; adjust as needed.
        frames[0].save(gif_path, save_all=True, append_images=frames[1:], duration=500, loop=0)
        print(f"GIF saved to {gif_path}")
    else:
        print("No frames extracted; GIF not created.")

if __name__ == '__main__':
    main()


Extracted image from record 1
Extracted image from record 2
Extracted image from record 3
Extracted image from record 4
Extracted image from record 5
Extracted image from record 6
Extracted image from record 7
Extracted image from record 8
Extracted image from record 9
Extracted image from record 10
Extracted image from record 11
Extracted image from record 12
Extracted image from record 13
Extracted image from record 14
Extracted image from record 15
Extracted image from record 16
Extracted image from record 17
Extracted image from record 18
Extracted image from record 19
Extracted image from record 20
Extracted image from record 21
Extracted image from record 22
Extracted image from record 23
Extracted image from record 24
Extracted image from record 25
Extracted image from record 26
Extracted image from record 27
Extracted image from record 28
Extracted image from record 29
Extracted image from record 30
Extracted image from record 31
Extracted image from record 32
Extracted image f