In [1]:
import numpy as np
from PIL import Image
import torch

# Define parameters
beta1 = 0.01  # Example value for beta1
beta2 = 0.99  # Example value for beta2
timesteps = 100  # Number of timesteps

# Construct b_t, a_t, and ab_t
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp().to(device)  # Move ab_t to the same device
ab_t[0] = 1.0  # Ensure the initial value is 1.0 (100% of the original image)

# Function to perturb an image to a specified noise level
def perturb_input(x, t, noise):
    return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]).sqrt() * noise

# Function to normalize and convert numpy array to PyTorch tensor
def normalize_and_convert(img_array):
    img_array = img_array / 255.0  # Normalize to [0, 1]
    img_tensor = torch.tensor(img_array, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)  # Convert to tensor and add batch dimension, move to device
    return img_tensor

# Function to convert image to numpy array
def image_to_numpy(image_path):
    # Open the image file
    img = Image.open(image_path)
    
    # Ensure the image is in RGB mode
    img = img.convert('RGB')
    
    # Convert image to numpy array
    img_array = np.array(img)
    
    return img_array

# Function to save a tensor image to file
def save_tensor_image(tensor, filename):
    tensor = tensor.squeeze().clamp(0, 1).cpu().permute(1, 2, 0).numpy()
    img = Image.fromarray((tensor * 255).astype(np.uint8))
    img.save(filename)

# Function to decode the perturbed image using the original noise level
def decode_perturbed_image(img_array, timestep, original_noise):
    img_tensor = normalize_and_convert(img_array)
    
    # Reverse the perturbation with the original noise
    perturbed_image = perturb_input(img_tensor, timestep, -original_noise)
    
    return perturbed_image

# Example usage to encode and decode an image
image_path = 'pexels-lastly-772803.jpg'  # Replace with your image path
img_array = image_to_numpy(image_path)

# Encode the image by perturbing it with Gaussian noise
original_noise = torch.randn_like(normalize_and_convert(img_array), device=device)
perturbed_image = perturb_input(normalize_and_convert(img_array), timesteps, original_noise)

# Save the perturbed image
save_tensor_image(perturbed_image, 'perturbed_image.jpg')

# Decode the perturbed image using the original noise
decoded_image = decode_perturbed_image(img_array, timesteps, original_noise)

# Save the decoded image
save_tensor_image(decoded_image, 'decoded_image.jpg')

# Optionally, display or further process the decoded image
