In [None]:
pip install lensai-profiler

In [None]:
import tensorflow as tf
from lensai_profiler.metrics import process_batch
from lensai_profiler.sketches import Sketches
import os

# Download and prepare the dataset
_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = tf.keras.utils.image_dataset_from_directory(train_dir,
                                                            shuffle=True,
                                                            batch_size=BATCH_SIZE,
                                                            image_size=IMG_SIZE)

# Initialize sketches
num_channels = 3  # Assuming RGB images
sketches = Sketches(num_channels)

# Apply map function in parallel to compute metrics
train_dataset = train_dataset.map(
    lambda images, labels: process_batch(images),
    num_parallel_calls=tf.data.AUTOTUNE
)

# Iterate through the dataset and update the KLL sketches in parallel
for brightness, sharpness, channel_mean, snr, channel_pixels in train_dataset:
    sketches.tf_update_sketches(brightness, sharpness, channel_mean, snr, channel_pixels)

# Save the KLL sketches to a specified directory
save_path = '/content/sample_data/'
sketches.save_sketches(save_path)