In [40]:
import glob
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib notebook

In [41]:
spectrum_primary = pd.read_csv("spectrum_primary.csv",index_col=None,delim_whitespace=True)
spectrum_primary.head()

Unnamed: 0,wave,flux,normed_flux
0,3000.0033,0.000226,0.719871
1,3000.0083,0.000295,0.939125
2,3000.0133,0.000297,0.945691
3,3000.0183,0.000287,0.912834
4,3000.0233,0.000272,0.866837


In [42]:
spectrum_secondary = pd.read_csv("spectrum_primary.csv",index_col=None,delim_whitespace=True)
spectrum_secondary.head()

Unnamed: 0,wave,flux,normed_flux
0,3000.0033,0.000226,0.719871
1,3000.0083,0.000295,0.939125
2,3000.0133,0.000297,0.945691
3,3000.0183,0.000287,0.912834
4,3000.0233,0.000272,0.866837


In [43]:
spectrum_primary = {k:v.values for k,v in spectrum_primary.items()}
spectrum_primary["wave"] = np.log10(spectrum_primary["wave"])
spectrum_secondary = {k:v.values for k,v in spectrum_secondary.items()}
spectrum_secondary["wave"] = np.log10(spectrum_secondary["wave"])
spectrum_primary["log_flux"] = np.log10(spectrum_primary["flux"]/spectrum_secondary["normed_flux"])
spectrum_secondary["log_flux"] = np.log10(spectrum_secondary["flux"]/spectrum_secondary["normed_flux"])

In [44]:
NO_FLUX_POINTS = spectrum_primary["wave"].shape[0]

### Train NN model

In [6]:
import jax
from jax import lax, random, numpy as jnp
import flax
# from flax.core import freeze, unfreeze
from flax import linen as nn

from flax.training import train_state, checkpoints  # Useful dataclass to keep train state
# import numpy as np                                # Ordinary NumPy
import optax                                        # Optimizers

In [7]:
print(jax.devices())

[GpuDevice(id=0, process_index=0)]


In [8]:
def frequency_encoding(x, min_period, max_period, dimension):
    periods = jnp.logspace(jnp.log10(min_period), jnp.log10(max_period), num=dimension)
    
    y = jnp.sin(2*jnp.pi/periods*x)
    return y

class MLP_single_wavelength_sine(nn.Module):
    architecture: tuple = (256, 256, 256, 256)
    @nn.compact
    def __call__(self, x):
        w = x
        enc_w = frequency_encoding(w, min_period=1e-5, max_period=1.0, dimension=128)
        _x = enc_w
        for features in self.architecture:
            _x = nn.gelu(nn.Dense(features)(_x))
        x = nn.Dense(2, bias_init=nn.initializers.ones)(_x)
        return x
    
class MLP_wavelength_sine(nn.Module):
    
    @nn.compact
    def __call__(self, inputs, train):
        log_waves = inputs
        
        DecManyWave = nn.vmap(
                    MLP_single_wavelength_sine, 
                    in_axes=0, out_axes=0,
                    variable_axes={'params': None}, 
                    split_rngs={'params': False})
        
        x = DecManyWave(name="decoder")(log_waves)
        return x

In [9]:
TRAINING_STEPS    = 100000
WARM_UP_STEPS     = 10000
LEARNING_RATE     = 1e-3
NO_SAMPLES        = int(2**16)#37549#int(2**15)
CHECKPOINTS_DIR   = 'ckpts'

In [10]:
def mse_loss(y_pred, y_true):
    return jnp.mean((y_pred - y_true) ** 2)

def mae_loss(y_pred, y_true):
    return jnp.mean(jnp.abs(y_pred - y_true))

def quantile_absolute_error(y_pred, y_true, q=0.95):
    return jnp.quantile(jnp.abs(y_pred - y_true), q=q)

def maximum_absolute_error(y_pred, y_true):
    return jnp.max(jnp.abs(y_pred - y_true))

def mare_loss(y_pred, y_true):
    return jnp.mean(jnp.abs((y_pred - y_true)/y_true))

# average of worst q*100 percent fluxes
def quantile_mean_absolute_error(y_pred, y_true, q=0.95):
    abs_error = jnp.abs(y_pred - y_true)
    v = jnp.quantile(abs_error, q=q)
    return jnp.mean(abs_error, where=(abs_error>v))

In [11]:
def compute_metrics(*, y_pred, y_true):
    loss = mse_loss(y_pred, y_true)
    mae = mae_loss(y_pred, y_true)
    qae090 = quantile_absolute_error(y_pred, y_true, q=0.95)
    max_ae = maximum_absolute_error(y_pred, y_true)
    mare = mare_loss(y_pred, y_true)
    m95_per = quantile_mean_absolute_error(y_pred, y_true, q=0.95)
    metrics = {
      'loss': loss,
      'mae': mae,
      '95_per': qae090,
      'max_ae': max_ae,
      'mare': mare,
      'm95_per': m95_per,
    }
    return metrics

def create_train_state(rng, model):
    """Creates initial `TrainState`."""
    m = model()
    dummpy_input = jnp.ones(NO_SAMPLES)
    params = m.init(rng, dummpy_input, False)['params']
    
    schedule = optax.warmup_cosine_decay_schedule(
        init_value=0.0,
        peak_value=LEARNING_RATE,
        warmup_steps=WARM_UP_STEPS,
        decay_steps=TRAINING_STEPS-WARM_UP_STEPS,
        end_value=0.0,
    )
    tx = optax.adam(learning_rate=schedule)

    return train_state.TrainState.create(apply_fn=m.apply, params=params, tx=tx)


@jax.jit
def train_step(state, batch, rngs):
    """Train for a single step."""
    # Make sure to get a new RNG at every step.
    step = state.step
    rngs = {name: jax.random.fold_in(rng, step) for name, rng in rngs.items()}
    
    def loss_fn(params):
        y_pred = state.apply_fn({'params': params}, 
                                   batch["input"],
                                   train = True,
                                   rngs=rngs)
        loss = mse_loss(y_pred=y_pred, y_true=batch["flux"])
        return loss, y_pred
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (_,y_pred), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)

    metrics = compute_metrics(y_pred=y_pred, y_true=batch['flux'])
    
    return state, metrics

@jax.jit
def eval_step(state, batch):
    y_pred = state.apply_fn({'params': state.params},
                        batch["input"], 
                        train=False
                        )
    metrics = compute_metrics(y_pred=y_pred, y_true=batch['flux'])
    return metrics

In [17]:
def sample_from_example(wave: np.array, normed_flux: np.array, log_flux: np.array , no_samples:int):
    min_wave = wave[0]
    max_wave = wave[-1]
    
    sampled_wave = np.random.uniform(min_wave, max_wave, size=no_samples)
    
    return sampled_wave,np.interp(sampled_wave, wave, normed_flux),np.interp(sampled_wave, wave, log_flux)

idx = 1
log_wave = spectrum_primary["wave"]
normed_flux = spectrum_primary["normed_flux"]
log_flux = spectrum_primary["log_flux"]
sw, snf, slf = sample_from_example(wave=log_wave,
                         normed_flux=normed_flux, 
                         log_flux=log_flux,
                         no_samples=NO_SAMPLES
                        )

fig, (ax1, ax2) = plt.subplots(2)
ax1.plot(log_wave, normed_flux,'b')
ax1.plot(sw, snf,'r.')

ax2.plot(log_wave, log_flux,'b')
ax2.plot(sw, slf,'r.')

<IPython.core.display.Javascript object>

[<matplotlib.lines.Line2D at 0x7f5458498b20>]

In [18]:
def get_batch(rng, example):
    while True:
        log_wave = example["wave"]
        normed_flux = example["normed_flux"]
        log_flux = example["log_flux"]
        sw, snf,slf = sample_from_example(wave=log_wave,
                                 normed_flux=normed_flux, 
                                 log_flux=log_flux,
                                 no_samples=NO_FLUX_POINTS
                                )
        
        yield {"input":sw,"flux":np.array([snf,slf]).T}

In [57]:
def train(state, 
          rngs, 
          train_ds,
          model_name = "MLP_sine",
          no_steps_stats = 100):
    
    full_training_metrics = []
    train_metrics = []

    for step_number, batch in enumerate(get_batch(rng,train_ds)):

        # Training step:
        state, metrics = train_step(state, batch, rngs)
        train_metrics.append(metrics)

        if state.step % no_steps_stats == 1: # Summarize after each no_steps_stats steps

            # compute mean of metrics across last no_steps_stats steps
            train_summary = jax.device_get(train_metrics)
            train_metrics = []
            train_summary = {
                      k: np.mean([metrics[k] for metrics in train_summary])
                      for k in train_summary[0]}

            # Gather metrics
            full_training_metrics.append(train_summary)

            print(f'TRAINING   step: {step_number:6d}, ' + ", ".join([f"{k} = {v:.6f}" for k,v in train_summary.items()]))

        if state.step > TRAINING_STEPS:
            print("FINISH TRAINING: state.step >= TRAINING_STEPS")
            prefix = f"checkpoint_{model_name}_"
            checkpoints.save_checkpoint(ckpt_dir=CHECKPOINTS_DIR, 
                                        target=state,
                                        overwrite=True,
                                        prefix=prefix,
                                        step=step_number)


            df_train_summary = pd.DataFrame(full_training_metrics)
   

            break
    return state, df_train_summary

In [19]:
model = MLP_wavelength_sine
model_name = "MLP_sine_2"
# model = MLP_wavelength_sirens
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
rngs = {'params': jax.random.PRNGKey(0), 'dropout': jax.random.PRNGKey(1)}

state = create_train_state(init_rng, model)
del init_rng

In [58]:
# run training
train_ds = spectrum_primary
state, df_train_summary = train(state, 
                                 rngs, 
                                 train_ds,
                                 model_name = model_name,
                                 no_steps_stats = 1000)

TRAINING   step:      0, 95_per = 4.856848, loss = 11.006946, m95_per = 4.895677, mae = 2.388453, mare = 0.690712, max_ae = 5.067384
TRAINING   step:   1000, 95_per = 1.698011, loss = 2.348789, m95_per = 1.850382, mae = 0.709018, mare = 0.231598, max_ae = 2.931641
TRAINING   step:   2000, 95_per = 0.239298, loss = 0.014007, m95_per = 0.308838, mae = 0.080751, mare = 0.047891, max_ae = 0.846583
TRAINING   step:   3000, 95_per = 0.095875, loss = 0.002982, m95_per = 0.154832, mae = 0.037174, mare = 0.029291, max_ae = 0.764171
TRAINING   step:   4000, 95_per = 0.062923, loss = 0.001580, m95_per = 0.112546, mae = 0.025072, mare = 0.020514, max_ae = 0.728378
TRAINING   step:   5000, 95_per = 0.048067, loss = 0.001104, m95_per = 0.092732, mae = 0.019491, mare = 0.016234, max_ae = 0.700822
TRAINING   step:   6000, 95_per = 0.041829, loss = 0.000923, m95_per = 0.082358, mae = 0.017215, mare = 0.014083, max_ae = 0.668343
TRAINING   step:   7000, 95_per = 0.037446, loss = 0.000764, m95_per = 0.07

TRAINING   step:  63000, 95_per = 0.001611, loss = 0.000001, m95_per = 0.002431, mae = 0.000679, mare = 0.000495, max_ae = 0.057690
TRAINING   step:  64000, 95_per = 0.001547, loss = 0.000001, m95_per = 0.002354, mae = 0.000636, mare = 0.000469, max_ae = 0.057158
TRAINING   step:  65000, 95_per = 0.001511, loss = 0.000001, m95_per = 0.002310, mae = 0.000630, mare = 0.000463, max_ae = 0.056602
TRAINING   step:  66000, 95_per = 0.001465, loss = 0.000001, m95_per = 0.002266, mae = 0.000578, mare = 0.000440, max_ae = 0.056120
TRAINING   step:  67000, 95_per = 0.001425, loss = 0.000001, m95_per = 0.002210, mae = 0.000582, mare = 0.000435, max_ae = 0.055699
TRAINING   step:  68000, 95_per = 0.001364, loss = 0.000001, m95_per = 0.002160, mae = 0.000544, mare = 0.000418, max_ae = 0.055666
TRAINING   step:  69000, 95_per = 0.001332, loss = 0.000001, m95_per = 0.002117, mae = 0.000535, mare = 0.000408, max_ae = 0.055000
TRAINING   step:  70000, 95_per = 0.001292, loss = 0.000001, m95_per = 0.002

In [72]:
# Plot
ncols= df_train_summary.shape[1]
fig, axs = plt.subplots(ncols=ncols, figsize=(16,4))
fig.suptitle(f"training statistics", fontsize=16)
for i, c in enumerate(df_train_summary.columns):
    axs[i].semilogy(df_train_summary[c],label=f"Train {c}")
    axs[i].legend()
    axs[i].grid(which="both", axis='y')

<IPython.core.display.Javascript object>

In [45]:
prefix = f"checkpoint_{model_name}_"
restored_state = checkpoints.restore_checkpoint(ckpt_dir=CHECKPOINTS_DIR, target=state, prefix=prefix)
if state is restored_state:
    raise FileNotFoundError(f"Cannot load checkpoint from {CHECKPOINTS_DIR}")

In [54]:
@jax.jit
def predict(state, batch):
    y_pred = state.apply_fn({'params': state.params}, 
                        batch["input"], 
                        train=False
                       )
    return compute_metrics(y_pred=y_pred, y_true=batch['flux']), y_pred

def sample_from_example_sort(wave: np.array, normed_flux: np.array, log_flux: np.array , no_samples:int):
    min_wave = wave[0]
    max_wave = wave[-1]
    
    sampled_wave = np.sort(np.random.uniform(min_wave, max_wave, size=no_samples))
    
    return sampled_wave,np.interp(sampled_wave, wave, normed_flux),np.interp(sampled_wave, wave, log_flux)

idx = 1


spec = spectrum_primary
sw, snf, slf = sample_from_example_sort(wave=spec["wave"],
                         normed_flux=spec["normed_flux"], 
                         log_flux=spec["log_flux"],
                         no_samples=NO_SAMPLES
                        )
exact_log_wave = sw
exact_normed_flux = snf
exact_log_flux = slf

batch = {"input":exact_log_wave,"flux":np.array([exact_normed_flux,exact_log_flux]).T}
metrics, predicted_normed_flux = predict(restored_state, batch)
metrics

{'95_per': DeviceArray(0.00104731, dtype=float32),
 'loss': DeviceArray(3.453963e-07, dtype=float32),
 'm95_per': DeviceArray(0.00175422, dtype=float32),
 'mae': DeviceArray(0.00036404, dtype=float32),
 'mare': DeviceArray(0.00031313, dtype=float32),
 'max_ae': DeviceArray(0.01961231, dtype=float32)}

In [56]:
fig, axs = plt.subplots(nrows=4,figsize=(8,8),sharex=True)

axs[0].plot(exact_log_wave, exact_normed_flux,label="Normed flux")
axs[0].plot(exact_log_wave, predicted_normed_flux[:,0],label="Predicted normed flux")
axs[0].legend()

axs[1].plot(exact_log_wave, (predicted_normed_flux[:,0]-exact_normed_flux), label="Residua")
axs[1].legend()

axs[2].plot(exact_log_wave, exact_log_flux,label="log flux")
axs[2].plot(exact_log_wave, predicted_normed_flux[:,1],label="Predicted log flux")
axs[2].legend()

axs[3].plot(exact_log_wave, (predicted_normed_flux[:,1]-exact_log_flux), label="Residua")
axs[3].legend()

<IPython.core.display.Javascript object>

<matplotlib.legend.Legend at 0x7f54140fab80>