In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

In [None]:
%load_ext autoreload
%autoreload 2
%pylab inline


from functools import partial
from tqdm import tqdm
import pickle
from pathlib import Path
import tensorboard
import numpy as np
import jax
import jax.numpy as jnp

import flax.linen as nn
from flax.metrics import tensorboard

from haiku._src.nets.resnet import ResNet18
import optax
import haiku as hk

import tensorflow_datasets as tfds
import tensorflow as tf
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
tfb = tfp.bijectors
tfd = tfp.distributions

from sbi_lens.gen_dataset.lensing_lognormal_dataset import LensingLogNormalDataset
from sbi_lens.normflow.models import AffineSigmoidCoupling, ConditionalRealNVP, AffineCoupling

from sbi_lens.gen_dataset.utils import augmentation_noise, augmentation_flip
from sbi_lens.config import config_lsst_y_10


import cmasher as cmr
from chainconsumer import ChainConsumer

'unset XLA_FLAGS'

#### Deal with Warnings 

In [None]:
import logging


import tensorflow_probability as tfp; tfp = tfp.substrates.jax


tfp.distributions.TransformedDistribution(
    tfp.distributions.Normal(0.0, 1.0), tfp.bijectors.Identity()
)

logger = logging.getLogger("root")


class CheckTypesFilter(logging.Filter):
    def filter(self, record):
        return "check_types" not in record.getMessage()


logger.addFilter(CheckTypesFilter())


tfp.distributions.TransformedDistribution(tfp.distributions.Normal(0.0, 1.0), tfp.bijectors.Identity())


## Dataset configuration

In [None]:
N = config_lsst_y_10.N
map_size = config_lsst_y_10.map_size
sigma_e = config_lsst_y_10.sigma_e
gals_per_arcmin2 = config_lsst_y_10.gals_per_arcmin2
nbins = config_lsst_y_10.nbins
a = config_lsst_y_10.a
b = config_lsst_y_10.b
z0 = config_lsst_y_10.z0

truth = config_lsst_y_10.truth

params_name = config_lsst_y_10.params_name

tf.random.set_seed(1)

augmentation = lambda example: augmentation_flip(
    augmentation_noise(example=example,
                       N=N,
                       map_size=map_size,
                       sigma_e=sigma_e,
                       gal_per_arcmin2=gals_per_arcmin2,
                       nbins=nbins,
                       a=a,
                       b=b,
                       z0=z0))

# Build compressor

In [None]:
dim = 6

In [None]:
compressor = hk.transform_with_state(
    lambda y : ResNet18(dim)(y, is_training=True)
  )

In [None]:
parameters_resnet, opt_state_resnet = compressor.init(
    jax.random.PRNGKey(0), y=0.5 * jnp.ones([1, N, N, nbins]))


## Create Density Estimetor for the compressor

In [None]:
class MDN(hk.Module):
  def __call__(self, x, dim=6):
    self.dim=dim
    x = jax.nn.relu(hk.Linear(128)(x))
    x = jax.nn.relu(hk.Linear(128)(x))
    x =  jax.nn.relu(hk.Linear(int((dim * (dim + 1) / 2)+dim))(x))
    x=jnp.tanh(hk.Linear(int((dim * (dim + 1) / 2)+dim))(x))
    gaussian_mu  = hk.Linear(int((dim * (dim + 1) / 2)+dim))(x)[...,:dim] 
    gaussian_tril = hk.Linear(int((dim * (dim + 1) / 2)+dim))(x)[...,dim:] 
    dist = tfd.MultivariateNormalTriL(loc=gaussian_mu, 
                  scale_tril=tfp.bijectors.FillScaleTriL(
                              diag_bijector=tfp.bijectors.Softplus()
                            )(gaussian_tril))                                  
    
    return dist

In [None]:
# model= hk.without_apply_rng(hk.transform(lambda x : MDN()(x)))
# params_nd =model.init(jax.random.PRNGKey(0),  0.5*jnp.ones([6]))
model= hk.without_apply_rng(hk.transform(lambda theta, x : MDN()(x).log_prob(theta).squeeze()))
params_nd =model.init(jax.random.PRNGKey(0),  theta=0.5*jnp.ones([1,6]), x=0.5*jnp.ones([1, 6]))

In [None]:
parameters_compressor= hk.data_structures.merge(
    parameters_resnet,
    params_nd
  )

# Dataset

In [None]:
ds = tfds.load('LensingLogNormalDataset/year_10_without_noise_score_density', 
               split='train[:80000]', 
               data_dir = 'tensorflow_dataset')

In [None]:
ds = ds.repeat()
ds = ds.shuffle(1000)
ds = ds.map(augmentation)
ds = ds.batch(128)
ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
ds_train = iter(tfds.as_numpy(ds))


In [None]:
def loss_gnll(params, theta, x, state_resnet):
    y, opt_state_resnet = compressor.apply(params, state_resnet, None, x)
    log_prob =  model.apply(
          params, 
          theta, 
          y)
    return -jnp.mean(log_prob), opt_state_resnet

@jax.jit    
def update(
    params,
    opt_state,
    theta,
    x,
    state_resnet
  ):

    (loss, opt_state_resnet), grads = jax.value_and_grad(
      loss_gnll,
      has_aux=True
    )(params, theta, x, state_resnet)

    updates, new_opt_state = optimizer.update(
      grads,
      opt_state
    )

    new_params = optax.apply_updates(
      params,
      updates
    )

    return loss, new_params, new_opt_state, opt_state_resnet



In [None]:
total_steps = 15_000
lr_scheduler = optax.piecewise_constant_schedule(
    init_value=0.001,
    boundaries_and_scales={int(total_steps*0.2):0.5,
                           int(total_steps*0.4):0.5,
                           int(total_steps*0.6):0.5,
                           int(total_steps*0.8):0.5}
)

optimizer = optax.adam(learning_rate=lr_scheduler)

opt_state_nd = optimizer.init(parameters_compressor)

## Training

In [None]:
test=2

In [None]:
summary_writer = tensorboard.SummaryWriter('logs/')

batch_loss=[]
for batch in tqdm(range(total_steps + 1)):
     ex = next(ds_train)
     if not jnp.isnan(ex['simulation']).any():
        l, parameters_compressor, opt_state_c, opt_state_resnet = update(
        params = parameters_compressor, 
        opt_state=opt_state_nd,
        theta=ex['theta'],
        x=ex['simulation'],
        state_resnet=opt_state_resnet
     )
     summary_writer.scalar('train_loggaussian_var{}'.format(test), l, batch)
     summary_writer.scalar('learning_rate_loggaussian_var{}'.format(test),lr_scheduler(batch),  batch)

     if jnp.isnan(l):
       print('NaN Loss')
       break
     batch_loss.append(l)   

In [None]:
with open("params_nd_compressor_gnll.pkl", "wb") as fp:
  pickle.dump(parameters_compressor, fp)

with open("opt_state_resnet_gnll.pkl", "wb") as fp:
  pickle.dump(opt_state_resnet, fp)
