<a href="https://colab.research.google.com/github/bschulman/Algo-projects/blob/main/IRpix2pixmodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install, Import, Mount Drive


In [None]:
!pip install neptune
#!pip install tensorflow-addons
!pip install opencv-python
!pip install -U tensorboard_plugin_profile

In [None]:
import os
import sys
import logging
from google.colab import drive, userdata
import tensorflow as tf
from tensorflow.keras import layers, Model, Input, losses, optimizers, mixed_precision, models, applications, callbacks
try:
    import neptune
except ImportError:
    print("Neptune client not found. Please install with: pip install -U neptune")
    neptune = None
import time # For timing epochs
import datetime
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.layers import SpectralNormalization, BatchNormalization, GroupNormalization
from tensorflow.keras.applications import vgg19, VGG16
from tensorflow.keras.applications.vgg16 import preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import TensorBoard
import tensorflow.keras.backend as K
from tensorboard import notebook
import glob
import cv2
from PIL import Image
import pkgutil
print(pkgutil.find_loader("tensorboard_plugin_profile") is not None)

In [None]:
if os.path.exists("/content/drive"):
  pass
else:
  drive.mount('/content/drive')

## General config

In [None]:
log_dir = './logs/'  # Define the log directory
writer = tf.summary.create_file_writer(log_dir)  # Use tf.summary.create_file_writer

In [None]:
RUN_NAME = 'pix2pix-baselineD_BCE'
np.random.seed(42)

In [None]:
# --- Configuration for Models ---
INPUT_CHANNELS = 3 # Generator input channels (e.g., 3 for RGB)
OUTPUT_CHANNELS = 3 # Generator output channels (e.g., 3 for RGB)
IMG_HEIGHT = 256 # Defined in previous data pipeline setup
IMG_WIDTH = 256  # Defined in previous data pipeline setup
# Define if input is left or right half
INPUT_IS_LEFT_HALF = True # Assume [Photo | Drawing]
NORMALIZATION_RANGE = (-1, 1) # Assuming data normalized to [-1, 1]
AUGMENTATIONS = 'random_horizontal_flip + random_jitter'
TARGET_CHANNELS = 3
VAL_SPLIT = 0.1

In [None]:
GENERATOR_ARCHITECTURE = 'U-Net'
DISCRIMINATOR_ARCHITECTURE = 'PatchGAN'
GENERATOR_OUTPUT_ACTIVATION = 'tanh' # Assuming tanh for [-1, 1] output range
use_batchnorm = False # Set to False if InstanceNorm
NORMALIZATION_LAYER = 'BatchNormalization'
GENERATOR_FILTERS_INITIAL = 64 # Starting filter count in generator
DISCRIMINATOR_FILTERS_INITIAL = 64 # Starting filter count in discriminator
BUFFER_SIZE = 400 # For shuffling. Should be >= dataset size for perfect shuffle, but smaller works.
BATCH_SIZE = 32 # For Colab Pro

In [None]:
# Define hyperparameters based on common pix2pix settings
OPTIMIZER_TYPE = 'Adam'
WEIGHT_DECAY = 0.0 # No weight decay in Adam by default
LEARNING_RATE = 0.0004
DISCRIMINATOR_LR_MULTIPLIER = 0.5
BETA_1 = 0.5 # Beta1 commonly set to 0.5 for GANs using Adam for stability
BETA_2 = 0.999 # Default beta2 is usually fine
DISC_UPDATE = 1

In [None]:
# Loss
ADVERSARIAL_LOSS_TYPE = 'BCE_logits'
RECONSTRUCTION_LOSS_TYPE = 'perceptual + edge'
LAMBDA_L1 = 0
LAMBDA_EDGE = 30
LAMBDA_P = 0.01
LOSS_LAYERS = ['block3_conv3', 'block4_conv3']
LOSS_MODEL = VGG16
LAYER_WEIGHTS = [1.0,1.0]

In [None]:
#Neptune
NEPTUNE_PROJECT = "IR-MAPPINGS/overviews"
NEPTUNE_API_TOKEN = userdata.get('NEPTUNE_API_TOKEN')

In [None]:
# === PROFILER CONFIG: add at top, right after your imports ===
# which epochs to profile, or toggle on/off by creating/removing this file
PROFILE_EPOCHS     = {0, 1, 2, 3, 4, 5, 6, 7}
PROFILE_FLAG_FILE  = "/tmp/enable_profiling.flag"
profiling_active   = True

In [None]:
#Training
TOTAL_EPOCHS = 100
LR_SCHEDULE = 'fixed_100_then_linear_decay_100'
LR_DECAY = 1
LR_EPOCH = 5
# Define frequencies for actions within the loop
CHECKPOINT_SAVE_FREQ = 10 # Save checkpoint every N epochs
IMAGE_LOG_FREQ = 5      # Log validation images to Neptune every N epochs
CONSOLE_LOG_FREQ = 50   # Log average losses to console/file every N steps within an epoch
DESCRIPTION = f"Baseline Pix2Pix: {ADVERSARIAL_LOSS_TYPE}+{RECONSTRUCTION_LOSS_TYPE}(L_l1={LAMBDA_L1}), L_p={LAMBDA_P}, L_edge={LAMBDA_EDGE} Adam(LR={LEARNING_RATE}, B1={BETA_1})"

In [None]:
#Metrics
METRIC_LAYER_NAMES = ['block2_conv2', 'block3_conv3', 'block4_conv3']

In [None]:

# Use 'mixed_float16' for broad compatibility and speed (T4, V100, A100)
policy = mixed_precision.Policy('mixed_bfloat16')
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
bf16 = tf.bfloat16
#logger.info(f"Mixed precision policy set to: {mixed_precision.global_policy().name}")

In [None]:
vgg = VGG16(include_top=False, weights='imagenet', input_shape=(None, None, 3))
vgg.trainable = False
for layer in vgg.layers:
    layer.trainable = False # Double-check all layers are frozen
# Create a new model that outputs features from the selected layers
layer_outputs = [vgg.get_layer(name).output for name in LOSS_LAYERS]
vgg_loss_model = Model(inputs=vgg.input, outputs=layer_outputs, name='vgg_loss_model')
vgg_loss_model.trainable = False # Ensure the combined model is also not trainable
print("VGG Loss Model built and frozen.")

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = f"logs/gradient_tape/{current_time}"
tb_writer = tf.summary.create_file_writer(log_dir)

## Logging Config

In [None]:
# --- Configuration ---
LOG_LEVEL = logging.INFO # Set minimum level to log (e.g., INFO, DEBUG)
LOG_FORMAT = '%(asctime)s - %(levelname)s - [%(name)s] - %(message)s' # Example format
LOG_DATEFMT = '%Y-%m-%d %H:%M:%S'

# --- Log to a file in Google Drive ---
# For persistent logs across sessions or long runs
LOG_TO_FILE = True # Set to False to only log to Colab output
log_file_path = '/content/drive/MyDrive/IR_Mappings/pix2pix_training.log'

# --- Apply Configuration ---
log_handlers = [logging.StreamHandler(sys.stdout)] # Log to Colab output (stdout)
if LOG_TO_FILE:
    # Ensure directory exists if logging to file
    os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
    log_handlers.append(logging.FileHandler(log_file_path, mode='a')) # 'a' for append

# Use force=True if you might re-run this config cell in the same session (Python 3.8+)
# Otherwise, basicConfig only works the first time it's called.
logging.basicConfig(
    level=LOG_LEVEL,
    format=LOG_FORMAT,
    datefmt=LOG_DATEFMT,
    handlers=log_handlers,
    force=True # Uncomment if needed for re-running the cell
)

# --- Get specific logger ---
# Using a named logger allows for more granular control
logger = logging.getLogger('Pix2PixTrainer')

# --- Test Messages ---
logger.info("Logging configured. Starting training process...")
logger.debug("This is a DEBUG message - it will only show if LOG_LEVEL is set to DEBUG.")
logger.info("This is an INFO message.")
logger.warning("This is a WARNING message.")
logger.error("This is an ERROR message.")

# Example of logging an exception
try:
    x = 1 / 0
except ZeroDivisionError as e:
    logger.error(f"An error occurred: {e}", exc_info=True) # exc_info=True adds traceback

In [None]:
# --- TEMPORARY DEBUGGING FOR CONSOLE OUTPUT ---
LOG_LEVEL = logging.INFO # Use DEBUG to see everything
LOG_FORMAT = '{asctime} - {levelname} - {message}'
LOG_DATEFMT = '%H:%M:%S'

logging.basicConfig(
    level=LOG_LEVEL,
    format=LOG_FORMAT,
    datefmt=LOG_DATEFMT,
    handlers=[logging.StreamHandler(sys.stdout)], # ONLY console handler
    force=True, # Force reconfiguration
    style='{' # Use old-style formatting (optional)
)
logging.critical("--- CONSOLE-ONLY LOGGING TEST ---")
logger = logging.getLogger('Pix2PixTrainer') # Get logger AFTER config
logger.info("Info message (console-only)")
logger.debug("Debug message (console-only)")
# --- END OF TEMPORARY DEBUGGING BLOCK ---

In [None]:
# # Run this in the cell AFTER the simplified basicConfig
# logging.info("Direct root logger INFO message.")
# logging.critical("Direct root logger CRITICAL message.")

## Tf/GPU config

In [None]:
logging.info(f"TensorFlow version: {tf.__version__}")

# Check for GPU availability
gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
    logging.info(f"GPU detected: {gpu_devices}")
    # Optional: Set memory growth to prevent TF from allocating all GPU memory upfront
    try:
        for gpu in gpu_devices:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        logging.error(e)
    logical_gpus = tf.config.list_logical_devices('GPU')
    #logging.info(len(gpu_devices), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    # Usually computation happens on '/GPU:0'
else:
    logging.info(f"No GPU found. Using CPU.")


# List physical devices (GPUs and CPUs)
physical_devices = tf.config.list_physical_devices()
print("All Physical Devices:", physical_devices)

# List GPUs specifically
gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
    print("GPUs Available:", gpu_devices)
else:
    print("No GPUs found. Using CPU.")

## Directories

In [None]:
# Define base directory and subdirectories
base_data_dir = '/content/drive/MyDrive/IR_Mappings/IR_pix2pix/'
concat_image_dir = os.path.join(base_data_dir, 'concatenated_images')
CONCAT_IMAGE_DIR = concat_image_dir
checkpoint_dir = '/content/drive/MyDrive/IR_Mappings/IR_pix2pix/training_checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)
logger.info(f"Checkpoint directory set to: {checkpoint_dir}")

## Test Loading

In [None]:
# --- Load one example concatenated image ---
# '''
# try:
#     # List files to help find an example (optional)
#     print("Sample concatenated files:", os.listdir(concat_image_dir)[:5])

#     example_concat_name = 'profiles_i1_concatenated.png' # REPLACE with a real concatenated filename
#     example_concat_path = os.path.join(concat_image_dir, example_concat_name)

#     # Load the concatenated image using PIL
#     # Load as RGB first, can handle diagram channels later if needed
#     img_concat = Image.open(example_concat_path).convert("RGB")

#     # --- !!! Crucial Step: Split the Image !!! ---
#     width, height = img_concat.size
#     if width == 512 and height == 256:
#         # Format is [PHOTO | DRAWING] (Input | Target) ***

#         # PIL's crop box is (left, upper, right, lower)
#         img_photo = img_concat.crop((0, 0, 256, 256))      # Left half (width 0 to 256)
#         img_diagram = img_concat.crop((256, 0, 512, 256))   # Right half (width 256 to 512)

#         # Optional: Convert diagram to grayscale now if desired,
#         # otherwise handle channels during TF preprocessing
#         # img_diagram = img_diagram.convert("L")

#     else:
#         print(f"Error: Unexpected image dimensions {img_concat.size}. Expected 256x512.")
#         # Handle error appropriately, maybe skip this file in a real pipeline
#         raise ValueError("Incorrect image dimensions")

#     # Display the split parts
#     fig, ax = plt.subplots(1, 2, figsize=(10, 5))
#     ax[0].imshow(img_photo)
#     ax[0].set_title(f"Input Photo (Split Left)")
#     ax[0].axis('off')

#     # Use grayscale cmap if diagram was converted to 'L' mode above
#     ax[1].imshow(img_diagram, cmap='gray' if img_diagram.mode == 'L' else None)
#     ax[1].set_title(f"Target Diagram (Split Right)")
#     ax[1].axis('off')
#     plt.show()

#     print(f"Photo size: {img_photo.size}, Mode: {img_photo.mode}")
#     print(f"Diagram size: {img_diagram.size}, Mode: {img_diagram.mode}")

# except FileNotFoundError:
#     print(f"Error: Example file '{example_concat_name}' not found in '{concat_image_dir}'. Please check path and filename.")
# except Exception as e:
#     print(f"An error occurred: {e}")

## Set Up Pipeline

In [None]:
# --- Define Loading and Preprocessing Function ---
def load_image(image_file):
    image = tf.io.read_file(image_file)
    # Decode. Use decode_png or decode_jpeg depending on your file type.
    # decode_image handles both but might be slower. Ensure 3 channels for RGB.
    image = tf.image.decode_png(image, channels=3)
    # Ensure output type is float32 for calculations
    image = tf.cast(image, tf.float32)
    return image
def random_crop(image, dim):
    height, width, _ = dim
    x, y = np.random.uniform(low=0,high=int(height-256)), np.random.uniform(low=0,high=int(width-256))
    return image[:, int(x):int(x)+256, int(y):int(y)+256]

@tf.function
def random_jittering(input_image, target_image, height=286, width=286):
    # --- Resize using tf.image ---
    # Stack images along the batch dimension temporarily to apply resize consistently
    # Note: tf.image.resize expects batch dimension or single image.
    # If input_image, target_image are single images (H,W,C), add batch dim.
    # If they already have batch dim (B,H,W,C), this might need adjustment.
    # Assuming single images from context of preprocess_image_train map:
    input_image = tf.expand_dims(input_image, axis=0) # Add batch dim: (1, H, W, C)
    target_image = tf.expand_dims(target_image, axis=0)

    # Resize (choose method like 'nearest' or 'bilinear')
    # tf.image.resize works on batches
    resized_input = tf.image.resize(input_image, [height, width], method='nearest')
    resized_target = tf.image.resize(target_image, [height, width], method='nearest')

    # --- Random Crop using tf.image ---
    # Stack resized images to apply the *same* random crop
    stacked_images = tf.stack([resized_input[0], resized_target[0]], axis=0) # Stack along new axis 0: (2, H', W', C)

    # Apply random crop to the stack
    # cropped_images shape will be (2, IMG_HEIGHT, IMG_WIDTH, C)
    cropped_images = tf.image.random_crop(stacked_images, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    # Unstack back into input and target images
    input_image_cropped = cropped_images[0]
    target_image_cropped = cropped_images[1]

    # Set shapes explicitly (good practice after operations that might lose shape info)
    input_image_cropped.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])
    target_image_cropped.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])

    return input_image_cropped, target_image_cropped

def preprocess_image_train(image):
    # image shape is (256, 512, 3) initially
    logger.debug(f"Preprocessing input image with shape: {tf.shape(image)}") # Use logger

    # --- 1. Split into Input and Real Image FIRST ---
    width = tf.shape(image)[1] // 2 # Should be 256
    if INPUT_IS_LEFT_HALF:
        # Original is [Photo | Drawing]
        input_image = image[:, :width, :]  # Input Photo (Left Half)
        real_image = image[:, width:, :]   # Target Diagram (Right Half)
        logger.debug("Split assuming Photo on Left, Drawing on Right.")
    else:
        # Original is [Drawing | Photo]
        real_image = image[:, :width, :]   # Target Diagram (Left Half)
        input_image = image[:, width:, :]  # Input Photo (Right Half)
        logger.debug("Split assuming Drawing on Left, Photo on Right.")

    # Ensure shapes after split before augmentation
    input_image.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])
    real_image.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])

    # --- 2. Augmentation (Applied Consistently After Splitting) ---
    # Decide whether to flip ONCE for the pair
    flip_condition = tf.random.uniform(()) > 0.5
    logger.debug(f"Flip condition evaluated: {flip_condition}") # Log the decision

    # Apply the SAME flip condition to both images using tf.cond
    # tf.cond(predicate, function_if_true, function_if_false)
    input_image = tf.cond(flip_condition,
                          lambda: tf.image.flip_left_right(input_image),
                          lambda: input_image)
    real_image = tf.cond(flip_condition,
                         lambda: tf.image.flip_left_right(real_image),
                         lambda: real_image)
    if flip_condition:
        logger.debug("Applied flip_left_right to both input and real image.")

    input_image, real_image = random_jittering(input_image, real_image)
    # --- 3. Normalization [-1, 1] ---
    input_image = (input_image / 127.5) - 1.0
    real_image = (real_image / 127.5) - 1.0
    logger.debug("Normalized both images to [-1, 1].")

    # Final shape check (optional)
    # input_image.set_shape([IMG_HEIGHT, IMG_WIDTH, 3]) # Already set after split
    # real_image.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])

    return input_image, real_image

# Define a preprocessing function *without* random augmentation
def preprocess_image_val(image):
    # Split
    width = tf.shape(image)[1] // 2
    if INPUT_IS_LEFT_HALF:
        input_image = image[:, :width, :]
        real_image = image[:, width:, :]
    else:
        real_image = image[:, :width, :]
        input_image = image[:, width:, :]
    input_image.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])
    real_image.set_shape([IMG_HEIGHT, IMG_WIDTH, 3])
    # Normalize ONLY
    input_image = (input_image / 127.5) - 1.0
    real_image = (real_image / 127.5) - 1.0
    return input_image, real_image

logger.info("Setting up tf.data pipeline...")

# --- Create Dataset from File Paths ---
try:
    # 1. Get the list of ALL file paths
    all_files_pattern = os.path.join(CONCAT_IMAGE_DIR, '*.png') # Or appropriate pattern
    all_files = tf.io.gfile.glob(all_files_pattern)
    if not all_files:
         raise ValueError(f"No files found matching pattern: {all_files_pattern}")
    dataset_size = len(all_files)
    logger.info(f"Found {dataset_size} total image files.")

    # 2. Shuffle the LIST of file paths
    np.random.seed(42) # Optional: for reproducible shuffles/splits
    np.random.shuffle(all_files) # Shuffle the Python list in-place
    logger.info("Shuffled file list.")

    # 3. Calculate split sizes
    num_val_samples = int(dataset_size * VAL_SPLIT)
    num_train_samples = dataset_size - num_val_samples
    logger.info(f"Splitting into {num_train_samples} training files and {num_val_samples} validation files.")

    # 4. Split the LIST using slicing
    train_files = all_files[:num_train_samples]
    val_files = all_files[num_train_samples:]

    # 5. Create separate tf.data.Dataset objects from the file lists
    train_dataset_paths = tf.data.Dataset.from_tensor_slices(train_files)
    val_dataset_paths = tf.data.Dataset.from_tensor_slices(val_files)

    # --- Now build the actual data pipelines ---
    logger.info("Building train_dataset pipeline...")
    # Adjust buffer size based on train set size
    BUFFER_SIZE = min(BUFFER_SIZE, num_train_samples)

    train_dataset = train_dataset_paths.shuffle(BUFFER_SIZE) # Shuffle train data each epoch
    train_dataset = train_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    train_dataset = train_dataset.map(preprocess_image_train, num_parallel_calls=tf.data.AUTOTUNE) # Has augmentation
    train_dataset = train_dataset.batch(BATCH_SIZE)
    train_dataset = train_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    logger.info("Training dataset pipeline built successfully.")

    logger.info("Building val_dataset pipeline...")
    val_dataset = val_dataset_paths.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
    # Use a non-augmenting preprocessing function for validation
    val_dataset = val_dataset.map(preprocess_image_val, num_parallel_calls=tf.data.AUTOTUNE) # Assumes preprocess_image_val exists
    val_dataset = val_dataset.batch(BATCH_SIZE) # Use same batch size
    val_dataset = val_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    logger.info("Validation dataset pipeline built successfully.")

    # --- Update: Fetch fixed validation batch from val_dataset ---
    logger.info("Fetching fixed validation batch for visualization from val_dataset...")
    fixed_val_input, fixed_val_target = None, None
    try:
        val_examples_to_take = min(BATCH_SIZE, 16)
        # Take from the start of the (unshuffled) validation set
        fixed_val_batch_dataset = val_dataset.take(1).unbatch().take(val_examples_to_take).batch(val_examples_to_take)
        fixed_val_input, fixed_val_target = next(iter(fixed_val_batch_dataset))
        logger.info(f"Fixed validation batch shapes: Input {fixed_val_input.shape}, Target {fixed_val_target.shape}")
    except Exception as e:
         logger.error(f"Could not get validation batch: {e}", exc_info=True)
         logger.warning("Proceeding without fixed validation data for image logging.")
    # dataset = tf.data.Dataset.list_files(os.path.join(CONCAT_IMAGE_DIR, '*.png'))
    # # Count items (optional but good for sanity check)
    # dataset_size = tf.data.experimental.cardinality(dataset).numpy()
    # if dataset_size == 0:
    #       raise ValueError(f"No files found matching pattern in {CONCAT_IMAGE_DIR}")
    # logger.info(f"Found {dataset_size} image files.")
    # num_val_samples = int(dataset_size * VAL_SPLIT)
    # num_train_samples = dataset_size - num_val_samples
    # np.random.shuffle(dataset)
    # train_files = dataset[:num_train_samples]
    # val_files = dataset[num_train_samples:]
    # logger.info(f"Split into {len(train_files)} training files and {len(val_files)} validation files.")
    # # Set buffer size based on dataset size if needed
    # BUFFER_SIZE = min(BUFFER_SIZE, len(train_files))

except Exception as e:
    logger.error(f"Error creating dataset from files: {e}", exc_info=True)
    # Stop execution or handle error
    raise





# val_dataset = tf.data.Dataset.from_tensor_slices(val_files) # Use val_files list
# val_dataset = val_dataset.map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
# val_dataset = val_dataset.map(preprocess_image_val, num_parallel_calls=tf.data.AUTOTUNE) # Use non-augmenting preprocess
# val_dataset = val_dataset.batch(BATCH_SIZE) # Use same batch size for efficiency
# val_dataset = val_dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
# logger.info("Validation dataset pipeline built successfully.")

## Calculate Step-Based Parameters

In [None]:
# === LEARNING RATE SCHEDULE SETUP ===
logger.info("Calculating parameters for Learning Rate Schedule...")

# Parameters from your config block [cite: 4, 120]
initial_learning_rate = LEARNING_RATE
epochs = TOTAL_EPOCHS
batch_size = BATCH_SIZE
# dataset_size should be available from the data pipeline setup [cite: 16]

# Calculate steps
# try:
#     # Estimate steps per epoch. Use ceiling division for robustness.
#     steps_per_epoch = (dataset_size + batch_size - 1) // batch_size
#     # Determine the epoch and step where decay starts
#     decay_start_epoch = epochs // 2
#     decay_start_step = steps_per_epoch * decay_start_epoch
#     # Determine the total number of steps over which decay occurs
#     decay_epochs = epochs - decay_start_epoch
#     total_decay_steps = steps_per_epoch * decay_epochs

#     logger.info(f"VERIFY: dataset_size = {dataset_size}")
#     logger.info(f"VERIFY: BATCH_SIZE = {BATCH_SIZE}")
#     logger.info(f"VERIFY: EPOCHS = {epochs}") # Or TOTAL_EPOCHS
#     logger.info(f"VERIFY: steps_per_epoch = {steps_per_epoch}")
#     logger.info(f"VERIFY: decay_start_step = {decay_start_step}")
#     logger.info(f"VERIFY: total_decay_steps = {total_decay_steps}")
#     # Sanity check: decay_start_step + total_decay_steps should equal total steps in training
#     total_steps_calculated = steps_per_epoch * epochs
#     logger.info(f"VERIFY: Total calculated steps = {total_steps_calculated}")
#     logger.info(f"VERIFY: Decay ends at step = {decay_start_step + total_decay_steps}")


# except NameError as e:
#     logger.error(f"Missing required variable for LR schedule calculation (dataset_size?): {e}", exc_info=True)
#     raise
# except Exception as e:
#      logger.error(f"Error calculating LR schedule steps: {e}", exc_info=True)
#      raise

## Build Pipeline

In [None]:
# --- Test the pipeline ---
logger.info("Fetching one batch element to inspect its structure...")
try:
    # Fetch one element (which is one batch) WITHOUT unpacking it in the loop definition
    for batch_element in train_dataset.take(1):
        logger.info(f"Successfully fetched one batch element.")

        # --- Inspect the Structure ---
        logger.info(f"Type of yielded batch element: {type(batch_element)}")
        element_len = -1 # Default value
        if isinstance(batch_element, (tuple, list)):
              element_len = len(batch_element)
              logger.info(f"Number of items in batch element tuple/list: {element_len}")
              # Log shapes of individual items
              for i, item in enumerate(batch_element):
                  if hasattr(item, 'shape'): # Check if it has a shape attribute (like Tensor)
                      logger.info(f"    Item {i} shape: {item.shape}")
                  else:
                      logger.info(f"    Item {i} type: {type(item)}")
        elif hasattr(batch_element, 'shape'):
              # If the element itself has a shape (e.g., a single tensor was yielded)
              logger.info(f"Batch element shape (element is not tuple/list): {batch_element.shape}")
        else:
              logger.info(f"Batch element is not a tuple/list or tensor-like.")

        # --- Conditional Unpacking and Visualization (Only if structure seems correct) ---
        if isinstance(batch_element, (tuple, list)) and element_len == 2:
            logger.info("Structure seems correct (tuple/list of size 2). Proceeding with unpacking and visualization.")
            example_input, example_target = batch_element # Unpack now

            logger.info(f"  Input batch shape: {example_input.shape}")
            logger.info(f"  Target batch shape: {example_target.shape}")

            # Verify the data range is approximately [-1, 1]
            logger.info(f"  Input min value: {tf.reduce_min(example_input).numpy():.2f}")
            plt.figure(figsize=(6, 6))
            plt.subplot(1, 2, 1)
            # De-normalize for display: (img + 1) * 127.5
            plt.imshow((example_input[0] + 1) / 2.0)
            plt.title("Sample Input")
            plt.axis("off")
            plt.subplot(1, 2, 2)
            # Display target - might need grayscale if it's B&W (but loaded as 3ch here)
            plt.imshow((example_target[0] + 1) / 2.0)
            plt.title("Sample Target")
            plt.axis("off")
            plt.show()
            logger.info("Pipeline test batch values checked and visualized successfully.")
        else:
            logger.error(f"Pipeline yielding unexpected structure! Expected tuple/list of 2 items, but got type {type(batch_element)} with length {element_len if element_len != -1 else 'N/A'}.")

except Exception as e:
    logger.error(f"Error fetching or inspecting dataset element: {e}", exc_info=True)
    # raise # Optionally re-raise

In [None]:
logger.info("Mapping functions and building the final training dataset pipeline...")
try:
    # Assumes 'dataset' holds the output of tf.data.Dataset.list_files(...)
    logger.info(f"Listing files from: {CONCAT_IMAGE_DIR}")
    # Create the initial dataset object from list_files
    # This defines the 'dataset' variable needed below
    dataset = tf.data.Dataset.list_files(os.path.join(CONCAT_IMAGE_DIR, '*.png')) # Or appropriate pattern
    dataset_size = tf.data.experimental.cardinality(dataset).numpy()
    if dataset_size == 0:
          raise ValueError(f"No files found matching pattern in {CONCAT_IMAGE_DIR}")
    logger.info(f"Found {dataset_size} image files.")

    # Set buffer size based on dataset size if needed (BUFFER_SIZE is
    # defined before this block)
    # BUFFER_SIZE exists from  config section
    BUFFER_SIZE = min(BUFFER_SIZE, dataset_size)


    logger.info(f"Initial dataset element spec: {dataset.element_spec}")

    # --- DEBUG: Check elements AFTER list_files (should be file paths) ---
    logger.info("--- Inspecting elements AFTER list_files ---")
    for item_path in dataset.take(2): # Look at first 2 paths
         logger.info(f"  Element type: {type(item_path)}, Value: {item_path.numpy().decode()}") # Decode string tensor

    # 1. Map the loading function
    # Takes file paths -> loads raw image tensors
    train_dataset_loaded = dataset.map(load_image, # Use 'dataset' here
                                      num_parallel_calls=tf.data.AUTOTUNE)
    logger.info(f"After load_image map, element spec: {train_dataset_loaded.element_spec}")
    # --- DEBUG: Check elements AFTER load_image map ---
    logger.info("--- Inspecting elements AFTER load_image ---")
    for item_loaded in train_dataset_loaded.take(1):
         logger.info(f"  Element type: {type(item_loaded)}, Shape: {item_loaded.shape}") # Expect (256, 512, 3)


    # 2. Map the preprocessing and augmentation function
    # Takes raw image tensors -> applies augmentation, splitting, normalization
    # Outputs pairs of (input_image, target_image)
    train_dataset_loaded = dataset.map(load_image,
                                      num_parallel_calls=tf.data.AUTOTUNE)
    logger.info(f"After load_image map, element spec: {train_dataset_loaded.element_spec}")
    # --- DEBUG: Check elements AFTER preprocess_image_train map ---
    logger.info("--- Inspecting elements AFTER preprocess_image_train ---")
    for item_processed in train_dataset_loaded.take(1):
         logger.info(f"  Element type: {type(item_processed)}") # Expect tuple
         if isinstance(item_processed, tuple):
             logger.info(f"  Element length: {len(item_processed)}") # Expect 2
             if len(item_processed) == 2:
                 logger.info(f"  Item 0 shape: {item_processed[0].shape}") # Expect (256, 256, 3)
                 logger.info(f"  Item 1 shape: {item_processed[1].shape}") # Expect (256, 256, 3)
         elif hasattr(item_processed, 'shape'):
              logger.info(f"  Element shape (NOT TUPLE!): {item_processed.shape}")
         else:
              logger.info(f"  Element is neither tuple nor tensor!")
    train_dataset_processed = train_dataset_loaded.map(preprocess_image_train,
                                                      num_parallel_calls=tf.data.AUTOTUNE)

    # 3. Shuffle the dataset
    # Uses BUFFER_SIZE defined earlier. Important for training stability.
    train_dataset_shuffled = train_dataset_processed.shuffle(BUFFER_SIZE)

    # 4. Batch the dataset
    # Groups pairs into batches of size BATCH_SIZE
    train_dataset_batched = train_dataset_shuffled.batch(BATCH_SIZE)
    logger.info(f"After batch, element spec: {train_dataset_batched.element_spec}") # Should be tuple of batched TensorSpecs


    # 5. Prefetch for performance
    # Allows the CPU to prepare the next batch(es) while the GPU is busy
    train_dataset_final = train_dataset_batched.prefetch(buffer_size=tf.data.AUTOTUNE) # Assign to final variable


    logger.info("Training dataset pipeline built successfully.")

    # --- Test the pipeline ---
    logger.info("Fetching one batch to test pipeline...")
    for example_input, example_target in train_dataset_final.take(1):
        logger.info(f"  Input batch shape: {example_input.shape}") # Should be (BATCH_SIZE, 256, 256, 3)
        logger.info(f"  Target batch shape: {example_target.shape}") # Should be (BATCH_SIZE, 256, 256, 3)
        # Display one example from the batch
        plt.figure(figsize=(6, 6))
        plt.subplot(1, 2, 1)
        # De-normalize for display: (img + 1) * 127.5
        plt.imshow((example_input[0] + 1) / 2.0)
        plt.title("Sample Input")
        plt.axis("off")
        plt.subplot(1, 2, 2)
        # Display target - might need grayscale if it's B&W (but loaded as 3ch here)
        plt.imshow((example_target[0] + 1) / 2.0, cmap='gray' if example_target.shape[-1]==1 else None)
        plt.title("Sample Target")
        plt.axis("off")
        plt.show()
    logger.info("Pipeline test batch fetched successfully.")
    # After fetching the batch:
    logger.info(f"  Target Batch Min Value (Normalized): {tf.reduce_min(example_target).numpy():.4f}")
    logger.info(f"  Target Batch Max Value (Normalized): {tf.reduce_max(example_target).numpy():.4f}")
    # Also visualize the DENORMALIZED target again to be sure it looks right
    plt.imshow((example_target[0].numpy() + 1) / 2.0)
    plt.title("VERIFY Sample Target Looks Correct")
    plt.show()
except Exception as e:
    logger.error(f"Error building or testing dataset pipeline: {e}", exc_info=True)
    raise


## Define generator and discriminator

### Helper Functions

In [None]:
# Define the downsampling block (Encoder)
def downsample(filters, size, apply_norm=True, apply_dropout=False, name=None):
    '''
    Implements a Conv2D layer with strides=2 for downsampling, followed by
    optional normalization (BatchNormalization in this baseline) and LeakyReLU
    activation. Bias is typically turned off in conv layers when followed by
    normalization.
    '''
    initializer = tf.random_normal_initializer(0., 0.02)
    result = tf.keras.Sequential(name=name)
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=not apply_norm)) # No bias if using norm
    if apply_norm:
        if use_batchnorm:
             # Using BatchNormalization. Consider momentum and epsilon defaults or tuning later.
             result.add(layers.BatchNormalization())
        # else: Add InstanceNormalization/LayerNormalization here if switching later
        else:
            result.add(layers.GroupNormalization(groups=-1))
    if apply_dropout:
      result.add(layers.Dropout(0.5)) # Standard dropout rate for pix2pix
    result.add(layers.LeakyReLU()) # LeakyReLU in encoder
    return result

# Define the upsampling block (Decoder)
def upsample(filters, size, apply_dropout=False, name=None):
  '''
  Implements Conv2DTranspose with strides=2 for upsampling, followed by
  optional normalization and optional Dropout (used in the first few decoder
  layers as per the pix2pix paper), and finally ReLU activation (as used in
  the pix2pix decoder).
  '''
  initializer = tf.random_normal_initializer(0., 0.02)
  result = tf.keras.Sequential(name=name)
  # 1. Upsample first using bilinear interpolation (doubles H, W)
  result.add(layers.UpSampling2D(size=2, interpolation='bilinear'))
  # 2. Apply standard Conv2D with strides=1 (preserves H, W)
  result.add(layers.Conv2D(filters, size, strides=1, # <<< STRIDES MUST BE 1 HERE
                            padding='same',
                            kernel_initializer=initializer,
                            # Use bias=False if Norm layer includes scale/center
                            use_bias=False))

  if use_batchnorm:
      result.add(layers.BatchNormalization())
  else:
      result.add(layers.GroupNormalization(groups=-1))
  # else: Add InstanceNorm/LayerNorm here if switching later

  if apply_dropout:
      result.add(layers.Dropout(0.5)) # Standard dropout rate for pix2pix decoder
  result.add(layers.ReLU()) # ReLU in decoder, as per original paper
  return result

### Build Generator (U-net)

In [None]:
logger.info("Defining Generator (U-Net) model...")

# Build the U-Net Generator Model using the Keras Functional API
def build_generator(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), output_channels=OUTPUT_CHANNELS, name="unet_generator"):
  '''
  Uses the Keras Functional API (Input, Model) which is well-suited for models
  with non-linear topology like skip connections.

  Defines an Input layer matching our image shape (256, 256, 3).

  Creates a down_stack by calling the downsample function repeatedly,
  progressively reducing spatial dimensions and increasing filter counts. Each
  layer's output is stored in the skips list.

  Creates an up_stack using the upsample function.

  The core U-Net logic iterates through the up_stack, takes the output x,
  concatenates it with the corresponding feature map from the skips list
  (layers.Concatenate()), and passes the result to the next upsample layer.

  A final Conv2DTranspose layer (last) upsamples to the full image size
  (256x256) with the desired number of output_channels (3) and uses a tanh
  activation function to ensure the output pixel values are in the range
  [-1, 1], matching our data normalization.
  '''
  with tf.device('/gpu:0'):
    gf = GENERATOR_FILTERS_INITIAL
    logger.info(f"Building Generator with input shape {input_shape} and output channels {output_channels}")
    inputs = Input(shape=input_shape)

    # Encoder path (downsampling) - filter counts double each time
    down_stack = [
        downsample(gf    , 4, apply_norm=False, name="down_1"), # (bs, 128, 128, gf)
        downsample(gf * 2, 4, name="down_2"),                   # (bs, 64, 64, gf*2)
        downsample(gf * 4, 4, name="down_3"),                   # (bs, 32, 32, gf*4)
        downsample(min(gf * 8, 512), 4, name="down_4"),         # (bs, 16, 16, gf*8 or 512)
        downsample(min(gf * 8, 512), 4, name="down_5"),         # (bs, 8, 8, gf*8 or 512)
        downsample(min(gf * 8, 512), 4, name="down_6"),         # (bs, 4, 4, gf*8 or 512)
        downsample(min(gf * 8, 512), 4, name="down_7"),         # (bs, 2, 2, gf*8 or 512)
        downsample(min(gf * 8, 512), 4, apply_dropout=True, name="down_8"),         # (bs, 1, 1, gf*8 or 512) - Bottleneck
    ]

    # Decoder path (upsampling) - filter counts halve each time
    up_stack = [
        upsample(min(gf * 8, 512), 4, apply_dropout=True, name="up_1"), # (bs, 2, 2, gf*8 or 512)
        upsample(min(gf * 8, 512), 4, apply_dropout=True, name="up_2"), # (bs, 4, 4, gf*8 or 512)
        upsample(min(gf * 8, 512), 4, apply_dropout=True, name="up_3"), # (bs, 8, 8, gf*8 or 512)
        upsample(min(gf * 8, 512), 4, name="up_4"),                     # (bs, 16, 16, gf*8 or 512)
        upsample(gf * 4, 4, name="up_5"),                               # (bs, 32, 32, gf*4)
        upsample(gf * 2, 4, name="up_6"),                               # (bs, 64, 64, gf*2)
        upsample(gf    , 4, name="up_7"),                               # (bs, 128, 128, gf)

    ]

    # Final Output Layer
    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(output_channels, 4,
                                    strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    activation='tanh', # Output activation is tanh for [-1, 1] range
                                    name="final_output_convt") # (bs, 256, 256, output_channels)

    x = inputs
    skips = []
    # Downsampling through the model, collecting skip connections
    for i, down in enumerate(down_stack):
        x = down(x)
        logger.debug(f"Generator Encoder Layer {i+1} output shape: {x.shape}")
        skips.append(x)

    # Convert reversed iterator to an explicit list BEFORE zipping
    skips_for_concat = list(reversed(skips[:-1])) # Use list() here
    logger.info(f"Num skips: {len(skips_for_concat)}, Num up-blocks: {len(up_stack)}")
    if len(up_stack) != len(skips_for_concat):
      raise ValueError("Skip/Up count mismatch!")


    # Upsampling and establishing the skip connections
    for i, (up, skip) in enumerate(zip(up_stack, skips_for_concat)):
        x = up(x) # Apply upsampling block (using UpSampling2D+Conv2D)
        logger.debug(f"Decoder Step {i+1}: Upsampled shape: {x.shape}, Skip shape: {skip.shape}")
        # shape check
        if x.shape[1:3] != skip.shape[1:3]:
            error_msg = (f"Shape mismatch before Concat {i+1}! Up: {x.shape}, Skip: {skip.shape}")
            logger.error(error_msg)
            raise ValueError(error_msg)
        # Concatenate the skip connection (from corresponding encoder layer)
        x = layers.Concatenate()([x, skip]) # Concatenate
        logger.debug(f"Generator Decoder Layer {i+1} output shape after concat: {x.shape}")


    x = last(x) # Apply the final output layer
    logger.debug(f"Generator Final Output shape: {x.shape}")

    return Model(inputs=inputs, outputs=x, name=name)

### Build Discriminator (PatchGAN)

In [None]:

logger.info("Defining Discriminator (PatchGAN) model...")
# --- Discriminator (PatchGAN) ---

# Build the PatchGAN Discriminator Model using the Keras Functional API
def build_discriminator(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), target_shape=(IMG_HEIGHT, IMG_WIDTH, 3), name="patchgan_discriminator"):
    '''
    Uses the Functional API. It takes two inputs: inp (the source photo) and tar
    (either the real target diagram or the fake one from the generator).

    These two inputs are concatenated along the channel axis using
    layers.Concatenate(), resulting in a tensor with 6 channels ((256, 256, 6)).

    Applies a few downsample blocks to reduce spatial dimensions and extract
    features.

    The PatchGAN structure commonly involves some Conv2D layers with strides=1
    and specific padding (ZeroPadding2D) towards the end. This allows the
    network's receptive field to grow and cover larger patches of the input
    without further drastic downsampling.

    The final layer is a Conv2D with a single output channel (1) and no
    activation function. This produces a grid of logits (raw scores), where each
    value represents the discriminator's "realness" prediction for a
    corresponding patch in the input images. The output shape (e.g., (30, 30, 1))
    depends on the exact layers/strides/padding used.
    '''
    logger.info(f"Building Discriminator with input shape {input_shape} and target shape {target_shape}")
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = Input(shape=input_shape, name='input_image')
    tar = Input(shape=target_shape, name='target_image')

    x = layers.Concatenate()([inp, tar])
    logger.debug(f"Discriminator input shape after concat: {x.shape}")

    # Layer 1 (Conv2D -> LeakyReLU) - Apply SN
    # Wrap the Conv2D layer directly with SpectralNormalization
    x = SpectralNormalization(layers.Conv2D(64, 4, strides=2, padding='same',
                                            kernel_initializer=initializer, use_bias=True, # Bias ok w/o BN
                                            name="disc_conv_1"))(x)
    x = layers.LeakyReLU(name="disc_leaky_1")(x)
    logger.debug(f"Discriminator Layer 1 SN output shape: {x.shape}")

    # Layer 2 (Conv2D -> Norm -> LeakyReLU) - Apply SN
    x = SpectralNormalization(layers.Conv2D(128, 4, strides=2, padding='same',
                                            kernel_initializer=initializer, use_bias=not use_batchnorm,
                                            name="disc_conv_2"))(x)
    if use_batchnorm: x = layers.BatchNormalization(name="disc_norm_2")(x)
    x = layers.LeakyReLU(name="disc_leaky_2")(x)
    logger.debug(f"Discriminator Layer 2 SN output shape: {x.shape}")

    # Layer 3 (Conv2D -> Norm -> LeakyReLU) - Apply SN
    x = SpectralNormalization(layers.Conv2D(256, 4, strides=1, padding='same',
                                            kernel_initializer=initializer, use_bias=not use_batchnorm,
                                            name="disc_conv_3"))(x)
    if use_batchnorm: x = layers.BatchNormalization(name="disc_norm_3")(x)
    x = layers.LeakyReLU(name="disc_leaky_3")(x)
    logger.debug(f"Discriminator Layer 3 SN output shape: {x.shape}")

    # Layer 4 (Padding -> Conv2D -> Norm -> LeakyReLU) - Apply SN
    #zero_pad1 = layers.ZeroPadding2D(name="disc_pad_1")(x)
    conv = SpectralNormalization(layers.Conv2D(512, 4, strides=1,
                                            kernel_initializer=initializer,
                                            use_bias=not use_batchnorm, name="disc_conv_4"))(x)
    logger.debug(f"Discriminator Conv4 SN output shape: {conv.shape}")

    if use_batchnorm: norm1 = layers.BatchNormalization(name="disc_norm_4")(conv)
    else: norm1 = conv
    leaky_relu = layers.LeakyReLU(name="disc_leaky_4")(norm1)

    # Final Output Layer (Padding -> Conv2D) - Apply SN
    zero_pad2 = layers.ZeroPadding2D(name="disc_pad_2")(leaky_relu)
    last = SpectralNormalization(layers.Conv2D(1, 4, strides=1, padding = 'valid',
                                            kernel_initializer=initializer, name="disc_output_conv"))(leaky_relu)
    logger.debug(f"Discriminator Final SN Output shape: {last.shape}")

    return Model(inputs=[inp, tar], outputs=last, name=name)

In [None]:
def build_discriminator_70x70(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
                              target_shape=(IMG_HEIGHT, IMG_WIDTH, 3),
                              name="patchgan_discriminator_70x70"):
    '''
    Builds a PatchGAN discriminator with a 70x70 receptive field.

    Uses the Functional API. It takes two inputs: inp (the source photo) and tar
    (either the real target diagram or the fake one from the generator).

    These two inputs are concatenated along the channel axis.

    Applies downsampling blocks. The key change for a 70x70 RF (compared to the
    previous 46x46) is using stride=2 in the third convolutional layer.

    The final layer is a Conv2D with a single output channel (1) and no
    activation function, producing a grid of logits. For a 256x256 input, this
    architecture typically yields a 30x30 output grid (specifically (256/8)-3 = 29x29).
    '''
    logger.info(f"Building Discriminator with input shape {input_shape}, target shape {target_shape}, aiming for 70x70 RF")
    initializer = tf.random_normal_initializer(0., 0.02)

    inp = Input(shape=input_shape, name='input_image')
    tar = Input(shape=target_shape, name='target_image')

    # Concatenate source image and target/generated image
    x = layers.Concatenate()([inp, tar]) # Shape: (H, W, 6)
    logger.debug(f"Discriminator input shape after concat: {x.shape}")

    # Layer 1 (C64): Conv2D -> LeakyReLU - Apply SN
    # Output size: H/2 x W/2
    # RF = 4, Jump = 2
    x = SpectralNormalization(layers.Conv2D(64, 4, strides=2, padding='same',
                                            kernel_initializer=initializer, use_bias=True, # Bias ok w/o BN
                                            name="disc_conv_1"))(x)
    x = layers.LeakyReLU(name="disc_leaky_1")(x)
    logger.debug(f"Discriminator Layer 1 SN output shape: {x.shape}") # e.g., (128, 128, 64)

    # Layer 2 (C128): Conv2D -> Norm -> LeakyReLU - Apply SN
    # Output size: H/4 x W/4
    # RF = 10, Jump = 4
    x = SpectralNormalization(layers.Conv2D(128, 4, strides=2, padding='same',
                                            kernel_initializer=initializer, use_bias=not use_batchnorm,
                                            name="disc_conv_2"))(x)
    if use_batchnorm: x = layers.BatchNormalization(name="disc_norm_2")(x)
    x = layers.LeakyReLU(name="disc_leaky_2")(x)
    logger.debug(f"Discriminator Layer 2 SN output shape: {x.shape}") # e.g., (64, 64, 128)

    # Layer 3 (C256): Conv2D -> Norm -> LeakyReLU - Apply SN
    # *** MODIFIED: Changed strides from 1 to 2 for 70x70 RF ***
    # Output size: H/8 x W/8
    # RF = 22, Jump = 8
    x = SpectralNormalization(layers.Conv2D(256, 4, strides=2, padding='same', # <-- Stride changed here
                                            kernel_initializer=initializer, use_bias=not use_batchnorm,
                                            name="disc_conv_3"))(x)
    if use_batchnorm: x = layers.BatchNormalization(name="disc_norm_3")(x)
    x = layers.LeakyReLU(name="disc_leaky_3")(x)
    logger.debug(f"Discriminator Layer 3 SN output shape: {x.shape}") # e.g., (32, 32, 256)

    # Layer 4 (C512): Conv2D -> Norm -> LeakyReLU - Apply SN
    # Output size: H/8 x W/8 (stride 1)
    # RF = 46, Jump = 8
    # Note: Original code had commented-out ZeroPadding here. Using padding='same' in Conv2D.
    conv = SpectralNormalization(layers.Conv2D(512, 4, strides=1, padding='same', # <-- padding='same' here
                                            kernel_initializer=initializer,
                                            use_bias=not use_batchnorm, name="disc_conv_4"))(x)
    logger.debug(f"Discriminator Conv4 SN output shape: {conv.shape}") # e.g., (32, 32, 512)

    if use_batchnorm: norm1 = layers.BatchNormalization(name="disc_norm_4")(conv)
    else: norm1 = conv # Assign conv to norm1 if not using batchnorm
    leaky_relu = layers.LeakyReLU(name="disc_leaky_4")(norm1)

    # Final Output Layer: Conv2D - Apply SN
    # Output size: (H/8 - 3) x (W/8 - 3) due to k=4, s=1, padding='valid'
    # RF = 70, Jump = 8
    # Note: Original code had ZeroPadding2D defined but not used for input here. Using leaky_relu directly.
    last = SpectralNormalization(layers.Conv2D(1, 4, strides=1, padding='valid', # <-- padding='valid' here
                                            kernel_initializer=initializer, name="disc_output_conv"))(leaky_relu)
    logger.debug(f"Discriminator Final SN Output shape: {last.shape}") # e.g., (29, 29, 1) for 256x256 input

    return Model(inputs=[inp, tar], outputs=last, name=name)

## Loss Functions

#### Explanation

Let's define the loss functions for our baseline pix2pix model. We need an adversarial loss to make the diagrams look realistic (like real diagrams) and a reconstruction loss to ensure the generated diagram matches the content of the input photo.

Considering our goal (generating 2D structural diagrams) and the need for a stable, standard baseline:

1. **Adversarial Loss: LSGAN (Least Squares GAN)**

- **Why LSGAN for baseline?** Compared to Binary Cross-Entropy (BCE), LSGAN is often found to be more stable during training and less prone to vanishing gradients, especially early on. It provides smoother gradients and can lead to higher quality results. For a baseline, stability is valuable.

- **How it works:** Instead of classifying patches as 0 (fake) or 1 (real) with a sigmoid cross-entropy, it uses a Mean Squared Error (MSE) loss. The discriminator tries to make its output close to 1 for real pairs and 0 for fake pairs (generated). The generator tries to make the discriminator output 1 for its fake pairs.

- **Implementation:** We'll use `tf.keras.losses.MeanSquaredError`.

  - **Discriminator Loss** (`L_D`):
    - Wants `D(real_photo, real_diagram)` output to be close to 1.
    - Wants `D(real_photo, generated_diagram)` output to be close to 0.
    - `loss_D_real = mse(D(real_photo, real_diagram), tf.ones_like(D_output))`
    - `loss_D_fake = mse(D(real_photo, generated_diagram), tf.zeros_like(D_output))`  
    - `L_D = 0.5 * (loss_D_real + loss_D_fake)` (The 0.5 scaling is common but optional)
  - **Generator Adversarial Loss** (`L_G_adversarial`):
    - Wants `D(real_photo, generated_diagram)` output to be close to 1 (to fool the discriminator).
    - `L_G_adversarial = mse(D(real_photo, generated_diagram), tf.ones_like(D_output))`
2. **Reconstruction Loss: L1 Loss (Mean Absolute Error)**

- **Why L1 for baseline?** While we note concerns about potential blurriness (which we can address later by adding Edge loss or tuning), L1 is the standard reconstruction loss used in the original pix2pix paper.
  - It provides strong structural guidance, forcing the generator's output pixels to be close to the target diagram's pixels.
  - It generally produces less blurry results than L2 (MSE) loss.
  - It establishes a standard baseline performance metric before introducing more specialized losses.
- **How it works:** Calculates the average absolute difference between each pixel in the generated diagram and the real target diagram. `Mean(|real_diagram - generated_diagram|)`.
- **Implementation:** We'll use `tf.keras.losses.MeanAbsoluteError` or `tf.reduce_mean(tf.abs(...))`.
- **Weight** (`LAMBDA_L1`): This hyperparameter balances the adversarial loss and the L1 loss. The original paper found `LAMBDA_L1 = 100` worked well, heavily emphasizing reconstruction accuracy. We will start with this for the baseline.
    - `L_L1 = LAMBDA_L1 * mae(real_diagram, generated_diagram)`
3. Total Generator Loss:
- The generator aims to minimize both its adversarial loss and the reconstruction loss.
- `L_G_Total = L_G_adversarial + L_L1`
- `L_G_Total = mse(D(fake), 1.0) + LAMBDA_L1 * mae(target, generated)`

#### Code for loss functions

In [None]:
# === EDGE LOSS FUNCTION ===
logger.info("Defining Edge Loss function...")

@tf.function
def calculate_edge_loss(target, generated):
    """
    Calculates the L1 loss between the Sobel edges of the target and generated images.

    Args:
        target: The ground truth target image tensor (shape [batch, H, W, C], range [-1, 1]).
        generated: The image tensor output by the generator (shape [batch, H, W, C], range [-1, 1]).

    Returns:
        A scalar tensor representing the mean absolute error between gradients.
    """
    # 1. Convert images to grayscale first (Sobel works best on single channel)
    #    Assuming input images are RGB (3 channels) in range [-1, 1]
    #    Convert to [0, 1] range first for rgb_to_grayscale
    target_gray = tf.image.rgb_to_grayscale((target + 1.0) / 2.0)
    generated_gray = tf.image.rgb_to_grayscale((generated + 1.0) / 2.0)
    logger.debug(f"Edge Loss - Grayscale shapes: Target {target_gray.shape}, Generated {generated_gray.shape}")

    # 2. Calculate Sobel edges (gradients in Y and X directions)
    #    Output shape: [batch, H, W, 1, 2] (1 channel, 2 gradient dims)
    target_sobel = tf.image.sobel_edges(target_gray)
    generated_sobel = tf.image.sobel_edges(generated_gray)
    logger.debug(f"Edge Loss - Sobel shapes: Target {target_sobel.shape}, Generated {generated_sobel.shape}")
    target_sobel_f32 = tf.cast(target_sobel, tf.float32)
    generated_sobel_f32 = tf.cast(generated_sobel, tf.float32)

    # 3. Calculate the L1 difference between the Sobel gradients
    #    Use tf.abs for absolute difference, then tf.reduce_mean for average error
    edge_loss = tf.reduce_mean(tf.abs(target_sobel_f32 - generated_sobel_f32))
    #logger.debug(f"Edge Loss - Calculated scalar value: {edge_loss:.4f}") # Note: Prints symbolic tensor info during trace

    return edge_loss

logger.info("Edge Loss function defined.")

In [None]:
@tf.function # Decorate for potential performance improvement
def multi_layer_perceptual_loss(y_true, y_pred):
    """
    Calculates multi-layer VGG-based perceptual loss.

    Args:
        y_true: Ground truth image tensor (shape [batch, H, W, C], range [-1, 1]).
        y_pred: Predicted image tensor (shape [batch, H, W, C], range [-1, 1]).

    Returns:
        A scalar tensor representing the weighted sum of MSE losses across selected VGG layers.
    """
    # 1. Rescale images from [-1, 1] to [0, 255]
    y_true_0_255 = (y_true + 1.0) * 127.5
    y_pred_0_255 = (y_pred + 1.0) * 127.5

    # 2. Handle Grayscale Input (Repeat channel if necessary)
    if y_true_0_255.shape[-1] == 1:
        y_true_rgb = tf.repeat(y_true_0_255, 3, axis=-1)
        y_pred_rgb = tf.repeat(y_pred_0_255, 3, axis=-1)
    else:
        y_true_rgb = y_true_0_255
        y_pred_rgb = y_pred_0_255

    # 3. Preprocess images for VGG16
    y_true_processed = preprocess_input(y_true_rgb)
    y_pred_processed = preprocess_input(y_pred_rgb)

    # 4. Extract features using the pre-built multi-output VGG model
    #    This returns a list of feature tensors (one for each layer in VGG_LAYERS_FOR_LOSS)
    true_features_list = vgg_loss_model(y_true_processed)
    pred_features_list = vgg_loss_model(y_pred_processed)

    # Ensure the output is a list (Keras sometimes returns single tensor if list has 1 item)
    if not isinstance(true_features_list, list):
        true_features_list = [true_features_list]
        pred_features_list = [pred_features_list]

    # 5. Calculate weighted MSE loss for each layer's features
    total_loss = tf.constant(0.0, dtype=true_features_list[0].dtype)
    for i in range(len(LOSS_LAYERS)):
        layer_loss = tf.math.reduce_mean(tf.math.square(true_features_list[i] - pred_features_list[i]))
        total_loss += LAYER_WEIGHTS[i] * layer_loss

        # Optional: Log individual layer losses if needed for debugging
        # tf.print(f"Layer {VGG_LAYERS_FOR_LOSS[i]} Loss: ", layer_loss)

    return total_loss


In [None]:
# === LOSS FUNCTIONS ===
logger.info("Defining loss functions...")

# Use MeanSquaredError for LSGAN adversarial loss
bce = tf.keras.losses.BinaryCrossentropy(from_logits=True) # Keep for potential comparison later
mse = tf.keras.losses.MeanSquaredError()

# Use MeanAbsoluteError for L1 reconstruction loss
mae = tf.keras.losses.MeanAbsoluteError()

def discriminator_loss(disc_real_output, disc_generated_output):
  """Calculates LSGAN discriminator loss."""
  # Discriminator wants real outputs to be close to 1
  real_loss = bce(tf.ones_like(disc_real_output), disc_real_output)
  # Discriminator wants generated outputs to be close to 0
  generated_loss = bce(tf.zeros_like(disc_generated_output), disc_generated_output)
  # Combine losses
  total_disc_loss = real_loss + generated_loss
  # Scale by 0.5 as commonly done (optional, affects magnitude but not gradients direction)
  # return total_disc_loss * 0.5
  #logger.debug(f"Discriminator Loss - Real: {real_loss:.4f}, Fake: {generated_loss:.4f}, Total: {total_disc_loss:.4f}")
  return total_disc_loss

def generator_loss(disc_generated_output, gen_output, target):
  """Calculates LSGAN Generator adversarial loss + L1 reconstruction loss."""
  # Generator wants discriminator to think generated images are real (output close to 1)
  gan_loss = mse(tf.ones_like(disc_generated_output), disc_generated_output)
  gan_loss = tf.cast(gan_loss,  dtype=bf16)
  # L1 loss (Mean Absolute Error) between generated and target images
  l1_loss = mae(target, gen_output)
  l1_loss = tf.cast(l1_loss,  dtype=bf16)
  # Calculate Perceptual Loss
  gen_perceptual_loss = multi_layer_perceptual_loss(target, gen_output)
  # --- Calculate Edge loss ---
  # >> Call calculate_edge_loss and assign the result to a variable <<
  #    (Make sure the function 'calculate_edge_loss' is defined earlier)
  edge_loss_value = calculate_edge_loss(target, gen_output)
  edge_loss_value = tf.cast(edge_loss_value, dtype=bf16)
  lambda_l1 = tf.cast(LAMBDA_L1,  dtype=bf16)
  lambda_edge = tf.cast(LAMBDA_EDGE,  dtype=bf16)
  lambda_p = tf.cast(LAMBDA_P,  dtype=bf16)
  # --- Combine losses ---
  # Ensure LAMBDA_L1 and LAMBDA_EDGE are accessible here (e.g., global or passed as args)
  total_gen_loss = gan_loss + (lambda_l1 * l1_loss) + (lambda_edge * edge_loss_value) + (lambda_p * gen_perceptual_loss)

  #logger.debug(f"Generator Loss - Adversarial: {gan_loss:.4f}, L1: {l1_loss:.4f} (W: {LAMBDA_L1*l1_loss:.4f}), Edge: {edge_loss:.4f} (W: {LAMBDA_EDGE*edge_loss:.4f}), Total: {total_gen_loss:.4f}")
  return total_gen_loss, gan_loss, l1_loss, edge_loss_value, gen_perceptual_loss # Return components for logging

logger.info(f"Loss functions defined: LSGAN (MSE-based) + L1 (MAE-based) with LAMBDA_L1 = {LAMBDA_L1} and Edge Loss with LAMBDA_EDGE = {LAMBDA_EDGE}")

# Custom Learning Rate Schedule Class

In [None]:
# Redefine the class with tf.print for debugging
# class Pix2PixSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
#     """
#     Custom LR Schedule: Constant initial rate then linear decay to zero.
#     Includes tf.print statements for debugging LR calculation.
#     """
#     def __init__(self, initial_learning_rate, decay_start_step, decay_steps, name="Pix2PixSchedule"):
#         super().__init__()
#         self.initial_learning_rate = tf.cast(initial_learning_rate, tf.float32)
#         self.decay_start_step = tf.cast(decay_start_step, tf.float32)
#         self.decay_steps = tf.cast(decay_steps, tf.float32)
#         self.name = name
#         # Log parameters during init
#         tf.print("Pix2PixSchedule Initialized:")
#         tf.print("  Initial LR:", self.initial_learning_rate)
#         tf.print("  Decay Start Step:", self.decay_start_step)
#         tf.print("  Decay Steps:", self.decay_steps)


#     def __call__(self, step):
#         """Calculates the learning rate for a given step."""
#         with tf.name_scope(self.name or "Pix2PixSchedule"):
#             step_float = tf.cast(step, tf.float32)
#             # --- Debug Prints ---
#             tf.print("Pix2PixSchedule Call - Input Step:", step_float, output_stream=sys.stdout)

#             cond = tf.less(step_float, self.decay_start_step)
#             # tf.print("Pix2PixSchedule Call - Decay Condition (<", self.decay_start_step, "?):", cond, output_stream=sys.stdout)

#             lr_if_true = self.initial_learning_rate

#             progress_in_decay = (step_float - self.decay_start_step) / self.decay_steps
#             decay_factor = 1.0 - progress_in_decay
#             decay_factor = tf.maximum(0.0, decay_factor) # Clamp at 0
#             lr_if_false = self.initial_learning_rate * decay_factor
#             # --- Debug Prints ---
#             tf.print("Pix2PixSchedule Call - Decay Factor:", decay_factor, output_stream=sys.stdout)

#             learning_rate = tf.where(cond, lr_if_true, lr_if_false)
#             # --- Debug Print ---
#             tf.print("Pix2PixSchedule Call - Returned LR:", learning_rate, output_stream=sys.stdout)
#             return learning_rate

#     def get_config(self):
#         """Returns the configuration of the schedule."""
#         return {
#             "initial_learning_rate": float(self.initial_learning_rate.numpy()),
#             "decay_start_step": int(self.decay_start_step.numpy()),
#             "decay_steps": int(self.decay_steps.numpy()),
#             "name": self.name
#         }

# --- Re-instantiate the schedule AFTER defining the modified class ---
# logger.info("Re-instantiating schedule with debug prints...")
# lr_schedule = Pix2PixSchedule(initial_learning_rate, decay_start_step, total_decay_steps)

# --- IMPORTANT: Re-instantiate the optimizers to use the new schedule instance ---
# logger.info("Re-defining optimizers with the new debug schedule instance...")
# try:
#     generator_optimizer = tf.keras.optimizers.Adam(
#         learning_rate=lr_schedule, beta_1=BETA_1, beta_2=BETA_2, name='generator_adam'
#     )
#     discriminator_optimizer = tf.keras.optimizers.Adam(
#         learning_rate=lr_schedule, beta_1=BETA_1, beta_2=BETA_2, name='discriminator_adam'
#     )
#     logger.info("Optimizers re-defined with debug schedule.")
# except Exception as e:
#     logger.error(f"Error re-defining optimizers with schedule: {e}", exc_info=True)
#     raise

In [None]:
# # Instantiate the custom learning rate schedule
# lr_schedule = Pix2PixSchedule(initial_learning_rate, decay_start_step, total_decay_steps)
# logger.info("Pix2PixSchedule instantiated.")

# Optimizers

### Explanation:


1. **Optimizer Choice:** We are using the Adam optimizer (`tf.keras.optimizers.Adam`) for both the Generator and the Discriminator. This is a standard and effective choice for GANs, including the original pix2pix implementation, known for its adaptive learning rate capabilities.
2. **Separate Optimizers:** It's crucial to define two separate instances of the Adam optimizer (`generator_optimizer` and `discriminator_optimizer`). The Generator and Discriminator are distinct networks with their own sets of trainable parameters and are trained based on different loss functions (they have competing objectives). Therefore, they need independent optimizers to track their respective gradients and update their weights correctly.
3. Hyperparameters:
- `LEARNING_RATE = 0.0002`: This is a commonly used starting learning rate for Adam in GAN training and often provides a good balance between convergence speed and stability.
- `BETA_1 = 0.5`: Adam uses beta parameters to control the exponential decay rates for its moving averages of gradients (moments). While the default for `beta_1` is often `0.9`, setting it to `0.5` is a common practice specifically for GANs. A lower `beta_1` gives less weight to past gradients (less momentum), which can help stabilize the sometimes oscillatory training dynamics between the generator and discriminator.
- `BETA_2 = 0.999`: The default value for the second moment estimate decay rate is typically used.
4. **Instantiation:** The code creates the two optimizer instances by calling `tf.keras.optimizers.Adam(...)` and passing the specified hyperparameters.
5. **`AdamW` Alternative**: `AdamW` is a potential alternative (available in newer TensorFlow versions or TensorFlow Addons). `AdamW` implements weight decay more effectively than standard Adam combined with L2 regularization and is often preferred if we add significant weight decay later for regularization. For this baseline without explicit weight decay added here, standard Adam is fine.
6. **Logging and Error Handling**: An informational message is logged confirming the optimizer configuration. The `try...except` block ensures that any errors during optimizer instantiation are caught, logged with details (`exc_info=True`), and execution is stopped (`raise`).

### Code:

In [None]:
logger.info("Defining optimizers with custom LR schedule and loss scaling to enable AMP...")

try:
    # Create the base Adam optimizers (using schedule or fixed LR)
    # base_generator_optimizer = tf.keras.optimizers.Adam(
    #     learning_rate=LEARNING_RATE, # Or your lr_schedule object
    #     beta_1=BETA_1, beta_2=BETA_2, name='base_generator_adam'
    # )
    # Use two separate Adam optimizers: one for the Generator and one for the Discriminator
    # generator_optimizer = mixed_precision.LossScaleOptimizer(
    #     base_generator_optimizer, name='amp_generator_adam'
    # )
    # base_discriminator_optimizer = tf.keras.optimizers.Adam(
    #     learning_rate=LEARNING_RATE, # Or schedule, possibly different rate
    #     beta_1=BETA_1, beta_2=BETA_2, name='base_discriminator_adam'
    # )
    # discriminator_optimizer = mixed_precision.LossScaleOptimizer(
    #     base_discriminator_optimizer, name='amp_discriminator_adam'
    # )
    # logger.info("Optimizers wrapped with LossScaleOptimizer for AMP.")
    generator_optimizer = tf.keras.optimizers.Adam(
        learning_rate=LEARNING_RATE, # Or your lr_schedule object
        beta_1=BETA_1, beta_2=BETA_2, name='base_generator_adam'
    )
    discriminator_optimizer = tf.keras.optimizers.Adam(
        learning_rate=LEARNING_RATE, # Or schedule, possibly different rate
        beta_1=BETA_1, beta_2=BETA_2, name='base_discriminator_adam'
    )
    logger.info(f"Optimizers defined: Adam with INITIAL LR={LEARNING_RATE}, beta_1={BETA_1}, beta_2={BETA_2}")
    # Note: Consider AdamW later if adding significant weight decay.
    # e.g., use tf.keras.optimizers.AdamW(...) if using TF 2.11+ or tfa.optimizers.AdamW otherwise

except Exception as e:
    logger.error(f"Error defining/wrapping optimizers: {e}", exc_info=True)
    raise # Stop execution if optimizers can't be created

# Metrics:

## VGG19 Perceptual loss

In [None]:
# --- Build the Feature Extractor ---
def build_vgg_feature_extractor(layer_names, input_shape=(None, None, 3)):
    """Builds a VGG19 model for feature extraction."""
    # Ensure weights are downloaded if not present; might need internet access first time
    try:
        vgg = vgg19.VGG19(include_top=False, weights='imagenet', input_shape=input_shape)
        vgg.trainable = False # Essential for metric/loss use
        outputs = [vgg.get_layer(name).output for name in layer_names]
        model = tf.keras.Model(vgg.input, outputs, name='vgg_feature_extractor')
        logger.info(f"Built VGG19 feature extractor with layers: {layer_names}")
        return model
    except Exception as e:
        logger.error(f"Failed to build VGG feature extractor: {e}", exc_info=True)
        logger.error("Ensure internet connection for downloading VGG weights if needed.")
        raise

In [None]:
class PerceptualMetric(tf.keras.metrics.Metric):
    def __init__(self, name='perceptual_metric',
                 layer_names=METRIC_LAYER_NAMES,
                 input_shape=(256, 256, 3), # MUST Match your drawing H, W
                 distance_metric='mse', **kwargs):
        super().__init__(name=name, **kwargs)
        self.distance_metric = distance_metric
        # Build the feature extractor within the metric instance
        self.feature_extractor = build_vgg_feature_extractor(layer_names, input_shape)
        # Ensure it's not trainable (should be handled by build_vgg..., but double-check)
        self.feature_extractor.trainable = False

        # Internal state using add_weight for proper TF metric handling
        self.total_distance = self.add_weight(name='total_distance', initializer='zeros')
        self.count = self.add_weight(name='count', initializer='zeros', dtype=tf.float32) # Ensure count is float for division
        logger.info(f"PerceptualMetric '{name}' initialized with shape {input_shape}, layers: {layer_names}, metric: {distance_metric}")


    def update_state(self, y_true, y_pred, sample_weight=None):
        """
        Args:
            y_true: Ground truth drawings (batch, H, W, C), range [0, 1]
            y_pred: Generated drawings (batch, H, W, C), range [0, 1]
        """
        gt_drawing = tf.cast(y_true, tf.float32)
        generated_drawing = tf.cast(y_pred, tf.float32)
        batch_size = tf.cast(tf.shape(gt_drawing)[0], tf.float32) # Cast batch size to float

        # --- Preprocessing ---
        # 1. Ensure 3 channels
        if generated_drawing.shape[-1] == 1:
            generated_drawing = tf.image.grayscale_to_rgb(generated_drawing)
        if gt_drawing.shape[-1] == 1:
            gt_drawing = tf.image.grayscale_to_rgb(gt_drawing)

        # 2. Apply VGG preprocessing (expects [0, 255] input range)
        # Ensure input range is correct before multiplying by 255!
        # Add clipping just in case values slightly exceed [0, 1] after denormalization
        generated_drawing_0_1 = tf.clip_by_value(generated_drawing, 0.0, 1.0)
        gt_drawing_0_1 = tf.clip_by_value(gt_drawing, 0.0, 1.0)

        gen_prep = vgg19.preprocess_input(generated_drawing_0_1 * 255.0)
        gt_prep = vgg19.preprocess_input(gt_drawing_0_1 * 255.0)

        # --- Feature Extraction ---
        # Use training=False explicitly
        gen_features = self.feature_extractor(gen_prep, training=False)
        gt_features = self.feature_extractor(gt_prep, training=False)

        # Handle case where only one layer is extracted (output is not a list)
        if not isinstance(gen_features, list):
            gen_features = [gen_features]
            gt_features = [gt_features]

        # --- Calculate Distance ---
        batch_distance = tf.constant(0.0, dtype=tf.float32) # Initialize with float32
        num_layers = len(gen_features)
        if num_layers == 0:
            # Avoid division by zero if no layers were somehow selected
            # This case should ideally be caught during init, but good practice
            return

        for i in range(num_layers):
            # Calculate distance per layer and ensure it's float32
            if self.distance_metric == 'l1':
                layer_dist = tf.reduce_mean(tf.abs(gen_features[i] - gt_features[i]))
            else:  # mse
                layer_dist = tf.reduce_mean(tf.square(gen_features[i] - gt_features[i]))
            batch_distance += tf.cast(layer_dist, tf.float32) # Accumulate as float32

        # Average distance across the layers for the batch
        avg_batch_distance = batch_distance / tf.cast(num_layers, tf.float32)

        # Update state: add sum of distances for the batch, weighted by batch size
        # This ensures batches of different sizes are weighted correctly in the final average
        self.total_distance.assign_add(avg_batch_distance * batch_size)
        self.count.assign_add(batch_size)


    def result(self):
        # Return the mean distance over all samples
        # Avoid division by zero if count is zero
        return tf.math.divide_no_nan(self.total_distance, self.count)

    def reset_state(self):
        # Reset weights managed by add_weight
        self.total_distance.assign(0.0)
        self.count.assign(0.0)

    def get_config(self):
        # Required for saving/loading models with custom metrics
        config = super().get_config()
        config.update({
            'layer_names': self.feature_extractor.output_names, # Or store layer_names in init
            'input_shape': self.feature_extractor.input_shape[1:], # Exclude batch dim
            'distance_metric': self.distance_metric
        })
        return config

    @classmethod
    def from_config(cls, config):
        # Required for loading models with custom metrics
        # Note: This assumes build_vgg_feature_extractor is available in the scope
        # Layer names might need careful handling depending on how get_config stores them
        # Simple approach: Re-pass the config args
        return cls(**config)



# Model Instantiation

In [None]:
# --- Instantiate the models ---
try:
    generator = build_generator()
    discriminator = build_discriminator()
    logger.info("Generator and Discriminator models built successfully.")

    # --- Optional: Log Model Summaries ---
    # Use a function to capture summary print output and log it
    def log_model_summary(model):
        import io
        stream = io.StringIO()
        # Pass the logger's info method directly to print_fn (or use a lambda)
        model.summary(print_fn=logger.info) # Simpler way to log summary
        # # Alternative using capture to string first:
        # model.summary(print_fn=lambda x: stream.write(x + '\n'))
        # summary_string = stream.getvalue()
        # stream.close()
        # logger.info(f"Start Model Summary ({model.name}):\n{summary_string}End Model Summary ({model.name})")

    logger.info("--- Generator Summary ---")
    log_model_summary(generator)
    logger.info("--- Discriminator Summary ---")
    log_model_summary(discriminator)

    # --- Optional: Test model outputs with dummy data (Sanity Check) ---
    logger.info("Testing model outputs with dummy data...")
    # Create a dummy batch (use tf.ones or tf.zeros for simplicity)
    dummy_input_batch = tf.ones([BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3])
    dummy_target_batch = tf.ones([BATCH_SIZE, IMG_HEIGHT, IMG_WIDTH, 3])

    gen_output = generator(dummy_input_batch, training=False)
    disc_output = discriminator([dummy_input_batch, gen_output], training=False)

    logger.info(f" Dummy Generator output batch shape: {gen_output.shape}") # Expect (BATCH_SIZE, 256, 256, 3)
    logger.info(f" Dummy Discriminator output batch shape: {disc_output.shape}") # Expect (BATCH_SIZE, 30, 30, 1)
    min_gen = tf.reduce_min(gen_output).numpy()
    max_gen = tf.reduce_max(gen_output).numpy()
    min_disc = tf.reduce_min(disc_output).numpy()
    max_disc = tf.reduce_max(disc_output).numpy()

    min_gen = float(min_gen)
    max_gen = float(max_gen)
    min_disc = float(min_disc)
    max_disc = float(max_disc)

    # Check output range for generator (should be roughly [-1, 1] due to tanh)
    logger.info(f"Dummy Generator output min/max: {min_gen:.2f}/{max_gen:.2f}")
    # Discriminator output is logits, so range can vary
    logger.info(f"Dummy Discriminator output min/max: {min_disc:.2f}/{max_disc:.2f}")

    logger.info("Model output shape and basic range tests passed.")


except Exception as e:
    logger.error(f"Error building or testing models: {e}", exc_info=True)
    # Re-raise the exception to stop execution if model building fails
    raise

In [None]:
logging.info(f"TensorFlow version: {tf.__version__}")

# Check for GPU availability
gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
    logging.info(f"GPU detected: {gpu_devices}")
    # Optional: Set memory growth to prevent TF from allocating all GPU memory upfront
    try:
        for gpu in gpu_devices:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        logging.error(e)
    logical_gpus = tf.config.list_logical_devices('GPU')
    #logging.info(len(gpu_devices), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    # Usually computation happens on '/GPU:0'
else:
    logging.info(f"No GPU found. Using CPU.")


# List physical devices (GPUs and CPUs)
physical_devices = tf.config.list_physical_devices()
print("All Physical Devices:", physical_devices)

# List GPUs specifically
gpu_devices = tf.config.list_physical_devices('GPU')
if gpu_devices:
    print("GPUs Available:", gpu_devices)
else:
    print("No GPUs found. Using CPU.")


# Checkpointing

In [None]:
!ls -la "{checkpoint_dir}"

In [None]:
# === CHECKPOINTING SETUP ===
# (Ensure generator, discriminator, generator_optimizer, discriminator_optimizer exist)
logger.info("Setting up checkpointing...")

# Create a Checkpoint object - include everything you need to restore training state
epoch_counter = tf.Variable(0, trainable=False, dtype=tf.int64)
ckpt = tf.train.Checkpoint(generator=generator,
                           discriminator=discriminator,
                           generator_optimizer=generator_optimizer,
                           discriminator_optimizer=discriminator_optimizer,
                           epoch=epoch_counter)
                           # Add epoch counter if you want to save/restore it too:
                           # 'epoch': tf.Variable(0)) # Requires tf.Variable definition earlier


# Create a CheckpointManager - manages multiple checkpoints (e.g., keeps last 5)
# CheckpointManager requires the directory, the checkpoint object, and max_to_keep
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_dir, max_to_keep=5) # Keep latest 5 checkpoints
logger.info("Checkpointing setup defined.")

# --- Attempt to restore the latest checkpoint ---
'''
if ckpt_manager.latest_checkpoint:
    # try:
    #     # ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial() # Use expect_partial if models might change slightly
    #     # logger.info(f"Restored checkpoint from {ckpt_manager.latest_checkpoint}")
    #     # If you saved the epoch: logger.info(f"Resuming training from epoch {ckpt.epoch.numpy() + 1}")
    # except Exception as e:
    #     logger.warning(f"Could not restore checkpoint: {e}", exc_info=False) # Log as warning, don't stop if restore fails
    #     # logger.error(f"Could not restore checkpoint: {e}", exc_info=True) # Or log as error if needed
else:
    logger.info("No checkpoint found at {}. Initializing from scratch.".format(checkpoint_dir))
    start_epoch = 0
'''
start_epoch = 0
logger.info(f"** FORCING START FROM EPOCH 0 ** (Restore logic is bypassed/commented out)")


# Note: Saving the checkpoint (ckpt_manager.save()) will happen periodically
#       INSIDE the main training loop (e.g., every N epochs).

# Training Step

In [None]:
# === TRAINING STEP FUNCTION ===
# Ensure models (generator, discriminator), loss functions (generator_loss, discriminator_loss),
# and optimizers (generator_optimizer, discriminator_optimizer) are defined and accessible.
# Ensure logger object exists.

logger.info("Defining the training step function with @tf.function...")

# The @tf.function decorator compiles the Python function into a callable
# TensorFlow graph. This typically provides a significant speedup by optimizing
# operations and reducing Python overhead during execution, especially inside loops.
# Note: Debugging inside a tf.function can be harder. For debugging,
# temporarily comment out the decorator to run in "eager mode" or use tf.print()
# inside the function (which works within tf.function).
@tf.function
def train_step(input_image, target_image, step_counter):
    """
    Performs a single training step (one batch) for the pix2pix model.

    This function executes the forward passes for both the generator and discriminator,
    calculates their respective losses, computes gradients using tf.GradientTape,
    and applies these gradients using the optimizers.

    Using tf.GradientTape:
      - The `tf.GradientTape` context automatically records operations involving
        trainable variables performed within its scope.
      - This recording allows TensorFlow to compute gradients using automatic
        differentiation when `tape.gradient()` is called.
      - We use two separate tapes (gen_tape, disc_tape) here because we need to
        calculate gradients and update the generator and discriminator independently
        based on their different loss functions and trainable variables. A single
        persistent tape could also be used but is often less clear.

    Args:
        input_image: The batch of input photos (e.g., shape [BATCH_SIZE, 256, 256, 3], range [-1, 1]).
        target_image: The batch of corresponding ground truth diagrams
                      (e.g., shape [BATCH_SIZE, 256, 256, 3], range [-1, 1]).

    Returns:
        A tuple containing scalar tensor losses for monitoring:
        (disc_loss, gen_loss_total, gen_gan_loss, gen_l1_loss)
        Returning these scalar tensors allows logging and monitoring outside
        this compiled function.
    """
    # Open two gradient tapes simultaneously to record operations for G and D.
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        input_image = tf.debugging.check_numerics(input_image, "Input Image")
        target_image = tf.debugging.check_numerics(target_image, "Target Image")

        # --- Forward Passes ---
        # 1. Generator produces an output image based on the input photo.
        #    `training=True` ensures layers like BatchNormalization update their
        #    moving statistics and Dropout layers are active.
        gen_output = generator(input_image, training=True)
        #gen_output = tf.debugging.check_numerics(gen_output, "Generator Output")


        # 2. Discriminator evaluates the real pair (input photo + real diagram).
        #    `training=True` ensures Discriminator's BatchNorm/Dropout are active.
        disc_real_output = discriminator([input_image, target_image], training=True)
        #disc_real_output = tf.debugging.check_numerics(disc_real_output, "Disc Real Output")

        # 3. Discriminator evaluates the fake pair (input photo + generated diagram).
        #    The same input photo is used, paired with the generator's output.
        disc_generated_output = discriminator([input_image, gen_output], training=True)
        #disc_generated_output = tf.debugging.check_numerics(disc_generated_output, "Disc Generated Output")

        # --- Loss Calculation ---
        # 4. Calculate Discriminator loss using the predefined function.
        #    Compares D's output on real vs. fake pairs to target labels (e.g., 1s and 0s for LSGAN).
        disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
        #disc_loss = tf.debugging.check_numerics(disc_loss, "Discriminator Loss")


        # 5. Calculate Generator loss using the predefined function.
        #    Includes the adversarial component (how well G fooled D) and the
        #    L1 reconstruction component (pixel difference between generated and target).
        gen_loss_total, gen_gan_loss, gen_l1_loss, gen_edge_loss, gen_perceptual_loss = generator_loss(disc_generated_output, gen_output, target_image)
        # gen_loss_total = tf.debugging.check_numerics(gen_loss_total, "Generator Total Loss")
        # gen_gan_loss = tf.debugging.check_numerics(gen_gan_loss, "Generator GAN Loss")
        # gen_l1_loss = tf.debugging.check_numerics(gen_l1_loss, "Generator L1 Loss")
        # gen_edge_loss = tf.debugging.check_numerics(gen_edge_loss, "Generator Edge Loss")
        # gen_perceptual_loss = tf.debugging.check_numerics(gen_perceptual_loss, "Generator Perceptual Loss")

        # >>> 6. Calculate Discriminator Accuracy <<<
        # Apply sigmoid to get probabilities (assuming outputs are logits)
        real_pred_prob = tf.sigmoid(disc_real_output)
        fake_pred_prob = tf.sigmoid(disc_generated_output)

        # Threshold probabilities at 0.5 to get predicted labels (1=Real, 0=Fake)
        real_pred_labels = tf.cast(real_pred_prob > 0.5, tf.float32)
        fake_pred_labels = tf.cast(fake_pred_prob > 0.5, tf.float32)

        # Define true labels: 1s for real, 0s for fake
        real_true_labels = tf.ones_like(real_pred_labels)
        fake_true_labels = tf.zeros_like(fake_pred_labels)

        # Calculate accuracy for the batch by comparing predicted vs true labels
        disc_real_acc = tf.reduce_mean(tf.cast(tf.equal(real_pred_labels, real_true_labels), tf.float32))
        disc_fake_acc = tf.reduce_mean(tf.cast(tf.equal(fake_pred_labels, fake_true_labels), tf.float32))

        # Check for numerical issues in accuracies as well
        # disc_real_acc = tf.debugging.check_numerics(disc_real_acc, "Discriminator Real Accuracy")
        # disc_fake_acc = tf.debugging.check_numerics(disc_fake_acc, "Discriminator Fake Accuracy")



    # --- Gradient Calculation ---
    # 6. Calculate gradients for the Generator.
    #    Computes derivatives of the *total generator loss* with respect to *generator's* trainable variables.
    generator_gradients = gen_tape.gradient(gen_loss_total,
                                            generator.trainable_variables)
    # 7. Calculate gradients for the Discriminator.
    #    Computes derivatives of the *discriminator loss* with respect to *discriminator's* trainable variables.
    discriminator_gradients = disc_tape.gradient(disc_loss,
                                                 discriminator.trainable_variables)

    # --- Gradient Application ---
    # 8. Apply calculated gradients to update Generator weights.
    #    The optimizer modifies the variables based on the gradients. zip pairs grads with vars.
    generator_optimizer.apply_gradients(zip(generator_gradients,
                                            generator.trainable_variables))
    # 9. Apply calculated gradients to update Discriminator weights.
    # CONDITIONALLY update discriminator (e.g., every 2 steps)
    # tf.cast is needed as step_counter is int64 usually
    if tf.cast(step_counter, tf.int32) % DISC_UPDATE == 0:
        discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                                    discriminator.trainable_variables))
        tf.print("Discriminator Updated - Step:", step_counter)


    # --- Return Values ---
    # Return the calculated scalar losses for external logging/monitoring.
    return disc_loss, gen_loss_total, gen_gan_loss, gen_l1_loss, gen_edge_loss, disc_real_acc, disc_fake_acc, gen_perceptual_loss

logger.info("Training step function defined and wrapped with @tf.function.")

## Test Training

In [None]:
# --- Test the train_step function ---
# This block tests if the train_step function executes correctly on one batch
# of real data from the input pipeline. It verifies that:
#   a) The function runs without crashing (model connections, math ops work).
#   b) It returns loss values (confirming loss calculations work).
# c) It gives a sanity check before starting the potentially long training loop.
logger.info("Testing one training step function call...")
try:
    # ——— Make sure no trace is running from before ———
    tf.summary.trace_off()
    # Fetch one batch from the already prepared train_dataset
    # Assumes 'train_dataset' object exists and yields (input, target) tuples
    for test_input, test_target in train_dataset.take(1):
        logger.info("Fetched one batch for testing train_step.")
        # Call the compiled train_step function
        # ——— Start the Profiler ———
        tf.profiler.experimental.start(log_dir)
        logger.info("Profiler started, running train_step…")
        disc_loss, gen_loss_total, gen_gan_loss, gen_l1_loss, gen_edge_loss, disc_real_acc, disc_fake_acc, gen_perceptual_loss = train_step(test_input, test_target, 0)
        # ——— Stop the Profiler ———
        tf.profiler.experimental.stop()
        logger.info("Profiler stopped, trace written to %s", log_dir)

        logger.info("Ran one train_step successfully.")
        # 1) Extract Python floats from the returned tensors
        disc_val      = float(disc_loss.numpy())
        gen_total_val = float(gen_loss_total.numpy())
        gen_gan_val   = float(gen_gan_loss.numpy())
        gen_perc_val  = float(gen_perceptual_loss.numpy())
        l1_val        = float(gen_l1_loss.numpy())
        edge_val      = float(gen_edge_loss.numpy())

        # 2) Log them with {:.4f}
        logger.info(f"  Test Discriminator Loss:              {disc_val:.4f}")
        logger.info(f"  Test Generator Total Loss:            {gen_total_val:.4f}")
        logger.info(f"  Test Generator GAN Loss:              {gen_gan_val:.4f}")
        logger.info(f"  Test Generator Perceptual Loss:       {gen_perc_val:.4f}")
        logger.info(f"  Test Generator L1 Loss (unweighted):  {l1_val:.4f}")
        logger.info(f"  Test Generator Edge Loss (unweighted):{edge_val:.4f}")

    logger.info("train_step function test completed successfully.")
except Exception as e:
    logger.error(f"Error testing train_step function: {e}", exc_info=True)
    # Depending on severity, you might want to stop execution if this test fails
    raise

# Neptune Initialization

In [None]:
hparams = {
    # Data
    'dataset_size': dataset_size, # Get this value after listing files
    'image_height': IMG_HEIGHT,
    'image_width': IMG_WIDTH,
    'input_channels': INPUT_CHANNELS,
    'target_channels': TARGET_CHANNELS,
    'batch_size': BATCH_SIZE,
    'normalization_range': NORMALIZATION_RANGE,
    'augmentations': AUGMENTATIONS,
    # Model
    'generator_architecture': GENERATOR_ARCHITECTURE,
    'discriminator_architecture': DISCRIMINATOR_ARCHITECTURE,
    'generator_output_activation': GENERATOR_OUTPUT_ACTIVATION,
    'normalization_layer': NORMALIZATION_LAYER,
    'generator_filters_initial': GENERATOR_FILTERS_INITIAL,
    'discriminator_filters_initial': DISCRIMINATOR_FILTERS_INITIAL,
    # Loss
    'adversarial_loss_type': ADVERSARIAL_LOSS_TYPE,
    'reconstruction_loss_type': RECONSTRUCTION_LOSS_TYPE,
    'lambda_l1': LAMBDA_L1,
    'lambda_edge': LAMBDA_EDGE,
    'lamda_p' : LAMBDA_P,
    'loss_model' : LOSS_MODEL,
    'loss_layers' : LOSS_LAYERS,
    'metric_layers' : METRIC_LAYER_NAMES,
    # Optimizer
    'optimizer_type': OPTIMIZER_TYPE,
    'learning_rate_generator': LEARNING_RATE, # Log base LR for G
    'learning_rate_discriminator': LEARNING_RATE * DISCRIMINATOR_LR_MULTIPLIER, # Log calculated D LR
    'discriminator_lr_multiplier': DISCRIMINATOR_LR_MULTIPLIER, # Log the multiplier
    'beta_1': BETA_1,
    'beta_2': BETA_2,
    'weight_decay': WEIGHT_DECAY,
    # Training
    'total_epochs': TOTAL_EPOCHS,
    'lr_schedule': LR_SCHEDULE,
    'description': DESCRIPTION,
    #Metric
    'layer_names': LOSS_LAYERS
}

In [None]:
# === NEPTUNE.AI SETUP ===
logger.info("Setting up Neptune.ai experiment tracking...")
try:
    if neptune:
        # Initialize a Neptune run
        run = neptune.init_run(
            project=NEPTUNE_PROJECT,
            api_token=NEPTUNE_API_TOKEN,
            name=RUN_NAME,
            description=DESCRIPTION
        )
        logger.info(f"Neptune run initialized: {run.get_url()}")

        # --- Log Hyperparameters ---
        run['parameters'] = hparams
        logger.info("Logged hyperparameters to Neptune.")

    else:
        logger.warning("Neptune client not installed or import failed. Skipping Neptune initialization.")
        run = None # Set run to None so later logging calls can check

except Exception as e:
    logger.error(f"Error initializing Neptune: {e}", exc_info=True)
    run = None # Ensure run is None if init fails
    # Decide if you want to stop execution or continue without Neptune
    # raise # Uncomment to stop if Neptune is critical

# Note: Logging metrics (losses) and images will happen periodically
#       INSIDE your main training loop using calls like:
#       if run: run['train/g_loss'].append(loss_value, step=global_step)
#       if run: run['images/samples'].append(neptune.types.File.as_image(img_array))

# Training Loop

In [None]:
initial_learning_rate = LEARNING_RATE # Initial LR
epochs = TOTAL_EPOCHS # Your total epochs (200)
decay_start_epoch = epochs // 2
decay_epochs = epochs - decay_start_epoch
def get_lr(epoch):
  """Calculates the learning rate for a given epoch based on linear decay."""
  if epoch < LR_EPOCH:
    return initial_learning_rate
  elif epoch < decay_start_epoch:
    return initial_learning_rate * LR_DECAY
  else:
    decay_factor = 1.0 - (epoch - decay_start_epoch) / decay_epochs
    return initial_learning_rate * LR_DECAY * max(0.0, decay_factor)

logger.info(f"Using MANUAL LR schedule: Fixed at {initial_learning_rate:.6f} for {decay_start_epoch} epochs, then linear decay for {decay_epochs} epochs.")

In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/logs/gradient_tape --reload_interval 5

In [None]:
print(f"CONFIRM POLICY BEFORE TRAINING: {tf.keras.mixed_precision.global_policy().name}")
logger.info(f"CONFIRM POLICY BEFORE TRAINING: {tf.keras.mixed_precision.global_policy().name}")


# === TRAINING LOOP ===
# This section orchestrates the main training process. It assumes the following
# components have been defined and initialized in previous cells:
#   - train_dataset: A tf.data.Dataset yielding (input_image, target_image) batches.
#   - generator: The compiled tf.keras.Model for the Generator (U-Net).
#   - discriminator: The compiled tf.keras.Model for the Discriminator (PatchGAN).
#   - generator_optimizer: A tf.keras.optimizers.Optimizer for the generator.
#   - discriminator_optimizer: A tf.keras.optimizers.Optimizer for the discriminator.
#   - generator_loss: A function calculating the generator's total loss.
#   - discriminator_loss: A function calculating the discriminator's loss.
#   - train_step: A @tf.function decorated function performing one G/D update.
#   - ckpt_manager: A tf.train.CheckpointManager for saving checkpoints.
#   - run: An active Neptune run object (or None if Neptune init failed).
#   - logger: A configured Python logger object.
#   - BATCH_SIZE, LAMBDA_L1, LEARNING_RATE: Hyperparameters.
#   - hparams: Dictionary of hyperparameters (used for logging and setting EPOCHS).
#
# The loop iterates through epochs, and within each epoch, iterates through batches
# of data from train_dataset. It calls the train_step function for each batch,
# accumulates losses, periodically logs progress to the console/file and Neptune,
# generates and logs visual samples using a fixed validation batch, and saves
# model checkpoints.

# ——— PROFILER CONFIG ———
# which epochs to profile, or toggle on/off via a flag file
PROFILE_EPOCHS    = {0, 5, 10}
PROFILE_FLAG_FILE = "/tmp/enable_profiling.flag"
profiling_active  = False
# ————————————————

# --- Training Configuration ---
# Get total epochs from hyperparameters logged earlier, with a default fallback
try:
    EPOCHS = hparams['total_epochs']
except NameError:
    logger.warning("hparams dictionary not found, using default EPOCHS=200")
    EPOCHS = 200
except KeyError:
    logger.warning("'total_epochs' not found in hparams, using default EPOCHS=200")
    EPOCHS = 200

logger.info(f"Starting training configuration for {EPOCHS} epochs...")
# Log key parameters for visibility at the start of training section
logger.info(f"  Batch Size: {BATCH_SIZE}")
logger.info(f"  Lambda L1: {LAMBDA_L1}")
logger.info(f"  Lambda edge: {LAMBDA_EDGE}")
logger.info(f"  Lambda perceptual: {LAMBDA_P}")
logger.info(f"  Initial Learning Rate: {LEARNING_RATE}")
logger.info(f"  LR Schedule: {hparams.get(LR_SCHEDULE, 'unknown')}")
logger.info(f"  Checkpoint Frequency: Every {CHECKPOINT_SAVE_FREQ} epochs")
logger.info(f"  Image Log Frequency: Every {IMAGE_LOG_FREQ} epochs")


# --- Prepare Fixed Validation Data for Visualization ---
# We take one batch from the training set before the loop starts.
# Using the same batch each time allows for consistent visual comparison
# of the generator's progress over epochs.
logger.info("Fetching fixed validation batch for visualization...")
fixed_val_input, fixed_val_target = None, None # Initialize to None
try:
    # Taking from train_dataset means it might be shuffled differently each run.
    # For perfect consistency, create a separate tf.data.Dataset from validation files.
    # Simple approach for now: take one batch from the configured training pipeline.
    # Use .unbatch().take().batch() to ensure we get exactly BATCH_SIZE examples if possible
    # Adjust BATCH_SIZE if it's larger than the number of available validation examples
    val_examples_to_take = min(BATCH_SIZE, 16) # Take up to 16 examples for validation vis
    val_dataset_vis = train_dataset.unbatch().take(val_examples_to_take).batch(val_examples_to_take)
    fixed_val_input, fixed_val_target = next(iter(val_dataset_vis))
    logger.info(f"Fixed validation batch shapes: Input {fixed_val_input.shape}, Target {fixed_val_target.shape}")
except Exception as e:
    logger.error(f"Could not get validation batch: {e}", exc_info=True)
    logger.warning("Proceeding without fixed validation data for image logging.")


# --- Learning Rate Schedule Function (Example: Fixed then Decay) ---
# This implements the schedule described in the original pix2pix paper.
# Note: Using tf.keras.optimizers.schedules is generally more robust and integrates
# directly with the optimizer, but this manual approach works for the baseline.
initial_learning_rate = LEARNING_RATE
decay_start_epoch = EPOCHS // 2
decay_epochs = EPOCHS - decay_start_epoch

logger.info(f"LR schedule: Fixed at {initial_learning_rate:.6f} for {decay_start_epoch} epochs, then linear decay for {decay_epochs} epochs.")

'''

# Inside your validation loop/callback
# Assume generated_drawings, target_drawings are tensors in range [-1, 1]
# Rescale to [0, 1] for SSIM function
generated_0_1 = (generated_drawings + 1.0) / 2.0
target_0_1 = (target_drawings + 1.0) / 2.0

current_ssim = tf.reduce_mean(tf.image.ssim(target_0_1, generated_0_1, max_val=1.0))
# Log current_ssim (e.g., using TensorBoard tf.summary.scalar)
print(f"Step {step}: Validation SSIM: {current_ssim:.4f}")
'''
# --- Metrics Accumulators ---
# Using tf.keras.metrics.Mean allows easy calculation of average loss per epoch.
logger.info("Initializing Keras metrics for epoch loss averaging...")
epoch_disc_loss_avg = tf.keras.metrics.Mean(name='epoch_disc_loss')
epoch_gen_loss_total_avg = tf.keras.metrics.Mean(name='epoch_gen_loss_total')
epoch_gen_gan_loss_avg = tf.keras.metrics.Mean(name='epoch_gen_gan_loss')
epoch_gen_l1_loss_avg = tf.keras.metrics.Mean(name='epoch_gen_l1_loss')
epoch_edge_loss_avg = tf.keras.metrics.Mean(name='epoch_edge_loss')
epoch_perceptual_loss_avg = tf.keras.metrics.Mean(name='epoch_perceptual_loss')
epoch_disc_real_acc_avg = tf.keras.metrics.Mean(name='epoch_disc_real_acc')
epoch_disc_fake_acc_avg = tf.keras.metrics.Mean(name='epoch_disc_fake_acc')
epoch_disc_overall_acc_avg = tf.keras.metrics.Mean(name='epoch_disc_overall_acc') # overall


# Validation Metrics
val_ssim_ms_metric = tf.keras.metrics.Mean(name='val_ssim_ms')
val_edge_l1_metric = tf.keras.metrics.Mean(name='val_edge_l1')
val_lap_var_metric = tf.keras.metrics.Mean(name='val_laplacian_variance')
val_psnr_metric = tf.keras.metrics.Mean(name='val_psnr')
val_perceptual_diff_metric = tf.keras.metrics.Mean(name='val_perceptual_diff')
val_l1_diff_metric = tf.keras.metrics.Mean(name='val_l1_diff')
val_edge_diff_metric = tf.keras.metrics.Mean(name='val_edge_diff')

# >>> PerceptualMetric Initialization <<<
try:
    val_perceptual_metric = PerceptualMetric(
        input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), # Ensure 3 channels
        name='val_perceptual' # Distinct name
        # Optionally change layers: layer_names=['block1_conv1', ...]
        # Optionally change distance: distance_metric='l1'
    )
    # val_perceptual_diff_metric = tf.keras.metrics.Mean(name='val_perceptual_diff') # Moved to metrics section
except Exception as metric_init_err:
     logger.error(f"Failed to initialize PerceptualMetric: {metric_init_err}", exc_info=True)
     logger.error("Cannot proceed without perceptual metric. Check VGG build or input shape.")
     # Depending on requirements, you might exit or set val_perceptual_metric to None and add checks later
     raise metric_init_err # Stop execution if essential


# --- Determine Starting Epoch (for resuming from checkpoint) ---
start_epoch = 0
# if ckpt_manager.latest_checkpoint:
#     # Attempt to parse epoch number from the checkpoint filename (e.g., 'ckpt-10')
#     try:
#         start_epoch = int(os.path.basename(ckpt_manager.latest_checkpoint).split('-')[-1])
#         logger.info(f"Checkpoint restored. Resuming training from epoch {start_epoch + 1}")
#     except ValueError:
#         logger.warning(f"Could not parse epoch number from checkpoint file {ckpt_manager.latest_checkpoint}. Starting epoch count from 0, although weights were restored.")
#     # Note: A more robust method is to save the epoch as a tf.Variable within the checkpoint object:
#     # ckpt = tf.train.Checkpoint(..., epoch=tf.Variable(0))
#     # And restore it: start_epoch = ckpt.epoch.numpy()

logger.critical(f"FINAL CHECK: Value of start_epoch JUST BEFORE loop starts: {start_epoch}")
print(f"DEBUG PRINT: Value of start_epoch JUST BEFORE loop: {start_epoch}") # Add a print too, just in case logging has issues

# -------------- MAIN TRAINING LOOP --------------
logger.info(f"=== Starting Training Loop from Epoch {start_epoch + 1} ===")
try: # Wrap the entire loop in try/finally to ensure Neptune run is stopped
    # Iterate through each epoch
    for epoch in range(start_epoch, EPOCHS):
        epoch_start_time = time.time() # Record time at the start of the epoch

        # — PROFILER CONTROL AT TOP OF EPOCH LOOP —
        should_profile = (epoch in PROFILE_EPOCHS) or os.path.exists(PROFILE_FLAG_FILE)
        if should_profile and not profiling_active:
            tf.profiler.experimental.start(log_dir)
            profiling_active = True
            logger.info(f"Profiler ON at epoch {epoch}")
        elif not should_profile and profiling_active:
            try:
                tf.profiler.experimental.stop()
            except tf.errors.UnavailableError:
                logger.warning("Profiler.stop() called but no profiler was running.")
            profiling_active = False
            logger.info(f"Profiler OFF after epoch {epoch-1}")
        # — END PROFILER CONTROL —

        # --- Learning Rate Update ---
        # Get the LR for the *current step* from the optimizer state
        # Note: optimizer.iterations is the number of steps taken by *this* optimizer
        current_step = generator_optimizer.iterations # Or discriminator_optimizer.iterations
        #current_lr = lr_schedule(current_step)
        # Calculate the base scheduled learning rate for the current epoch
        base_lr_for_epoch = get_lr(epoch)

        # Assign the learning rates to the optimizers
        gen_lr = base_lr_for_epoch # Generator uses the base scheduled rate
        disc_lr = base_lr_for_epoch * DISCRIMINATOR_LR_MULTIPLIER # Discriminator uses modified rate

        generator_optimizer.learning_rate.assign(gen_lr)
        discriminator_optimizer.learning_rate.assign(disc_lr)

        # Update logging to show both rates
        logger.info(f"Epoch {epoch + 1}/{EPOCHS} - Set LR -> G: {gen_lr:.6f}, D: {disc_lr:.6f}")
        if run: # Log potentially different LRs to Neptune if desired
             run['train/epoch_lr_generator'].append(gen_lr, step=epoch+1)
             run['train/epoch_lr_discriminator'].append(disc_lr, step=epoch+1)
        # === TENSORBOARD: log learning rates ===
        with tb_writer.as_default():
            tf.summary.scalar('LearningRate/Generator', gen_lr, step=epoch+1)
            tf.summary.scalar('LearningRate/Discriminator', disc_lr, step=epoch+1)

        # --- Reset Metrics ---
        # Clear the stored averages from the previous epoch before starting the new one
        epoch_disc_loss_avg.reset_state()
        epoch_gen_loss_total_avg.reset_state()
        epoch_gen_gan_loss_avg.reset_state()
        epoch_gen_l1_loss_avg.reset_state()
        epoch_edge_loss_avg.reset_state()
        epoch_perceptual_loss_avg.reset_state()
        val_ssim_ms_metric.reset_state()
        val_edge_l1_metric.reset_state()
        val_lap_var_metric.reset_state()
        val_psnr_metric.reset_state()
        val_edge_l1_metric.reset_state()
        val_l1_diff_metric.reset_state()
        val_edge_diff_metric.reset_state()
        epoch_disc_real_acc_avg.reset_state()
        epoch_disc_fake_acc_avg.reset_state()
        epoch_disc_overall_acc_avg.reset_state()
        if 'val_perceptual_metric' in locals() and val_perceptual_metric: # Check if init succeeded
            val_perceptual_metric.reset_state()
            val_perceptual_diff_metric.reset_state()

        # --- Batch Loop (instrumented for profiling) ---
        logger.info(f"Iterating through dataset batches for Epoch {epoch + 1}...")
        train_iter = iter(train_dataset)
        step = 0
        while True:
            try:
                step_start_time = time.time()

                if profiling_active:
                    with tf.profiler.experimental.Trace('train', step_num=step, _r=1):
                        input_image, target_image = next(train_iter)
                        (disc_loss,
                        gen_loss_total,
                        gen_gan_loss,
                        gen_l1_loss,
                        gen_edge_loss,
                        disc_real_acc,
                        disc_fake_acc,
                        gen_perceptual_loss) = train_step(
                            input_image,
                            target_image,
                            generator_optimizer.iterations
                        )
                else:
                    input_image, target_image = next(train_iter)
                    (disc_loss,
                    gen_loss_total,
                    gen_gan_loss,
                    gen_l1_loss,
                    gen_edge_loss,
                    disc_real_acc,
                    disc_fake_acc,
                    gen_perceptual_loss) = train_step(
                        input_image,
                        target_image,
                        generator_optimizer.iterations
                    )

                # --- Update Epoch Metrics ---
                epoch_disc_loss_avg.update_state(disc_loss)
                epoch_gen_loss_total_avg.update_state(gen_loss_total)
                epoch_gen_gan_loss_avg.update_state(gen_gan_loss)
                epoch_gen_l1_loss_avg.update_state(gen_l1_loss)
                epoch_edge_loss_avg.update_state(gen_edge_loss)
                epoch_disc_real_acc_avg.update_state(disc_real_acc)
                epoch_disc_fake_acc_avg.update_state(disc_fake_acc)
                epoch_disc_overall_acc_avg.update_state((disc_real_acc + disc_fake_acc) / 2.0)
                epoch_perceptual_loss_avg.update_state(gen_perceptual_loss)

                # --- Periodic Console Logging ---
                if (step + 1) % CONSOLE_LOG_FREQ == 0:
                    step_time = time.time() - step_start_time
                    logger.info(
                        f"  Epoch {epoch+1}, Step {step+1}: "
                        f"D Loss={epoch_disc_loss_avg.result():.4f}, "
                        f"G Total={epoch_gen_loss_total_avg.result():.4f} "
                        f"(G‑GAN={epoch_gen_gan_loss_avg.result():.4f}, "
                        f"G‑L1={epoch_gen_l1_loss_avg.result():.4f}), "
                        f"D Acc Real={epoch_disc_real_acc_avg.result():.3f}, "
                        f"D Acc Fake={epoch_disc_fake_acc_avg.result():.3f}, "
                        f"D Acc Overall={epoch_disc_overall_acc_avg.result():.3f} "
                        f"D Perceptual={epoch_perceptual_loss_avg.result():.4f} "
                        f"Time/Step={step_time:.3f}s"
                    )

                step += 1
            except StopIteration:
                break
        # --- End of Batch Loop ---
            # --- Optional: Step-wise Neptune Logging ---
            # Logging every step can create a lot of data points in Neptune.
            # Usually logging epoch averages is sufficient. Uncomment if needed.
            # global_step = epoch * estimated_steps_per_epoch + step
            # if run:
            #     run['train/step_disc_loss'].append(disc_loss.numpy(), step=global_step)
            #     run['train/step_gen_loss_total'].append(gen_loss_total.numpy(), step=global_step)


        # --- End of Epoch Actions ---
        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time

        # --- Log Epoch Summary (Console/File) ---
        # Log the final average losses for the completed epoch
        logger.info(f"--- Epoch {epoch + 1} Summary ---")
        logger.info(f"  Avg Discriminator Loss: {epoch_disc_loss_avg.result():.4f}")
        logger.info(f"  Avg Generator Total Loss: {epoch_gen_loss_total_avg.result():.4f}")
        logger.info(f"  Avg Generator GAN Loss: {epoch_gen_gan_loss_avg.result():.4f}")
        logger.info(f"  Avg Generator L1 Loss (Unweighted): {epoch_gen_l1_loss_avg.result():.4f}")
        logger.info(f"  Avg Generator Edge Loss (Unweighted): {epoch_edge_loss_avg.result():.4f}")
        logger.info(f"  Avg Generator Perceptual Loss (Unweighted): {epoch_perceptual_loss_avg.result():.4f}")
        logger.info(f"  Avg Discriminator Real Acc: {epoch_disc_real_acc_avg.result():.4f}")
        logger.info(f"  Avg Discriminator Fake Acc: {epoch_disc_fake_acc_avg.result():.4f}")
        logger.info(f"  Avg Discriminator Overall Acc: {epoch_disc_overall_acc_avg.result():.4f}")
        logger.info(f"  Epoch Duration: {epoch_duration:.2f} seconds")

        # --- Log Epoch Metrics to Neptune ---
        if run: # Check if Neptune run is active
            try:
                # Log the final average for the epoch. Use step=(epoch + 1) for Neptune plots.
                run['train/epoch_disc_loss'].append(epoch_disc_loss_avg.result().numpy(), step=epoch+1)
                run['train/epoch_gen_loss_total'].append(epoch_gen_loss_total_avg.result().numpy(), step=epoch+1)
                run['train/epoch_gen_gan_loss'].append(epoch_gen_gan_loss_avg.result().numpy(), step=epoch+1)
                # Log the average unweighted L1 loss
                run['train/epoch_gen_l1_loss'].append(epoch_gen_l1_loss_avg.result().numpy(), step=epoch+1)
                run['train/epoch_edge_loss'].append(epoch_edge_loss_avg.result().numpy(), step=epoch+1)
                run['train/epoch_disc_real_acc'].append(epoch_disc_real_acc_avg.result().numpy(), step=epoch+1)
                run['train/epoch_disc_fake_acc'].append(epoch_disc_fake_acc_avg.result().numpy(), step=epoch+1)
                run['train/epoch_disc_overall_acc'].append(epoch_disc_overall_acc_avg.result().numpy(), step=epoch+1) # Optional
                run['train/epoch_perceptual_loss'].append(epoch_perceptual_loss_avg.result().numpy(), step=epoch+1)

                # Log timing info
                run['train/epoch_duration_sec'].append(epoch_duration, step=epoch+1)
                logger.info("Logged epoch metrics to Neptune.")
                # Log averaged val metrics for this epoch
            except Exception as neptune_err:
                # Log error but don't stop training if Neptune fails
                logger.error(f"Neptune logging failed for epoch {epoch+1} metrics: {neptune_err}", exc_info=False)
             # === TENSORBOARD: log all epoch‐level metrics ===
            with tb_writer.as_default():
                # losses
                tf.summary.scalar('Loss/Discriminator',        epoch_disc_loss_avg.result(),        step=epoch+1)
                tf.summary.scalar('Loss/Generator_Total',      epoch_gen_loss_total_avg.result(),   step=epoch+1)
                tf.summary.scalar('Loss/Generator_GAN',        epoch_gen_gan_loss_avg.result(),     step=epoch+1)
                tf.summary.scalar('Loss/Generator_L1',         epoch_gen_l1_loss_avg.result(),      step=epoch+1)
                tf.summary.scalar('Loss/Generator_Edge',       epoch_edge_loss_avg.result(),        step=epoch+1)
                tf.summary.scalar('Loss/Generator_Perceptual', epoch_perceptual_loss_avg.result(),  step=epoch+1)

                # discriminator accuracies
                tf.summary.scalar('Accuracy/Disc_Real',        epoch_disc_real_acc_avg.result(),    step=epoch+1)
                tf.summary.scalar('Accuracy/Disc_Fake',        epoch_disc_fake_acc_avg.result(),    step=epoch+1)
                tf.summary.scalar('Accuracy/Disc_Overall',     epoch_disc_overall_acc_avg.result(), step=epoch+1)
            # === end TensorBoard logging ===

        # --- Generate and Log Validation Images Periodically ---
        if ((epoch + 1) % IMAGE_LOG_FREQ == 0 and fixed_val_input is not None) or epoch == 0:
            logger.info(f"--- Starting Validation Epoch {epoch+1} ---")
            logger.info(f"Generating and logging validation images for epoch {epoch+1}...")
            # --- 1. Reset Validation Metrics ---
            # Assuming val_ssim_metric, val_edge_l1_metric, val_lap_var_metric exist
            val_ssim_ms_metric.reset_state()
            val_edge_l1_metric.reset_state()
            val_lap_var_metric.reset_state()
            val_psnr_metric.reset_state()
            val_l1_diff_metric.reset_state()
            val_edge_diff_metric.reset_state()
            if 'val_perceptual_metric' in locals() and val_perceptual_metric:
                 val_perceptual_metric.reset_state()
            val_perceptual_diff_metric.reset_state()

            # --- 2. Iterate over the validation dataset ---
            # Assumes 'val_dataset' exists
            logger.info("Calculating metrics over validation set...")
            for val_step, (val_input, val_target) in enumerate(val_dataset):
                try:
                    prediction_batch = generator(val_input, training=False) # Prediction for this batch
                    prediction_batch = tf.cast(prediction_batch, tf.float32)
                    val_target       = tf.cast(val_target,       tf.float32)
                    # --- De-normalize for metrics expecting [0, 1] ---
                    target_dn = (val_target + 1) / 2.0 # Assumes input range is [-1, 1]
                    pred_dn_batch = (prediction_batch + 1) / 2.0 # Assumes output range is [-1, 1]
                    pred_dn_batch = tf.clip_by_value(pred_dn_batch, 0.0, 1.0) # Clip just in case values go slightly out of bounds
                    target_0_1 = tf.clip_by_value(target_dn, 0.0, 1.0)

                    # --- Calculate and Update Standard Validation Metrics ---
                    batch_ssim_ms = tf.reduce_mean(tf.image.ssim_multiscale(target_dn, pred_dn_batch, max_val=1.0))
                    batch_psnr = tf.reduce_mean(tf.image.psnr(target_0_1, pred_dn_batch, max_val=1.0))
                    batch_edge_l1 = calculate_edge_loss(val_target, prediction_batch)
                    # --- Corrected LapVar Calc ---
                    pred_gray_dn_batch = tf.image.rgb_to_grayscale(pred_dn_batch)
                    laplacian_kernel_vals = [[0, 1, 0], [1, -4, 1], [0, 1, 0]]
                    laplacian_kernel = tf.constant(laplacian_kernel_vals, dtype=tf.float32)
                    laplacian_kernel = tf.reshape(laplacian_kernel, [3, 3, 1, 1])
                    laplacian_images = tf.nn.depthwise_conv2d(
                        pred_gray_dn_batch, laplacian_kernel, strides=[1, 1, 1, 1], padding='VALID'
                    )
                    variance_per_image = tf.math.reduce_variance(laplacian_images, axis=[1, 2])
                    batch_lap_var = tf.reduce_mean(variance_per_image)
                    # --- End LapVar Calc ---
                    # --- Update Accumulators ---
                    val_ssim_ms_metric.update_state(batch_ssim_ms)
                    val_edge_l1_metric.update_state(batch_edge_l1)
                    val_lap_var_metric.update_state(batch_lap_var)
                    val_psnr_metric.update_state(batch_psnr)


                    if callable(calculate_edge_loss): # Check if function exists
                        batch_edge_l1 = calculate_edge_loss(val_target, prediction_batch) # Assumes edge loss uses [-1,1] range? Adjust if needed.
                        val_edge_l1_metric.update_state(batch_edge_l1)
                    # >>> UPDATE Perceptual Metric <<<
                    if 'val_perceptual_metric' in locals() and val_perceptual_metric:
                        # Pass the [0, 1] range images to the metric
                        val_perceptual_metric.update_state(target_dn, pred_dn_batch)
                        # Calculate and update val_perceptual_diff_metric inside the loop
                        # Get the results as tensors before subtraction
                        val_perceptual_result = val_perceptual_metric.result()
                        epoch_perceptual_loss_result = epoch_perceptual_loss_avg.result()
                        val_perceptual_diff = tf.reduce_mean(tf.abs(val_perceptual_result - epoch_perceptual_loss_result))
                        val_perceptual_diff_metric.update_state(val_perceptual_diff)
                except Exception as val_batch_err:
                    logger.error(f"Error calculating metrics on validation batch {val_step}: {val_batch_err}", exc_info=False)
            # --- Get and Log Final Averaged Metrics ---
            try:
                final_avg_ssim_ms = val_ssim_ms_metric.result()
                final_avg_edge_l1 = val_edge_l1_metric.result()
                final_avg_lap_var = val_lap_var_metric.result()
                final_avg_psnr = val_psnr_metric.result()
                final_avg_perceptual = tf.constant(0.0) # Default value if metric failed/disabled
                if 'val_perceptual_metric' in locals() and val_perceptual_metric:
                    final_avg_perceptual = val_perceptual_metric.result()
                    final_avg_perceptual_diff = val_perceptual_diff_metric.result()
                logger.info("Finished calculating validation metrics.")

                # --- 4. Log Validation Metrics to Console/File ---
                logger.info(f"  Epoch {epoch+1} Validation Results: "
                            f"SSIM_ms={final_avg_ssim_ms:.4f}, "
                            f"EdgeL1={final_avg_edge_l1:.4f}, "
                            f"LapVar={final_avg_lap_var:.4f}, "
                            f"PSNR={final_avg_psnr:.4f}, "
                            # >>> Add Perceptual to Log String <<<
                            f"Perceptual={final_avg_perceptual:.4f}," # Lower is better
                            f"EdgeL1_diff={val_edge_diff_metric.result():.4f}, "
                            f"L1_diff={val_l1_diff_metric.result():.4f}, "
                            f"Perceptual_diff={val_perceptual_diff_metric.result():.4f}"
                            )


                # --- 5. Log Validation Metrics to Neptune ---
                # !!! THIS MUST ALSO BE INDENTED INSIDE THE if BLOCK !!!
                if run and neptune:
                    try:
                        # Use the final average variables defined above
                        run[f'val/epoch_ssim_ms'].append(final_avg_ssim_ms.numpy(), step=epoch+1)
                        run[f'val/epoch_edge_l1'].append(final_avg_edge_l1.numpy(), step=epoch+1)
                        run[f'val/epoch_laplacian_variance'].append(final_avg_lap_var.numpy(), step=epoch+1)
                        run[f'val/epoch_psnr'].append(final_avg_psnr.numpy(), step=epoch+1)
                        run[f'val/epoch_edge_l1_diff'].append(val_edge_diff_metric.result().numpy(), step=epoch+1)
                        run[f'val/epoch_l1_diff'].append(val_l1_diff_metric.result().numpy(), step=epoch+1)
                        run[f'val/epoch_perceptual_diff'].append(val_perceptual_diff_metric.result().numpy(), step=epoch+1)
                        # >>> Log Perceptual Metric <<<
                        # Check if metric was initialized and calculate result before logging
                        # Ensure 'val_perceptual_metric' exists and is not
                        if 'val_perceptual_metric' in locals() and val_perceptual_metric:
                             run[f'val/epoch_perceptual'].append(final_avg_perceptual.numpy(), step=epoch+1)
                        logger.info(f"Logged validation metrics to Neptune for epoch {epoch+1}.")
                    except Exception as neptune_val_metric_err:
                        logger.error(f"Neptune VAL metric logging failed: {neptune_val_metric_err}", exc_info=False)
                else:
                    logger.info("Neptune run not active, skipping validation metric logging.")
            except Exception as metric_err:
                logger.error(f"Error finalizing/logging validation metrics: {metric_err}")

            # Prepare images for display/logging (de-normalize, maybe concatenate)
            prediction_fixed = generator(fixed_val_input, training=False) # Generate prediction for fixed batch
            num_display = min(fixed_val_input.shape[0], 6) # Show up to 4 pairs
            display_list = []
            for i in range(num_display):
                # De-normalize pixel values from [-1, 1] to [0, 1] range for visualization
                input_dn  = tf.cast((fixed_val_input[i]  + 1) / 2.0, tf.float32)
                target_dn = tf.cast((fixed_val_target[i] + 1) / 2.0, tf.float32)
                pred_dn   = tf.cast((prediction_fixed[i]  + 1) / 2.0, tf.float32)

                # Concatenate images horizontally: Input Photo | Target Diagram | Generated Diagram
                concatenated_img = tf.concat([input_dn, target_dn, pred_dn], axis=1)
                # Clip values to [0, 1] to prevent potential minor floating point issues during display
                concatenated_img = tf.clip_by_value(concatenated_img, 0.0, 1.0)
                # Convert tensor to numpy array for Neptune logging / matplotlib display
                display_list.append(concatenated_img.numpy())

            # Stack the horizontal images vertically if displaying multiple
            if display_list:
                  combined_display_image = np.vstack(display_list)
                  # === TENSORBOARD: log validation images ===
                  # stack your list into a batch: shape [N, H, W, 3]
                  val_batch = np.stack(display_list, axis=0)
                  with tb_writer.as_default():
                      tf.summary.image(
                          'Val/Input|Target|Pred',
                          val_batch,
                          step=epoch+1,
                          max_outputs=val_batch.shape[0]
                      )
                  # === END ===
                  # Log the combined image to Neptune
                  if run and neptune: # Check Neptune run and import again just in case
                      try:
                          logger.info(f"Attempting neptune.types.File.as_image for epoch {epoch+1}")
                          logger.info(f"Image array shape: {combined_display_image.shape}, dtype: {combined_display_image.dtype}")

                          # Use neptune.types.File.as_image() to log numpy array as image
                          neptune_img = neptune.types.File.as_image(combined_display_image)
                          # Log under 'images/validation_samples', associate with epoch number
                          logger.info(f"Attempting run['images/validation_samples'].append for epoch {epoch+1}")
                          run[f'images/validation_samples'].append(neptune_img, step=epoch+1)
                          logger.info(f"Logged validation image sample to Neptune for epoch {epoch+1}.")
                      except Exception as neptune_img_err:
                          logger.error(f"Neptune image logging failed epoch {epoch+1}: {neptune_img_err}", exc_info=True)

                  # Optional: Display the same image in Colab output
                  # plt.figure(figsize=(9, 3 * num_display)) # Adjust size based on num_display
                  # plt.imshow(combined_display_image)
                  # plt.axis('off')
                  # plt.title(f"Validation Samples Epoch {epoch+1}\nInput | Target | Prediction")
                  # plt.show()


        # --- Save Checkpoint Periodically ---
        if (epoch + 1) % CHECKPOINT_SAVE_FREQ == 0:
            try:
                # Use the CheckpointManager to save the current state
                save_path = ckpt_manager.save()
                # If saving epoch: ckpt.epoch.assign(epoch + 1) # Update epoch variable *before* saving if included in ckpt
                logger.info(f"Saved checkpoint for epoch {epoch + 1} to {save_path}")
            except Exception as ckpt_err:
                logger.error(f"Failed to save checkpoint for epoch {epoch+1}: {ckpt_err}", exc_info=True)
                # --- Increment Epoch Counter (Using tf.Variable Method) ---
        # Increment AFTER all processing and saving for the current epoch `epoch` is done.
        # The variable now holds the count of completed epochs (0-based index + 1).
        # Check if 'epoch_counter' tf.Variable exists (if implementing this method)
        if 'epoch_counter' in locals() and isinstance(epoch_counter, tf.Variable):
           epoch_counter.assign_add(1)
           logger.debug(f"Incremented epoch counter tf.Variable to {epoch_counter.numpy()}")
        # --- End of Current Epoch Iteration ---


    logger.info(f"=== Training Loop Completed after {EPOCHS} epochs ===")


# Ensure Neptune run is stopped cleanly even if errors occur
finally:
    # === STOP PROFILER if still active ===
    if profiling_active:
      try:
          tf.profiler.experimental.stop()
          logger.info("Profiler 🔍 OFF (final stop)")
      except tf.errors.UnavailableError:
          logger.warning("Final Profiler.stop() called but no profiler was running.")

    # === STOP NEPTUNE RUN ===
    logger.info("Attempting to stop Neptune run if active...")
    # Check if 'run' variable exists and is a Neptune Run object before stopping
    if 'run' in locals() and run and isinstance(run, neptune.Run):
        run.stop()
        logger.info("Neptune run stopped.")
    else:
        logger.info("No active Neptune run found to stop.")

# Stop Neptune Run

In [None]:
# === STOP NEPTUNE RUN ===
# Place this in the very last cell of notebook or within a finally block
logger.info("Attempting to stop Neptune run if active...")
if run and isinstance(run, neptune.Run): # Check if 'run' exists and is a Neptune Run object
    run.stop()
    logger.info("Neptune run stopped.")
else:
    logger.info("No active Neptune run to stop.")

In [None]:
!ls -la "{checkpoint_dir}"

In [None]:
from google.colab import runtime
runtime.unassign()