In [None]:
import numpy as np
from astropy.wcs import WCS
from astropy.io import fits
from astropy.convolution import convolve, Gaussian2DKernel
from reproject import reproject_interp, reproject_exact
import matplotlib.pyplot as plt

# Parameters
image_size = 128  # Original image size (pixels)
pixel_scale = 4.0  # Pixel scale (arcseconds per pixel)
new_pixel_scale = 1.0  # New pixel scale
noise_level = 0.05  # Noise standard deviation

# Create a WCS for the original image
wcs_original = WCS(naxis=2)
wcs_original.wcs.crpix = [image_size / 2, image_size / 2]
wcs_original.wcs.cdelt = [pixel_scale / 3600.0, pixel_scale / 3600.0]  # Degrees per pixel
wcs_original.wcs.crval = [0, 0]
wcs_original.wcs.ctype = ["RA---TAN", "DEC--TAN"]

# Create a WCS for the new image
new_size = int(image_size * pixel_scale / new_pixel_scale)
wcs_new = WCS(naxis=2)
wcs_new.wcs.crpix = [new_size / 2, new_size / 2]
wcs_new.wcs.cdelt = [new_pixel_scale / 3600.0, new_pixel_scale / 3600.0]  # Degrees per pixel
wcs_new.wcs.crval = [0, 0]
wcs_new.wcs.ctype = ["RA---TAN", "DEC--TAN"]

# Simulate Gaussian point sources
np.random.seed(42)  # For reproducibility
num_sources = 20
positions = np.random.uniform(0, image_size, (num_sources, 2))
amplitudes = np.random.uniform(0.5, 1.5, num_sources)

image = np.zeros((image_size, image_size))
for pos, amp in zip(positions, amplitudes):
    y, x = np.meshgrid(np.arange(image_size), np.arange(image_size))
    image += amp * np.exp(-((x - pos[0])**2 + (y - pos[1])**2) / (2 * (1.0)**2))

# Convolve with a Gaussian PSF
psf = Gaussian2DKernel(1.5)  # PSF with standard deviation of 1.5 pixels
image_psf = convolve(image, psf)

# Add Gaussian noise
noise = np.random.normal(0, noise_level, (image_size, image_size))
image_with_noise = image_psf + noise

# Reproject the original image and noise separately
original_hdu = fits.PrimaryHDU(image_with_noise, header=wcs_original.to_header())
target_shape = (new_size, new_size)

reprojected_image_with_noise, _ = reproject_exact(original_hdu, output_projection=wcs_new, shape_out=target_shape)
reprojected_image, _ = reproject_exact(fits.PrimaryHDU(image_psf, header=wcs_original.to_header()), 
                                       output_projection=wcs_new, shape_out=target_shape)
reprojected_noise, _ = reproject_exact(fits.PrimaryHDU(noise, header=wcs_original.to_header()), 
                                       output_projection=wcs_new, shape_out=target_shape)

# Check equivalence
combined_reprojection, _ = reproject_exact(fits.PrimaryHDU(image_with_noise, header=wcs_original.to_header()), 
                                           output_projection=wcs_new, shape_out=target_shape)
separate_sum = reprojected_image + reprojected_noise

# Plot results
fig, ax = plt.subplots(2, 3, figsize=(15, 10))
ax[0, 0].imshow(image_psf, origin="lower", cmap="gray")
ax[0, 0].set_title("Original Image (Signal)")
ax[0, 1].imshow(noise, origin="lower", cmap="gray")
ax[0, 1].set_title("Original Noise")
ax[0, 2].imshow(image_with_noise, origin="lower", cmap="gray")
ax[0, 2].set_title("Original Image + Noise")

ax[1, 0].imshow(reprojected_image, origin="lower", cmap="gray")
ax[1, 0].set_title("Reprojected Image (Signal)")
ax[1, 1].imshow(reprojected_noise, origin="lower", cmap="gray")
ax[1, 1].set_title("Reprojected Noise")
ax[1, 2].imshow(np.abs(combined_reprojection - separate_sum), origin="lower", cmap="hot")
ax[1, 2].set_title("Difference: f(A1 + N) vs f(A1) + f(N)")

plt.tight_layout()
plt.show()
