In [1]:
import os
import struct
import numpy as np
import tensorflow as tf
from tqdm import tqdm

import numpy as np
import struct

def robust_fast_read_aedat(filename):
    """
    Skip all header lines starting with '#', then read data as big-endian int32 at once,
    and discard the remaining data to ensure reshape(-1,2) is possible.
    """
    # 1. Locate binary section
    with open(filename, 'rb') as f:
        while True:
            pos = f.tell()
            line = f.readline()
            if not line:
                raise ValueError(f"No binary data section found in {filename}")
            # Try decoding: if decodable and starts with '#', continue; otherwise it's data section
            try:
                text = line.decode('ascii')
                if text.lstrip().startswith('#'):
                    continue
            except UnicodeDecodeError:
                # Decoding failed, also considered as binary section
                pass
            # Return to start of this line
            f.seek(pos)
            break
        offset = f.tell()

    # 2. Read all int32 data at once
    raw = np.fromfile(filename, dtype='>i4', offset=offset)

    # 3. Discard extra int32 if length is odd
    if raw.size % 2 != 0:
        raw = raw[:-1]

    # 4. Reshape to separate address and timestamp
    events = raw.reshape(-1, 2)
    addresses = events[:, 0].astype(np.uint32)
    timestamps = events[:, 1]

    # 5. Vectorized extraction of x, y, polarity
    polarities = (addresses & 1).astype(np.uint8)
    xs = ((addresses >> 1) & 0x7F).astype(np.uint8)
    ys = ((addresses >> 8) & 0x7F).astype(np.uint8)

    return {
        'timestamps': timestamps,
        'xs': xs,
        'ys': ys,
        'polarities': polarities
    }

def read_aedat_file(filename):
    """Read a single .aedat file and return parsed data"""
    with open(filename, 'rb') as f:
        header_lines = []
        while True:
            pos = f.tell()
            line = f.readline()
            if not line:
                raise ValueError(f"No binary data section found in file {filename}, please check file format")

            try:
                decoded_line = line.decode('ascii', errors='strict')
            except UnicodeDecodeError:
                # Unable to decode ASCII, means this is the start of binary data
                f.seek(pos)
                break

            stripped_line = decoded_line.strip()
            if stripped_line.startswith('#'):
                header_lines.append(stripped_line)
            else:
                f.seek(pos)
                break

        data_start_index = f.tell()  # Data section start offset
        data = f.read()

    event_size = 8
    num_events = len(data) // event_size

    timestamps = []
    xs = []
    ys = []

    for i in range(num_events):
        event_data = data[i * event_size:(i + 1) * event_size]
        # Parse address and timestamp in big-endian order
        address, t = struct.unpack('>ii', event_data)
        x = (address >> 1) & 0x7F
        y = (address >> 8) & 0x7F

        xs.append(x)
        ys.append(y)
        timestamps.append(t)

    return {
        'header': header_lines,
        'timestamps': timestamps,
        'xs': xs,
        'ys': ys
    }

def process_and_save_event_count_tf(input_base_folder, output_base_folder, grid_size=(128, 128), num_time_bins=1, time_fraction=1):
    """
    Convert aedat files to event count grids and save as TFRecord format.
    
    Improvements: Use NumPy vectorized processing for event data, use np.histogramdd for fast statistics
    """
    if not os.path.exists(output_base_folder):
        os.makedirs(output_base_folder)

    # Iterate through category folders
    class_folders = [f for f in os.listdir(input_base_folder) if os.path.isdir(os.path.join(input_base_folder, f))]
    for class_folder in tqdm(class_folders, desc="Processing Categories", unit="category"):
        input_folder = os.path.join(input_base_folder, class_folder)
        output_folder = os.path.join(output_base_folder, class_folder)
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        # Iterate through .aedat files in current category
        files = [f for f in os.listdir(input_folder) if f.endswith('.aedat')]
        for filename in tqdm(files, desc=f"Processing {class_folder}", unit="file", leave=False):
            filepath = os.path.join(input_folder, filename)
            output_filepath = os.path.join(output_folder, filename.replace('.aedat', '.tfrecord'))

            # Read aedat data
            data = robust_fast_read_aedat(filepath)
            xs = data['xs']
            ys = data['ys']
            timestamps = data['timestamps']

            # Calculate time window boundaries
            t_min, t_max = np.min(timestamps), np.max(timestamps)
            # If time_fraction is not 1, only take the first 1/time_fraction of the time range
            if time_fraction > 1:
                t_max = t_min + (t_max - t_min) / time_fraction

            # First filter, only keep events in [t_min, limit]
            mask = (timestamps >= t_min) & (timestamps <= t_max)
            xs = xs[mask]
            ys = ys[mask]
            timestamps = timestamps[mask]

            # Then bin the data
            time_bin_edges = np.linspace(t_min, t_max, num=num_time_bins + 1)

            # Vectorized calculation of time indices for all events
            t_indices = np.searchsorted(time_bin_edges, timestamps, side='right') - 1

            # Construct event coordinate array, shape (N, 3) -> [x, y, t_index]
            event_coords = np.stack([xs, ys, t_indices], axis=1)

            # Define bin boundaries for each dimension (note: boundaries need one extra)
            bins = [np.arange(0, grid_size[0] + 1),
                    np.arange(0, grid_size[1] + 1),
                    np.arange(0, num_time_bins + 1)]
            # Use np.histogramdd to count event distribution, result shape (grid_size[0], grid_size[1], num_time_bins)
            event_count_grid, _ = np.histogramdd(event_coords, bins=bins)
            event_count_grid = event_count_grid.astype(np.int32)

            # If only one time window, copy single channel three times to form 3-channel data
            if num_time_bins == 1:
                event_count_grid = np.repeat(event_count_grid, 3, axis=-1)

            # Convert to TensorFlow tensor (for subsequent serialization)
            event_tensor = tf.convert_to_tensor(event_count_grid, dtype=tf.float32)

            # Serialize to TFRecord example
            serialized_example = serialize_example(event_tensor.numpy())
            with tf.io.TFRecordWriter(output_filepath) as writer:
                writer.write(serialized_example)

def serialize_example(event_grid):
    """
    Serialize event count grid to TFRecord format.
    Save two fields:
      - 'event_grid': event count grid data stored in bytes format
      - 'shape': grid shape information
    """
    feature = {
        'event_grid': tf.train.Feature(bytes_list=tf.train.BytesList(value=[event_grid.tobytes()])),
        'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=event_grid.shape))
    }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
# Parameter settings
input_base_folder = r"C:\Users\Lem17\Master Thesis\Data processing\data_aedat2"
output_base_folder = r"D:\Dataset\eventData_dataset\timeStack_1281281_tf"
num_time_bins = 1  # When set to 1, it will automatically be copied to 3 channels
time_fraction = 1
grid_size = (128, 128)

process_and_save_event_count_tf(input_base_folder, output_base_folder, grid_size=grid_size, num_time_bins=num_time_bins, time_fraction=time_fraction)

In [3]:
# Parameter settings
input_base_folder = r"D:\Dataset\eventData_dataset\data_aedat2"
base_output_folder = r"D:\Dataset\eventData_dataset"
num_time_bins = 1
fractions = [1, 3, 6, 12, 24, 48]
grid_size = (128, 128)

for time_fraction in fractions:
    # Construct output folder name
    output_folder = f"{base_output_folder}\\timeStack_1281281_1of{time_fraction}_tf"
    # Call main processing function
    process_and_save_event_count_tf(
        input_base_folder,
        output_folder,
        grid_size=grid_size,
        num_time_bins=num_time_bins,
        time_fraction=time_fraction
    )
    print(f"Finished: {output_folder}")

print("All processing finished.")

Processing Categories: 100%|██████████| 10/10 [03:47<00:00, 22.77s/category]


Finished: D:\Dataset\eventData_dataset\timeStack_1281281_1of1_tf


Processing Categories: 100%|██████████| 10/10 [01:37<00:00,  9.74s/category]


Finished: D:\Dataset\eventData_dataset\timeStack_1281281_1of3_tf


Processing Categories: 100%|██████████| 10/10 [01:34<00:00,  9.41s/category]


Finished: D:\Dataset\eventData_dataset\timeStack_1281281_1of6_tf


Processing Categories: 100%|██████████| 10/10 [01:24<00:00,  8.45s/category]


Finished: D:\Dataset\eventData_dataset\timeStack_1281281_1of12_tf


Processing Categories: 100%|██████████| 10/10 [01:19<00:00,  7.97s/category]


Finished: D:\Dataset\eventData_dataset\timeStack_1281281_1of24_tf


Processing Categories: 100%|██████████| 10/10 [01:15<00:00,  7.59s/category]

Finished: D:\Dataset\eventData_dataset\timeStack_1281281_1of48_tf
All processing finished.



