<a href="https://colab.research.google.com/github/matbutom/maquina-de-contrapropaganda/blob/main/copy_of_tensorflow_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TensorFlow Datasets

TFDS provides a collection of ready-to-use datasets for use with TensorFlow, Jax, and other Machine Learning frameworks.

It handles downloading and preparing the data deterministically and constructing a `tf.data.Dataset` (or `np.array`).

Note: Do not confuse [TFDS](https://www.tensorflow.org/datasets) (this library) with `tf.data` (TensorFlow API to build efficient data pipelines). TFDS is a high level wrapper around `tf.data`. If you're not familiar with this API, we encourage you to read [the official tf.data guide](https://www.tensorflow.org/guide/data) first.


Copyright 2018 The TensorFlow Datasets Authors, Licensed under the Apache License, Version 2.0

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/datasets/overview"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/overview.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/datasets/blob/master/docs/overview.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/datasets/docs/overview.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

## Installation

TFDS exists in two packages:

* `pip install tensorflow-datasets`: The stable version, released every few months.
* `pip install tfds-nightly`: Released every day, contains the last versions of the datasets.

This colab uses `tfds-nightly`:


In [None]:
#!pip install tfds-nightly tensorflow matplotlib apache_beam mlcroissant

In [None]:
#import matplotlib.pyplot as plt
#import numpy as np
#import tensorflow as tf

#import tensorflow_datasets as tfds

## Find available datasets

All dataset builders are subclass of `tfds.core.DatasetBuilder`. To get the list of available builders, use `tfds.list_builders()` or look at our [catalog](https://www.tensorflow.org/datasets/catalog/overview).

In [None]:
#tfds.list_builders()

## Load a dataset

### tfds.load

The easiest way of loading a dataset is `tfds.load`. It will:

1. Download the data and save it as [`tfrecord`](https://www.tensorflow.org/tutorials/load_data/tfrecord) files.
2. Load the `tfrecord` and create the `tf.data.Dataset`.


In [None]:
#ds = tfds.load('mnist', split='train', shuffle_files=True)
#assert isinstance(ds, tf.data.Dataset)
#print(ds)

Some common arguments:

*   `split=`: Which split to read (e.g. `'train'`, `['train', 'test']`, `'train[80%:]'`,...). See our [split API guide](https://www.tensorflow.org/datasets/splits).
*   `shuffle_files=`: Control whether to shuffle the files between each epoch (TFDS store big datasets in multiple smaller files).
*   `data_dir=`: Location where the dataset is saved (
defaults to `~/tensorflow_datasets/`)
*   `with_info=True`: Returns the `tfds.core.DatasetInfo` containing dataset metadata
*   `download=False`: Disable download


### tfds.builder

`tfds.load` is a thin wrapper around `tfds.core.DatasetBuilder`. You can get the same output using the `tfds.core.DatasetBuilder` API:

In [None]:
#builder = tfds.builder('mnist')
# 1. Create the tfrecord files (no-op if already exists)
#builder.download_and_prepare()
# 2. Load the `tf.data.Dataset`
#ds = builder.as_dataset(split='train', shuffle_files=True)
#print(ds)
#

### `tfds build` CLI

If you want to generate a specific dataset, you can use the [`tfds` command line](https://www.tensorflow.org/datasets/cli). For example:

```sh
tfds build mnist
```

See [the doc](https://www.tensorflow.org/datasets/cli) for available flags.

In [None]:
!tfds new datos

In [None]:
%cd datos/
%pwd

In [None]:
!ls

## Iterate over a dataset

### As dict

By default, the `tf.data.Dataset` object contains a `dict` of `tf.Tensor`s:

In [None]:
# ds = tfds.load('mnist', split='train')
# ds = ds.take(1)  # Only take a single example

# for example in ds:  # example is `{'image': tf.Tensor, 'label': tf.Tensor}`
#  print(list(example.keys()))
#  image = example["image"]
#  label = example["label"]
#  print(image.shape, label)

To find out the `dict` key names and structure, look at the dataset documentation in [our catalog](https://www.tensorflow.org/datasets/catalog/overview#all_datasets). For example: [mnist documentation](https://www.tensorflow.org/datasets/catalog/mnist).

### As tuple (`as_supervised=True`)

By using `as_supervised=True`, you can get a tuple `(features, label)` instead for supervised datasets.

In [None]:
# ds = tfds.load('mnist', split='train', as_supervised=True)
# ds = ds.take(1)

# for image, label in ds:  # example is (image, label)
#  print(image.shape, label)

### As numpy (`tfds.as_numpy`)

Uses `tfds.as_numpy` to convert:

*   `tf.Tensor` -> `np.array`
*   `tf.data.Dataset` -> `Iterator[Tree[np.array]]` (`Tree` can be arbitrary nested `Dict`, `Tuple`)



In [None]:
# ds = tfds.load('mnist', split='train', as_supervised=True)
# ds = ds.take(1)

# for image, label in tfds.as_numpy(ds):
 # print(type(image), type(label), label)

### As batched tf.Tensor (`batch_size=-1`)

By using `batch_size=-1`, you can load the full dataset in a single batch.

This can be combined with `as_supervised=True` and `tfds.as_numpy` to get the the data as `(np.array, np.array)`:


In [None]:
# image, label = tfds.as_numpy(tfds.load(
#    'mnist',
#   split='test',
#    batch_size=-1,
#    as_supervised=True,
# ))

# print(type(image), image.shape)

Be careful that your dataset can fit in memory, and that all examples have the same shape.

## Benchmark your datasets

Benchmarking a dataset is a simple `tfds.benchmark` call on any iterable (e.g. `tf.data.Dataset`, `tfds.as_numpy`,...).


In [None]:
# ds = tfds.load('mnist', split='train')
# ds = ds.batch(32).prefetch(1)

# tfds.benchmark(ds, batch_size=32)
# tfds.benchmark(ds, batch_size=32)  # Second epoch much faster due to auto-caching

* Do not forget to normalize the results per batch size with the `batch_size=` kwarg.
* In the summary, the first warmup batch is separated from the other ones to capture `tf.data.Dataset` extra setup time (e.g. buffers initialization,...).
* Notice how the second iteration is much faster due to [TFDS auto-caching](https://www.tensorflow.org/datasets/performances#auto-caching).
* `tfds.benchmark` returns a `tfds.core.BenchmarkResult` which can be inspected for further analysis.

### Build end-to-end pipeline

To go further, you can look:

*   Our [end-to-end Keras example](https://www.tensorflow.org/datasets/keras_example) to see a full training pipeline (with batching, shuffling,...).
*   Our [performance guide](https://www.tensorflow.org/datasets/performances) to improve the speed of your pipelines (tip: use `tfds.benchmark(ds)` to benchmark your datasets).


## Visualization

### tfds.as_dataframe

`tf.data.Dataset` objects can be converted to [`pandas.DataFrame`](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) with `tfds.as_dataframe` to be visualized on [Colab](https://colab.research.google.com).

* Add the `tfds.core.DatasetInfo` as second argument of `tfds.as_dataframe` to visualize images, audio, texts, videos,...
* Use `ds.take(x)` to only display the first `x` examples. `pandas.DataFrame` will load the full dataset in-memory, and can be very expensive to display.

In [None]:
# ds, info = tfds.load('mnist', split='train', with_info=True)

# tfds.as_dataframe(ds.take(4), info)

### tfds.show_examples

`tfds.show_examples` returns a `matplotlib.figure.Figure` (only image datasets supported now):

In [None]:
# ds, info = tfds.load('mnist', split='train', with_info=True)

# fig = tfds.show_examples(ds, info)

## Access the dataset metadata

All builders include a `tfds.core.DatasetInfo` object containing the dataset metadata.

It can be accessed through:

*   The `tfds.load` API:


In [None]:
# ds, info = tfds.load('mnist', with_info=True)

*   The `tfds.core.DatasetBuilder` API:

In [None]:
# builder = tfds.builder('mnist')
# info = builder.info

The dataset info contains additional informations about the dataset (version, citation, homepage, description,...).

In [None]:
# print(info)

### Features metadata (label names, image shape,...)

Access the `tfds.features.FeatureDict`:

In [None]:
# info.features

Number of classes, label names:

In [None]:
# print(info.features["label"].num_classes)
# print(info.features["label"].names)
# print(info.features["label"].int2str(7))  # Human readable version (8 -> 'cat')
# print(info.features["label"].str2int('7'))

Shapes, dtypes:

In [None]:
# print(info.features.shape)
# print(info.features.dtype)
# print(info.features['image'].shape)
# print(info.features['image'].dtype)

### Split metadata (e.g. split names, number of examples,...)

Access the `tfds.core.SplitDict`:

In [None]:
# print(info.splits)

Available splits:

In [None]:
# print(list(info.splits.keys()))

Get info on individual split:

In [None]:
# print(info.splits['train'].num_examples)
# print(info.splits['train'].filenames)
# print(info.splits['train'].num_shards)

It also works with the subsplit API:

In [None]:
# print(info.splits['train[15%:75%]'].num_examples)
# print(info.splits['train[15%:75%]'].file_instructions)

## Troubleshooting

### Manual download (if download fails)

If download fails for some reason (e.g. offline,...). You can always manually download the data yourself and place it in the `manual_dir` (defaults to `~/tensorflow_datasets/downloads/manual/`.

To find out which urls to download, look into:

 * For new datasets (implemented as folder): [`tensorflow_datasets/`](https://github.com/tensorflow/datasets/tree/master/tensorflow_datasets/)`<type>/<dataset_name>/checksums.tsv`. For example: [`tensorflow_datasets/datasets/bool_q/checksums.tsv`](https://github.com/tensorflow/datasets/blob/master/tensorflow_datasets/datasets/bool_q/checksums.tsv).

   You can find the dataset source location in [our catalog](https://www.tensorflow.org/datasets/catalog/overview).
 * For old datasets: [`tensorflow_datasets/url_checksums/<dataset_name>.txt`](https://github.com/tensorflow/datasets/tree/master/tensorflow_datasets/url_checksums)

### Fixing `NonMatchingChecksumError`

TFDS ensure determinism by validating the checksums of downloaded urls.
If `NonMatchingChecksumError` is raised, might indicate:

  * The website may be down (e.g. `503 status code`). Please check the url.
  * For Google Drive URLs, try again later as Drive sometimes rejects downloads when too many people access the same URL. See [bug](https://github.com/tensorflow/datasets/issues/1482)
  * The original datasets files may have been updated. In this case the TFDS dataset builder should be updated. Please open a new Github issue or PR:
     * Register the new checksums with `tfds build --register_checksums`
     * Eventually update the dataset generation code.
     * Update the dataset `VERSION`
     * Update the dataset `RELEASE_NOTES`: What caused the checksums to change ? Did some examples changed ?
     * Make sure the dataset can still be built.
     * Send us a PR

Note: You can also inspect the downloaded file in `~/tensorflow_datasets/download/`.


## Citation

If you're using `tensorflow-datasets` for a paper, please include the following citation, in addition to any citation specific to the used datasets (which can be found in the [dataset catalog](https://www.tensorflow.org/datasets/catalog/overview)).

```
@misc{TFDS,
  title = { {TensorFlow Datasets}, A collection of ready-to-use datasets},
  howpublished = {\url{https://www.tensorflow.org/datasets}},
}
```

In [None]:
!rm -rf ~/tensorflow_datasets/maquina_contrapropaganda


In [None]:
# ============================================================
# 🧩 Limpieza y redimensionado físico del dataset
# ============================================================

import os
from PIL import Image

base_dir = "/content/recortes_letras"
target_size = (64, 64)

for root, dirs, files in os.walk(base_dir):
    for f in files:
        if not f.lower().endswith((".jpg", ".jpeg", ".png")):
            continue
        path = os.path.join(root, f)
        try:
            im = Image.open(path).convert("RGB")
            im = im.resize(target_size, Image.LANCZOS)
            im.save(path)
        except Exception as e:
            print("⚠️ Error con", path, "→", e)

print("✅ Todas las imágenes fueron redimensionadas físicamente a 64×64 px.")


In [None]:
# ============================================================
# 🧩 Verificador de dataset — reconstruye solo si hay letras nuevas
# ============================================================

import os
import tensorflow_datasets as tfds

# ruta base donde están las letras (ajústala si usas Drive)
data_dir = '/content/recortes_letras'
builder_dir = os.path.expanduser('~/tensorflow_datasets/maquina_contrapropaganda')

# función auxiliar para listar carpetas válidas
def contar_carpetas(path):
    return sorted([d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))])

# carpetas actuales detectadas
carpetas_actuales = contar_carpetas(data_dir)
num_actual = len(carpetas_actuales)

# cuántas clases tenía el dataset anterior (si existe)
prev_num = 0
if os.path.exists(builder_dir):
    try:
        info = tfds.builder('maquina_contrapropaganda').info
        prev_num = info.features["label"].num_classes
    except Exception:
        pass

print(f"📦 Letras actuales detectadas: {carpetas_actuales}")
print(f"🧠 Dataset anterior: {prev_num} clases | Nuevo: {num_actual} clases")

# si hay nuevas letras, borrar dataset cacheado
if num_actual > prev_num:
    print("⚠️ Se detectaron nuevas letras. Regenerando dataset completo...")
    !rm -rf ~/tensorflow_datasets/maquina_contrapropaganda
else:
    print("✅ No hay cambios en las clases, se mantiene el dataset anterior.")


In [None]:
# ============================================================
# 🔍 Verificación física de tamaños reales en disco
# ============================================================

from PIL import Image
import os

base_dir = "/content/recortes_letras"
malas = []

for root, dirs, files in os.walk(base_dir):
    for f in files:
        if not f.lower().endswith((".jpg", ".jpeg", ".png")):
            continue
        path = os.path.join(root, f)
        try:
            with Image.open(path) as im:
                if im.size != (64, 64):
                    malas.append((path, im.size))
        except Exception as e:
            malas.append((path, "❌ error"))

print(f"Total de imágenes fuera de tamaño esperado: {len(malas)}")
for i, (p, s) in enumerate(malas[:10]):
    print(f"{i+1:02d}. {p} → {s}")


In [None]:
# ============================================================
# 📦 Custom Dataset — Máquina de Contrapropaganda
# ============================================================

import tensorflow_datasets as tfds
import tensorflow as tf
import os

_DESCRIPTION = """
Dataset visual para el proyecto 'Máquina de Contrapropaganda'.
Contiene letras recortadas clasificadas por carpeta (A–Z),
extraídas de carteles propagandísticos.
"""

_CITATION = """
@misc{rafita2025maquinacontrapropaganda,
  title={Máquina de Contrapropaganda Dataset},
  author={Arce, Mateo},
  year={2025},
  howpublished={Rafita Studio / Universidad de Chile}
}
"""

class MaquinaContrapropaganda(tfds.core.GeneratorBasedBuilder):
    VERSION = tfds.core.Version('1.0.0')

    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            description=_DESCRIPTION,
            features=tfds.features.FeaturesDict({
                "image": tfds.features.Image(shape=(None, None, 3)),
                "label": tfds.features.ClassLabel(names=[chr(i) for i in range(65, 91)])  # A–Z
            }),
            supervised_keys=("image", "label"),
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        data_dir = os.path.expanduser('/content/recortes_letras')
        return {"train": self._generate_examples(data_dir)}

    def _generate_examples(self, path):
        for label_name in sorted(os.listdir(path)):
            label_dir = os.path.join(path, label_name)
            if not os.path.isdir(label_dir):
                continue
            for img_name in os.listdir(label_dir):
                if img_name.lower().endswith((".jpg", ".png", ".jpeg")):
                    yield img_name, {
                        "image": os.path.join(label_dir, img_name),
                        "label": label_name,
                    }

# === Construcción del dataset ===
builder = MaquinaContrapropaganda()
builder.download_and_prepare()

ds = builder.as_dataset(split="train", as_supervised=True)

print("✅ Dataset cargado correctamente.")
print("Clases detectadas:", builder.info.features["label"].names)



In [None]:
# ============================================================
# 👁️ Visualización de ejemplos del dataset
# ============================================================

import matplotlib.pyplot as plt

for image, label in ds.take(9):
    plt.figure(figsize=(2, 2))
    plt.imshow(image)
    plt.title(builder.info.features["label"].int2str(label.numpy()))
    plt.axis("off")
plt.show()


In [None]:
# ============================================================
# 🛠️ Redimensionado físico forzado (solo las malas)
# ============================================================

from PIL import Image

for path, size in malas:
    try:
        im = Image.open(path).convert("RGB")
        im = im.resize((64, 64), Image.LANCZOS)
        im.save(path)
    except Exception as e:
        print("❌ No se pudo reparar:", path)

print("✅ Todas las imágenes malas fueron corregidas.")


In [None]:
# ============================================================
# 🧩 División automática del dataset en train / val / test
# ============================================================

import tensorflow as tf
import math

# tamaño total del dataset
total = sum(1 for _ in ds)
train_size = math.floor(total * 0.8)
val_size = math.floor(total * 0.1)
test_size = total - train_size - val_size

print(f"📊 Total de ejemplos: {total}")
print(f"🔹 Train: {train_size} | 🔸 Val: {val_size} | ⚪ Test: {test_size}")

# --- dividir usando el método take() y skip() ---
train_ds = ds.take(train_size)
val_ds = ds.skip(train_size).take(val_size)
test_ds = ds.skip(train_size + val_size)

# --- normalizar imágenes ---
AUTOTUNE = tf.data.AUTOTUNE

def preprocess(img, label):
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img, label

train_ds = train_ds.map(preprocess).cache().shuffle(1000).batch(32).prefetch(AUTOTUNE)
val_ds = val_ds.map(preprocess).cache().batch(32).prefetch(AUTOTUNE)
test_ds = test_ds.map(preprocess).cache().batch(32).prefetch(AUTOTUNE)

print("✅ Datasets divididos y listos para entrenamiento.")


In [None]:
# ============================================================
# ✅ Comprobación de tamaño de batch y forma de imágenes
# ============================================================

for imgs, labels in train_ds.take(1):
    print("✅ batch shape:", imgs.shape)
    print("🔹 dtype:", imgs.dtype)
    print("🔹 rango de valores:", tf.reduce_min(imgs).numpy(), "→", tf.reduce_max(imgs).numpy())

    # muestra una de las imágenes para confirmar visualmente
    import matplotlib.pyplot as plt
    plt.imshow(imgs[0])
    plt.title(f"Ejemplo de imagen — tamaño {imgs[0].shape}")
    plt.axis("off")
    plt.show()


In [None]:
# ============================================================
# 🧩 Configuración general
# ============================================================

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

IMG_SIZE = 64
EPOCHS = 40

# ============================================================
# 🔧 Dataset sin etiquetas y con repetición infinita
# ============================================================

def ensure_valid_image(img):
    # normaliza y redimensiona cada imagen a 64x64
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
    return tf.ensure_shape(img, [IMG_SIZE, IMG_SIZE, 3])

train_ds_no_labels = (
    train_ds.unbatch()
    .map(lambda x, y: ensure_valid_image(x), num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(512)
    .batch(32)
    .repeat()
    .prefetch(tf.data.AUTOTUNE)
)

val_ds_no_labels = (
    val_ds.unbatch()
    .map(lambda x, y: ensure_valid_image(x), num_parallel_calls=tf.data.AUTOTUNE)
    .batch(32)
    .repeat()
    .prefetch(tf.data.AUTOTUNE)
)

print("✅ Datasets verificados:")
for imgs in train_ds_no_labels.take(1):
    print("train batch:", imgs.shape)
for imgs in val_ds_no_labels.take(1):
    print("val batch:", imgs.shape)


# ============================================================
# 🎨 VisualCallback corregido (seguro y estable)
# ============================================================

class VisualCallback(tf.keras.callbacks.Callback):
    def __init__(self, sample_batch, save_dir="/content/outputs", interval=5):
        super().__init__()
        self.sample_batch = sample_batch
        self.save_dir = save_dir
        self.interval = interval
        os.makedirs(save_dir, exist_ok=True)
        self.generated_images = [] # List to store generated images for GIF

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.interval != 0:
            return

        sample_imgs = self.sample_batch[:8]
        z_mean, z_log_var, z = self.model.encoder(sample_imgs)
        reconstructed = self.model.decoder(z)

        n = 8
        fig, axes = plt.subplots(2, n, figsize=(n * 1.5, 3))
        for i in range(n):
            axes[0, i].imshow(sample_imgs[i])
            axes[0, i].axis("off")
            axes[1, i].imshow(reconstructed[i])
            axes[1, i].axis("off")
        plt.tight_layout()

        # Save the figure as an image for later GIF creation
        path = os.path.join(self.save_dir, f"epoch_{epoch+1:03d}.png")
        plt.savefig(path)
        plt.close(fig)
        print(f"🌀 Letras alucinadas guardadas en: {path}")

        # Display the generated images live
        plt.figure(figsize=(n * 1.5, 3))
        for i in range(n):
             plt.subplot(2, n, i + 1)
             plt.imshow(sample_imgs[i])
             plt.axis("off")
             plt.subplot(2, n, i + n + 1)
             plt.imshow(reconstructed[i])
             plt.axis("off")
        plt.suptitle(f"Epoch {epoch+1}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()


        # Store the generated image batch for GIF creation
        self.generated_images.append(reconstructed.numpy())


# ============================================================
# ⚙️ Definición de pérdida del VAE
# ============================================================

# This loss function is no longer directly used by vae.fit because
# we define a custom train_step in the VAE model.
def vae_total_loss(y_true, y_pred):
    reconstruction_loss = tf.reduce_mean(
        tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ) * IMG_SIZE * IMG_SIZE * 3
    # KL divergence loss is calculated in the train_step
    return reconstruction_loss # This will be combined with KL loss in train_step


# # ============================================================
# # 🧠 Entrenamiento del VAE (versión estable) - DEPRECATED
# # ============================================================

# # obtenemos un batch de muestra para el callback
# sample_batch = next(iter(train_ds_no_labels))

# vae = VAE(encoder, decoder)
# vae.compile(optimizer=tf.keras.optimizers.Adam(), loss=vae_total_loss)

# vae.fit(
#     train_ds_no_labels,
#     validation_data=val_ds_no_labels,
#     epochs=EPOCHS,
#     steps_per_epoch=50,
#     validation_steps=10,
#     callbacks=[VisualCallback(sample_batch)],
#     verbose=1
# )


# # ============================================================
# # 💾 Guardado de modelos entrenados - DEPRECATED
# # ============================================================

# decoder.save("/content/drive/MyDrive/maquina-de-contrapropaganda/models/decoder_solo.keras")
# encoder.save("/content/drive/MyDrive/maquina-de-contrapropaganda/models/encoder_solo.keras")
# vae.save("/content/drive/MyDrive/maquina-de-contrapropaganda/models/vae_completo.keras")

# print("✅ Modelos guardados correctamente en Drive.")

In [None]:
# ============================================================
# 🧠 Definición del Encoder (versión estable)
# ============================================================

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf

def sampling(args):
    z_mean, z_log_var = args
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.random.normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Encoder network
# Input shape is IMG_SIZE x IMG_SIZE x 3 (64x64x3)
encoder_inputs = keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
z_mean = layers.Dense(LATENT_DIM, name="z_mean")(x)
z_log_var = layers.Dense(LATENT_DIM, name="z_log_var")(x)
z = layers.Lambda(sampling, name="z")([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

In [None]:
# ============================================================
# 🧠 Definición del Decoder (versión estable)
# ============================================================

# Decoder network
latent_inputs = keras.Input(shape=(LATENT_DIM,))
x = layers.Dense(8 * 8 * 64, activation="relu")(latent_inputs) # Adjusted dense layer output
x = layers.Reshape((8, 8, 64))(x) # Adjusted reshape
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
# Added another Conv2DTranspose layer to reach 64x64
x = layers.Conv2DTranspose(3, 3, activation="sigmoid", strides=2, padding="same")(x)
decoder_outputs = x # Final output
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

In [None]:
# ============================================================
# 🧠 Definición del VAE Model con train_step (versión estable)
# ============================================================

from tensorflow import keras
import tensorflow as tf

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        return reconstructed

    # Define the training step
    def train_step(self, data):
        # The dataset is yielding only images
        images = data

        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(images)
            reconstructed_images = self.decoder(z)

            # Calculate reconstruction loss
            reconstruction_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(images, reconstructed_images)
            ) * IMG_SIZE * IMG_SIZE * 3  # Scale by image dimensions

            # Calculate KL divergence loss
            kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)

            # Total VAE loss
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

    # Define the test step for validation/evaluation
    def test_step(self, data):
        images = data

        z_mean, z_log_var, z = self.encoder(images)
        reconstructed_images = self.decoder(z)

        reconstruction_loss = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(images, reconstructed_images)
        ) * IMG_SIZE * IMG_SIZE * 3

        kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)

        total_loss = reconstruction_loss + kl_loss

        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

In [None]:
# ============================================================
# 🧠 Entrenamiento del VAE (versión estable con train_step)
# ============================================================

# obtenemos un batch de muestra para el callback
sample_batch = next(iter(train_ds_no_labels))

# Instantiate the VAE model
vae = VAE(encoder, decoder)

# Compile the VAE (loss and metrics are handled in train_step)
vae.compile(optimizer=tf.keras.optimizers.Adam())

# Instantiate the VisualCallback
visual_callback = VisualCallback(sample_batch)

print("Starting VAE training...")
history = vae.fit(
    train_ds_no_labels,
    validation_data=val_ds_no_labels,
    epochs=EPOCHS,
    steps_per_epoch=50,
    validation_steps=10,
    callbacks=[visual_callback], # Use the instantiated callback
    verbose=1
)
print("VAE training finished.")

# ============================================================
# 🖼️ Generar GIF de la evolución de las letras
# ============================================================

import imageio

# Assuming the generated images are stored in visual_callback.generated_images
# Convert the list of numpy arrays to a format imageio can handle (list of images)
# Each element in generated_images is a batch (batch_size, 64, 64, 3)
# We need to select the images we want to include in the GIF, e.g., the first 8
gif_images = []
for batch in visual_callback.generated_images:
    # Take the first 8 images from each batch and convert to uint8
    gif_images.extend([np.uint8(img * 255) for img in batch[:8]])

# Save the GIF
gif_path = "/content/vae_evolution.gif"
imageio.mimsave(gif_path, gif_images, fps=1) # Adjust fps as needed

print(f"✅ GIF de la evolución guardado en: {gif_path}")

# Display the GIF in the notebook
from IPython.display import Image as IPyImage
IPyImage(open(gif_path,'rb').read())

In [None]:
# ============================================================
# 🖼️ Generar nuevas letras desde el espacio latente
# ============================================================

import numpy as np
import matplotlib.pyplot as plt
import imageio
import os

# Número de nuevas letras a generar
num_new_letters = 16 # Let's generate 16 new letters

# Directorio para guardar las imágenes generadas para el GIF
generate_dir = "/content/generated_letters"
os.makedirs(generate_dir, exist_ok=True)

# Lista para almacenar las imágenes generadas para el GIF
gif_frames = []

print(f"Generating {num_new_letters} new letters from the latent space...")

# Generate images over a few steps to simulate evolution for the GIF
num_generation_steps = 10 # Number of frames for the GIF per letter

# Sample latent vectors once
random_latent_vectors = tf.random.normal(shape=(num_new_letters, LATENT_DIM))

for step in range(num_generation_steps):
    # You could potentially add noise or interpolate here for a more dynamic GIF
    # For simplicity, we will just generate the final images repeatedly for the frames
    generated_images = decoder(random_latent_vectors).numpy()

    # Create a figure to display the generated images
    n = int(np.sqrt(num_new_letters))
    fig, axes = plt.subplots(n, n, figsize=(n * 2, n * 2))
    axes = axes.flatten()

    for i in range(num_new_letters):
        axes[i].imshow(generated_images[i])
        axes[i].axis("off")

    plt.suptitle(f"Generation Step {step+1}/{num_generation_steps}", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap

    # Save the figure as an image frame for the GIF
    frame_path = os.path.join(generate_dir, f"generation_step_{step+1:03d}.png")
    plt.savefig(frame_path)
    plt.close(fig)

    # Append the generated images (as uint8) to the list for the GIF
    # We'll just take the first few for the GIF to keep it manageable
    gif_frames.append(np.uint8(generated_images[:min(num_new_letters, 16)] * 255))


print("Finished generating image frames.")

# Create the GIF from the saved frames
gif_path = "/content/new_letters_evolution.gif"

# Need to flatten the list of batches into a single list of images for imageio.mimsave
flat_gif_frames = [img for batch in gif_frames for img in batch]

imageio.mimsave(gif_path, flat_gif_frames, fps=5) # Adjust fps as needed

print(f"✅ GIF of new letter generation evolution saved to: {gif_path}")

# Display the GIF in the notebook
from IPython.display import Image as IPyImage
IPyImage(open(gif_path,'rb').read())

In [None]:
# ============================================================
# 🧠 Entrenamiento del VAE (versión estable con train_step)
# ============================================================

# obtenemos un batch de muestra para el callback
sample_batch = next(iter(train_ds_no_labels))

# Instantiate the VAE model
vae = VAE(encoder, decoder)

# Compile the VAE (loss and metrics are handled in train_step)
vae.compile(optimizer=tf.keras.optimizers.Adam())

# Instantiate the VisualCallback
visual_callback = VisualCallback(sample_batch)

print("Starting VAE training...")
history = vae.fit(
    train_ds_no_labels,
    validation_data=val_ds_no_labels,
    epochs=EPOCHS,
    steps_per_epoch=50,
    validation_steps=10,
    callbacks=[visual_callback], # Use the instantiated callback
    verbose=1
)
print("VAE training finished.")

# ============================================================
# 🖼️ Generar GIF de la evolución de las letras
# ============================================================

import imageio

# Assuming the generated images are stored in visual_callback.generated_images
# Convert the list of numpy arrays to a format imageio can handle (list of images)
# Each element in generated_images is a batch (batch_size, 64, 64, 3)
# We need to select the images we want to include in the GIF, e.g., the first 8
gif_images = []
for batch in visual_callback.generated_images:
    # Take the first 8 images from each batch and convert to uint8
    gif_images.extend([np.uint8(img * 255) for img in batch[:8]])

# Save the GIF
gif_path = "/content/vae_evolution.gif"
imageio.mimsave(gif_path, gif_images, fps=1) # Adjust fps as needed

print(f"✅ GIF de la evolución guardado en: {gif_path}")

# Display the GIF in the notebook
from IPython.display import Image as IPyImage
IPyImage(open(gif_path,'rb').read())