In [1]:
from phastphase.retrieval_jax import refine, retrieve
%load_ext autoreload
%autoreload 2

In [2]:
from phastphase.retrieval_jax.alternative_methods._gradient_flows import (
    wirtinger_flow, amplitude_flow, truncated_amplitude_flow, truncated_wirtinger_flow)


In [3]:
import jax.numpy as jnp
import cv2 as cv
import jax
jax.config.update("jax_enable_x64", True)
jax.config.update('jax_default_device', jax.devices("cpu")[0])

In [4]:
def import_normalized_figure_data(figure_name):
    img_bgr = cv.imread(figure_name)
    img_hls = cv.cvtColor(img_bgr, cv.COLOR_BGR2HLS_FULL)
    hls_array = jnp.array(img_hls, dtype=jnp.float64)/255.0

    return hls_array

def convert_hls_to_complex(hsl_array):
    h_channel = hsl_array[:,:,0]
    l_channel = hsl_array[:,:,1]
    s_channel = hsl_array[:,:,2]

    magnitude = jnp.sqrt(l_channel)
    phase = h_channel * 2 * jnp.pi

    real_part = magnitude * jnp.cos(phase)
    imag_part = magnitude * jnp.sin(phase)

    complex_image = real_part + 1j * imag_part

    return complex_image, s_channel

def convert_complex_to_hls(complex_image, saturation_channel):
    real_part = jnp.real(complex_image)
    imag_part = jnp.imag(complex_image)

    magnitude = jnp.sqrt(real_part**2 + imag_part**2)
    phase = jnp.angle(complex_image)

    l_channel = jnp.clip(magnitude**2, 0.0, 1.0)
    h_channel = jnp.mod(phase / (2 * jnp.pi), 1.0)
    s_channel = jnp.clip(saturation_channel, 0.0, 1.0)

    hls_image = jnp.stack([h_channel, l_channel, s_channel], axis=-1)

    return hls_image


def import_complex_figure_data(figure_name):
    hls_array = import_normalized_figure_data(figure_name)
    complex_image, saturation_channel = convert_hls_to_complex(hls_array)

    return complex_image, saturation_channel




In [5]:
comp, sat = import_complex_figure_data("PillarsOfCreationLarge.png")

In [6]:
print(comp.shape)

(14589, 8423)


In [None]:
jax.config.update('jax_default_device', jax.devices()[0])
small_comp = comp[:512, :512]
overs = 2
normalized_comp = small_comp/jnp.linalg.vector_norm(small_comp)
schwarz_comp = normalized_comp.at[0,0].set(10*jnp.size(small_comp))
far_field_shape = (small_comp.shape[0]*overs, small_comp.shape[1]*overs)
y = jnp.abs(jnp.fft.fft2(schwarz_comp, s=far_field_shape, norm="ortho"))**2

mask = jnp.ones_like(small_comp)
x_out, val = retrieve(
    y,
    mask,
    max_iters=1000,
    descent_method=0,
    grad_tolerance=1e-14,
    winding_guess=(0, 0),
)

