In [27]:
from PIL import ImageEnhance, Image
import numpy as np
import os
def jitter_image(image, patch_size, mode="normal"):
    image_pil = Image.fromarray(image)  # Convert the image to PIL Image format

    width, height = image_pil.size
    num_patches_horizontal = width // patch_size
    num_patches_vertical = height // patch_size

    jittered_patches = []
    # s = np.random.binomial(1, 0.5)
    s = 1
    if s >= 0.5:
        for i in range(num_patches_vertical):
            for j in range(num_patches_horizontal):
                if mode == "normal":
                    brightness = np.random.uniform(0.9,1.1)
                    hue = np.random.uniform(0.9,1.1)
                    contrast = np.random.uniform(0.9,1.1)
                elif mode == "extreme":
                    s_b = np.random.binomial(1, 0.5)
                    s_h = np.random.binomial(1, 0.5)
                    s_c = np.random.binomial(1, 0.5)
                    left_bound = [0.2,0.5]
                    right_bound = [1.5,3]
                    bound_b = left_bound if s_b >= 0.5 else right_bound
                    bound_h = left_bound if s_h >= 0.5 else right_bound
                    bound_c = left_bound if s_c >= 0.5 else right_bound
                    brightness = np.random.uniform(bound_b[0], bound_b[1])
                    hue = np.random.uniform(bound_h[0],bound_h[1])
                    contrast = np.random.uniform(bound_c[0],bound_c[1])                    
                left = j * patch_size
                upper = i * patch_size
                right = left + patch_size
                lower = upper + patch_size

                patch = image_pil.crop((left, upper, right, lower))  # Extract the patch
                patch = apply_jittering(patch, brightness, hue,  contrast)  # Apply jittering to the patch
                jittered_patches.append(patch)
        # Create a new image by stitching the jittered patches together
        stitched_image = Image.new('RGB', (width, height))
        x_offset = 0
        y_offset = 0
        for patch in jittered_patches:
            stitched_image.paste(patch, (x_offset, y_offset))
            x_offset += patch_size
            if x_offset >= width:
                x_offset = 0
                y_offset += patch_size

        jittered_image = np.array(stitched_image)  # Convert the image back to numpy array format

        return jittered_image
    else:
        return np.array(image_pil)


def apply_jittering(patch, brightness=1.0, hue=0.0, contrast=1.0):
    # Adjust brightness
    enhancer_brightness = ImageEnhance.Brightness(patch)
    patch = enhancer_brightness.enhance(brightness)

    # Adjust hue
    enhancer_hue = ImageEnhance.Color(patch)
    patch = enhancer_hue.enhance(1 + hue)


    # Adjust contrast
    enhancer_contrast = ImageEnhance.Contrast(patch)
    patch = enhancer_contrast.enhance(contrast)

    return patch



def generate_ood(source_folder="ood_experiment/celebahq_source", 
                 target_folder="ood_experiment/celebahq_ood/train", 
                 mode="normal",
                n_samples=100, 
                patch_size=128, 
                max_source_images=20):
    
    source_file_list = os.listdir(source_folder)
    for im_count, filename in enumerate(source_file_list):
        if im_count == max_source_images:
            break
        fullname = os.path.join(source_folder, filename)
        basename = os.path.basename(fullname).replace(".png", "")
        basename = os.path.basename(fullname).replace(".jpg", "")
        sub_folder_path = os.path.join(target_folder, basename)
        if not(os.path.exists(sub_folder_path)):
            os.makedirs(sub_folder_path)
        for i in range(n_samples+1):
            save_path = os.path.join(sub_folder_path, f"{basename}_{i}.png")
            original_image = np.array(Image.open(fullname))
            if i == 0:
                original_image = Image.fromarray(original_image)
                original_image.save(save_path)
            else:
                jittered_image = jitter_image(original_image, patch_size, mode)
                jittered_image = Image.fromarray(jittered_image)
                jittered_image.save(save_path)



In [28]:
generate_ood(mode="normal", target_folder="ood_experiment/celebahq_ood/train")
generate_ood(mode="extreme", target_folder="ood_experiment/celebahq_ood/test")