In [None]:
!pip install tensorflow huggingface_hub

Defaulting to user installation because normal site-packages is not writeable


### Download original dataset

In [None]:
import os
import kagglehub

current_dir = os.getcwd()
os.environ['KAGGLEHUB_CACHE'] = current_dir

path = kagglehub.dataset_download("sabari50312/fundus-pytorch")

print("Dataset downloaded to:", path)

### Download cycleGAN model

In [None]:
import os
from huggingface_hub import from_pretrained_keras

os.environ["HF_HOME"] = os.getcwd()

model = from_pretrained_keras("keras-io/CycleGAN")
model.save(os.path.join(os.getcwd(), "CycleGAN_model"))

print("Model loaded and saved successfully!")

### Apply model onto dataset

In [1]:
import os
import tensorflow as tf
import numpy as np
from PIL import Image
from tqdm import tqdm

# Load the CycleGAN model from the saved folder
model_path = os.path.join(os.getcwd(), 'cycleGAN_model')
model = tf.saved_model.load(model_path)

# Function to preprocess the image
def preprocess_image(img_path, target_size=(256, 256)):
    img = Image.open(img_path).convert('RGB')
    img = img.resize(target_size)
    img = np.array(img) / 127.5 - 1  # Normalize to [-1, 1]
    img = np.expand_dims(img, axis=0)  # Add batch dimension
    return img.astype(np.float32)  # Ensure it's a float32 array

# Function to apply the CycleGAN model
def apply_cyclegan(model, img):
    # Get the generator function from the model's signatures
    generator = model.signatures['serving_default']
    
    # Apply the CycleGAN to the image
    output_img = generator(tf.convert_to_tensor(img))['activation_14']
    
    return output_img.numpy()

def save_image(output_img, output_path):
    # Post-process and save the generated image
    output_img = (output_img + 1) * 127.5  # Convert back to [0, 255]
    output_img = np.squeeze(output_img, axis=0).astype(np.uint8)
    output_image = Image.fromarray(output_img)
    output_image.save(output_path)

In [None]:
input_dir = 'datasets/sabari50312/fundus-pytorch/versions/1/train'
output_dir = 'train_degraded_images'
os.makedirs(output_dir, exist_ok=True)

for class_dir in ['0', '1']:
    class_path = os.path.join(input_dir, class_dir)
    img_files = os.listdir(class_path)
    
    # Wrap the loop with tqdm
    for img_file in tqdm(img_files, desc=f"Processing class {class_dir}", unit="image"):
        img_path = os.path.join(class_path, img_file)
        
        # Preprocess the image
        img = preprocess_image(img_path)
        
        # Apply CycleGAN model to generate the degraded image
        output_img = apply_cyclegan(model, img)
        
        # Save the output image
        output_img_path = os.path.join(output_dir, f'{class_dir}_{img_file}')
        save_image(output_img, output_img_path)

print("All images processed and saved successfully!")

input_dir = 'datasets/sabari50312/fundus-pytorch/versions/1/val'
output_dir = 'val_degraded_images'
os.makedirs(output_dir, exist_ok=True)

for class_dir in ['0', '1']:
    class_path = os.path.join(input_dir, class_dir)
    img_files = os.listdir(class_path)
    
    # Wrap the loop with tqdm
    for img_file in tqdm(img_files, desc=f"Processing class {class_dir}", unit="image"):
        img_path = os.path.join(class_path, img_file)
        
        # Preprocess the image
        img = preprocess_image(img_path)
        
        # Apply CycleGAN model to generate the degraded image
        output_img = apply_cyclegan(model, img)
        
        # Save the output image
        output_img_path = os.path.join(output_dir, f'{class_dir}_{img_file}')
        save_image(output_img, output_img_path)


input_dir = 'datasets/sabari50312/fundus-pytorch/versions/1/test'
output_dir = 'test_degraded_images'
os.makedirs(output_dir, exist_ok=True)

for class_dir in ['0', '1']:
    class_path = os.path.join(input_dir, class_dir)
    img_files = os.listdir(class_path)
    
    # Wrap the loop with tqdm
    for img_file in tqdm(img_files, desc=f"Processing class {class_dir}", unit="image"):
        img_path = os.path.join(class_path, img_file)
        
        # Preprocess the image
        img = preprocess_image(img_path)
        
        # Apply CycleGAN model to generate the degraded image
        output_img = apply_cyclegan(model, img)
        
        # Save the output image
        output_img_path = os.path.join(output_dir, f'{class_dir}_{img_file}')
        save_image(output_img, output_img_path)

print("All images processed and saved successfully!")

Processing class 0:   1%|          | 29/3539 [00:51<1:45:52,  1.81s/image]

In [None]:
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# Function to display before and after images side by side
def display_before_after(before_img, after_img):
    # Convert from float [-1, 1] to [0, 255] if needed
    if before_img.dtype == np.float64 or before_img.max() <= 1.0:
        before_img = (before_img + 1) * 127.5
    if after_img.dtype == np.float64 or after_img.max() <= 1.0:
        after_img = (after_img + 1) * 127.5

    # Squeeze only if needed (e.g., TensorFlow outputs a batch dimension)
    if len(before_img.shape) == 4 and before_img.shape[0] == 1:
        before_img = np.squeeze(before_img, axis=0)
    if len(after_img.shape) == 4 and after_img.shape[0] == 1:
        after_img = np.squeeze(after_img, axis=0)

    before_img = before_img.astype(np.uint8)
    after_img = after_img.astype(np.uint8)

    plt.figure(figsize=(14, 7))
    plt.subplot(1, 2, 1)
    plt.imshow(before_img)
    plt.title('Before')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(after_img)
    plt.title('After')
    plt.axis('off')

    plt.show()

# Function to search for a random match in the original and output directories
def display_random_match(input_dir, output_dir):
    input_images = [f for f in os.listdir(input_dir) if f.endswith('.jpg') or f.endswith('.png')]
    output_images = [f for f in os.listdir(output_dir) if f.endswith('.jpg') or f.endswith('.png')]

    # Create a mapping from stripped output filenames to their full filenames
    output_map = {f[2:]: f for f in output_images if len(f) > 2 and f[1] == '_'}

    # Filter to only input images that have a corresponding output
    matching_input_images = [img for img in input_images if img in output_map]

    if not matching_input_images:
        print("No matching files found between input and output directories.")
        return

    random_input_image = random.choice(matching_input_images)
    matched_output_image = output_map[random_input_image]

    input_path = os.path.join(input_dir, random_input_image)
    output_path = os.path.join(output_dir, matched_output_image)

    before_img = np.array(Image.open(input_path))
    after_img = np.array(Image.open(output_path))

    display_before_after(before_img, after_img)

# Paths to the original images and generated output images
input_dir = 'datasets/sabari50312/fundus-pytorch/versions/1/train/0'  # Change as needed
output_dir = 'output_images'  # Change as needed

# Call the function to display a random match
display_random_match(input_dir, output_dir)


### Fine-tune LED

In [None]:
import os
import sys

# Add LED-main to sys.path
sys.path.append('./LED-main')  # Use relative path

# Now import
from led.pipelines.led_pipeline import LEDPipeline

import matplotlib.pyplot as plt
import cv2

# Set which GPU to use
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

# Load and show the image
low_quality_image_path = 'output_images/0_BEH-172.png'
low_quality_image = cv2.imread(low_quality_image_path)[:, :, ::-1]
plt.imshow(low_quality_image)
plt.title("Low-Quality Image")
plt.axis('off')
plt.show()

KeyError: 'led'