# Training

## Load modules

In [None]:
import os

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

from maddeb.callbacks import define_callbacks
from maddeb.FlowVAEnet import FlowVAEnet
from maddeb.losses import deblender_encoder_loss_wrapper, deblender_loss_fn_wrapper
from maddeb.utils import get_data_dir_path

tfd = tfp.distributions

## Creating toy data

In [None]:
isolated_noisy_galaxies = np.random.rand(8, 11, 11, 6)
noiseless_galaxies = np.random.rand(8, 11, 11, 6)
blended_galaxies = np.random.rand(8, 11, 11, 6)

## Define the model

In [None]:
kl_prior = tfd.Independent(
    tfd.Normal(loc=tf.zeros(1), scale=1), reinterpreted_batch_ndims=1
)

f_net = FlowVAEnet(
    stamp_shape=11,
    latent_dim=4,
    filters_encoder=[1, 1, 1, 1],
    filters_decoder=[1, 1, 1],
    kernels_encoder=[1, 1, 1, 1],
    kernels_decoder=[1, 1, 1],
    dense_layer_units=1,
    num_nf_layers=1,
    kl_prior=kl_prior,
    kl_weight=1,
)

## Train VAE as a denoiser

In [None]:
vae_epochs = 2

data = np.random.rand(8, 11, 11, 6)

# Keras Callbacks
data_path = get_data_dir_path()

path_weights = os.path.join(data_path, "test_temp")
callbacks = define_callbacks(
    os.path.join(path_weights, "vae"),
    lr_scheduler_epochs=1,
    patience=1,
)

_ = f_net.train_vae(
    (isolated_noisy_galaxies[:6], noiseless_galaxies[:6]),  # training
    (isolated_noisy_galaxies[6:], noiseless_galaxies[6:]),  # validation
    callbacks=callbacks,
    epochs=int(0.5 * vae_epochs),
    train_encoder=True,
    train_decoder=True,
    track_kl=True,
    optimizer=tf.keras.optimizers.Adam(1e-5, clipvalue=0.1),
    loss_function=deblender_loss_fn_wrapper(
        sigma_cutoff=np.array([1] * 6),  # Noise level in the data
        linear_norm_coeff=1,  # coefficient of linear normalization
    ),
    verbose=2,
    # loss_function=vae_loss_fn_wrapper(sigma=noise_sigma, linear_norm_coeff=linear_norm_coeff),
)

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_5 (InputLayer)        [(None, 11, 11, 6)]       0         
                                                                 
 encoder (Functional)        (None, 14)                95        
                                                                 
 latent_space (MultivariateN  ((None, 4),              0         
 ormalTriL)                   (None, 4))                         
                                                                 
 decoder (Functional)        (None, 11, 11, 6)         686       
                                                                 
Total params: 781
Trainable params: 781
Non-trainable params: 0
_________________________________________________________________



--- Training only VAE network ---
Encoder status: True
Decoder status: True
Number of epochs: 1



Epoch 1: val_mse improved from inf to 0.32158, saving model to /pbs/throng/lsst/users/bbiswas/FlowDeblender/maddeb/data/test_temp/vae/val_mse/weights.ckpt

Epoch 1: val_loss improved from inf to 137.18494, saving model to /pbs/throng/lsst/users/bbiswas/FlowDeblender/maddeb/data/test_temp/vae/val_loss/weights.ckpt
1/1 - 8s - loss: 140.7427 - mse: 0.3327 - kl_metric: 0.6108 - val_loss: 137.1849 - val_mse: 0.3216 - val_kl_metric: 1.4766 - lr: 4.0000e-06 - 8s/epoch - 8s/step


## Train Normalizing Flow

In [None]:
f_net.load_vae_weights(os.path.join(path_weights, "vae", "val_loss"))

In [None]:
flow_epochs = 2

callbacks = define_callbacks(
    os.path.join(path_weights, "flow"),
    lr_scheduler_epochs=1,
    patience=1,
)

hist_flow = f_net.train_flow(
    (
        isolated_noisy_galaxies[:6],
        np.zeros_like(isolated_noisy_galaxies[:6]),
    ),  # training
    (
        isolated_noisy_galaxies[6:],
        np.zeros_like(isolated_noisy_galaxies[6:]),
    ),  # validation
    callbacks=callbacks,
    optimizer=tf.keras.optimizers.Adam(1e-4, clipvalue=0.01),
    epochs=flow_epochs,
    verbose=2,
)

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_5 (InputLayer)        [(None, 11, 11, 6)]       0         
                                                                 
 encoder (Functional)        (None, 14)                95        
                                                                 
 latent_space (MultivariateN  ((None, 4),              0         
 ormalTriL)                   (None, 4))                         
                                                                 
 flow (Functional)           (None,)                   1480      
                                                                 
Total params: 1,575
Trainable params: 1,480
Non-trainable params: 95
_________________________________________________________________



--- Training only FLOW network ---
Number of epochs: 2


Epoch 1/2


2024-03-20 21:49:40.467868: I tensorflow/compiler/xla/service/service.cc:173] XLA service 0x7f5650044350 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2024-03-20 21:49:40.467924: I tensorflow/compiler/xla/service/service.cc:181]   StreamExecutor device (0): Host, Default Version
2024-03-20 21:49:40.491523: I tensorflow/compiler/jit/xla_compilation_cache.cc:477] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.




Can save best model only with val_mse available, skipping.



Epoch 1: val_loss improved from inf to 5.78013, saving model to /pbs/throng/lsst/users/bbiswas/FlowDeblender/maddeb/data/test_temp/flow/val_loss/weights.ckpt
1/1 - 5s - loss: 5.7525 - val_loss: 5.7801 - lr: 4.0000e-05 - 5s/epoch - 5s/step
Epoch 2/2


Can save best model only with val_mse available, skipping.



Epoch 2: val_loss improved from 5.78013 to 5.42270, saving model to /pbs/throng/lsst/users/bbiswas/FlowDeblender/maddeb/data/test_temp/flow/val_loss/weights.ckpt
1/1 - 0s - loss: 5.7437 - val_loss: 5.4227 - lr: 1.6000e-05 - 133ms/epoch - 133ms/step


## Train VAE-deblender

In [None]:
f_net_original = FlowVAEnet(
    stamp_shape=11,
    latent_dim=4,
    filters_encoder=[1, 1, 1, 1],
    filters_decoder=[1, 1, 1],
    kernels_encoder=[1, 1, 1, 1],
    kernels_decoder=[1, 1, 1],
    dense_layer_units=1,
    num_nf_layers=1,
    kl_prior=kl_prior,
    kl_weight=1,
)
f_net_original.load_vae_weights(os.path.join(path_weights, "vae", "val_loss"))

In [None]:
callbacks = define_callbacks(
    os.path.join(path_weights, "deblender"),
    lr_scheduler_epochs=1,
    patience=1,
)

In [None]:
hist_deblender = f_net.train_encoder(
    (blended_galaxies[:6], isolated_noisy_galaxies[:6]),  # training
    (blended_galaxies[6:], isolated_noisy_galaxies[6:]),  # validation
    callbacks=callbacks,
    epochs=2,
    optimizer=tf.keras.optimizers.Adam(1e-5, clipvalue=0.1),
    loss_function=deblender_encoder_loss_wrapper(
        original_encoder=f_net_original.encoder,
        noise_sigma=np.array([1] * 6),
        latent_dim=4,
    ),
    verbose=2,
    # loss_function=vae_loss_fn_wrapper(sigma=noise_sigma, linear_norm_coeff=linear_norm_coeff),
)

Model: "encoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 11, 11, 6)]       0         
                                                                 
 conv2d (Conv2D)             (None, 6, 6, 1)           7         
                                                                 
 p_re_lu (PReLU)             (None, 6, 6, 1)           36        
                                                                 
 conv2d_1 (Conv2D)           (None, 3, 3, 1)           2         
                                                                 
 p_re_lu_1 (PReLU)           (None, 3, 3, 1)           9         
                                                                 
 conv2d_2 (Conv2D)           (None, 2, 2, 1)           2         
                                                                 
 p_re_lu_2 (PReLU)           (None, 2, 2, 1)           4   


--- Training only encoder network ---
Number of epochs: 2


Epoch 1/2


Can save best model only with val_mse available, skipping.



Epoch 1: val_loss improved from inf to 2.57688, saving model to /pbs/throng/lsst/users/bbiswas/FlowDeblender/maddeb/data/test_temp/deblender/val_loss/weights.ckpt
1/1 - 3s - loss: 5.0717 - val_loss: 2.5769 - lr: 4.0000e-06 - 3s/epoch - 3s/step
Epoch 2/2


Can save best model only with val_mse available, skipping.



Epoch 2: val_loss did not improve from 2.57688
1/1 - 0s - loss: 4.9084 - val_loss: 3.7430 - lr: 1.6000e-06 - 40ms/epoch - 40ms/step
