### **Imports and utils**

Importing libraies

In [None]:
import matplotlib.pyplot as plt
from astropy.io import fits
import numpy as np
from keras.layers import Input, Conv2D, BatchNormalization, Activation, Add, Concatenate, UpSampling2D
from keras.models import Model
from keras.optimizers import Adam
import tensorflow as tf

Utils functions

In [None]:
def read_header_and_data(file_path, idx):
    with fits.open(file_path) as hdu:
        header = hdu[idx].header
        data = hdu[idx].data

        return header, data

def show_image(image_data, cmap='gist_gray'):
    plt.figure()
    plt.imshow(image_data, cmap=cmap)
    plt.colorbar()

def mean_subtract(img_data):
    img_mean = np.mean(img_data)
    img_data -= img_mean

    return img_data

def normalize(array):
    arr_max = np.max(array)
    arr_min = np.min(array)

    arr_normalized = (array - arr_min) / (arr_max - arr_min)

    return arr_normalized

def normalize_list(list_arrays):
    list_arrays_normalized = []
    list_arrays_normalized += [normalize(array) for array in list_arrays]
    list_arrays_normalized = np.array(list_arrays_normalized)

    return list_arrays_normalized

### **Denoising**

Opening denoising data

In [None]:
image_file_path = 'outputs/output_denoising_snr100.fits'
img_head, img_data = read_header_and_data(image_file_path, 0)

ResUNet

In [None]:
# from paper Qian, H. et al 2022

def resnet_block(x, x_shortcut, num_filters, kernel_size=3):
    # Main block
    rb = Conv2D(num_filters, kernel_size, padding='same')(x)
    rb = BatchNormalization()(rb)
    rb = Activation('relu')(rb)
    rb = Conv2D(num_filters, kernel_size, padding='same')(rb)

    # Add
    x_shortcut = Conv2D(num_filters, kernel_size=1, padding='same')(x_shortcut)
    #x_shortcut = BatchNormalization()(x_shortcut)
    rb = Add()([x_shortcut, rb])

    return rb

def upsample_concatenate(x, skip):
    x = UpSampling2D((2, 2))(x)

    return Concatenate()([x, skip])

def resUnet(input_shape=(48, 48, 1), num_filters=64, num_resnetblocks=4, kernel_size=3): # change filters dynamically and check blocks
    inputs = Input(input_shape)

    # Encoder
    x = inputs
    skip_connections = []
    for i in range(num_resnetblocks):
        if not i == 0: x = BatchNormalization()(x)
        x = resnet_block(x, x, num_filters, kernel_size)
        if not i == num_resnetblocks-1:
            skip_connections += [x] # copy maybe
            x = Conv2D(num_filters, kernel_size, strides=2, padding='same')(x)

    # Transition
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(num_filters * 2, kernel_size, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(num_filters * 2, kernel_size, padding='same')(x)

    # Decoder
    for skip in reversed(skip_connections):
        x = upsample_concatenate(x, skip)
        x_shortcut = x
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = resnet_block(x, x_shortcut, num_filters, kernel_size)

    x = BatchNormalization()(x)
    x = Conv2D(num_filters, kernel_size, padding='same')(x)

    outputs = Conv2D(1, 1, activation='sigmoid')(x)
    network = Model(inputs, outputs)

    return network

Training ResUnet

In [None]:
#img_data = np.random.permutation(img_data)
img_data = normalize_list(img_data)
img_size = img_data[0].shape[0]
num_total = len(img_data)
train_split = 0.7
train_idx = int(num_total * train_split)

img_data = np.array([np.reshape(img, (img_size, img_size, 1)) for img in img_data])
img_shape = img_data[0].shape

# Spliting training and testing
train = img_data[0:train_idx]
train_x = train[[i for i in range(0, len(train), 2)]]
train_y = train[[i for i in range(1, len(train), 2)]]

test = img_data[train_idx:num_total]
test_x = test[[i for i in range(0, len(test), 2)]]
test_y = test[[i for i in range(1, len(test), 2)]]

idx = 0
show_image(test_x[idx])
show_image(test_y[idx])

In [None]:
resunet_model = resUnet(input_shape=img_shape, num_filters=64, num_resnetblocks=4, kernel_size=3)
resunet_model.summary()

In [None]:
resunet_model.compile(optimizer=Adam(), loss='mean_squared_error', metrics=['mse'])

batch_size = 128
epochs = 25
history = resunet_model.fit(train_x, train_y, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(test_x, test_y))

In [None]:
predictions = resunet_model.predict(test_x)
idx = 0
show_image(test_x[idx])
show_image(predictions[idx])

Saving model to file

In [None]:
resunet_model.save('resunet_100snr2.keras')

**Deconvolution**

Importing images and spliting training and testing datasets

In [None]:
# Importing images
image_file_path = 'outputs/output_deconv_snr100_5k.fits'
img_head, img_data = read_header_and_data(image_file_path, 0) # img_data = [mod_img1, original_img1, mod_img2, original_img2, ...]

image_psf_file_path = 'outputs/output_deconv_snr100_5k_psf.fits'
psf_head, psf_data = read_header_and_data(image_psf_file_path, 0)

# Getting image size
img_size = img_data[0].shape[0]

# Normalizing images
img_data = normalize_list(img_data)
psf_data = normalize_list(psf_data)

# Reshaping images
img_data = np.array([np.reshape(img, (img_size, img_size, 1)) for img in img_data])
psf_data = np.array([np.reshape(psf, (img_size, img_size, 1)) for psf in psf_data])
img_shape = img_data[0].shape

# Spliting training and testing
train_split = 0.7
img_size = img_data[0].shape[0]
num_total = len(img_data)
train_idx = int(num_total * train_split)

img_train = img_data[0:train_idx]
psf_train = psf_data[0:train_idx]

img_test = img_data[train_idx:num_total]
psf_test = psf_data[train_idx:num_total]

# Image train
img_train_x = img_train[[i for i in range(0, len(img_train), 2)]]
img_train_y = img_train[[i for i in range(1, len(img_train), 2)]]

# PSF of the train images
psf_train_x = psf_train[[i for i in range(0, len(psf_train), 2)]]

# Image test
img_test_x = img_test[[i for i in range(0, len(img_test), 2)]]
img_test_y = img_test[[i for i in range(1, len(img_test), 2)]]

# PSF of the test images
psf_test_x = psf_test[[i for i in range(0, len(psf_test), 2)]]

idx = 0
show_image(img_train_x[idx])
show_image(psf_train_x[idx])
show_image(img_train_y[idx])
show_image(img_test_x[idx])
show_image(psf_test_x[idx])

Unrolled ADMM Network

In [None]:
def compute_alpha(N, y, beta=0.33):
    return np.sum(y) / (N*beta)

def unrolled_admm(input_shape, n):
    # Inputs
    y_input = Input(shape=input_shape)
    h_input = Input(shape=input_shape)

    # Get inputs as np objects
    y = tf.keras.backend.eval(y_input) 
    h = tf.keras.backend.eval(h_input)
    
    # Initializing parameters
    N = input_shape[0]
    alpha = compute_alpha(N, y) # 500, 200
    alpha_inv = 1 / alpha

    # Initializing denoising network
    resunet = resUnet(input_shape, num_filters=64, num_resnetblocks=4, kernel_size=3)

    # Initialize x with Wiener Filter
    Y = np.fft.fft(y)
    H = np.fft.fft(h)
    Ht = H.T
    Ht_conj = Ht.conj()
    H_abs_sqr = np.abs(H)**2 # np.linalg.det(H)**2
    x = [alpha_inv * np.fft.ifft((Ht_conj * Y) / (alpha_inv + H_abs_sqr))]

    # Initialize other parameters
    z = [x[0]]
    v = [np.copy(y)]
    u1 = [0]
    u2 = [0]

    # Tilde variables are calculated inside main loop
    x0_tilde = [] 
    x1_tilde = []
    v_tilde = [] 
    z_tilde = []

    # Hyperparameters (can be initialized by nn)
    rho1 = 1 # 10e3 # rho1 in (10−5, 10−2)
    rho2 = 1 # rho2 in ?
    # gamma = 1.5 # gamma in (1, 2)
    # eta = 0.5 # ?

    # ADMM iterations
    for k in range(1, n): # 1..n?
        v_tilde[k-1] = np.convolve(h, x[k-1]) + u2[k-1]
        v[k] = (rho2 * v_tilde[k-1] + y) / (1 + rho2) # gaussian mle

        z_tilde[k-1] = x[k-1] + u1[k-1]
        z[k] = tf.keras.backend.eval(resunet(z_tilde[k-1])) # eq 13

        x0_tilde[k-1] = z[k] - u1[k-1]
        x1_tilde[k-1] = v[k] - u2[k-1]
        X0_tilde = np.fft.fft(x0_tilde[k-1])
        X1_tilde = np.fft.fft(x1_tilde[k-1])
        rho_ratio = rho2 / rho1
        x[k] = np.fft.ifft((X0_tilde + rho_ratio * Ht_conj * X1_tilde) / (1 + rho_ratio * H_abs_sqr)) # eq 12

        # update Lagrangian multipliers
        u1[k] = u1[k-1] + x[k] - z[k] 
        u2[k] = u2[k-1] + np.convolve(H, x[k]) - v[k]

    # Convert processed image back to tensor
    output_img = tf.convert_to_tensor(np.real(x[-1]), dtype=tf.float32) # how to connect

    output = Conv2D(1, 1, activation='sigmoid', padding='same')(output_img)
    network = Model(inputs=[y, h], outputs=output)

    return network


In [None]:
unrolled_admm_model = unrolled_admm(input_shape=(48, 48, 1), n=8)
unrolled_admm_model.compile(optimizer=Adam(), loss='mean_absolute_error', metrics=['mae'])

batch_size = None # 128
epochs = 25

history = unrolled_admm_model.fit([img_train_x, psf_train_x], img_train_y, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=([img_test_x, psf_test_x], img_test_y))

In [None]:
# Creating model as subclass

class UnrolledADMM(Model):
    def __init__(self, n):
        super().__init__()
        self.n = n
    
    def compute_alpha(self, N, y, beta=0.33):
        print("calc alpha")
        return np.sum(y) / (N*beta)
    
    def call(self, inputs, training=None, mask=None):
        y = inputs[0]
        h = inputs[1]       

        # Initializing parameters
        N = tf.shape(y)
        print(N)
        alpha = self.compute_alpha(N, y) # 500, 200
        print("done calc alpha")
        alpha_inv = 1 / alpha

        # Initializing denoising network
        resunet = resUnet(input_shape=(N, N, 1), num_filters=64, num_resnetblocks=4, kernel_size=3)

        # Initialize x with Wiener Filter
        Y = np.fft.fft(y)
        H = np.fft.fft(h)
        Ht = H.T
        Ht_conj = Ht.conj()
        H_abs_sqr = np.abs(H)**2 # np.linalg.det(H)**2
        x = [alpha_inv * np.fft.ifft((Ht_conj * Y) / (alpha_inv + H_abs_sqr))]

        # Initialize other parameters
        z = [x[0]]
        v = [np.copy(y)]
        u1 = [0]
        u2 = [0]

        # Tilde variables are calculated inside main loop
        x0_tilde = [] 
        x1_tilde = []
        v_tilde = [] 
        z_tilde = []

        # Hyperparameters (can be initialized by nn)
        rho1 = 1 # 10e3 # rho1 in (10−5, 10−2)
        rho2 = 1 # rho2 in ?
        # gamma = 1.5 # gamma in (1, 2)
        # eta = 0.5 # ?

        # ADMM iterations
        for k in range(1, self.n): # 1..n?
            v_tilde[k-1] = np.convolve(h, x[k-1]) + u2[k-1]
            v[k] = (rho2 * v_tilde[k-1] + y) / (1 + rho2) # gaussian mle

            z_tilde[k-1] = x[k-1] + u1[k-1]
            z[k] = tf.keras.backend.eval(resunet(z_tilde[k-1])) # eq 13

            x0_tilde[k-1] = z[k] - u1[k-1]
            x1_tilde[k-1] = v[k] - u2[k-1]
            X0_tilde = np.fft.fft(x0_tilde[k-1])
            X1_tilde = np.fft.fft(x1_tilde[k-1])
            rho_ratio = rho2 / rho1
            x[k] = np.fft.ifft((X0_tilde + rho_ratio * Ht_conj * X1_tilde) / (1 + rho_ratio * H_abs_sqr)) # eq 12

            # update Lagrangian multipliers
            u1[k] = u1[k-1] + x[k] - z[k] 
            u2[k] = u2[k-1] + np.convolve(H, x[k]) - v[k]

        return x[-1] # return last image of iterations
        
        #return super().call(inputs, training, mask)
        

In [None]:
unrolled_admm_model = UnrolledADMM(n=8)
unrolled_admm_model.compile(optimizer=Adam(), loss='mean_absolute_error', metrics=['mae'])

batch_size = None # 128
epochs = 25

history = unrolled_admm_model.fit([img_train_x, psf_train_x], img_train_y, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=([img_test_x, psf_test_x], img_test_y))