#### Import neccessary libraries and set paths

In [None]:
import tensorflow as tf
config_tf = tf.ConfigProto()
config_tf.gpu_options.allow_growth=True
sess = tf.Session(config=config_tf)

import json
from keras.models import model_from_json

import sys
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from keras.utils import multi_gpu_model
from sklearn.ensemble import RandomForestRegressor
from importlib import reload
from pygifsicle import optimize
import imageio
import os
import matplotlib.animation as animation

In [None]:
params = {
    "legend.fontsize": "x-large",
    "axes.labelsize": "x-large",
    "axes.titlesize": "x-large",
    "xtick.labelsize": "x-large",
    "ytick.labelsize": "x-large",
    "figure.facecolor": "w",
    "xtick.top": True,
    "ytick.right": True,
    "xtick.direction": "in",
    "ytick.direction": "in",
    "font.family": "serif",
    "mathtext.fontset": "dejavuserif"
}
plt.rcParams.update(params)

In [None]:
# Path where your software library is saved
# Clone the latest version of morphCaps branch from github
path_photoz = '/home/bid13/code/photozCapsNet'

sys.path.insert(1, path_photoz)
path_photoz = Path(path_photoz)

#### Import custom modules

In [None]:
from encapzulate.data_loader.data_loader import load_data
from encapzulate.utils.fileio import load_model, load_config
from encapzulate.utils import metrics
from encapzulate.utils.utils import import_model
from encapzulate.utils.metrics import Metrics, probs_to_redshifts, bins_to_redshifts
reload(metrics)

#### Specify the results to be explored

In [None]:
# Parameters for the exploration
run_name = "paper1_regression_80perc_0"
checkpoint_eval = 100

In [None]:
#Create and set different paths
# path_output = "/data/bid13/photoZ/results"
path_output = "/home/bid13/code/photozCapsNet/results"
path_output = Path(path_output)
path_results = path_output / run_name.split("_")[0] / run_name / "results" 
path_config =  path_results / "config.yml"

#### Load Config, Model and Data

In [None]:
config  = load_config(path_config)
scale= config['image_scale']

In [None]:
log = pd.read_csv(path_results/ "logs" /"log.csv")

In [None]:
max_acc = log[log.val_decoder_model_loss==log.val_decoder_model_loss.min()]
max_acc

In [None]:
#with tf.device('/cpu:0'):
model = load_model( path_results / "eval_model.json", path_results/ "weights" / f"weights-{checkpoint_eval:02d}.h5")
# model = multi_gpu_model(model,gpus=2)
model.summary()

In [None]:
(
    (x_train, y_train, vals_train, z_spec_train, cat_train),
    (x_dev, y_dev, vals_dev, z_spec_dev, cat_dev),
    (x_test, y_test, vals_test, z_spec_test, cat_test),
) = load_data(load_cat=True, **config)

#### Run Predictions

In [None]:
y_caps_dev, y_caps_all_dev, y_prob_dev, x_recon_dev, z_phot_dev = model.predict(x_dev,batch_size=1024)

In [None]:
del x_train
del x_test
# del x_dev
# del x_recon_test
# del x_recon_dev

#### Plot images

In [None]:
# https://github.com/legacysurvey/imagine/blob/acac773c6a43c7e6d6ea0c128d5e963ad8295229/map/views.py#L3881
def sdss_rgb(imgs, bands, scales=None, m=0.02, Q=20, alpha=1, p=0.7):
    import numpy as np

    rgbscales = {
        "u": (2, 1.5),  # 1.0,
        "g": (2, 2.8),
        "r": (1, 1.4),
        "i": (0, 1.1),
        "z": (0, 0.4),  # 0.3
    }
    if scales is not None:
        rgbscales.update(scales)

    I = 0
    for img, band in zip(imgs, bands):
        plane, scale = rgbscales[band]
        img = np.maximum(0, img * scale + m)
        I = I + img
    I /= len(bands)

#     Q = 20
#     alpha = 1
#     p =0.7
#     #     fI = np.arcsinh(Q * I) / np.sqrt(Q)
    fI = np.arcsinh(alpha * Q * I) / (Q**p)
    I += (I == 0.0) * 1e-6
    H, W = I.shape
    rgb = np.zeros((H, W, 3), np.float32)
    for img, band in zip(imgs, bands):
        plane, scale = rgbscales[band]
        rgb[:, :, plane] = (img * scale + m) * fI / I

    # R = fI * r / I
    # G = fI * g / I
    # B = fI * b / I
    # # maxrgb = reduce(np.maximum, [R,G,B])
    # # J = (maxrgb > 1.)
    # # R[J] = R[J]/maxrgb[J]
    # # G[J] = G[J]/maxrgb[J]
    # # B[J] = B[J]/maxrgb[J]
    # rgb = np.dstack((R,G,B))
    rgb = np.clip(rgb, 0, 1)
    return rgb

In [None]:
# from astropy.visualization import make_lupton_rgb

# # The function below has not yet been finalized. Can be fine tuned before incorporating into the main code

# def plot_image(image, band, scaling="linear", ax=None, show=False, input_bands=None):
#     """Plot different colored images of galaxies

#     Args:
#         image (array): five colored sdss image
#         band (str): u, g, r, i or z band or gri composite image (also works with 0,1,2,3,4,5 codes)
#         scaling: linear or asinh for the single band images. gri images are always asinh scaled
#         ax (object): Matplotlib object to plot on
#         show (bool): Whether or not to show the plot
#         input_bands: use gri if input image has only three colors
        
#    Returns:
#         Matplotlib axis object
#     """
    
#     bands = {"u":0, "g":1, "r":2, "i":3, "z":4, "gri":5}
    
#     assert (band in bands) or (band in bands.values()) , "Choose from u, g, r, i, z bands or gri composite image"
#     assert (scaling in ["linear", "asinh"]), "scaling should be either linear or asinh for the single band images"
    
    
    
#     if ax == None:
#         fig, ax = plt.subplots()
      
#     if (band == "gri") or (band==5):
#         if input_bands == "gri":
#             stretch = 1
#             Q=8
#             scale =1.3
#             rgb = make_lupton_rgb(scale*1*image[:,:,2], scale*1.8*image[:,:,1], scale*2.3*image[:,:,0], stretch=stretch, Q=Q)
#         else:
#             stretch = 1.5
#             Q=5
#             scale = 1
#             rgb = make_lupton_rgb(scale*1*image[:,:,3], scale*1.5*image[:,:,2], scale*2.5*image[:,:,1], stretch=stretch, Q=Q, minimum=-0.02)
        
#         ax.imshow(rgb, aspect="equal", origin="lower")
#         ax.axes.get_xaxis().set_ticks([])
#         ax.axes.get_yaxis().set_ticks([])
    
#     else:
        
#         if band in bands:
#             band = bands[band]
#         if scaling == "linear":
#             ax.imshow(image[:,:,band], aspect="equal", origin="lower", cmap="Greys_r")
            
#         if scaling == "asinh":
#             img = make_lupton_rgb(image[:,:,band], image[:,:,band], image[:,:,band], stretch=stretch, Q=Q)
#             ax.imshow(img[:,:,0], aspect="equal", origin="lower", cmap="Greys_r")
#         ax.axis("off")
        
    
        
#     if show:
#         plt.show()
        
#     return ax

In [None]:
def plot_image(image, band="gri", ax=None, m=0., Q=20, alpha=0.8, p=0.7):
    rgb = sdss_rgb(np.moveaxis(image, -1,0)[1:4], [ "g", "r", "i"],m=m, Q=Q, alpha=alpha, p=p)
    if ax == None:
        fig, ax = plt.subplots()
    ax.imshow(rgb, aspect="equal", origin="lower")
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticks([])
    return ax

In [None]:
# index = 0

# from scipy import ndimage
# fig, ax = plt.subplots(1,2)
# ax = ax.ravel()
# rgb_obs = sdss_rgb(np.moveaxis(scale*x_dev[index], -1,0)[1:4], [ "g", "r", "i"],m=-0.02)
# # rgb_obs = ndimage.median_filter(rgb_obs, 2)

# ax[0].imshow(rgb_obs, aspect="equal", origin="lower")
# ax[0].set_xlabel("Observed", fontsize=20)
# rgb_recon = sdss_rgb(np.moveaxis(scale*x_recon_dev[index], -1,0)[1:4], [ "g", "r", "i"],m =-0.02)
# ax[1].imshow(rgb_recon, aspect="equal", origin="lower")
# ax[1].set_xlabel("Reconstructed", fontsize=20)

In [None]:
index = 0

fig, ax = plt.subplots(1,2)
ax = ax.ravel()
plot_image(scale*x_dev[index], "gri", ax=ax[0])
ax[0].set_xlabel("Observed", fontsize=20)
plot_image(scale*x_recon_dev[index], "gri", ax=ax[1])
ax[1].set_xlabel("Reconstructed", fontsize=20)

In [None]:
for i in [1,2,3]:
    plt.hist(np.ravel(scale*x_dev[index][i]), histtype="step", label="observed")
    plt.hist(np.ravel(scale*x_recon_dev[index][i]), histtype="step", label="recon")
    plt.legend()
    plt.show()

In [None]:
mean_o = np.mean(scale*x_dev[index], axis=(0,1))
std_o = np.std(scale*x_dev[index], axis=(0,1))

In [None]:
mean_r = np.mean(scale*x_recon_dev[index], axis=(0,1))
std_r = np.std(scale*x_recon_dev[index], axis=(0,1))

In [None]:
index = 0

fig, ax = plt.subplots(1,2)
ax = ax.ravel()
plot_image(scale*x_dev[index], "gri", ax=ax[0])
ax[0].set_xlabel("Observed", fontsize=20)
abcd = (((scale*x_recon_dev[index]-mean_r)/std_r))*std_o + mean_r
plot_image(abcd, "gri", ax=ax[1])
ax[1].set_xlabel("Reconstructed", fontsize=20)

In [None]:
fig, axs = plt.subplots(3,2, figsize=(7.8,12))
axs =axs.flatten()
selected_spirals = [0, 14, 13]
for i in range(3):
    plot_image(scale*x_dev[selected_spirals[i]], "gri", ax = axs[2*(i)])
    plot_image(scale*x_recon_dev[selected_spirals[i]], "gri", ax =axs[2*(i)+1])
t = fig.suptitle("Spirals", fontsize=40, y=1.0)
axs[-2].set_xlabel("Observed", fontsize=30)
axs[-1].set_xlabel("Reconstructed", fontsize=30)
plt.tight_layout()
fig.savefig("./figs/disks.pdf",bbox_inches='tight',bbox_extra_artists=[t],dpi=300)

In [None]:
fig, axs = plt.subplots(3,2, figsize=(7.8,12))
axs =axs.flatten()
selected_spirals = [20, 57, 80]
for i in range(3):
    plot_image(scale*x_dev[selected_spirals[i]], "gri", ax = axs[2*(i)])
    plot_image(scale*x_recon_dev[selected_spirals[i]], "gri", ax =axs[2*(i)+1])
t = fig.suptitle("Ellipticals", fontsize=40, y=1.)
axs[-2].set_xlabel("Observed", fontsize=30)
axs[-1].set_xlabel("Reconstructed", fontsize=30)
plt.tight_layout()
fig.savefig("./figs/spheroids.pdf",bbox_inches='tight',bbox_extra_artists=[t],dpi=300)

In [None]:
#with tf.device('/cpu:0'):
model = load_model( path_results / "train_model.json", path_results/ "weights" / f"weights-{checkpoint_eval:02d}.h5")
model = multi_gpu_model(model,gpus=2)
model.summary()

In [None]:
config["input_shape"] = config["image_shape"]
CapsNet = import_model(model_name=config["model_name"])
train_model, eval_model,manipulate_model,decoder_model,redshift_model, = CapsNet(**config)
manipulate_model.load_weights(
    path_results / "weights" / f"weights-{checkpoint_eval:02d}.h5", by_name=True
)

# Tinker All disk

In [None]:
config["input_shape"] = config["image_shape"]
CapsNet = import_model(model_name=config["model_name"])
train_model, eval_model,manipulate_model,decoder_model,redshift_model, = CapsNet(**config)
manipulate_model.load_weights(
    path_results / "weights" / f"weights-{checkpoint_eval:02d}.h5", by_name=True
)




img_indx =0 # 20 and 0

sigma_arr = np.std(y_caps_dev, axis=0)

caps_gal = y_caps_dev[img_indx].copy()

change_grid = [-3,-2,-1,0,1,2,3]
num_caps=16
fig, axs = plt.subplots(num_caps,len(change_grid), figsize=(1.4*8.3,2.1*11.7))


for caps_index in range(num_caps):
    for j in range(len(change_grid)):
        tinkered_dim = caps_gal[caps_index] + change_grid[j] * sigma_arr[caps_index]
        tinkered_caps = caps_gal.copy()
        tinkered_caps[caps_index] = tinkered_dim
        tinkered_recon = decoder_model.predict(np.expand_dims(tinkered_caps, axis=0))[0]
        plot_image(scale*tinkered_recon, "gri", ax =axs[caps_index][j] )

cols = ['{}$\sigma$'.format(col) for col in change_grid]
rows = ['Dim: {}'.format(row) for row in np.arange(1,num_caps+1).astype(str) ]

for ax, col in zip(axs[0], cols):
    ax.set_title(col, size=25)

for ax, row in zip(axs[:,0], rows):
    ax.set_ylabel(row, size=23)
# fig.suptitle("Redshift: "+ str(z_spec_dev[img_indx]), y =1.01, size=20)
plt.tight_layout()
fig.savefig("./figs/tinker_disk_appendix.pdf",bbox_inches='tight',dpi=300)

# Tinker all spheroid

In [None]:
config["input_shape"] = config["image_shape"]
CapsNet = import_model(model_name=config["model_name"])
train_model, eval_model,manipulate_model,decoder_model,redshift_model, = CapsNet(**config)
manipulate_model.load_weights(
    path_results / "weights" / f"weights-{checkpoint_eval:02d}.h5", by_name=True
)




img_indx =20 # 20 and 0

sigma_arr = np.std(y_caps_dev, axis=0)

caps_gal = y_caps_dev[img_indx].copy()

change_grid = [-3,-2,-1,0,1,2,3]
num_caps=16
fig, axs = plt.subplots(num_caps,len(change_grid), figsize=(1.4*8.3,2.1*11.7))


for caps_index in range(num_caps):
    for j in range(len(change_grid)):
        tinkered_dim = caps_gal[caps_index] + change_grid[j] * sigma_arr[caps_index]
        tinkered_caps = caps_gal.copy()
        tinkered_caps[caps_index] = tinkered_dim
        tinkered_recon = decoder_model.predict(np.expand_dims(tinkered_caps, axis=0))[0]
        plot_image(scale*tinkered_recon, "gri", ax =axs[caps_index][j] )

cols = ['{}$\sigma$'.format(col) for col in change_grid]
rows = ['Dim: {}'.format(row) for row in np.arange(1,num_caps+1).astype(str) ]

for ax, col in zip(axs[0], cols):
    ax.set_title(col, size=25)

for ax, row in zip(axs[:,0], rows):
    ax.set_ylabel(row, size=23)
# fig.suptitle("Redshift: "+ str(z_spec_dev[img_indx]), y =1.01, size=20)
plt.tight_layout()
fig.savefig("./figs/tinker_spheroid_appendix.pdf",bbox_inches='tight',dpi=300)

# Tinker some

In [None]:
config["input_shape"] = config["image_shape"]
CapsNet = import_model(model_name=config["model_name"])
train_model, eval_model,manipulate_model,decoder_model,redshift_model, = CapsNet(**config)
manipulate_model.load_weights(
    path_results / "weights" / f"weights-{checkpoint_eval:02d}.h5", by_name=True
)


img_indx = 20

sigma_arr = np.std(y_caps_dev, axis=0)

caps_gal = y_caps_dev[img_indx].copy()

change_grid = [-3, -2, -1, 0, 1, 2, 3]

num_caps=[1,8,12,13]

fig, axs = plt.subplots(len(num_caps),len(change_grid), figsize=(15,9))


for caps_count, caps_dim in enumerate(num_caps):
    for j in range(len(change_grid)):
        tinkered_dim = caps_gal[caps_dim] + change_grid[j] * sigma_arr[caps_dim]
        tinkered_caps = caps_gal.copy()
        tinkered_caps[caps_dim] = tinkered_dim
        tinkered_recon = decoder_model.predict(np.expand_dims(tinkered_caps, axis=0))[0]
        plot_image(scale*tinkered_recon, "gri", ax =axs[caps_count][j] )


col_names = [ r"$-3\sigma$", r"$-2\sigma$", r"$-1\sigma$", r"$0\sigma$", r"$1\sigma$", r"$2\sigma$", r"$3\sigma$"]
row_names = ["Size\n(Dim: 2)","Orientation\n(Dim: 9)", "Bulge\n(Dim: 13)", "Surface\nBrightness\n(Dim: 14)" ]
for ax, col in zip(axs[0], col_names):
    ax.set_title(col, fontsize=30)

for ax, row in zip(axs[:,0], row_names):
    ax.set_ylabel(row, fontsize=25)
# t=fig.suptitle("Spirals", size=30, y=1.01)
plt.tight_layout()
fig.savefig("./figs/tinker_spheroid.pdf",bbox_inches="tight")#,bbox_extra_artists=[t])

In [None]:
config["input_shape"] = config["image_shape"]
CapsNet = import_model(model_name=config["model_name"])
train_model, eval_model,manipulate_model,decoder_model,redshift_model, = CapsNet(**config)
manipulate_model.load_weights(
    path_results / "weights" / f"weights-{checkpoint_eval:02d}.h5", by_name=True
)


img_indx =0

sigma_arr = np.std(y_caps_dev, axis=0)

caps_gal = y_caps_dev[img_indx].copy()

change_grid = [-3, -2, -1, 0, 1, 2, 3]

num_caps=[1,8,12,13]

fig, axs = plt.subplots(len(num_caps),len(change_grid), figsize=(15,9))


for caps_count, caps_dim in enumerate(num_caps):
    for j in range(len(change_grid)):
        tinkered_dim = caps_gal[caps_dim] + change_grid[j] * sigma_arr[caps_dim]
        tinkered_caps = caps_gal.copy()
        tinkered_caps[caps_dim] = tinkered_dim
        tinkered_recon = decoder_model.predict(np.expand_dims(tinkered_caps, axis=0))[0]
        plot_image(scale*tinkered_recon, "gri", ax =axs[caps_count][j] )


col_names = [ r"$-3\sigma$", r"$-2\sigma$", r"$-1\sigma$", r"$0\sigma$", r"$1\sigma$", r"$2\sigma$", r"$3\sigma$"]
row_names = ["Size\n(Dim: 2)","Orientation\n(Dim: 9)", "Bulge\n(Dim: 13)", "Surface\nBrightness\n(Dim: 14)" ]
for ax, col in zip(axs[0], col_names):
    ax.set_title(col, fontsize=30)

for ax, row in zip(axs[:,0], row_names):
    ax.set_ylabel(row, fontsize=25)
# t=fig.suptitle("Spirals", size=30, y=1.01)
plt.tight_layout()
fig.savefig("./figs/tinker_disk.pdf",bbox_inches="tight")#,bbox_extra_artists=[t])

# GIF for presentation

In [None]:
img_indx =[0, 20]

config["input_shape"] = config["image_shape"]
CapsNet = import_model(model_name=config["model_name"])
train_model, eval_model,manipulate_model,decoder_model,redshift_model, = CapsNet(**config)
manipulate_model.load_weights(
    path_results / "weights" / f"weights-{checkpoint_eval:02d}.h5", by_name=True
)

# MAke GIF

In [None]:
sigma_arr = np.std(y_caps_dev, axis=0)

step =0.2
pause = 10
change_grid = np.concatenate([np.arange(0,3+step,step),
                              3*np.ones(pause),
                              np.arange(3,-3-step,-1*step),
                              -3*np.ones(pause),
                              np.arange(-3-step,0,step),
                              0*np.ones(pause),])

num_caps=[1,8,12,13]
filenames = []

for frame, tinker in enumerate(change_grid):
    fig, axs = plt.subplots(2,6, figsize=(18,15*1080/1920))
    for img_count, i in enumerate(img_indx):
        plot_image(scale*x_dev[i], "gri", ax =axs[img_count][0])
        plot_image(scale*x_recon_dev[i], "gri", ax =axs[img_count][1])
        
        caps_gal = y_caps_dev[i].copy()
        for caps_count, caps_dim in enumerate(num_caps):
            tinkered_dim = caps_gal[caps_dim] + tinker * sigma_arr[caps_dim]
            tinkered_caps = caps_gal.copy()
            tinkered_caps[caps_dim] = tinkered_dim
            tinkered_recon = decoder_model.predict(np.expand_dims(tinkered_caps, axis=0))[0]
            plot_image(scale*tinkered_recon, "gri", ax =axs[img_count][caps_count+2] )
    row_names = ["Disk", "Spheroid"]
    col_names = ["Observed","Reconstructed", f"Size\n ({round(tinker,3)}$\sigma$)",
                 f"Orientation\n ({round(tinker,3)}$\sigma$)",
                 f"Central Bulge\n ({round(tinker,3)}$\sigma$)",
                f"Surface\nBrightness\n ({round(tinker,3)}$\sigma$)",]
    for ax, col in zip(axs[0], col_names):
        ax.set_title(col, fontsize=30, y = 1.1)
    for ax, row in zip(axs[:,0], row_names):
        ax.set_ylabel(row, fontsize=40)
    # save frame
    filename = f'{frame}.png'
    filenames.append(filename)
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight", dpi=100)
    plt.close()

# build gif
with imageio.get_writer('./figs/tinker_gif.gif', mode='I', fps=5) as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

# frames = []
# for filename in filenames:
#     image = imageio.imread(filename)
#     frames.append(image)
# imageio.mimsave('./figs/tinker_gif.gif', frames)

# optimize('./figs/tinker_gif.gif')
# Remove files
for filename in set(filenames):
    os.remove(filename)

In [None]:
from IPython.display import Image
Image(filename="./figs/tinker_gif.gif")

# Make video

In [None]:
sigma_arr = np.std(y_caps_dev, axis=0)

step =0.2
pause = 10
change_grid = np.concatenate([np.arange(0,3+step,step),
                              3*np.ones(pause),
                              np.arange(3,-3-step,-1*step),
                              -3*np.ones(pause),
                              np.arange(-3-step,0,step),
                              0*np.ones(pause),])

num_caps=[1,8,12,13]
imgs = []

for frame, tinker in enumerate(change_grid):
    fig, axs = plt.subplots(2,6, figsize=(18,15*1080/1920))
    for img_count, i in enumerate(img_indx):
        plot_image(scale*x_dev[i], "gri", ax =axs[img_count][0])
        plot_image(scale*x_recon_dev[i], "gri", ax =axs[img_count][1])
        
        caps_gal = y_caps_dev[i].copy()
        for caps_count, caps_dim in enumerate(num_caps):
            tinkered_dim = caps_gal[caps_dim] + tinker * sigma_arr[caps_dim]
            tinkered_caps = caps_gal.copy()
            tinkered_caps[caps_dim] = tinkered_dim
            tinkered_recon = decoder_model.predict(np.expand_dims(tinkered_caps, axis=0))[0]
            plot_image(scale*tinkered_recon, "gri", ax =axs[img_count][caps_count+2] )
    row_names = ["Disk", "Spheroid"]
    col_names = ["Observed","Reconstructed", f"Size\n ({round(tinker,3)}$\sigma$)",
                 f"Orientation\n ({round(tinker,3)}$\sigma$)",
                 f"Central Bulge\n ({round(tinker,3)}$\sigma$)",
                f"Surface\nBrightness\n ({round(tinker,3)}$\sigma$)",]
    for ax, col in zip(axs[0], col_names):
        ax.set_title(col, fontsize=30, y = 1.1)
    for ax, row in zip(axs[:,0], row_names):
        ax.set_ylabel(row, fontsize=40)
    # save frame
    filename = f'{frame}.png'
    filenames.append(filename)
    plt.tight_layout()
    plt.savefig(filename, bbox_inches="tight", dpi=100)
    plt.close()

# build gif
with imageio.get_writer('./figs/tinker_vid.mp4', mode='I', fps=5) as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

# frames = []
# for filename in filenames:
#     image = imageio.imread(filename)
#     frames.append(image)
# imageio.mimsave('./figs/tinker_gif.gif', frames)

# optimize('./figs/tinker_gif.gif')
# Remove files
for filename in set(filenames):
    os.remove(filename)

In [None]:
from IPython.display import Video
Video(data="./figs/tinker_vid.mp4")

# Check Dim 10

In [None]:
threshold = np.percentile(y_caps_dev[:,9], 95.5)
threshold2 = np.percentile(y_caps_dev[:,9], 95.6)
mask = (y_caps_dev[:,9]>threshold) & (y_caps_dev[:,9]<threshold2)

In [None]:
config["input_shape"] = config["image_shape"]
CapsNet = import_model(model_name=config["model_name"])
train_model, eval_model,manipulate_model,decoder_model,redshift_model, = CapsNet(**config)
manipulate_model.load_weights(
    path_results / "weights" / f"weights-{checkpoint_eval:02d}.h5", by_name=True
)


sigma_arr = np.std(y_caps_dev, axis=0)

caps_gal = y_caps_dev.copy()

change_grid = [-3,-2,-1,0,1,2,3]

fig, axs = plt.subplots(np.sum(mask),len(change_grid), figsize=(30,200))



for j in range(len(change_grid)):
    tinkered_dim = caps_gal[mask][:,9] + change_grid[j] * sigma_arr[9]
    tinkered_caps = (caps_gal[mask]).copy()
    tinkered_caps[:,9] = tinkered_dim
    tinkered_recon = decoder_model.predict(tinkered_caps)
    
    for k in range(np.sum(mask)):
        plot_image(scale*tinkered_recon[k], "gri", ax =axs[k][j] )

cols = ['{}$\sigma$'.format(col) for col in change_grid]
# rows = ['Dim: {}'.format(row) for row in np.arange(1,num_caps+1).astype(str) ]

for ax, col in zip(axs[0], cols):
    ax.set_title(col, size=25)

# for ax, row in zip(axs[:,0], rows):
#     ax.set_ylabel(row, size=23)
# # fig.suptitle("Redshift: "+ str(z_spec_dev[img_indx]), y =1.01, size=20)
# plt.tight_layout()
# fig.savefig("./figs/tinker_spheroid_appendix.pdf",bbox_inches='tight',dpi=300)
