In [None]:
"""
Model definitions for CycleGAN-style training.

Contain:
- InstanceNorm, ReflectionPadding2D
- ResNet generator builder
- PatchGAN discriminator builder
"""
import tensorflow as tf
from tensorflow.keras import layers

class InstanceNorm(layers.Layer):
    def __init__(self, epsilon=1e-5, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon

    def build(self, input_shape):
        channel_dim = int(input_shape[-1])
        self.gamma = self.add_weight(
            name="gamma", shape=(channel_dim,), initializer="ones", trainable=True
        )
        self.beta = self.add_weight(
            name="beta", shape=(channel_dim,), initializer="zeros", trainable=True
        )

    def call(self, x):
        mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
        x_norm = (x - mean) / tf.sqrt(var + self.epsilon)
        return x_norm * self.gamma + self.beta


class ReflectionPadding2D(layers.Layer):
    def __init__(self, padding=(1,1), **kwargs):
        super().__init__(**kwargs)
        if isinstance(padding, int):
            self.padding = ((padding, padding),(padding,padding))
        elif isinstance(padding, (list,tuple)) and len(padding)==2 and isinstance(padding[0], int):
            self.padding = ((padding[0], padding[0]), (padding[1], padding[1]))
        else:
            self.padding = padding

    def call(self, x):
        pad_top, pad_bottom = self.padding[0]
        pad_left, pad_right = self.padding[1]
        paddings = [[0,0],[pad_top, pad_bottom],[pad_left, pad_right],[0,0]]
        return tf.pad(x, paddings, mode="REFLECT")

    def get_config(self):
        cfg = super().get_config()
        cfg.update({"padding": self.padding})
        return cfg


def residual_block(x_in, filters=256):
    x = ReflectionPadding2D(padding=1)(x_in)
    x = layers.Conv2D(filters, 3, padding="valid", use_bias=False)(x)
    x = InstanceNorm()(x)
    x = layers.ReLU()(x)
    x = ReflectionPadding2D(padding=1)(x)
    x = layers.Conv2D(filters, 3, padding="valid", use_bias=False)(x)
    x = InstanceNorm()(x)
    return layers.Add()([x_in, x])


def build_generator_resnet(input_shape=(256,256,3), n_res_blocks=9):
    inputs = layers.Input(shape=input_shape)
    x = ReflectionPadding2D(padding=3)(inputs)
    x = layers.Conv2D(64, 7, padding='valid', use_bias=False)(x)
    x = InstanceNorm()(x); x = layers.ReLU()(x)

    x = layers.Conv2D(128, 3, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNorm()(x); x = layers.ReLU()(x)

    x = layers.Conv2D(256, 3, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNorm()(x); x = layers.ReLU()(x)

    for _ in range(n_res_blocks):
        x = residual_block(x, filters=256)

    x = layers.UpSampling2D(size=2, interpolation='bilinear')(x)
    x = layers.Conv2D(128, 3, padding='same', use_bias=False)(x)
    x = InstanceNorm()(x); x = layers.ReLU()(x)

    x = layers.UpSampling2D(size=2, interpolation='bilinear')(x)
    x = layers.Conv2D(64, 3, padding='same', use_bias=False)(x)
    x = InstanceNorm()(x); x = layers.ReLU()(x)

    x = ReflectionPadding2D(padding=3)(x)
    x = layers.Conv2D(3, 7, padding='valid')(x)
    outputs = layers.Activation('tanh')(x)
    return tf.keras.Model(inputs, outputs, name='resnet_generator')


def build_patchgan_discriminator(input_shape=(256,256,3), n_filters=64):
    inp = layers.Input(shape=input_shape)
    x = layers.Conv2D(n_filters, 4, strides=2, padding='same')(inp)
    x = layers.LeakyReLU(0.2)(x)

    nf = n_filters * 2
    x = layers.Conv2D(nf, 4, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNorm()(x); x = layers.LeakyReLU(0.2)(x)

    nf *= 2
    x = layers.Conv2D(nf, 4, strides=2, padding='same', use_bias=False)(x)
    x = InstanceNorm()(x); x = layers.LeakyReLU(0.2)(x)

    nf *= 2
    x = layers.Conv2D(nf, 4, strides=1, padding='same', use_bias=False)(x)
    x = InstanceNorm()(x); x = layers.LeakyReLU(0.2)(x)

    x = layers.Conv2D(1, 4, strides=1, padding='same')(x)
    return tf.keras.Model(inp, x, name='patchgan_discriminator')

In [None]:
"""
Dataset utilities: TFRecord parsing and dataset builder
"""
import glob
import os
from pathlib import Path
import tensorflow as tf

AUTOTUNE = tf.data.AUTOTUNE
IMG_KEY = "image"

def parse_tf_example(serialized, image_key=IMG_KEY):
    features = {image_key: tf.io.FixedLenFeature([], tf.string)}
    parsed = tf.io.parse_single_example(serialized, features)
    return parsed[image_key]

def parse_and_preprocess(serialized, input_size=256, image_key=IMG_KEY):
    features = {image_key: tf.io.FixedLenFeature([], tf.string)}
    parsed = tf.io.parse_single_example(serialized, features)
    img = tf.image.decode_image(parsed[image_key], channels=3, expand_animations=False)
    img.set_shape([None, None, 3])
    img = tf.image.convert_image_dtype(img, tf.float32)  # [0,1]
    img = tf.image.resize(img, [input_size, input_size], method=tf.image.ResizeMethod.BILINEAR)
    img = img * 2.0 - 1.0  # to [-1,1]
    return img

def make_image_dataset(tfrecords, batch_size=1, input_size=256, shuffle=True, image_key=IMG_KEY, repeat=True):
    files = []
    for p in tfrecords:
        files += sorted(tf.io.gfile.glob(p))
    if not files:
        raise ValueError(f"No TFRecord files found for patterns: {tfrecords}")
    ds = tf.data.TFRecordDataset(files, num_parallel_reads=AUTOTUNE)
    if shuffle:
        ds = ds.shuffle(8192)
    ds = ds.map(lambda x: parse_and_preprocess(x, input_size=input_size, image_key=image_key), num_parallel_calls=AUTOTUNE)
    if repeat:
        ds = ds.repeat()
    ds = ds.batch(batch_size, drop_remainder=True)
    ds = ds.prefetch(AUTOTUNE)
    return ds

def count_tfrecord_examples(pattern):
    files = sorted(glob.glob(pattern))
    total = 0
    for f in files:
        for _ in tf.data.TFRecordDataset(f):
            total += 1
    return total

In [None]:
"""
Misc utilities: denormalization, sampling, warm TFRecord writer
"""
import os
import glob
from pathlib import Path
import tensorflow as tf
import pandas as pd

def denorm_img_tensor(x):
    """Tensor in [-1,1] -> uint8 numpy array"""
    x = (x + 1.0) * 127.5
    x = tf.clip_by_value(x, 0.0, 255.0)
    return tf.cast(x, tf.uint8).numpy()

def sample_tfrecords(patterns, sample_n=16, out_dir="samples", image_key="image"):
    files = []
    for p in patterns:
        files += sorted(glob.glob(p))
    if not files:
        raise ValueError("No TFRecord files found for patterns: " + ",".join(patterns))
    ds = tf.data.TFRecordDataset(files)
    ds = ds.map(lambda x: tf.io.parse_single_example(x, {image_key: tf.io.FixedLenFeature([], tf.string)})[image_key])
    ds = ds.take(sample_n)
    os.makedirs(out_dir, exist_ok=True)
    idx = 0
    for b in ds:
        img_bytes = b.numpy()
        p = Path(out_dir) / f"sample_{idx:03d}.jpg"
        p.write_bytes(img_bytes)
        idx += 1
    return idx

def write_warm_tfrecord_from_csv(feature_csv, photos_root, out_tfrecord, top_pct=10, image_key="image"):
    df = pd.read_csv(feature_csv)
    if 'min_monet_nn' not in df.columns:
        raise ValueError("feature CSV must contain 'min_monet_nn' column")
    df = df.dropna(subset=['min_monet_nn'])
    k = max(1, int(len(df) * (top_pct/100.0)))
    warm = df.sort_values('min_monet_nn').head(k)
    print(f"Selecting {len(warm)} photos (top {top_pct}%) for warm-start")
    writer = tf.io.TFRecordWriter(out_tfrecord)
    def _bytes_feature(b): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[b]))
    for _, row in warm.iterrows():
        p = Path(row['path'])
        if not p.exists():
            p = Path(photos_root) / p.name
        if not p.exists():
            print("WARN: missing file:", row['path'])
            continue
        img_bytes = p.read_bytes()
        feat = {
            image_key: _bytes_feature(img_bytes),
            "image/filename": _bytes_feature(str(p).encode('utf-8')),
            "min_monet_nn": tf.train.Feature(float_list=tf.train.FloatList(value=[float(row['min_monet_nn'])])),
        }
        example = tf.train.Example(features=tf.train.Features(feature=feat))
        writer.write(example.SerializeToString())
    writer.close()
    print("Wrote warm-start TFRecord:", out_tfrecord)

In [None]:
"""
Compact training entrypoint for Kaggle / local runs.

Usage (from notebook):
    from train import train_cyclegan
    train_cyclegan(...)
"""
import os
import time
import tensorflow as tf
from models import build_generator_resnet, build_patchgan_discriminator
from dataset import make_image_dataset
from utils import denorm_img_tensor
import numpy as np

# Loss helpers
def lsgan_d_loss(real_pred, fake_pred):
    real_pred = tf.cast(real_pred, tf.float32)
    fake_pred = tf.cast(fake_pred, tf.float32)
    return tf.reduce_mean((real_pred - 1.0) ** 2) + tf.reduce_mean((fake_pred) ** 2)

def lsgan_g_loss(fake_pred):
    fake_pred = tf.cast(fake_pred, tf.float32)
    return tf.reduce_mean((fake_pred - 1.0) ** 2)

def mae(x, y):
    x = tf.cast(x, tf.float32); y = tf.cast(y, tf.float32)
    return tf.reduce_mean(tf.abs(x - y))

class ImagePool:
    def __init__(self, max_size=50):
        self.max_size = max_size
        self.images = []

    def query(self, images):
        out = []
        images = list(tf.unstack(images))
        for img in images:
            if len(self.images) < self.max_size:
                self.images.append(img); out.append(img)
            else:
                if np.random.rand() > 0.5:
                    idx = np.random.randint(0, len(self.images))
                    tmp = self.images[idx]
                    self.images[idx] = img
                    out.append(tmp)
                else:
                    out.append(img)
        return tf.stack(out)

def train_cyclegan(
    monet_tfrecs=("data/monet_tfrec/*.tfrec",),
    photo_tfrecs=("warm/warm-00000-of-00001.tfrec",),
    image_size=128,
    batch_size=1,
    n_res_blocks=6,
    total_steps=20000,
    g_lr=2e-4,
    d_lr=2e-4,
    lambda_cycle=10.0,
    lambda_id=5.0,
    sample_dir="samples_train",
    log_dir="logs/cyclegan",
    ckpt_dir="ckpts/cyclegan",
    sample_every=500,
    ckpt_every=2000,
):
    os.makedirs(sample_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(ckpt_dir, exist_ok=True)

    ds_A = make_image_dataset(monet_tfrecs, batch_size=batch_size, input_size=image_size, shuffle=True)
    ds_B = make_image_dataset(photo_tfrecs, batch_size=batch_size, input_size=image_size, shuffle=True)
    it_A = iter(ds_A); it_B = iter(ds_B)

    G_AB = build_generator_resnet(input_shape=(image_size,image_size,3), n_res_blocks=n_res_blocks)
    G_BA = build_generator_resnet(input_shape=(image_size,image_size,3), n_res_blocks=n_res_blocks)
    D_A = build_patchgan_discriminator(input_shape=(image_size,image_size,3))
    D_B = build_patchgan_discriminator(input_shape=(image_size,image_size,3))

    print("G_AB params:", G_AB.count_params())
    print("G_BA params:", G_BA.count_params())
    print("D_A params:", D_A.count_params())
    print("D_B params:", D_B.count_params())

    g_opt = tf.keras.optimizers.Adam(g_lr, beta_1=0.5)
    d_opt = tf.keras.optimizers.Adam(d_lr, beta_1=0.5)

    pool_A = ImagePool(50); pool_B = ImagePool(50)

    summary_writer = tf.summary.create_file_writer(log_dir)
    ckpt = tf.train.Checkpoint(G_AB=G_AB, G_BA=G_BA, D_A=D_A, D_B=D_B, g_opt=g_opt, d_opt=d_opt)
    manager = tf.train.CheckpointManager(ckpt, directory=ckpt_dir, max_to_keep=5)
    latest = manager.latest_checkpoint
    if latest:
        ckpt.restore(latest)
        print("Restored from checkpoint:", latest)
    else:
        print("Training from scratch.")

    def denorm_img(x):
        return denorm_img_tensor(x)

    def train_step(real_A, real_B):
        real_A_f32 = tf.cast(real_A, tf.float32)
        real_B_f32 = tf.cast(real_B, tf.float32)
        with tf.GradientTape(persistent=True) as tape:
            fake_B = G_AB(real_A, training=True); fake_A = G_BA(real_B, training=True)
            cycled_A = G_BA(fake_B, training=True); cycled_B = G_AB(fake_A, training=True)
            same_A = G_BA(real_A, training=True); same_B = G_AB(real_B, training=True)
            pred_fake_A = D_A(fake_A, training=True); pred_fake_B = D_B(fake_B, training=True)
            g_adv_AB = lsgan_g_loss(pred_fake_B); g_adv_BA = lsgan_g_loss(pred_fake_A)
            cycle_A_loss = mae(real_A_f32, tf.cast(cycled_A, tf.float32))
            cycle_B_loss = mae(real_B_f32, tf.cast(cycled_B, tf.float32))
            id_A_loss = mae(real_A_f32, tf.cast(same_A, tf.float32))
            id_B_loss = mae(real_B_f32, tf.cast(same_B, tf.float32))
            g_loss_AB = g_adv_AB + lambda_cycle*(cycle_A_loss + cycle_B_loss)/2.0 + lambda_id*id_B_loss
            g_loss_BA = g_adv_BA + lambda_cycle*(cycle_A_loss + cycle_B_loss)/2.0 + lambda_id*id_A_loss
            total_g_loss = g_loss_AB + g_loss_BA

        g_vars = G_AB.trainable_variables + G_BA.trainable_variables
        g_grads = tape.gradient(total_g_loss, g_vars)
        g_grads_vars = [(g, v) for g, v in zip(g_grads, g_vars) if g is not None]
        g_opt.apply_gradients(g_grads_vars)

        fake_A_pool = pool_A.query(fake_A); fake_B_pool = pool_B.query(fake_B)
        with tf.GradientTape() as tape_d:
            pred_real_A = D_A(real_A, training=True); pred_real_B = D_B(real_B, training=True)
            pred_fake_A_pool = D_A(fake_A_pool, training=True); pred_fake_B_pool = D_B(fake_B_pool, training=True)
            d_A_loss = lsgan_d_loss(pred_real_A, pred_fake_A_pool); d_B_loss = lsgan_d_loss(pred_real_B, pred_fake_B_pool)
            d_loss = d_A_loss + d_B_loss
        d_vars = D_A.trainable_variables + D_B.trainable_variables
        d_grads = tape_d.gradient(d_loss, d_vars)
        d_grads_vars = [(g, v) for g, v in zip(d_grads, d_vars) if g is not None]
        d_opt.apply_gradients(d_grads_vars)

        return {"g_loss": total_g_loss, "d_loss": d_loss, "g_adv_AB": g_adv_AB, "g_adv_BA": g_adv_BA,
                "cycle_A": cycle_A_loss, "cycle_B": cycle_B_loss, "id_A": id_A_loss, "id_B": id_B_loss}

    start_time = time.time()
    for step in range(1, total_steps+1):
        real_A = next(it_A); real_B = next(it_B)
        metrics = train_step(real_A, real_B)
        if step % 50 == 0 or step == 1:
            print(f"[{step:06d}/{total_steps}] g={float(metrics['g_loss']):.4f} d={float(metrics['d_loss']):.4f} (time={(time.time()-start_time)/60.0:.1f} min)")
        if step % 50 == 0:
            with summary_writer.as_default():
                tf.summary.scalar("g_loss", metrics["g_loss"], step=step)
                tf.summary.scalar("d_loss", metrics["d_loss"], step=step)
        if step % sample_every == 0 or step == 1:
            a = real_A[0:1]; b = real_B[0:1]
            fake_B = G_AB(a, training=False); fake_A = G_BA(b, training=False)
            Image = __import__("PIL.Image").Image
            grid = tf.concat([denorm_img(a[0]), denorm_img(fake_B[0]), denorm_img(G_BA(fake_B, training=False)[0])], axis=1)
            Image.fromarray(grid.numpy()).save(os.path.join(sample_dir, f"step_{step:06d}_AtoB.png"))
        if step % ckpt_every == 0 or step == total_steps:
            ckpt_path = manager.save(); print("Checkpoint saved at:", ckpt_path)

    print("Training finished.")

In [None]:
"""
Inference utilities: restore checkpoint and generate Monet-style images.
"""
import os
import tensorflow as tf
from PIL import Image
from models import build_generator_resnet, build_patchgan_discriminator
from dataset import make_image_dataset
from utils import denorm_img_tensor

def load_cyclegan_for_inference(ckpt_dir="ckpts/cyclegan", image_size=128, n_res_blocks=6):
    G_AB = build_generator_resnet(input_shape=(image_size,image_size,3), n_res_blocks=n_res_blocks)
    G_BA = build_generator_resnet(input_shape=(image_size,image_size,3), n_res_blocks=n_res_blocks)
    D_A = build_patchgan_discriminator(input_shape=(image_size,image_size,3))
    D_B = build_patchgan_discriminator(input_shape=(image_size,image_size,3))
    g_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    d_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    ckpt = tf.train.Checkpoint(G_AB=G_AB, G_BA=G_BA, D_A=D_A, D_B=D_B, g_opt=g_opt, d_opt=d_opt)
    manager = tf.train.CheckpointManager(ckpt, directory=ckpt_dir, max_to_keep=5)
    latest = manager.latest_checkpoint
    if latest is None:
        raise RuntimeError(f"No checkpoint found in {ckpt_dir}")
    ckpt.restore(latest).expect_partial()
    print("Restored checkpoint:", latest)
    return G_BA  # photo -> Monet

def denorm_img(x):
    x = (x + 1.0) * 127.5
    x = tf.clip_by_value(x, 0.0, 255.0)
    return tf.cast(x, tf.uint8).numpy()

def generate_monet_from_tfrecords(photo_tfrecs=("data/photo_tfrec/*.tfrec",), ckpt_dir="ckpts/cyclegan",
                                  out_dir="generated_monet_from_tfrec", image_size=128, n_res_blocks=6,
                                  batch_size=1, max_batches=None):
    os.makedirs(out_dir, exist_ok=True)
    G_photo2monet = load_cyclegan_for_inference(ckpt_dir=ckpt_dir, image_size=image_size, n_res_blocks=n_res_blocks)
    ds_ph = make_image_dataset(photo_tfrecs, batch_size=batch_size, input_size=image_size, shuffle=False, repeat=False)
    idx = 0
    for batch_idx, real_B in enumerate(ds_ph):
        fake_A = G_photo2monet(real_B, training=False)
        for i in range(fake_A.shape[0]):
            img = denorm_img(fake_A[i])
            Image.fromarray(img).save(os.path.join(out_dir, f"monet_{idx:06d}.png"))
            idx += 1
        if (max_batches is not None) and (batch_idx + 1 >= max_batches):
            break
    print(f"Saved {idx} Monet-style images to {out_dir}")

def generate_monet_from_folder(input_dir="data/photo_jpg", ckpt_dir="ckpts/cyclegan", out_dir="generated_monet_from_folder",
                               image_size=128, n_res_blocks=6):
    os.makedirs(out_dir, exist_ok=True)
    G = load_cyclegan_for_inference(ckpt_dir=ckpt_dir, image_size=image_size, n_res_blocks=n_res_blocks)
    exts = (".jpg",".jpeg",".png",".bmp",".webp")
    paths = [os.path.join(input_dir,f) for f in sorted(os.listdir(input_dir)) if f.lower().endswith(exts)]
    for idx, p in enumerate(paths):
        img_bytes = tf.io.read_file(p)
        img = tf.image.decode_image(img_bytes, channels=3, expand_animations=False)
        img = tf.image.convert_image_dtype(img, tf.float32)
        img = tf.image.resize(img, [image_size, image_size])
        img = img * 2.0 - 1.0
        img = tf.expand_dims(img, 0)
        fake = G(img, training=False)
        out = denorm_img(fake[0])
        Image.fromarray(out).save(os.path.join(out_dir, f"monet_{idx:06d}.png"))
    print(f"Processed {len(paths)} images; outputs in {out_dir}")

In [None]:
# Kaggle inference cell: build model, restore checkpoint, generate 7k-10k 256x256 images and zip them.
import os, math, shutil, glob, time
import tensorflow as tf
from PIL import Image

# === Settings: change these to match your Kaggle dataset / checkpoint names ===
CKPT_PREFIX = "/kaggle/input/your-checkpoint-dataset/ckpts/cyclegan/ckpt-10"   # path prefix, no extension
PHOTO_TFRECS = "/kaggle/input/your-dataset/data/photo_tfrec/*.tfrec"          # or point to /kaggle/input/.../photo_jpg
OUT_DIR = "/kaggle/working/generated_monet"                                   # outputs must go to /kaggle/working
IMAGES_ZIP = "/kaggle/working/images"                                         # will create /kaggle/working/images.zip
IMAGE_SIZE = 256      # target size required by competition
BATCH_SIZE = 1
MAX_IMAGES = 10000    # stop if generated this many
MIN_IMAGES = 7000

# === Ensure output dir exists ===
os.makedirs(OUT_DIR, exist_ok=True)

# === Recreate model builder: import from your models.py if present, or re-define minimal generator builder here ===
# If you added models.py to the kernel, do: from models import build_generator_resnet
# For brevity, this example assumes you have build_generator_resnet available in the notebook environment.
# If not, import or paste your generator builder here.
from models import build_generator_resnet  # ensure models.py is in notebook files or working dir

# === Build generator at IMAGE_SIZE (works if architecture same) ===
# If the checkpoint was trained at 128x128 and you want to run at 256, rebuild with image_size=256 but same architecture.
GEN = build_generator_resnet(input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), n_res_blocks=6)  # adjust n_res_blocks as trained

# Create dummy optimizers only if using Checkpoint structure that included them (they are ignored by expect_partial)
g_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
d_opt = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

# Restore checkpoint (prefix without extension). This will use ckpt-10.index/.data-00000-of-00001 automatically.
ckpt = tf.train.Checkpoint(G_AB=None, G_BA=GEN, D_A=None, D_B=None, g_opt=g_opt, d_opt=d_opt)
# If you trained with G_BA name as photo->monet, restore into that slot (G_BA). Adjust names if different.
print("Restoring checkpoint:", CKPT_PREFIX)
ckpt.restore(CKPT_PREFIX).expect_partial()
print("Checkpoint restore finished.")

# === Dataset: either TFRecords or image folder ===
def parse_and_preprocess(serialized, input_size=IMAGE_SIZE, image_key="image"):
    features = {image_key: tf.io.FixedLenFeature([], tf.string)}
    parsed = tf.io.parse_single_example(serialized, features)
    img = tf.image.decode_image(parsed[image_key], channels=3, expand_animations=False)
    img.set_shape([None, None, 3])
    img = tf.image.convert_image_dtype(img, tf.float32)  # [0,1]
    img = tf.image.resize(img, [input_size, input_size], method="bilinear")
    img = img * 2.0 - 1.0
    return img

def make_image_dataset_from_tfrecs(pattern, batch_size=1, input_size=IMAGE_SIZE):
    files = sorted(glob.glob(pattern))
    if not files:
        raise RuntimeError("No TFRecord files found for pattern: " + pattern)
    ds = tf.data.TFRecordDataset(files, num_parallel_reads=tf.data.AUTOTUNE)
    ds = ds.map(lambda x: parse_and_preprocess(x, input_size=input_size), num_parallel_calls=tf.data.AUTOTUNE)
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

# If you prefer raw JPGs:
def make_image_dataset_from_folder(folder, batch_size=1, input_size=IMAGE_SIZE):
    exts = (".jpg", ".jpeg", ".png", ".bmp", ".webp")
    files = [os.path.join(folder, f) for f in sorted(os.listdir(folder)) if f.lower().endswith(exts)]
    if not files:
        raise RuntimeError("No image files in folder: " + folder)
    ds = tf.data.Dataset.from_tensor_slices(files)
    def _load(p):
        img = tf.io.read_file(p)
        img = tf.image.decode_image(img, channels=3, expand_animations=False)
        img = tf.image.convert_image_dtype(img, tf.float32)
        img = tf.image.resize(img, [input_size, input_size])
        img = img * 2.0 - 1.0
        return img
    ds = ds.map(_load, num_parallel_calls=tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds

# pick dataset source here:
use_tfrecs = True
if use_tfrecs:
    ds = make_image_dataset_from_tfrecs(PHOTO_TFRECS, batch_size=BATCH_SIZE, input_size=IMAGE_SIZE)
else:
    ds = make_image_dataset_from_folder("/kaggle/input/your-dataset/data/photo_jpg", batch_size=BATCH_SIZE, input_size=IMAGE_SIZE)

# === Generation loop ===
def denorm_and_save(tensor, out_path):
    img = (tensor + 1.0) * 127.5
    img = tf.clip_by_value(img, 0, 255)
    arr = tf.cast(img, tf.uint8).numpy()
    Image.fromarray(arr).save(out_path)

count = 0
start = time.time()
for batch in ds:
    fake = GEN(batch, training=False)       # outputs in [-1,1] shape [B,H,W,3]
    # if your GEN was trained at 128 and you built at 256, it still produces 256 sized output here.
    for i in range(fake.shape[0]):
        out_path = os.path.join(OUT_DIR, f"monet_{count:06d}.png")
        denorm_and_save(fake[i], out_path)
        count += 1
        if count >= MAX_IMAGES:
            break
    if count >= MAX_IMAGES:
        break
print(f"Generated {count} images in {time.time()-start:.1f}s")

# Validate count
if count < MIN_IMAGES:
    raise RuntimeError(f"Generated only {count} images; need at least {MIN_IMAGES} for submission")

# Zip outputs as images.zip at /kaggle/working/images.zip
zip_base = "/kaggle/working/images"
if os.path.exists(zip_base + ".zip"):
    os.remove(zip_base + ".zip")
shutil.make_archive(zip_base, 'zip', OUT_DIR)
print("Wrote images.zip ->", zip_base + ".zip")