In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers
import sys
import time

tfd = tfp.distributions
tfb = tfp.bijectors
tfkl=keras.layers

import matplotlib.pyplot as plt
import numpy as np


#path to shared tensorflow dataset
eagle_dir='/storage/scratch/mhuertas/data/sfh/tensorflow_datasets/eagle'

Load Dataset

In [None]:
from sfh.datasets.eagle import eagle

dset_eagle = tfds.load('eagle', split='train', data_dir=eagle_dir)

Visualize

In [None]:
print("Train",len(dset_eagle))

for example in dset_eagle.take(3):
    print(example)

fig, axs = plt.subplots(1, 1)
for example in dset_eagle.take(3):
    #print(wl[example['inds_valid']])
    axs.plot(example['time'],example['SFR_Max'])

Preprocessing, include normalization

In [None]:
def preprocessing(example):
    return tf.reshape(example['SFR_Max'],(-1,100,1)), \
           tf.reshape(example['SFR_Max'],(-1,100,1))

def preprocessing_wmass(example):
    mass = example['Mstar'][:,0]
    mass_half = example['Mstar_Half'][:,0]
    tiler = tf.constant([100])
    mass = tf.reshape(tf.tile(mass, tiler),(-1,100,1))
    mass_half = tf.reshape(tf.tile(mass_half, tiler),(-1,100,1))
    sfr = tf.math.add(tf.reshape(example['SFR_Max'],(-1,100,1)), 1e-5)
    res = tf.concat([sfr, mass, mass_half], axis=2)
    return res, res

def preprocessing_wmass_atan(example):
    mass = example['Mstar'][:,0]
    #mass_half = example['Mstar_Half'][:,0]
    #sed = (tf.gather(example['sed'],inds, axis=1) + 20.70243)/2.0466275
    sed = example['sed']
    tiler = tf.constant([100])
    mass = tf.reshape(tf.tile(mass, tiler),(-1,100,1))
    #mass_half = tf.reshape(tf.tile(mass_half, tiler),(-1,100,1))
    sfr = tf.math.tanh(tf.math.asinh(tf.reshape(example['SFR_Max'],(-1,100,1))/40) + 1e-3 + 0.005*tf.math.softplus(tf.random.normal(shape=[64,100,1])))
    res = tf.concat([sfr], axis=2) #  mass, mass_half
    return (res, sed), res

def input_fn(mode='train', batch_size=64, 
             dataset_name='tng100', data_dir=None,
             include_mass=True, arctan=True):
    """
    mode: 'train' or 'test'
    """
    keys = ['sed','Mstar', 'SFR_Max', 'mass_quantiles', 'sed', 'time']
    if mode == 'train':
        dataset = tfds.load(dataset_name, split='train[:90%]', data_dir=data_dir)
        dataset = dataset.map(lambda x: {k:x[k] for k in keys})
        dataset = dataset.repeat()
        dataset = dataset.shuffle(10000)
    else:
        dataset = tfds.load(dataset_name, split='train[90%:]', data_dir=data_dir)
        dataset = dataset.map(lambda x: {k:x[k] for k in keys}) #dataset = dataset.repeat()
        
    dataset = dataset.batch(batch_size, drop_remainder=True)
    if include_mass and arctan:
        dataset = dataset.map(preprocessing_wmass_atan) # Apply data preprocessing
    elif include_mass:
        dataset = dataset.map(preprocessing_wmass)
    else : 
        dataset = dataset.map(preprocessing)
    dataset = dataset.prefetch(-1)       # fetch next batches while training current one (-1 for autotune)
    return dataset

Prepare your training and validation dataset

In [None]:
batch_size = 64
epochs = 10

dtrain_eagle = input_fn(mode='train', batch_size=batch_size, dataset_name='eagle',data_dir=eagle_dir)
dval_eagle = input_fn(mode='val', batch_size=batch_size, dataset_name='eagle',data_dir=eagle_dir)

Generating regression model (CNN with continuous output)

In [None]:
""""Keras model implementing PixelCNN."""


def generate_model():
    """Generate the regression Keras model.

    Parameters"""
    ----------

    return regression_cnn

In [None]:
regression_cnn = generate_model()
regression_cnn.summary()

Fit your model with EAGLE data

In [None]:
hist = regression_cnn.fit(dtrain_eagle, 
                     epochs=epochs,
                     steps_per_epoch=1000,validation_data=dval_eagle)

Test with EAGLE data

In [None]:
dset_test = dval_eagle.as_numpy_iterator()
data = next(dset_test)
ind=55
sample = np.zeros([64,100,1])
true = data[0][0][ind,:,0]
sed = data[0][1][ind].reshape([1,125,1]).repeat(64,axis=0)

# init at the 
sample[:,0,0] = true[0]

for i in range(99):
    tmp = eagle_cnn((sample, sed)).sample()
    sample[:,i+1,0] = tmp[:,i+1]

plt.plot(true,label='true SFH')
for i in range(64):
    plt.plot(sample[i,:,0],color='C1',alpha=0.1)
plt.plot(sample[1,:,0],color='C1',alpha=1.,label='individual sample')    
plt.plot(sample.mean(axis=0)[:,0],'--',color='red',label='mean posterior')
plt.legend(loc='upper left')


