<a href="https://colab.research.google.com/github/matbutom/maquina-de-contrapropaganda/blob/main/copy_of_tensorflow_datasets.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%bash
# Crea primero el directorio base si no existe
mkdir -p /content/recortes_letras

# Crea los subdirectorios de A a Z
for i in {A..Z}; do
  mkdir -p /content/recortes_letras/$i
done

In [None]:
%%bash

# 1. Definir el URL base de tu repositorio de GitHub
REPO_URL="https://github.com/matbutom/maquina-de-contrapropaganda.git"
REPO_NAME="maquina-de-contrapropaganda"
TARGET_DIR="/content/recortes_letras"

echo "Clonando el repositorio completo ($REPO_NAME) en el entorno de Colab..."
# Clonar el repositorio
git clone $REPO_URL

# 2. Mover las carpetas con imágenes al directorio de trabajo (recortes_letras)
SOURCE_CONTENT="$REPO_NAME/recortes_letras/*"

echo "Moviendo el contenido de las carpetas de letras (A, B, C...) a $TARGET_DIR..."
# 'cp -r' copia recursivamente el contenido de las subcarpetas A-Z
cp -r $SOURCE_CONTENT $TARGET_DIR/

# 3. Limpiar el repositorio clonado (ya no se necesita)
echo "Limpiando el repositorio clonado..."
rm -rf $REPO_NAME

# 4. Verificación: Mostrar el contenido de la carpeta 'A' para confirmar que las imágenes se cargaron
echo "✅ ¡Carga completa! Verificando la carpeta 'A':"
ls -l $TARGET_DIR/A | head -n 5

In [None]:
!rm -rf ~/tensorflow_datasets/maquina_contrapropaganda


In [None]:
# ============================================================
# 🧩 Limpieza y redimensionado físico del dataset
# ============================================================

import os
from PIL import Image

base_dir = "/content/recortes_letras"
target_size = (64, 64)

for root, dirs, files in os.walk(base_dir):
    for f in files:
        if not f.lower().endswith((".jpg", ".jpeg", ".png")):
            continue
        path = os.path.join(root, f)
        try:
            im = Image.open(path).convert("RGB")
            im = im.resize(target_size, Image.LANCZOS)
            im.save(path)
        except Exception as e:
            print("⚠️ Error con", path, "→", e)

print("✅ Todas las imágenes fueron redimensionadas físicamente a 64×64 px.")


In [None]:
# ============================================================
# 🧩 Verificador de dataset — reconstruye solo si hay letras nuevas
# ============================================================

import os
import tensorflow_datasets as tfds

# ruta base donde están las letras (ajústala si usas Drive)
data_dir = '/content/recortes_letras'
builder_dir = os.path.expanduser('~/tensorflow_datasets/maquina_contrapropaganda')

# función auxiliar para listar carpetas válidas
def contar_carpetas(path):
    return sorted([d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))])

# carpetas actuales detectadas
carpetas_actuales = contar_carpetas(data_dir)
num_actual = len(carpetas_actuales)

# cuántas clases tenía el dataset anterior (si existe)
prev_num = 0
if os.path.exists(builder_dir):
    try:
        info = tfds.builder('maquina_contrapropaganda').info
        prev_num = info.features["label"].num_classes
    except Exception:
        pass

print(f"📦 Letras actuales detectadas: {carpetas_actuales}")
print(f"🧠 Dataset anterior: {prev_num} clases | Nuevo: {num_actual} clases")

# si hay nuevas letras, borrar dataset cacheado
if num_actual > prev_num:
    print("⚠️ Se detectaron nuevas letras. Regenerando dataset completo...")
    !rm -rf ~/tensorflow_datasets/maquina_contrapropaganda
else:
    print("✅ No hay cambios en las clases, se mantiene el dataset anterior.")


In [None]:
# ============================================================
# 🔍 Verificación física de tamaños reales en disco
# ============================================================

from PIL import Image
import os

base_dir = "/content/recortes_letras"
malas = []

for root, dirs, files in os.walk(base_dir):
    for f in files:
        if not f.lower().endswith((".jpg", ".jpeg", ".png")):
            continue
        path = os.path.join(root, f)
        try:
            with Image.open(path) as im:
                if im.size != (64, 64):
                    malas.append((path, im.size))
        except Exception as e:
            malas.append((path, "❌ error"))

print(f"Total de imágenes fuera de tamaño esperado: {len(malas)}")
for i, (p, s) in enumerate(malas[:10]):
    print(f"{i+1:02d}. {p} → {s}")


In [None]:
# ============================================================
# 📦 Custom Dataset — Máquina de Contrapropaganda
# ============================================================

import tensorflow_datasets as tfds
import tensorflow as tf
import os

_DESCRIPTION = """
Dataset visual para el proyecto 'Máquina de Contrapropaganda'.
Contiene letras recortadas clasificadas por carpeta (A–Z),
extraídas de carteles propagandísticos.
"""

_CITATION = """
@misc{rafita2025maquinacontrapropaganda,
  title={Máquina de Contrapropaganda Dataset},
  author={Arce, Mateo},
  year={2025},
  howpublished={Rafita Studio / Universidad de Chile}
}
"""

class MaquinaContrapropaganda(tfds.core.GeneratorBasedBuilder):
    VERSION = tfds.core.Version('1.0.0')

    def _info(self):
        return tfds.core.DatasetInfo(
            builder=self,
            description=_DESCRIPTION,
            features=tfds.features.FeaturesDict({
                "image": tfds.features.Image(shape=(None, None, 3)),
                "label": tfds.features.ClassLabel(names=[chr(i) for i in range(65, 91)])  # A–Z
            }),
            supervised_keys=("image", "label"),
            citation=_CITATION,
        )

    def _split_generators(self, dl_manager):
        data_dir = os.path.expanduser('/content/recortes_letras')
        return {"train": self._generate_examples(data_dir)}

    def _generate_examples(self, path):
        for label_name in sorted(os.listdir(path)):
            label_dir = os.path.join(path, label_name)
            if not os.path.isdir(label_dir):
                continue
            for img_name in os.listdir(label_dir):
                if img_name.lower().endswith((".jpg", ".png", ".jpeg")):
                    yield img_name, {
                        "image": os.path.join(label_dir, img_name),
                        "label": label_name,
                    }

# === Construcción del dataset ===
builder = MaquinaContrapropaganda()
builder.download_and_prepare()

ds = builder.as_dataset(split="train", as_supervised=True)

print("✅ Dataset cargado correctamente.")
print("Clases detectadas:", builder.info.features["label"].names)



In [None]:
# ============================================================
# ⚙️ Verificación y Configuración de GPU
# ============================================================

import tensorflow as tf

gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # 1. Configurar TensorFlow para que use la GPU
        tf.config.set_visible_devices(gpus[0], 'GPU')
        # 2. Habilitar el crecimiento de memoria para evitar errores de OOM
        tf.config.experimental.set_memory_growth(gpus[0], True)

        print(f"🚀 GPU detectada y configurada: {gpus[0].name}")
        print("El entrenamiento se ejecutará automáticamente en la GPU.")

    except RuntimeError as e:
        # Esto puede ocurrir si se configura después de que TensorFlow haya inicializado dispositivos
        print(f"⚠️ Error al configurar la GPU: {e}")
else:
    print("🐌 No se detectó ninguna GPU. Asegúrate de activar el entorno de ejecución (Runtime) T4/A100.")

In [None]:
# ============================================================
# 👁️ Visualización de ejemplos del dataset
# ============================================================

import matplotlib.pyplot as plt

for image, label in ds.take(9):
    plt.figure(figsize=(2, 2))
    plt.imshow(image)
    plt.title(builder.info.features["label"].int2str(label.numpy()))
    plt.axis("off")
plt.show()


In [None]:
# ============================================================
# 🛠️ Redimensionado físico forzado (solo las malas)
# ============================================================

from PIL import Image

for path, size in malas:
    try:
        im = Image.open(path).convert("RGB")
        im = im.resize((64, 64), Image.LANCZOS)
        im.save(path)
    except Exception as e:
        print("❌ No se pudo reparar:", path)

print("✅ Todas las imágenes malas fueron corregidas.")


In [None]:
# ============================================================
# 🧩 División automática del dataset en train / val / test
# ============================================================

import tensorflow as tf
import math

# tamaño total del dataset
total = sum(1 for _ in ds)
train_size = math.floor(total * 0.8)
val_size = math.floor(total * 0.1)
test_size = total - train_size - val_size

print(f"📊 Total de ejemplos: {total}")
print(f"🔹 Train: {train_size} | 🔸 Val: {val_size} | ⚪ Test: {test_size}")

# --- dividir usando el método take() y skip() ---
train_ds = ds.take(train_size)
val_ds = ds.skip(train_size).take(val_size)
test_ds = ds.skip(train_size + val_size)

# --- normalizar imágenes ---
AUTOTUNE = tf.data.AUTOTUNE

def preprocess(img, label):
    img = tf.image.convert_image_dtype(img, tf.float32)
    return img, label

train_ds = train_ds.map(preprocess).cache().shuffle(1000).batch(32).prefetch(AUTOTUNE)
val_ds = val_ds.map(preprocess).cache().batch(32).prefetch(AUTOTUNE)
test_ds = test_ds.map(preprocess).cache().batch(32).prefetch(AUTOTUNE)

print("✅ Datasets divididos y listos para entrenamiento.")


In [None]:
# ============================================================
# ✅ Comprobación de tamaño de batch y forma de imágenes
# ============================================================

for imgs, labels in train_ds.take(1):
    print("✅ batch shape:", imgs.shape)
    print("🔹 dtype:", imgs.dtype)
    print("🔹 rango de valores:", tf.reduce_min(imgs).numpy(), "→", tf.reduce_max(imgs).numpy())

    # muestra una de las imágenes para confirmar visualmente
    import matplotlib.pyplot as plt
    plt.imshow(imgs[0])
    plt.title(f"Ejemplo de imagen — tamaño {imgs[0].shape}")
    plt.axis("off")
    plt.show()


In [None]:
# ============================================================
# 🧩 Configuración general
# ============================================================

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

IMG_SIZE = 64
EPOCHS = 40

# ============================================================
# 🔧 Dataset sin etiquetas y con repetición infinita
# ============================================================

def ensure_valid_image(img):
    # normaliza y redimensiona cada imagen a 64x64
    img = tf.image.convert_image_dtype(img, tf.float32)
    img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
    return tf.ensure_shape(img, [IMG_SIZE, IMG_SIZE, 3])

train_ds_no_labels = (
    train_ds.unbatch()
    .map(lambda x, y: ensure_valid_image(x), num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(512)
    .batch(32)
    .repeat()
    .prefetch(tf.data.AUTOTUNE)
)

val_ds_no_labels = (
    val_ds.unbatch()
    .map(lambda x, y: ensure_valid_image(x), num_parallel_calls=tf.data.AUTOTUNE)
    .batch(32)
    .repeat()
    .prefetch(tf.data.AUTOTUNE)
)

print("✅ Datasets verificados:")
for imgs in train_ds_no_labels.take(1):
    print("train batch:", imgs.shape)
for imgs in val_ds_no_labels.take(1):
    print("val batch:", imgs.shape)


# ============================================================
# 🎨 VisualCallback corregido (seguro y estable)
# ============================================================

class VisualCallback(tf.keras.callbacks.Callback):
    def __init__(self, sample_batch, save_dir="/content/outputs", interval=5):
        super().__init__()
        self.sample_batch = sample_batch
        self.save_dir = save_dir
        self.interval = interval
        os.makedirs(save_dir, exist_ok=True)
        self.generated_images = [] # List to store generated images for GIF

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.interval != 0:
            return

        sample_imgs = self.sample_batch[:8]
        z_mean, z_log_var, z = self.model.encoder(sample_imgs)
        reconstructed = self.model.decoder(z)

        n = 8
        fig, axes = plt.subplots(2, n, figsize=(n * 1.5, 3))
        for i in range(n):
            axes[0, i].imshow(sample_imgs[i])
            axes[0, i].axis("off")
            axes[1, i].imshow(reconstructed[i])
            axes[1, i].axis("off")
        plt.tight_layout()

        # Save the figure as an image for later GIF creation
        path = os.path.join(self.save_dir, f"epoch_{epoch+1:03d}.png")
        plt.savefig(path)
        plt.close(fig)
        print(f"🌀 Letras alucinadas guardadas en: {path}")

        # Display the generated images live
        plt.figure(figsize=(n * 1.5, 3))
        for i in range(n):
             plt.subplot(2, n, i + 1)
             plt.imshow(sample_imgs[i])
             plt.axis("off")
             plt.subplot(2, n, i + n + 1)
             plt.imshow(reconstructed[i])
             plt.axis("off")
        plt.suptitle(f"Epoch {epoch+1}", fontsize=16)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.show()


        # Store the generated image batch for GIF creation
        self.generated_images.append(reconstructed.numpy())


# ============================================================
# ⚙️ Definición de pérdida del VAE
# ============================================================

# This loss function is no longer directly used by vae.fit because
# we define a custom train_step in the VAE model.
def vae_total_loss(y_true, y_pred):
    reconstruction_loss = tf.reduce_mean(
        tf.keras.losses.binary_crossentropy(y_true, y_pred)
    ) * IMG_SIZE * IMG_SIZE * 3
    # KL divergence loss is calculated in the train_step
    return reconstruction_loss # This will be combined with KL loss in train_step


# # ============================================================
# # 🧠 Entrenamiento del VAE (versión estable) - DEPRECATED
# # ============================================================

# # obtenemos un batch de muestra para el callback
# sample_batch = next(iter(train_ds_no_labels))

# vae = VAE(encoder, decoder)
# vae.compile(optimizer=tf.keras.optimizers.Adam(), loss=vae_total_loss)

# vae.fit(
#     train_ds_no_labels,
#     validation_data=val_ds_no_labels,
#     epochs=EPOCHS,
#     steps_per_epoch=50,
#     validation_steps=10,
#     callbacks=[VisualCallback(sample_batch)],
#     verbose=1
# )


# # ============================================================
# # 💾 Guardado de modelos entrenados - DEPRECATED
# # ============================================================

# decoder.save("/content/drive/MyDrive/maquina-de-contrapropaganda/models/decoder_solo.keras")
# encoder.save("/content/drive/MyDrive/maquina-de-contrapropaganda/models/encoder_solo.keras")
# vae.save("/content/drive/MyDrive/maquina-de-contrapropaganda/models/vae_completo.keras")

# print("✅ Modelos guardados correctamente en Drive.")

In [None]:
# ============================================================
# ⚙️ Función para preparar el Dataset por Letra
# ============================================================

# NOTA: Asegúrate de que 'builder' (de la celda MaquinaContrapropaganda) esté disponible.

def prepare_dataset_for_letter(base_ds, target_label_name):
    """
    Filtra, preprocesa y prepara el dataset para entrenar una sola letra.
    Retorna un dataset que solo contiene imágenes (sin etiquetas) y repetido.
    """
    # Usamos las variables globales definidas en la celda 'Configuración general'
    global IMG_SIZE
    BATCH_SIZE = 32 # O el tamaño de batch que estés usando
    AUTOTUNE = tf.data.AUTOTUNE

    # 1. Obtener el índice numérico (integer label) de la letra
    target_label_int = builder.info.features["label"].str2int(target_label_name)

    # 2. Filtrar el dataset base y quitar la etiqueta (y)
    ds_filtered = base_ds.filter(lambda x, y: tf.equal(y, target_label_int)).map(lambda x, y: x)

    # 3. Preprocesamiento final (normalizar a [0, 1] y asegurar forma)
    def ensure_valid_image(img):
        # Normalizar a [0, 1]
        img = tf.image.convert_image_dtype(img, tf.float32)
        # Redimensionar (si ya redimensionaste en PIL, este paso es redundante pero seguro)
        img = tf.image.resize(img, [IMG_SIZE, IMG_SIZE])
        return tf.ensure_shape(img, [IMG_SIZE, IMG_SIZE, 3])

    ds_final = (
        ds_filtered
        .map(ensure_valid_image, num_parallel_calls=AUTOTUNE)
        .shuffle(512)
        .batch(BATCH_SIZE)
        .repeat() # Es necesario que se repita para que VAE.fit se ejecute con steps_per_epoch
        .prefetch(AUTOTUNE)
    )

    return ds_final

print("✅ Función prepare_dataset_for_letter definida.")

In [None]:
# ============================================================
# 🧠 Definición del Encoder (versión estable)
# ============================================================

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf

def sampling(args):
    z_mean, z_log_var = args
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.random.normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Encoder network
# Input shape is IMG_SIZE x IMG_SIZE x 3 (64x64x3)
encoder_inputs = keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(256, activation="relu")(x)
z_mean = layers.Dense(LATENT_DIM, name="z_mean")(x)
z_log_var = layers.Dense(LATENT_DIM, name="z_log_var")(x)
z = layers.Lambda(sampling, name="z")([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name="encoder")
encoder.summary()

In [None]:
# ============================================================
# 🧠 Definición del Decoder (versión estable)
# ============================================================

# Decoder network
latent_inputs = keras.Input(shape=(LATENT_DIM,))
x = layers.Dense(8 * 8 * 64, activation="relu")(latent_inputs) # Adjusted dense layer output
x = layers.Reshape((8, 8, 64))(x) # Adjusted reshape
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
# Added another Conv2DTranspose layer to reach 64x64
x = layers.Conv2DTranspose(3, 3, activation="sigmoid", strides=2, padding="same")(x)
decoder_outputs = x # Final output
decoder = keras.Model(latent_inputs, decoder_outputs, name="decoder")
decoder.summary()

In [None]:
# ============================================================
# 🧠 Definición del VAE Model con train_step (versión estable)
# ============================================================

from tensorflow import keras
import tensorflow as tf

class VAE(keras.Model):
    def __init__(self, encoder, decoder, **kwargs):
        super().__init__(**kwargs)
        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        return reconstructed

    # Define the training step
    def train_step(self, data):
        # The dataset is yielding only images
        images = data

        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(images)
            reconstructed_images = self.decoder(z)

            # Calculate reconstruction loss
            reconstruction_loss = tf.reduce_mean(
                tf.keras.losses.binary_crossentropy(images, reconstructed_images)
            ) * IMG_SIZE * IMG_SIZE * 3  # Scale by image dimensions

            # Calculate KL divergence loss
            kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)

            # Total VAE loss
            total_loss = reconstruction_loss + kl_loss

        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))

        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

    # Define the test step for validation/evaluation
    def test_step(self, data):
        images = data

        z_mean, z_log_var, z = self.encoder(images)
        reconstructed_images = self.decoder(z)

        reconstruction_loss = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(images, reconstructed_images)
        ) * IMG_SIZE * IMG_SIZE * 3

        kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - tf.exp(z_log_var) + 1)

        total_loss = reconstruction_loss + kl_loss

        return {
            "loss": total_loss,
            "reconstruction_loss": reconstruction_loss,
            "kl_loss": kl_loss,
        }

In [None]:
# ============================================================
# 🧠 Entrenamiento del VAE (versión estable con train_step)
# ============================================================

# obtenemos un batch de muestra para el callback
sample_batch = next(iter(train_ds_no_labels))

# Instantiate the VAE model
vae = VAE(encoder, decoder)

# Compile the VAE (loss and metrics are handled in train_step)
vae.compile(optimizer=tf.keras.optimizers.Adam())

# Instantiate the VisualCallback
visual_callback = VisualCallback(sample_batch)

print("Starting VAE training...")
history = vae.fit(
    train_ds_no_labels,
    validation_data=val_ds_no_labels,
    epochs=EPOCHS,
    steps_per_epoch=50,
    validation_steps=10,
    callbacks=[visual_callback], # Use the instantiated callback
    verbose=1
)
print("VAE training finished.")

# ============================================================
# 🖼️ Generar GIF de la evolución de las letras
# ============================================================

import imageio

# Assuming the generated images are stored in visual_callback.generated_images
# Convert the list of numpy arrays to a format imageio can handle (list of images)
# Each element in generated_images is a batch (batch_size, 64, 64, 3)
# We need to select the images we want to include in the GIF, e.g., the first 8
gif_images = []
for batch in visual_callback.generated_images:
    # Take the first 8 images from each batch and convert to uint8
    gif_images.extend([np.uint8(img * 255) for img in batch[:8]])

# Save the GIF
gif_path = "/content/vae_evolution.gif"
imageio.mimsave(gif_path, gif_images, fps=1) # Adjust fps as needed

print(f"✅ GIF de la evolución guardado en: {gif_path}")

# Display the GIF in the notebook
from IPython.display import Image as IPyImage
IPyImage(open(gif_path,'rb').read())

In [None]:
# ============================================================
# 🖼️ Generar nuevas letras desde el espacio latente
# ============================================================

import numpy as np
import matplotlib.pyplot as plt
import imageio
import os

# Número de nuevas letras a generar
num_new_letters = 16 # Let's generate 16 new letters

# Directorio para guardar las imágenes generadas para el GIF
generate_dir = "/content/generated_letters"
os.makedirs(generate_dir, exist_ok=True)

# Lista para almacenar las imágenes generadas para el GIF
gif_frames = []

print(f"Generating {num_new_letters} new letters from the latent space...")

# Generate images over a few steps to simulate evolution for the GIF
num_generation_steps = 10 # Number of frames for the GIF per letter

# Sample latent vectors once
random_latent_vectors = tf.random.normal(shape=(num_new_letters, LATENT_DIM))

for step in range(num_generation_steps):
    # You could potentially add noise or interpolate here for a more dynamic GIF
    # For simplicity, we will just generate the final images repeatedly for the frames
    generated_images = decoder(random_latent_vectors).numpy()

    # Create a figure to display the generated images
    n = int(np.sqrt(num_new_letters))
    fig, axes = plt.subplots(n, n, figsize=(n * 2, n * 2))
    axes = axes.flatten()

    for i in range(num_new_letters):
        axes[i].imshow(generated_images[i])
        axes[i].axis("off")

    plt.suptitle(f"Generation Step {step+1}/{num_generation_steps}", fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap

    # Save the figure as an image frame for the GIF
    frame_path = os.path.join(generate_dir, f"generation_step_{step+1:03d}.png")
    plt.savefig(frame_path)
    plt.close(fig)

    # Append the generated images (as uint8) to the list for the GIF
    # We'll just take the first few for the GIF to keep it manageable
    gif_frames.append(np.uint8(generated_images[:min(num_new_letters, 16)] * 255))


print("Finished generating image frames.")

# Create the GIF from the saved frames
gif_path = "/content/new_letters_evolution.gif"

# Need to flatten the list of batches into a single list of images for imageio.mimsave
flat_gif_frames = [img for batch in gif_frames for img in batch]

imageio.mimsave(gif_path, flat_gif_frames, fps=5) # Adjust fps as needed

print(f"✅ GIF of new letter generation evolution saved to: {gif_path}")

# Display the GIF in the notebook
from IPython.display import Image as IPyImage
IPyImage(open(gif_path,'rb').read())

In [None]:
# ============================================================
# 🧠 Entrenamiento del VAE (versión estable con train_step)
# ============================================================

# obtenemos un batch de muestra para el callback
sample_batch = next(iter(train_ds_no_labels))

# Instantiate the VAE model
vae = VAE(encoder, decoder)

# Compile the VAE (loss and metrics are handled in train_step)
vae.compile(optimizer=tf.keras.optimizers.Adam())

# Instantiate the VisualCallback
visual_callback = VisualCallback(sample_batch)

print("Starting VAE training...")
history = vae.fit(
    train_ds_no_labels,
    validation_data=val_ds_no_labels,
    epochs=EPOCHS,
    steps_per_epoch=50,
    validation_steps=10,
    callbacks=[visual_callback], # Use the instantiated callback
    verbose=1
)
print("VAE training finished.")

# ============================================================
# 🖼️ Generar GIF de la evolución de las letras
# ============================================================

import imageio

# Assuming the generated images are stored in visual_callback.generated_images
# Convert the list of numpy arrays to a format imageio can handle (list of images)
# Each element in generated_images is a batch (batch_size, 64, 64, 3)
# We need to select the images we want to include in the GIF, e.g., the first 8
gif_images = []
for batch in visual_callback.generated_images:
    # Take the first 8 images from each batch and convert to uint8
    gif_images.extend([np.uint8(img * 255) for img in batch[:8]])

# Save the GIF
gif_path = "/content/vae_evolution.gif"
imageio.mimsave(gif_path, gif_images, fps=1) # Adjust fps as needed

print(f"✅ GIF de la evolución guardado en: {gif_path}")

# Display the GIF in the notebook
from IPython.display import Image as IPyImage
IPyImage(open(gif_path,'rb').read())