In [None]:

from google.colab import drive
drive.mount('/content/drive',force_remount=True)
import os
os.chdir("drive/My Drive/Ben_Boyd_MSc_Project/Data")


In [None]:


import matplotlib.pyplot as plt # creating visualizations
import numpy as onp # we still need original nump for some tasks


from matplotlib import cm
from matplotlib.ticker import LinearLocator

# using stax, much cleaner
from jax.experimental import stax
from jax.scipy.special import logsumexp #def generate_data(n_samples): Compute the log of the sum of exponentials of input elements
from jax.nn import softmax # pretty much the same as the interface as the one in scipy.special

from jax.experimental.stax import (Dense, Tanh, Flatten, Relu, LogSoftmax, Softmax, Exp,Sigmoid,Softplus,LeakyRelu)

import jax.numpy as np # JAX is supposed to have an API that closely resemble NumPy's
from jax import grad, jit, vmap, value_and_grad
from jax import random
import jax.nn as nn

from astropy.table import Table

from astropy import table
from astropy.io import fits,ascii,votable
from astropy import units as u 
from astropy import constants as const
from astropy import table
from astropy.cosmology import Planck15,FlatLambdaCDM
import copy
from jax.experimental import optimizers
from scipy.stats import norm
key = random.PRNGKey(777)

def sigma68(pred,real):
  sorted=onp.sort((pred-real)/(1+real))
  return 0.5*(sorted[int(0.841*len(pred))]-sorted[int(0.159*len(pred))])

def outlier_frac(pred,real):
  return len(pred[np.absolute((pred-real)/(1+real))>0.15])/len(pred)


def bias(pred,real):
  return np.median((pred-real)/(1+real))

def plot_func(fun,pred,real,bins,zlim):
  bin_edges=onp.linspace(0,zlim,bins+1)
  vals=onp.array([])
  for i in range(bins):
    logic= onp.logical_and(real>bin_edges[i],real<bin_edges[i+1])
    vals=onp.append(vals,fun(pred[logic],real[logic]))
  width=bin_edges[1]-bin_edges[0]
  plot=plt.plot(-width*0.5+bin_edges[1:],vals)
  plt.xlabel('True Redshift',fontsize=16)
  plt.xticks(fontsize=14)
  plt.yticks(fontsize=14)
  return plot

n_mixture = 3
batch_size=100000
epochs=10000
n_input=22
test_name='MDN_fewer_bands'

# get output from network
#init_fun, the_network = stax.serial(Dense(512), Relu,Dense(1024), Sigmoid,Dense(512), Relu,Dense(256),Relu,Dense(128), Relu,Dense(64), Relu,Dense(32),Relu, Dense(n_mixture*3))
init_fun, the_network = stax.serial(Dense(512),Relu, Dense(2054),Relu,Dense(512),Sigmoid,Dense(n_mixture*3))
#init_fun, the_network = stax.serial(Dense(256),Sigmoid, Dense(1024),Sigmoid,Dense(256),Sigmoid,Dense(n_mixture*3))
logSqrtTwoPI = onp.log(onp.sqrt(2.0 * onp.pi))

def lognormal(y, mean, logstd):
  return -0.5 * ((y - mean) / np.exp(logstd)) ** 2 - logstd - logSqrtTwoPI

def get_mdn_coef(output):
  logmix, mean, logstd = output.split(3, axis=1)
  logmix = nn.log_softmax(logmix)
  mean=nn.sigmoid(mean)
  return logmix, mean, logstd

def mdn_loss_func(logmix, mean, logstd, y):
  v = logmix + lognormal(y, mean, logstd)
  v = logsumexp(v, axis=1)
  return -np.mean(v)

def loss_fn(params, inputs, targets):
  """ MDN Loss function for training loop. """
  outputs = the_network(params, inputs)
  logmix, mean, logstd = get_mdn_coef(outputs)
  return mdn_loss_func(logmix, mean, logstd, targets)

@jit
def update(s,params, x, y, opt_state):
    """ Perform a forward pass, calculate the MSE & perform a SGD step. """
    loss, grads = value_and_grad(loss_fn)(params, x, y)
    opt_state = opt_update(s, grads, opt_state)
    return get_params(opt_state), opt_state, loss

def train(params, x_data, y_data, opt_state):
  s=0
  for epoch in range(epochs):
    loss_tot=0
    for b in range(int(len(y_data)/batch_size)):
        params, opt_state, loss = update(s,params, x_data[b*batch_size:(b+1)*batch_size,:], y_data[b*batch_size:(b+1)*batch_size,:], opt_state)
        loss_tot+=loss
        s+=1
    print('Epoch: ',epoch,' Loss: ', loss_tot/int(len(y_data)/batch_size))
    if epoch % 500 ==0:
      if epoch != 0:
        np.save(test_name+'_params.npy',params,allow_pickle=True)
  return params


def gumbel_sample(x, axis=1):
  z = onp.random.gumbel(loc=0, scale=1, size=x.shape)
  return (onp.log(x) + z).argmax(axis=axis)



def pit(true,pi_data,mu_data,std):
  n,k=onp.shape(logmix)
  pit_val=onp.zeros(n)
  for x in range(k):
    pit_val+=pi_data[:,x]*norm.cdf(true,mu_data[:,x],std[:,x])
  return pit_val

from scipy import stats

def plot_pdf(pi,mu,sig,pred,true=None,zrange=6):
  plt.figure(figsize=(12,10))
  x=onp.linspace(0,zrange,100)
  pdf=onp.zeros(100)
  for i in range(len(mu)):
      pdf+=pi[i]*stats.norm.pdf(x,mu[i],sig[i])
  plt.plot(x,pdf,color='blue')
  plt.plot([pred,pred],[0,onp.max(pdf)],color='blue')
  plt.plot([true,true],[0,onp.max(pdf)],color='red')
  plt.show()

def minmax_fit_and_scale(X):
  min = X.min(axis=0)
  max = X.max(axis=0)
  X_std = (X - min) / (max - min)
  return X_std,min,max

def minmax_scale(X,min,max):
  return (X - min) / (max - min)

def minmax_unscale(X,min,max):
  return X * (max - min) + min




In [None]:


n_samples =  20000000 # just use train everything in a batch
beta=True
small_range=False
magnitudes=False
a=0
if beta:
  a+=5

x_data=onp.zeros((n_samples,54))
y_data=onp.zeros(n_samples)

cats=20

def mag_err(flux_err,flux):
    
    return 1.09*flux_err/flux

def mag(flux):
    x=-2.5*onp.log10(flux) + 23.9 
    x[onp.isnan(x)]=0
    return x 

  
for c in range(cats):
    cat = table.Table.read('sim/mil'+str(c+1)+'_noisy_gal.fits',format='fits',hdu=1)
    keys=cat.keys()
    count=0
    y_data[(c)*1000000:(c+1)*1000000]=cat['redshift']
    for key_name in keys:

        if key_name[len(key_name)-9:]=='BETA_FLUX':

                
            filt=key_name[:len(key_name)-10+a]

            if magnitudes:

              x_data[(c)*1000000:(c+1)*1000000,count]=mag(cat[filt+'_FLUX'])
              x_data[(c)*1000000:(c+1)*1000000,count+1]=mag_err(cat[filt+'_FLUXERR'],cat[filt+'_FLUX'])
            else:
              x_data[(c)*1000000:(c+1)*1000000,count]=cat[filt+'_FLUX']
              x_data[(c)*1000000:(c+1)*1000000,count+1]=cat[filt+'_FLUXERR']
            
            count+=2
            
            
        
    print(c)


x_data = np.array(x_data.reshape(n_samples, 54))

if n_input!=54:
  x_data=x_data[:,:n_input]



if small_range:
  ids=np.logical_and(y_data>=0,y_data<=2)
  y_data=y_data[ids]
  x_data=x_data[ids,:]

print(x_data)

y_data = np.array(y_data.reshape(len(y_data), 1))/6

cat=0



In [None]:
!pip install optax
import optax

In [None]:

constant_scheduler = optax.constant_schedule(0.001)

_, params = init_fun(key, (batch_size, n_input))
opt_init, opt_update, get_params = optimizers.adam(constant_scheduler )
opt_state = opt_init(params)




In [None]:
params = train(params, x_data, y_data, opt_state)
np.save(test_name+'_params.npy',params,allow_pickle=True)

