In [10]:
import cv2
import numpy as np
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
import random

def apply_wavelet_transform(image, amplitude=10, frequency=0.05):
    h, w = image.shape[:2]
    x = np.arange(w)
    y = np.arange(h)
    x_wave = (np.sin(2 * np.pi * frequency * y) * amplitude).astype(np.int32)
    x_indices = (np.tile(x, (h, 1)) + x_wave[:, None]) % w
    warped_image = image[np.arange(h)[:, None], x_indices]
    return warped_image

def apply_geometric_distortion(image, max_offset=5):
    h, w = image.shape[:2]
    distorted_image = np.zeros_like(image)
    for i in range(h):
        for j in range(w):
            dx = np.random.randint(-max_offset, max_offset + 1)
            dy = np.random.randint(-max_offset, max_offset + 1)
            new_x = np.clip(j + dx, 0, w - 1)
            new_y = np.clip(i + dy, 0, h - 1)
            distorted_image[i, j] = image[new_y, new_x]
    return distorted_image

def apply_occlusions(image, num_occlusions=5, max_size=50):
    corrupted_image = image.copy()
    h, w = image.shape[:2]
    for _ in range(num_occlusions):
        x1 = np.random.randint(0, w)
        y1 = np.random.randint(0, h)
        x2 = np.clip(x1 + np.random.randint(10, max_size), 0, w)
        y2 = np.clip(y1 + np.random.randint(10, max_size), 0, h)
        corrupted_image[y1:y2, x1:x2] = 255  # Mask with black
    return corrupted_image

def apply_blurring(image, sigma=0.9):
    blurred_image = gaussian_filter(image, sigma=sigma)
    return blurred_image

def adjust_brightness(image, factor=1.2):
    bright_image = np.clip(image * factor, 0, 255).astype(np.uint8)
    return bright_image

def apply_corruptions(image_path, output_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert to RGB for visualization

    # wavelet_image = apply_wavelet_transform(image)
    # distorted_image = apply_geometric_distortion(image)
    # occluded_image = apply_occlusions(image)
    # blurred_image = apply_blurring(image)
    # bright_image = adjust_brightness(image)

    wavelet_image = apply_wavelet_transform(image)
    distorted_image = apply_geometric_distortion(wavelet_image)
    blurred_image = apply_blurring(distorted_image)
    bright_image = adjust_brightness(blurred_image)
    occluded_image = apply_occlusions(bright_image)

    corrupted_images = [image, wavelet_image, distorted_image, blurred_image, bright_image, occluded_image]
    titles = ["Original", "Wavelet Transform", "Geometric Distortion", "Blurring", "Brightness", "Occlusions"]

    plt.figure(figsize=(15, 10))
    for i, (img, title) in enumerate(zip(corrupted_images, titles)):
        plt.subplot(2, 3, i + 1)
        plt.imshow(img)
        plt.title(title)
        plt.axis("off")

    plt.tight_layout()
    plt.savefig(output_path)
    plt.show()

# Example usage
apply_corruptions("assets/bull_shark.jpg", "corrupted_images/bull_shark_distorted.jpg")
