In [None]:
# ==============================================================================
#           DUAL-IMAGE SUPER-RESOLUTION MODEL TRAINING NOTEBOOK
# ==============================================================================

# Objective: Train a state-of-the-art super-resolution model and save it to
#            Google Drive.

# Instructions:
# 1. Place your 'hr_image-*.tfrecord' file(s) in a folder in your Google Drive.
# 2. Update the 'TFRECORD_FOLDER_PATH' variable in Step 2 to point to that folder.
# 3. Go to 'Runtime' -> 'Change runtime type' and select 'T4 GPU'.
# 4. Run each cell sequentially.
# ==============================================================================

# @title Step 1: Setup Environment and Mount Google Drive
import os
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Add, PReLU, Concatenate, UpSampling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint
import glob
from google.colab import drive

# Disable XLA JIT compilation to avoid ResizeBicubic issues
tf.config.optimizer.set_jit(False)

# Check for GPU
print("TensorFlow version:", tf.__version__)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
if not tf.config.list_physical_devices('GPU'):
    print("WARNING: No GPU detected. Training will be very slow. Please enable a GPU runtime.")

# Mount Google Drive
drive.mount('/content/drive')


# @title Step 2: Configure Paths and Parameters
# --- IMPORTANT: UPDATE THIS PATH TO POINT TO THE FOLDER ---
TFRECORD_FOLDER_PATH = '/content/drive/MyDrive/SuperResolutionProject images/'  # Folder containing TFRecord files
# ------------------------------------------------------------

MODEL_SAVE_DIR = '/content/drive/MyDrive/SuperResolutionProduct_01/models/'
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
print(f"Model will be saved to: {MODEL_SAVE_DIR}")

# Training Parameters
SCALE_FACTOR = 4
INPUT_HEIGHT = 64
INPUT_WIDTH = 64
CHANNELS = 3
EPOCHS = 75
BATCH_SIZE = 16
LEARNING_RATE = 0.0001


# @title Step 3: Define Data Loading Functions
def parse_tfrecord_fn(example):
    feature_description = {
        'vis-red': tf.io.FixedLenFeature([], tf.string),
        'vis-green': tf.io.FixedLenFeature([], tf.string),
        'vis-blue': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example, feature_description)
    r = tf.io.decode_raw(example['vis-red'], tf.uint8)
    g = tf.io.decode_raw(example['vis-green'], tf.uint8)
    b = tf.io.decode_raw(example['vis-blue'], tf.uint8)
    image_shape = tf.cast([256, 256, 3], tf.int32)
    image = tf.reshape(tf.stack([r, g, b], axis=-1), image_shape)
    return image


def create_lr_pair(hr_image, scale=SCALE_FACTOR):
    hr_image = tf.cast(hr_image, tf.float32) / 255.0
    hr_shape = tf.shape(hr_image)
    lr_height = hr_shape[0] // scale
    lr_width = hr_shape[1] // scale

    lr1 = tf.image.resize(hr_image, [lr_height, lr_width], method='bicubic')
    hr_shifted = hr_image[1:, 1:, :]  # Sub-pixel shift simulation
    lr2 = tf.image.resize(hr_shifted, [lr_height, lr_width], method='bicubic')

    lr1 = lr1[:lr_height, :lr_width, :]
    lr2 = lr2[:lr_height, :lr_width, :]

    hr_cropped = hr_image[:lr_height * scale, :lr_width * scale, :]
    return (lr1, lr2), hr_cropped


def create_training_dataset(tfrecord_path, batch_size):
    tfrecord_files = glob.glob(os.path.join(tfrecord_path, 'hr_image-*.tfrecord'))
    print(f"Found TFRecord files: {tfrecord_files}")  # Debugging print

    if not tfrecord_files:
        raise FileNotFoundError(
            f"FATAL: No TFRecord files found at '{tfrecord_path}'. Please check the path."
        )

    raw_dataset = tf.data.TFRecordDataset(tfrecord_files)
    hr_dataset = raw_dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
    train_dataset = hr_dataset.map(create_lr_pair, num_parallel_calls=tf.data.AUTOTUNE)

    return train_dataset.shuffle(100).batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)


# @title Step 4: Define the Advanced Generator Model Architecture (FIXED)
def res_block(x_in, num_filters):
    x = Conv2D(num_filters, (3, 3), padding='same')(x_in)
    x = PReLU(shared_axes=[1, 2])(x)
    x = Conv2D(num_filters, (3, 3), padding='same')(x)
    x = Add()([x_in, x])
    return x

In [None]:
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import Input, Conv2D, Concatenate, UpSampling2D
from tensorflow.keras.models import Model

def build_generator():
    # Dual inputs
    input1 = Input(shape=(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS), name='lr1_input')
    input2 = Input(shape=(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS), name='lr2_input')
    
    # Concatenate inputs along channels
    x = Concatenate()([input1, input2])  # Shape: (64, 64, 6)
    
    x = Conv2D(128, kernel_size=3, padding='same', activation='relu')(x)
    
    # SOLUTION 1: Use UpSampling2D instead of Resizing layer
    # This avoids the XLA compilation issue
    x = UpSampling2D(size=(SCALE_FACTOR, SCALE_FACTOR), interpolation='bilinear')(x)
    
    x = Conv2D(64, kernel_size=3, padding='same', activation='relu')(x)
    outputs = Conv2D(CHANNELS, kernel_size=3, padding='same', activation='sigmoid')(x)  # Changed to sigmoid for [0,1] range
    
    generator = Model(inputs=[input1, input2], outputs=outputs)
    return generator


# Alternative generator using transpose convolution (uncomment to use)
def build_generator_alternative():
    """Alternative generator using transpose convolutions for upsampling"""
    input1 = Input(shape=(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS), name='lr1_input')
    input2 = Input(shape=(INPUT_HEIGHT, INPUT_WIDTH, CHANNELS), name='lr2_input')
    
    # Concatenate inputs along channels
    x = Concatenate()([input1, input2])  # Shape: (64, 64, 6)
    
    x = Conv2D(128, kernel_size=3, padding='same', activation='relu')(x)
    
    # Use transpose convolution for upsampling
    x = Conv2D(256, kernel_size=3, padding='same', activation='relu')(x)
    x = tf.keras.layers.Conv2DTranspose(128, kernel_size=4, strides=2, padding='same', activation='relu')(x)  # 64x64 -> 128x128
    x = tf.keras.layers.Conv2DTranspose(64, kernel_size=4, strides=2, padding='same', activation='relu')(x)   # 128x128 -> 256x256
    
    outputs = Conv2D(CHANNELS, kernel_size=3, padding='same', activation='sigmoid')(x)
    
    generator = Model(inputs=[input1, input2], outputs=outputs)
    return generator


# @title Step 5: Train the Model
print("--- Starting Model Training ---")

print("1. Creating training dataset...")
try:
    train_dataset = create_training_dataset(TFRECORD_FOLDER_PATH, BATCH_SIZE)
    print("   Dataset created successfully.")
except FileNotFoundError as e:
    print(e)
    raise SystemExit()

print("2. Building generator model...")
generator = build_generator()  # Use build_generator_alternative() for the alternative approach
generator.compile(optimizer=Adam(learning_rate=LEARNING_RATE), loss='mse')  # Changed to MSE for better convergence
print("   Model built and compiled.")
generator.summary()

model_filepath = os.path.join(MODEL_SAVE_DIR, 'sr_generator_best.h5')
checkpoint = ModelCheckpoint(
    filepath=model_filepath,
    monitor='loss',
    verbose=1,
    save_best_only=True,
    mode='min'
)

print("\n3. Starting training process...")
history = generator.fit(
    train_dataset,
    epochs=EPOCHS,
    callbacks=[checkpoint],
    verbose=1
)

print("\n--- Training Finished! ---")
print(f"✅ Best model saved to: {model_filepath}")
print("You can now download this file from your Google Drive and use it in the Streamlit application.")

# @title Step 6: Save Final Model and Training History
import pickle

# Save the final model
final_model_path = os.path.join(MODEL_SAVE_DIR, 'sr_generator_final.h5')
generator.save(final_model_path)

# Save training history
history_path = os.path.join(MODEL_SAVE_DIR, 'training_history.pkl')
with open(history_path, 'wb') as f:
    pickle.dump(history.history, f)

print(f"✅ Final model saved to: {final_model_path}")
print(f"✅ Training history saved to: {history_path}")