In [None]:
from astropy.nddata.utils import Cutout2D
from astropy.io import fits
import fitsio
from astropy import table
import numpy as np
import matplotlib.pyplot as plt
from astropy.visualization import (ZScaleInterval, ImageNormalize)
from tensorflow import keras

# 1) Cut out 64x64 Patches from Image Data

In [None]:
z_image = fitsio.read("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/PS1.262.284.z.fits")
z_image

In [None]:
z_weights = fitsio.read("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/PS1.262.284.z.wt.fits")
z_weights

In [None]:
img = fits.open("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/PS1.262.284.g.fits")
print(img.info())
img.close()

In [None]:
len(np.where(np.isnan(z_image))[0])

In [None]:
z_image = np.where(np.isnan(z_image), 0, z_image)
len(np.where(np.isnan(z_image))[0])

In [None]:
r_image = fitsio.read("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/CFIS.264.282.r.fits")
len(np.where(np.isnan(r_image))[0])

In [None]:
plt.imshow(z_image, cmap='gray')
plt.colorbar()

In [None]:
norm = ImageNormalize(z_image, interval=ZScaleInterval())
plt.imshow(z_image,norm=norm)
plt.colorbar()

In [None]:
norm

In [None]:
image_cat = table.Table.read("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/CFIS.264.282.r.cat", format="ascii.sextractor")

In [None]:
image_cat

In [None]:
# Create 64x64 patches of sources
sources = []
for (x, y) in zip(image_cat["X_IMAGE"], image_cat["Y_IMAGE"]): # Centers of sources
    sources.append(Cutout2D(r_image, (x, y), 64, mode="partial", fill_value=0).data)
sources = np.array(sources)

In [None]:
sources.shape

In [None]:
# Plot first 9 sources
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        axes[row][col].imshow(sources[i])
        i += 1

In [None]:
sources[7][0]

In [None]:
# Plot first 9 sources
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        norm = ImageNormalize(sources[i], interval=ZScaleInterval())
        axes[row][col].imshow(sources[i], norm=norm)
        i += 1

# 2) Get Cutouts from CFIS Tiles

In [None]:
# %run make_cutouts.py

In [None]:
import glob

In [None]:
weight_image = fits.open("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/CFIS.264.282.r.weight.fits.fz")
weight_image.info()

In [None]:
image_dir = "/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/"
tiles = sorted(glob.glob(image_dir + "CFIS.*.r.fits"))
weights = sorted(glob.glob(image_dir + "CFIS.*.r.weight*"))
cats = sorted(glob.glob(image_dir + "CFIS.*.r.cat"))

In [None]:
print(len(tiles))
print(len(weights))
print(len(cats))

In [None]:
sources_cfis = []
sources_norm_keras = []
size = 64
for (t, w, c) in zip(tiles[:3], weights[:3], cats[:3]):
    image = fits.open(t)
    #weight = fits.open(w)
    image_data = image[0].data
    #weight_data = weight[1].data
    image.close()
    #weight.close()
    image_cat = table.Table.read(c, format="ascii.sextractor")
    for (x, y) in zip(image_cat["X_IMAGE"], image_cat["Y_IMAGE"]):
        cutout = Cutout2D(image_data, (x, y), size, mode="partial", fill_value=0).data
        sources_cfis.append(cutout)
        sources_norm_keras.append(keras.utils.normalize(cutout))
        #sources_cfis.append(Cutout2D(weight_data, (x, y), size, mode="partial", fill_value=0).data)
sources_cfis = np.array(sources_cfis)
sources_cfis = sources_cfis.reshape(*sources_cfis.shape, 1)
sources_norm_keras = np.array(sources_norm_keras)
sources_norm_keras = sources_norm_keras.reshape(*sources_norm_keras.shape, 1)

In [None]:
np.shape(sources_cfis)

In [None]:
len(sources_cfis[np.isnan(sources_cfis)])

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 10000
for row in range(3):
    for col in range(3):
        norm = ImageNormalize(sources_cfis[i], interval=ZScaleInterval())
        axes[row][col].imshow(sources_cfis[i], norm=norm)
        i += 1

In [None]:
# Normalize data between 0 and 1
sources_norm_01 = (sources_cfis - np.min(sources_cfis)) / (np.max(sources_cfis) - np.min(sources_cfis))

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 10000
for row in range(3):
    for col in range(3):
        norm = ImageNormalize(sources_norm_01[i], interval=ZScaleInterval())
        axes[row][col].imshow(sources_norm_01[i], norm=norm)
        i += 1

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 10000
for row in range(3):
    for col in range(3):
        norm = ImageNormalize(sources_norm_keras[i], interval=ZScaleInterval())
        axes[row][col].imshow(sources_norm_keras[i], norm=norm)
        i += 1

# 3) Get Cutouts from Multiple Channels

In [None]:
g_weights_file = fits.open("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/PS1.262.284.g.wt.fits")
g_weights_file.info()

In [None]:
r_weights_file = fits.open("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/CFIS.262.284.r.weight.fits.fz")
r_weights_file.info()
r_weights_file.close()

In [None]:
g_weights_data = g_weights_file[0].data
g_weights_data
g_weights_file.close()

In [None]:
plt.imshow(g_weights_data, cmap='gray')
plt.colorbar()

In [None]:
norm = ImageNormalize(g_weights_data, interval=ZScaleInterval())
plt.imshow(g_weights_data,norm=norm)
plt.colorbar()

In [None]:
g_fits = sorted(glob.glob(image_dir + "PS1.*.g.fits"))
g_weights = sorted(glob.glob(image_dir + "PS1.*.g.wt.fits"))
i_fits = sorted(glob.glob(image_dir + "PS1.*.i.fits"))
i_weights = sorted(glob.glob(image_dir + "PS1.*.i.wt.fits"))
r_fits = tiles
r_weights = weights
z_fits = sorted(glob.glob(image_dir + "PS1.*.z.fits"))
z_weights = sorted(glob.glob(image_dir + "PS1.*.z.wt.fits"))

In [None]:
print(len(g_fits))
print(len(g_weights))
print(len(i_fits))
print(len(i_weights))
print(len(r_fits))
print(len(r_weights))
print(len(z_fits))
print(len(z_weights))

In [None]:
# Only use tiles with all four channels

PS1 = sorted(glob.glob(image_dir + "PS1.*"))
start = len(image_dir + "PS1.")
ra_dec = []
for file in PS1:
    ra_dec_str = file[start:start+7]
    if not ra_dec_str in ra_dec:
        ra_dec.append(ra_dec_str)

g_fits = []
g_weights = []
i_fits = []
i_weights = []
r_fits = []
r_weights = []
z_fits = []
z_weights = []
cats = []

for ra_dec_str in ra_dec:
    try:
        g_fits_file = open(image_dir + "PS1." + ra_dec_str + ".g.fits", "r")
        g_weights_file = open(image_dir + "PS1." + ra_dec_str + ".g.wt.fits", "r")
        i_fits_file = open(image_dir + "PS1." + ra_dec_str + ".i.fits", "r")
        i_weights_file = open(image_dir + "PS1." + ra_dec_str + ".i.wt.fits", "r")
        r_fits_file = open(image_dir + "CFIS." + ra_dec_str + ".r.fits", "r")
        r_weights_file = open(image_dir + "CFIS." + ra_dec_str + ".r.weight.fits.fz", "r")
        z_fits_file = open(image_dir + "PS1." + ra_dec_str + ".z.fits", "r")
        z_weights_file = open(image_dir + "PS1." + ra_dec_str + ".z.wt.fits", "r")
        cat_file = open(image_dir + "CFIS." + ra_dec_str + ".r.cat", "r")
        
        # All files exist
        g_fits.append(g_fits_file.name)
        g_weights.append(g_weights_file.name)
        i_fits.append(i_fits_file.name)
        i_weights.append(i_weights_file.name)
        r_fits.append(r_fits_file.name)
        r_weights.append(r_weights_file.name)
        z_fits.append(z_fits_file.name)
        z_weights.append(z_weights_file.name)
        cats.append(cat_file.name)
        
        g_fits_file.close()
        g_weights_file.close()
        i_fits_file.close()
        i_weights_file.close()
        r_fits_file.close()
        r_weights_file.close()
        z_fits_file.close()
        z_weights_file.close()
        cat_file.close()
    except FileNotFoundError:
        continue

In [None]:
print(len(g_fits))
print(len(g_weights))
print(len(i_fits))
print(len(i_weights))
print(len(r_fits))
print(len(r_weights))
print(len(z_fits))
print(len(z_weights))
print(len(cats))

In [None]:
n_cutouts = 0
for i in range(len(cats)):
    cat = table.Table.read(cats[i], format="ascii.sextractor")
    n_cutouts += 2 * len(cat) # fits/weights files

In [None]:
print("Number of cutouts: " + str(n_cutouts))
print("Tensor size: " + str(n_cutouts*size*size*4))

In [None]:
# Only use 30000 cutouts to reduce tensor size
sources_multi = np.zeros((30000, size, size, 4)) # Order of channels is r,g,i,z
n = 0
for i in range(1):
    r_fits_file = fits.open(r_fits[i])
    #r_weights_file = fits.open(r_weights[i])
    g_fits_file = fits.open(g_fits[i])
    #g_weights_file = fits.open(g_weights[i])
    i_fits_file = fits.open(i_fits[i])
    #i_weights_file = fits.open(i_weights[i])
    z_fits_file = fits.open(z_fits[i])
    #z_weights_file = fits.open(z_weights[i])
    
    r_fits_data = r_fits_file[0].data
    #r_weights_data = r_weights_file[1].data
    g_fits_data = g_fits_file[0].data
    #g_weights_data = g_weights_file[0].data
    i_fits_data = i_fits_file[0].data
    #i_weights_data = i_weights_file[0].data
    z_fits_data = z_fits_file[0].data
    #z_weights_data = z_weights_file[0].data
    
    r_fits_file.close()
    #r_weights_file.close()
    g_fits_file.close()
    #g_weights_file.close()
    i_fits_file.close()
    #i_weights_file.close()
    z_fits_file.close()
    #z_weights_file.close()
    
    cat = table.Table.read(cats[i], format="ascii.sextractor")
    for (x, y) in zip(cat["X_IMAGE"], cat["Y_IMAGE"]):
        sources_multi[n,:,:,0] = Cutout2D(r_fits_data, (x, y), size, mode="partial", fill_value=0).data
        #sources[n+1,:,:,0] = Cutout2D(r_weights_data, (x, y), size, mode="partial", fill_value=0).data
        sources_multi[n,:,:,1] = Cutout2D(g_fits_data, (x, y), size, mode="partial", fill_value=0).data
        #sources[n+1,:,:,1] = Cutout2D(g_weights_data, (x, y), size, mode="partial", fill_value=0).data
        sources_multi[n,:,:,2] = Cutout2D(i_fits_data, (x, y), size, mode="partial", fill_value=0).data
        #sources[n+1,:,:,2] = Cutout2D(i_weights_data, (x, y), size, mode="partial", fill_value=0).data
        sources_multi[n,:,:,3] = Cutout2D(z_fits_data, (x, y), size, mode="partial", fill_value=0).data
        #sources[n+1,:,:,3] = Cutout2D(z_weights_data, (x, y), size, mode="partial", fill_value=0).data
        n += 1
        if n == 30000:
            break

In [None]:
sources_multi.shape

In [None]:
30000*32*32*4

In [None]:
r_fits_file = fits.open(r_fits[0])
r_fits_file.info()

In [None]:
r_fits_data = r_fits_file[0].data
r_fits_file.close()
r_fits_data

In [None]:
g_fits_file = fits.open("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/PS1.262.284.g.fits")
g_fits_data = g_fits_file[0].data
g_fits_file.close()
g_fits_data

In [None]:
g_fits_file = fits.open("/home/eyvorch9/projects/rrg-kyi/astro/cfis/W3/PS1.262.284.g.fits")
g_fits_file.info()

In [None]:
i_fits_file = fits.open(i_fits[0])
i_fits_data = i_fits_file[0].data
i_fits_file.close()
sum(np.isnan(i_fits_data))

In [None]:
z_fits_file = fits.open(z_fits[0])
z_fits_data = z_fits_file[0].data
z_fits_file.close()
z_fits_data

# 4) Build Autoencoder for MNIST Data

In [None]:
#import datetime

In [None]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

In [None]:
x_train = x_train.astype("float32") / 255.0
x_train = x_train.reshape(*x_train.shape, 1)
x_test = x_test.astype("float32") / 255.0
x_test = x_test.reshape(*x_test.shape, 1)

y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)

In [None]:
class Autoencoder(keras.Model):
    def __init__(self, shape):
        super(Autoencoder, self).__init__()
        self.encoder = keras.Sequential([
            keras.layers.Input(shape=shape),
            keras.layers.Conv2D(8, kernel_size=3, activation='relu', padding='same'),
            keras.layers.MaxPooling2D(),
            keras.layers.Conv2D(4, kernel_size=3, activation='relu', padding='same')])
        
        self.decoder = keras.Sequential([
            keras.layers.Conv2DTranspose(4, kernel_size=3, activation='relu', padding='same'),
            keras.layers.UpSampling2D((2,2)),
            keras.layers.Conv2DTranspose(8, kernel_size=3, activation='relu', padding='same'),
            keras.layers.BatchNormalization(),
            keras.layers.Conv2D(1, (3,3), activation='linear', padding='same')])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
class Autoencoder(keras.Model):
    def __init__(self, shape):
        super(Autoencoder, self).__init__()
        self.encoder = keras.Sequential([
            keras.layers.InputLayer(shape),
            keras.layers.Flatten(),
            keras.layers.Dense(1000)])

        self.decoder = keras.Sequential([
            keras.layers.InputLayer(1000,),
            keras.layers.Dense(np.prod(shape)),
            keras.layers.Reshape(shape)])

    def call(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
def create_autoencoder(shape):
    input_img = keras.Input(shape=shape)
    x = keras.layers.Conv2D(8, kernel_size=3, activation='relu', padding='same')(input_img)
    x = keras.layers.MaxPooling2D((2,2), padding='same')(x)
    encoded = keras.layers.Conv2D(16, kernel_size=3, activation='relu', padding='same')(x)

    x = keras.layers.Conv2DTranspose(16, kernel_size=3, activation='relu', padding='same')(encoded)
    x = keras.layers.UpSampling2D((2,2))(x)
    x = keras.layers.Conv2DTranspose(8, kernel_size=3, activation='relu', padding='same')(x)
    decoded = keras.layers.Conv2D(1, (3,3), activation='linear', padding='same')(x)
    
    return keras.Model(input_img, decoded)

In [None]:
autoencoder = create_autoencoder((x_train.shape[1], x_train.shape[2], 1))
autoencoder.compile(optimizer='adam', loss="mse")

In [None]:
autoencoder.summary()

In [None]:
#log_dir = "logs/fit/MNIST" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
#tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
autoencoder.fit(x_train, x_train, epochs=5, validation_split=0.2)

In [None]:
decoded_imgs = autoencoder.predict(x_test)

In [None]:
# Compare original images to reconstructed images
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        axes[row][col].imshow(x_test[i])
        i += 1

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        axes[row][col].imshow(decoded_imgs[i])
        i += 1

In [None]:
# Recreate original digit from noisy image

In [None]:
def add_gaussian(x, amp=1):
    x_gauss = np.copy(x)
    for i in range(len(x)):
        img = np.copy(x_gauss[i])
        noise = amp*np.random.random((x.shape[1], x.shape[2]))
        noise = noise.reshape(*noise.shape, 1)
        noisy_image = noise + img
        # Renormalize pixel values from 0 to 1
        noisy_image = keras.utils.normalize(noisy_image, axis=0)
        x_gauss[i] = noisy_image
    return x_gauss

In [None]:
def add_blob(x):
    x_blob = np.copy(x)
    
    lims = [2, x[0].shape[0] - 3] # For limits of center of blob
    for i in range(len(x)):
        # Place blob with resolution 5x5 onto image
        img = np.copy(x[i])
        center_x = np.random.randint(low=lims[0], high=lims[1] + 1)
        center_y = np.random.randint(low=lims[0], high=lims[1] + 1)
        
        # Produce standard Gaussian blob and place on image
        mux = 0
        muy = 0
        sigmax = 1
        sigmay = 1
        blob_size = 5
        x_gauss = np.linspace(-1, 1, blob_size)
        y_gauss = np.linspace(-1, 1, blob_size)
        x_gauss, y_gauss = np.meshgrid(x_gauss, y_gauss)
        blob = np.exp(-((x_gauss-mux)**2/(2*sigmax**2) + (y_gauss-muy)**2/(2*sigmay**2)))
        blob = blob.reshape(*blob.shape, 1)
        img[center_x - 2:center_x + 3,center_y - 2:center_y + 3] = blob
        x_blob[i] = img
        
    return x_blob

In [None]:
x_train_blob = add_blob(x_train)
x_test_blob = add_blob(x_test)
x_train_complex = add_gaussian(x_train_blob)
x_test_complex = add_gaussian(x_test_blob)

In [None]:
autoencoder_complex = Autoencoder((x_train.shape[1], x_train.shape[2], 1))
autoencoder_complex.compile(optimizer='adam', loss=keras.losses.MeanSquaredError())

In [None]:
log_dir = "logs/fit/MNIST_Complex" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
autoencoder_complex.fit(x_train_complex, x_train, epochs=25, validation_split=0.2, callbacks=[tensorboard_callback])

In [None]:
encoded_imgs_complex = autoencoder_complex.encoder(x_test).numpy()
decoded_imgs_complex = autoencoder_complex.decoder(encoded_imgs_complex).numpy()

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        axes[row][col].imshow(x_test[i])
        i += 1

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        axes[row][col].imshow(x_test_complex[i])
        i += 1

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        axes[row][col].imshow(decoded_imgs_complex[i])
        i += 1

# 5) Build Autoencoder for CFIS Data

In [None]:
# First for single channel

In [None]:
# Use 80% cutouts for training and remaining for testing
threshold = int(0.8*len(sources_cfis))
sources_train = sources_cfis[:threshold]
sources_test = sources_cfis[threshold:]
sources_train_01 = sources_norm_01[:threshold]
sources_test_01 = sources_norm_01[threshold:]
sources_train_keras = sources_norm_keras[:threshold]
sources_test_keras = sources_norm_keras[threshold:]

In [None]:
np.shape(sources_train_01)

In [None]:
autoencoder_cfis_01 = create_autoencoder((sources_cfis.shape[1], sources_cfis.shape[2], 1))
autoencoder_cfis_01.compile(optimizer="adam", loss="mse")

In [None]:
autoencoder_cfis_keras = Autoencoder((sources_cfis.shape[1], sources_cfis.shape[2], 1))
autoencoder_cfis_keras.compile(optimizer='adam', loss="mse")

In [None]:
autoencoder_cfis_01.summary()

In [None]:
#log_dir = "logs/fit/CFIS" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
#tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
history_01 = autoencoder_cfis_01.fit(sources_train_01, sources_train_01, epochs=300, validation_split=0.2)

In [None]:
history_keras = autoencoder_cfis_keras.fit(sources_train_keras, sources_train_keras, epochs=100, validation_split=0.2)

In [None]:
def plot_loss_curves(history):
    plt.plot(history.history["loss"], color="g", label="Training")
    plt.plot(history.history["val_loss"], color="b", label="Validation")
    plt.title("Loss Curves for Training/Validation Sets")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

In [None]:
plot_loss_curves(history_01)

In [None]:
plot_loss_curves(history_keras)

In [None]:
decoded_imgs_01 = autoencoder_cfis_01.predict(sources_test_01[:100])

In [None]:
encoded_imgs_keras = autoencoder_cfis_keras.encoder(sources_test_keras[:100])
decoded_imgs_keras = autoencoder_cfis_keras.decoder(encoded_imgs_keras)

In [None]:
np.min(sources_test)

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        norm = ImageNormalize(sources_test[i], interval=ZScaleInterval())
        axes[row][col].imshow(sources_test[i])
        i += 1

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        norm = ImageNormalize(sources_test_01[i], interval=ZScaleInterval())
        axes[row][col].imshow(sources_test_01[i])
        i += 1

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        norm = ImageNormalize(decoded_imgs_01[i], interval=ZScaleInterval())
        axes[row][col].imshow(decoded_imgs_01[i], norm=norm)
        i += 1

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        norm = ImageNormalize(sources_test_keras[i], interval=ZScaleInterval())
        axes[row][col].imshow(sources_test_keras[i])
        i += 1

In [None]:
fig, axes = plt.subplots(3,3, figsize=(8,8))
i = 0
for row in range(3):
    for col in range(3):
        norm = ImageNormalize(decoded_imgs_keras[i], interval=ZScaleInterval())
        axes[row][col].imshow(decoded_imgs_keras[i], norm=norm)
        i += 1

In [None]:
# Now multi-channel

In [None]:
# Use first 25000 cutouts for training and remaining for testing
sources_train = sources_multi[:25000]
sources_test = sources_multi[25000:]

In [None]:
autoencoder_multi = Autoencoder((sources_multi.shape[1], sources_multi.shape[2], 4))
autoencoder_multi.compile(optimizer='adam', loss=keras.losses.MeanSquaredError())

In [None]:
log_dir = "logs/fit/Multi" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
autoencoder_multi.fit(sources_train, sources_train, epochs=25, validation_split=0.2, callbacks=[tensorboard_callback])

In [None]:
encoded_imgs_cfis = autoencoder_multi.encoder(sources_test).numpy()
decoded_imgs_cfis = autoencoder_multi.decoder(encoded_imgs_cfis).numpy()

In [None]:
fig, axes = plt.subplots(2,2, figsize=(8,8))
i = 0
for row in range(2):
    for col in range(2):
        norm = ImageNormalize(sources_train[0,:,:,i], interval=ZScaleInterval())
        axes[row][col].imshow(sources_train[0,:,:,i])
        i += 1

In [None]:
fig, axes = plt.subplots(2,2, figsize=(8,8))
i = 0
for row in range(2):
    for col in range(2):
        norm = ImageNormalize(decoded_imgs_cfis[0,:,:,i], interval=ZScaleInterval())
        axes[row][col].imshow(decoded_imgs_cfis[0,:,:,i], norm=norm)
        i += 1