In [None]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)
import os
os.chdir("drive/My Drive/Ben_Boyd_MSc_Project/Data")


In [None]:


import matplotlib.pyplot as plt # creating visualizations
import numpy as onp # we still need original nump for some tasks


from matplotlib import cm
from matplotlib.ticker import LinearLocator

# using stax, much cleaner
from jax.experimental import stax
from jax.scipy.special import logsumexp #def generate_data(n_samples): Compute the log of the sum of exponentials of input elements
from jax.nn import softmax # pretty much the same as the interface as the one in scipy.special

from jax.experimental.stax import (Dense, Tanh, Flatten, Relu, LogSoftmax, Softmax, Exp,Sigmoid,Softplus,LeakyRelu)

import jax.numpy as np # JAX is supposed to have an API that closely resemble NumPy's
from jax import grad, jit, vmap, value_and_grad
from jax import random
import jax.nn as nn

from astropy.table import Table

from astropy import table
from astropy.io import fits,ascii,votable
from astropy import units as u 
from astropy import constants as const
from astropy import table
from astropy.cosmology import Planck15,FlatLambdaCDM
import copy
from jax.experimental import optimizers
from scipy.stats import norm
key = random.PRNGKey(123)

def sigma68(pred,real):
  sorted=onp.sort((pred-real)/(1+real))
  return 0.5*(sorted[int(0.841*len(pred))-1]-sorted[int(0.159*len(pred))-1])


def sigma_nmad(pred,real):

  med=np.median(pred-real)

  return 1.48*np.median(np.absolute(pred-real-med)/(1+real))

def outlier_frac(pred,real):
  return len(pred[np.absolute((pred-real)/(1+real))>0.15])/len(pred)


def bias(pred,real):
  return np.median((pred-real)/(1+real))

def plot_func(fun,pred,real,bins,zlim,ax,color='blue',label=None):
  bin_edges=onp.linspace(0,zlim,bins+1)
  vals=onp.array([])
  for i in range(bins):
    logic= onp.logical_and(real>bin_edges[i],real<bin_edges[i+1])
    vals=onp.append(vals,fun(pred[logic],real[logic]))
  width=bin_edges[1]-bin_edges[0]

  if ax==None:
    plot=plt.plot(-width*0.5+bin_edges[1:],vals)
    plt.xlabel('Spectroscopic Redshift',fontsize=16)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)


  else:
    plot=ax.plot(-width*0.5+bin_edges[1:],vals,color=color,label=label)
    ax.set_xlabel('Spectroscopic Redshift',fontsize=20)
    for tick in ax.xaxis.get_major_ticks():
      tick.label.set_fontsize(18) 
    for tick in ax.yaxis.get_major_ticks():
      tick.label.set_fontsize(18) 
  return plot

n_mixture = 3
batch_size=100000
epochs=10000
n_input=54
# get output from network
#init_fun, the_network = stax.serial(Dense(512), Relu,Dense(1024), Sigmoid,Dense(512), Relu,Dense(256),Relu,Dense(128), Relu,Dense(64), Relu,Dense(32),Relu, Dense(n_mixture*3))
init_fun, the_network = stax.serial(Dense(512),Relu, Dense(2048),Relu,Dense(512),Sigmoid,Dense(n_mixture*3))
#init_fun, the_network = stax.serial(Dense(256),Sigmoid, Dense(1024),Sigmoid,Dense(256),Sigmoid,Dense(n_mixture*3))
logSqrtTwoPI = onp.log(onp.sqrt(2.0 * onp.pi))

def lognormal(y, mean, logstd):
  return -0.5 * ((y - mean) / np.exp(logstd)) ** 2 - logstd - logSqrtTwoPI

def get_mdn_coef(output):
  logmix, mean, logstd = output.split(3, axis=1)
  logmix = nn.log_softmax(logmix)
  mean=nn.sigmoid(mean)
  return logmix, mean, logstd

def mdn_loss_func(logmix, mean, logstd, y):
  v = logmix + lognormal(y, mean, logstd)
  v = logsumexp(v, axis=1)
  return -np.mean(v)




def loss_fn(params, inputs, targets):
  """ MDN Loss function for training loop. """
  outputs = the_network(params, inputs)
  logmix, mean, logstd = get_mdn_coef(outputs)
  return mdn_loss_func(logmix, mean, logstd, targets)

@jit
def update(params, x, y, opt_state):
    """ Perform a forward pass, calculate the MSE & perform a SGD step. """
    loss, grads = value_and_grad(loss_fn)(params, x, y)
    opt_state = opt_update(0, grads, opt_state)
    return get_params(opt_state), opt_state, loss

def train(params, x_data, y_data, opt_state):
  for epoch in range(epochs):
    for b in range(int(len(y_data)/batch_size)):
        params, opt_state, loss = update(params, x_data[b*batch_size:(b+1)*batch_size,:], y_data[b*batch_size:(b+1)*batch_size,:], opt_state)
        print('Epoch: ',epoch,' Batch: ',b,' Loss: ', loss)
    if epoch % 500 ==0:
      if epoch != 0:
        np.save(test_name+'_params_'+str(epoch)+'.npy',params,allow_pickle=True)
  return params


def gumbel_sample(x, axis=1):
  z = onp.random.gumbel(loc=0, scale=1, size=x.shape)
  return (onp.log(x) + z).argmax(axis=axis)



def pit(true,pi_data,mu_data,std):
  n,k=onp.shape(logmix)
  pit_val=onp.zeros(n)
  for x in range(k):
    pit_val+=pi_data[:,x]*norm.cdf(true,mu_data[:,x],std[:,x])
  return pit_val

from scipy import stats

def plot_pdf(pi,mu,sig,pred,true=None,zrange=6,ax=None):
  x=onp.linspace(0,zrange,100)
  pdf=onp.zeros(100)
  for i in range(len(mu)):
      pdf+=pi[i]*stats.norm.pdf(x,mu[i],sig[i])

  if ax!=None:
      plot=ax.plot(x,pdf,color='blue',label='PDF')
      ax.plot([pred,pred],[0,onp.max(pdf)],color='blue',label=r'Photometric Redshift',linestyle='--')
      ax.plot([true,true],[0,onp.max(pdf)],color='red',label=r'Spectroscopic Redshift')
  else:
    plot=plt.plot(x,pdf,color='blue')
    plt.plot([pred,pred],[0,onp.max(pdf)],color='blue')
    plt.plot([true,true],[0,onp.max(pdf)],color='red')

  return plot

def minmax_fit_and_scale(X):
  min = X.min(axis=0)
  max = X.max(axis=0)
  X_std = (X - min) / (max - min)
  return X_std,min,max

def minmax_scale(X,min,max):
  return (X - min) / (max - min)

def minmax_unscale(X,min,max):
  return X * (max - min) + min





In [None]:
#params=np.load('more_nodes_params_2001.npy',allow_pickle=True)
#params=np.load('short_z_params_2000.npy',allow_pickle=True)
params=np.load('three_newest_beta_params_2000.npy',allow_pickle=True)

In [None]:


n_samples =  1000000 # just use train everything in a batch
beta=True
small_range=False
magnitudes=False
a=0
if beta:
  a+=5


import warnings
warnings.filterwarnings("ignore")
cats=1

def mag_err(flux_err,flux):
    
    return 1.09*flux_err/flux

def mag(flux):
    x=-2.5*onp.log10(flux) + 23.9 
    x[onp.isnan(x)]=-2.5*onp.log10(-flux[onp.isnan(x)]) + 23.9 

    return x 
filter_names=onp.array([])
for c in range(cats):
    cat = table.Table.read('sim/mil'+str(c+21)+'_noisy_gal.fits',format='fits',hdu=1)
    if small_range:
      cat=cat[cat['redshift']<2]
    keys=cat.keys()
    count=0
    y_data=cat['redshift']
    x_data=onp.zeros((len(y_data),54))
    for key_name in keys:

        if key_name[len(key_name)-9:]=='BETA_FLUX':

                
            filt=key_name[:len(key_name)-10+a]
            filter_names=onp.append(filter_names,key_name[:len(key_name)-10])
            if magnitudes:

              x_data[:,count]=mag(cat[filt+'_FLUX'])
              x_data[:,count+1]=mag_err(cat[filt+'_FLUXERR'],cat[filt+'_FLUX'])
            else:
              x_data[:,count]=cat[filt+'_FLUX']
              x_data[:,count+1]=cat[filt+'_FLUXERR']
            
            count+=2
            
            
        
    print(c)


x_data = np.array(x_data.reshape(len(x_data), n_input))


print(x_data)

y_data = np.array(y_data.reshape(len(y_data), 1))/6
print(filter_names)
#cat=0



In [None]:
import copy
import math
pi_symb=math.pi

y_pred=onp.zeros(len(y_data))
y_pit=onp.zeros(len(y_data))
sig_arr=onp.zeros((len(y_data),n_mixture))
mu_arr=onp.zeros((len(y_data),n_mixture))
pi_arr=onp.zeros((len(y_data),n_mixture))


def gaus(x,pi,mu,sig):
  a=-(x-mu)**2/(2*sig**2)
  return pi*1/(sig*(2*pi_symb)**0.5)*np.exp(a)

def pick(pi,mu,sig):
  c=pi.shape[1]
  eval=onp.zeros_like(pi)
  for i in range(c):
    for j in range(c):

      eval[:,i]+=gaus(mu[:,i],pi[:,j],mu[:,j],sig[:,j])
  return eval


loss_arr=np.array([])

for b in range(int(len(y_data)/batch_size)):

  logmix, mu_data, logstd = get_mdn_coef(the_network(params, x_data[b*batch_size:(b+1)*batch_size,:]))

  pi_data = np.exp(logmix)
  sigma_data = np.exp(logstd)
  k = gumbel_sample(pi_data)
  indices = (onp.arange(batch_size), k)
  rn = onp.random.randn(batch_size)
  sampled=mu_data[np.arange(len(mu_data)),np.argmax(pick(pi_data,mu_data*6,sigma_data*6),axis=1)]
  mu_arr[b*batch_size:(b+1)*batch_size,:]=mu_data
  sig_arr[b*batch_size:(b+1)*batch_size,:]=sigma_data
  pi_arr[b*batch_size:(b+1)*batch_size,:]= pi_data
  y_pred[b*batch_size:(b+1)*batch_size]=sampled*6
  y_pit[b*batch_size:(b+1)*batch_size]=pit(y_data[b*batch_size:(b+1)*batch_size,0]*6,pi_data,mu_data*6,sigma_data*6)
  loss_arr=np.append(loss_arr,(loss_fn(params,x_data[b*batch_size:(b+1)*batch_size,:],y_data[b*batch_size:(b+1)*batch_size,:])))



print('PERFORMANCE ON SYNTHETIC DATA')
print('       ')
print('Average Loss: ',np.mean(loss_arr))
print(r'Median delta z /(1+z) ',bias(y_pred,y_data[:,0]*6))
print('Overall Sigma_NMAD/(1+z) ',sigma_nmad(y_pred,y_data[:,0]*6))
print('Outlier Fraction ',outlier_frac(y_pred,y_data[:,0]*6))
print('Median PIT: ', np.median(y_pit))
print('        ')

bins=np.linspace(0,6,100)
import matplotlib as mpl
plt.figure(figsize=(14,10))
plt.hist2d(y_data[:,0]*6,y_pred,bins=bins,cmap='hot',norm=mpl.colors.LogNorm())
plt.xlabel('True Redshift',fontsize=18)
plt.ylabel('Photometric Redshift Estimate',fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
cbar=plt.colorbar()
cbar.ax.tick_params(labelsize=16)
plt.show()

fig, axs = plt.subplots(2, 2,figsize=(16,10))
plot=plot_func(bias,y_pred,y_data[:,0]*6,20,6,ax=axs[0,0])
axs[0,0].set_ylabel(r'$<\Delta z /(1+z)>$',fontsize=20)



plot=plot_func(sigma_nmad,y_pred,y_data[:,0]*6,20,6,ax=axs[1,0])
axs[1,0].set_ylabel(r'$\sigma_{NMAD}/(1+z)$',fontsize=20)



plot=plot_func(outlier_frac,y_pred,y_data[:,0]*6,20,6,ax=axs[0,1])
axs[0,1].set_ylabel(r'$\eta_{0.15}$',fontsize=20)


axs[1,1].hist(y_pit,bins=20,edgecolor='black',color='blue')
axs[1,1].set_ylabel('Galaxies',fontsize=20)
axs[1,1].set_xlabel('PIT',fontsize=20)

for tick in axs[1,1].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[1,1].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
fig.tight_layout()
plt.show()




In [None]:
dz=np.absolute((y_pred-y_data[:,0]*6)/(1+y_data[:,0]*6))
i=0
count=0
print(y_pred)
while count!=20:
  if dz[i]>0.5:  
    print(y_pred[i])
    plot_pdf(pi_arr[i,:],mu_arr[i,:]*6,sig_arr[i,:]*6,pred=y_pred[i],true=y_data[i,0]*6,zrange=6)
    plt.show()
    count+=1
  i+=1



In [None]:


cat0 = table.Table.read('COSMOS2020.fits',format='fits',hdu=1)

spec_cat=cat0
n_samples=len(spec_cat)
x_spec=onp.zeros((n_samples,n_input))
keys=spec_cat.keys()
filt_names=onp.array([])
y_spec=spec_cat['ZSPEC']
count=0
for key_name in keys:
  if key_name[len(key_name)-4:]=='FLUX':
    
    filt=key_name[:len(key_name)-5]
    filt_names=onp.append(filt_names,filt)
    x_spec[:,count]=spec_cat[filt+'_FLUX']
    x_spec[:,count+1]=spec_cat[filt+'_FLUXERR']

    count+=2

spec_cat=0
cat0=0


y_spec=np.array(y_spec)
y_spec=y_spec.reshape(len(y_spec),1)/6
x_spec=np.array(x_spec)

In [None]:
dz=np.absolute((y_pred-y_data[:,0]*6)/(1+y_data[:,0]*6))

print(np.median(np.min(x_data[dz<0.1,:],axis=1)))
print(np.median(np.min(x_data[dz>0.1,:],axis=1)))


flux=np.array([True,False]*27)
error=np.array([False,True]*27)


bins=np.linspace(15,35,50)

def mag(flux):
    x=-2.5*onp.log10(flux) + 23.9 
    x[onp.isnan(x)]=-2.5*onp.log10(-flux[onp.isnan(x)]) + 23.9 

    return x 


import matplotlib as mpl

plt.figure(figsize=(15,10))

plt.hist(mag(np.max(x_data[dz<0.15,:][:,flux],axis=1)),density=True,color='blue',histtype='step',bins=bins,linewidth=3)
plt.hist(mag(np.max(x_data[dz>0.15,:][:,flux],axis=1)),density=True,color='red',histtype='step',bins=bins,linewidth=3)
plt.hist(mag(np.max(x_spec[:,flux],axis=1)),density=True,color='green',histtype='step',bins=bins,linewidth=3)
plt.xlabel('Brightest Band Magnitude',fontsize=16)
plt.ylabel('Frequency Density',fontsize=16)

plt.show()

for i in range(27):
  print(i)
  plt.figure(figsize=(15,10))
  plt.hist(mag(x_data[dz<0.15,2*i]),density=True,color='blue',histtype='step',bins=bins,linewidth=3,label=r'Sythnetic Galaxies $|\Delta z|/(1+z)<0.15$')
  plt.hist(mag(x_data[dz>0.15,2*i]),density=True,color='red',histtype='step',bins=bins,linewidth=3,label=r'Synthetic Galaxies $|\Delta z|/(1+z)>0.15$ ')
  plt.hist(mag(x_spec[:,2*i]),density=True,color='green',histtype='step',bins=bins,linewidth=3,label='COSMOS20 Galaxies')
  plt.legend(fontsize=16)
  plt.ylabel('Frequency Density',fontsize=16)
  plt.xlabel(filter_names[i]+' Magnitude',fontsize=16)
  plt.show()









In [None]:
fig, axs = plt.subplots(2, 2,figsize=(16,10))



bins=np.linspace(15,35,50)

i=2
axs[0,0].hist(mag(x_data[dz<0.15,2*i]),density=True,color='blue',histtype='step',bins=bins,linewidth=3,label=r'Sythnetic $|\Delta z|/(1+z)<0.15$')
axs[0,0].hist(mag(x_data[dz>0.15,2*i]),density=True,color='red',histtype='step',bins=bins,linewidth=3,label=r'Synthetic $|\Delta z|/(1+z)>0.15$ ')
axs[0,0].hist(mag(x_spec[:,2*i]),density=True,color='green',histtype='step',bins=bins,linewidth=3,label='COSMOS20')
axs[0,0].legend(fontsize=19,loc='upper left',frameon=False)
axs[0,0].set_ylabel('Frequency Density',fontsize=20)
axs[0,0].set_xlabel(filter_names[i]+' magnitude',fontsize=20)


for tick in axs[0,0].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[0,0].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 

i=3
axs[1,0].hist(mag(x_data[dz<0.15,2*i]),density=True,color='blue',histtype='step',bins=bins,linewidth=3,label=r'Sythnetic Galaxies $|\Delta z|/(1+z)<0.15$')
axs[1,0].hist(mag(x_data[dz>0.15,2*i]),density=True,color='red',histtype='step',bins=bins,linewidth=3,label=r'Synthetic Galaxies $|\Delta z|/(1+z)>0.15$ ')
axs[1,0].hist(mag(x_spec[:,2*i]),density=True,color='green',histtype='step',bins=bins,linewidth=3,label='COSMOS20 Galaxies')

axs[1,0].set_ylabel('Frequency Density',fontsize=20)
axs[1,0].set_xlabel(filter_names[i]+' magnitude',fontsize=20)


for tick in axs[1,0].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[1,0].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 


i=4
axs[0,1].hist(mag(x_data[dz<0.15,2*i]),density=True,color='blue',histtype='step',bins=bins,linewidth=3,label=r'Sythnetic Galaxies $|\Delta z|/(1+z)<0.15$')
axs[0,1].hist(mag(x_data[dz>0.15,2*i]),density=True,color='red',histtype='step',bins=bins,linewidth=3,label=r'Synthetic Galaxies $|\Delta z|/(1+z)>0.15$ ')
axs[0,1].hist(mag(x_spec[:,2*i]),density=True,color='green',histtype='step',bins=bins,linewidth=3,label='COSMOS20 Galaxies')

axs[0,1].set_ylabel('Frequency Density',fontsize=20)
axs[0,1].set_xlabel(filter_names[i]+' magnitude',fontsize=20)


for tick in axs[0,1].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[0,1].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 


i=5
axs[1,1].hist(mag(x_data[dz<0.15,2*i]),density=True,color='blue',histtype='step',bins=bins,linewidth=3,label=r'Sythnetic Galaxies $|\Delta z|/(1+z)<0.15$')
axs[1,1].hist(mag(x_data[dz>0.15,2*i]),density=True,color='red',histtype='step',bins=bins,linewidth=3,label=r'Synthetic Galaxies $|\Delta z|/(1+z)>0.15$ ')
axs[1,1].hist(mag(x_spec[:,2*i]),density=True,color='green',histtype='step',bins=bins,linewidth=3,label='COSMOS20 Galaxies')

axs[1,1].set_ylabel('Frequency Density',fontsize=20)
axs[1,1].set_xlabel(filter_names[i]+' magnitude',fontsize=20)


for tick in axs[1,1].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[1,1].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 



fig.tight_layout()

plt.show()

In [None]:

n_input=54
n_samples =  10000000 # just use train everything in a batch
beta=True
small_range=False
magnitudes=False
a=0
if beta:
  a+=5

x_train_data=onp.zeros((n_samples,n_input))
y_train_data=onp.zeros(n_samples)

cats=int(n_samples/1e6)

def mag_err(flux_err,flux):
    
    return 1.09*flux_err/flux

def mag(flux):
    x=-2.5*onp.log10(flux) + 23.9 
    x[onp.isnan(x)]=0
    return x 

  
for c in range(cats):
    cat = table.Table.read('sim/mil'+str(c+1)+'_noisy_gal.fits',format='fits',hdu=1)
    keys=cat.keys()
    count=0
    y_train_data[(c)*1000000:(c+1)*1000000]=cat['redshift']
    for key_name in keys:

        if key_name[len(key_name)-9:]=='BETA_FLUX':

                
            filt=key_name[:len(key_name)-10+a]

            if magnitudes:

              x_train_data[(c)*1000000:(c+1)*1000000,count]=mag(cat[filt+'_FLUX'])
              x_train_data[(c)*1000000:(c+1)*1000000,count+1]=mag_err(cat[filt+'_FLUXERR'],cat[filt+'_FLUX'])
            else:
              x_train_data[(c)*1000000:(c+1)*1000000,count]=cat[filt+'_FLUX']
              x_train_data[(c)*1000000:(c+1)*1000000,count+1]=cat[filt+'_FLUXERR']
            
            count+=2
            
            
        
    print(c)


x_train_data = np.array(x_train_data.reshape(n_samples, n_input))


y_train_data = np.array(y_train_data.reshape(len(y_train_data), 1))/6

cat=0



In [None]:
def mag_err(flux_err,flux):
    
    return np.absolute(1.09*flux_err/flux)

def mag(flux):
    warnings.filterwarnings("ignore")
    x=-2.5*onp.log10(flux) + 23.9 
    x[onp.isnan(x)]=-2.5*onp.log10(-flux[onp.isnan(x)]) + 23.9 
    return x 

def importance_hist(test_gal,training_gals,training_redshift,ax=None):

  flux=np.array([True,False]*27)
  error=np.array([False,True]*27)

  test_mag=mag(test_gal[:,flux])
  train_mag=mag(training_gals[:,flux])
  train_error=mag_err(training_gals[:,error],test_gal[:,flux])
  

  log_weights= -np.sum(np.log(train_error),axis=1)- 0.5 *np.sum(((test_mag-train_mag)/train_error)**2,axis=1)

  frq, edges = np.histogram(training_redshift.reshape(len(training_redshift),), bins=50, weights=np.exp(log_weights), density=True)
  if ax!=None:
    the_plot=ax.bar(edges[:-1]*6, frq, width=np.diff(edges)*6, edgecolor="black", align="edge",color='blue')
  else:
    the_plot=plt.bar(edges[:-1]*6, frq, width=np.diff(edges)*6, edgecolor="black", align="edge")

  return the_plot




ids=np.where(np.logical_and(dz>0.045,dz<0.05))[0]

for i in range(1):
  select_id=ids[i]
  print(i)
  print(dz[select_id])
  test_gal=x_data[select_id,:].reshape(1,54)


  fig, axs = plt.subplots(2, 1,sharex=True,figsize=(6,8))

  logmix, mu_data, logstd = get_mdn_coef(the_network(params,test_gal))
  
  pi_data = np.exp(logmix)
  sigma_data = np.exp(logstd)
  sampled=mu_data[np.arange(len(mu_data)),np.argmax(pick(pi_data,mu_data*6,sigma_data*6),axis=1)]

  importance_hist(test_gal,x_train_data,y_train_data,ax=axs[1])

  plot_pdf(pi_data[0,:],mu_data[0,:]*6,sigma_data[0,:]*6,pred=sampled[0]*6,true=y_data[select_id,0]*6,zrange=6,ax=axs[0])
  fig.subplots_adjust(hspace=0)

  plt.show()

  







In [None]:

fig, axs = plt.subplots(3, 2,sharex=True,figsize=(20,16))

ids=np.where(dz<0.05)[0]

select_id=ids[4]
print(dz[select_id])
test_gal=x_data[select_id,:].reshape(1,54)

logmix, mu_data, logstd = get_mdn_coef(the_network(params,test_gal))
  
pi_data = np.exp(logmix)
sigma_data = np.exp(logstd)
sampled=mu_data[np.arange(len(mu_data)),np.argmax(pick(pi_data,mu_data*6,sigma_data*6),axis=1)]

importance_hist(test_gal,x_train_data,y_train_data,ax=axs[0,1])

plot_pdf(pi_data[0,:],mu_data[0,:]*6,sigma_data[0,:]*6,pred=sampled[0]*6,true=y_data[select_id,0]*6,zrange=6,ax=axs[0,0])

ids=np.where(np.logical_and(dz>0.045,dz<0.05))[0]

select_id=ids[30]
print(dz[select_id])
test_gal=x_data[select_id,:].reshape(1,54)

logmix, mu_data, logstd = get_mdn_coef(the_network(params,test_gal))
  
pi_data = np.exp(logmix)
sigma_data = np.exp(logstd)
sampled=mu_data[np.arange(len(mu_data)),np.argmax(pick(pi_data,mu_data*6,sigma_data*6),axis=1)]

importance_hist(test_gal,x_train_data,y_train_data,ax=axs[1,1])

plot_pdf(pi_data[0,:],mu_data[0,:]*6,sigma_data[0,:]*6,pred=sampled[0]*6,true=y_data[select_id,0]*6,zrange=6,ax=axs[1,0])


ids=np.where(dz>0.15)[0]

select_id=ids[67]
print(dz[select_id])
test_gal=x_data[select_id,:].reshape(1,54)

logmix, mu_data, logstd = get_mdn_coef(the_network(params,test_gal))

pi_data = np.exp(logmix)
sigma_data = np.exp(logstd)
sampled=mu_data[np.arange(len(mu_data)),np.argmax(pick(pi_data,mu_data*6,sigma_data*6),axis=1)]
importance_hist(test_gal,x_train_data,y_train_data,ax=axs[2,1])
plot_pdf(pi_data[0,:],mu_data[0,:]*6,sigma_data[0,:]*6,pred=sampled[0]*6,true=y_data[select_id,0]*6,zrange=6,ax=axs[2,0])


fig.subplots_adjust(hspace=0)


axs[2,0].set_xlabel(r'Redshift $z$',fontsize=20)
axs[0,0].set_ylabel(r'$p(z|x)$',fontsize=20)
axs[1,0].set_ylabel(r'$p(z|x)$',fontsize=20)
axs[2,0].set_ylabel(r'$p(z|x)$',fontsize=20)

axs[2,1].set_xlabel(r'Redshift $z$',fontsize=20)
axs[0,1].set_ylabel(r'Weighted Frequency',fontsize=20)
axs[1,1].set_ylabel(r'Weighted Frequency',fontsize=20)
axs[2,1].set_ylabel(r'Weighted Frequency',fontsize=20)


axs[1,0].set_ylim(0,1.19)
axs[1,1].set_ylim(0,29.9)


for tick in axs[2,0].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[2,1].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 

for tick in axs[0,0].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[1,0].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[2,0].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 

for tick in axs[0,1].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[1,1].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[2,1].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 

axs[0,0].legend(fontsize=20)

plt.show()



In [None]:


#x_data=0
#y_data=0
cat0 = table.Table.read('COSMOS2020.fits',format='fits',hdu=1)
#cat0 = cat0[cat0['lp_zq']>0]
spec_cat=cat0[cat0['ZSPEC']!=-1]
#spec_cat=spec_cat[spec_cat['ZSPEC']<2]
#spec_cat=cat0
n_samples=len(spec_cat)
x_spec=onp.zeros((n_samples,n_input))
keys=spec_cat.keys()
filt_names=onp.array([])
y_spec=spec_cat['ZSPEC']
count=0
for key_name in keys:
  if key_name[len(key_name)-4:]=='FLUX':
    
    filt=key_name[:len(key_name)-5]
    filt_names=onp.append(filt_names,filt)
    x_spec[:,count]=spec_cat[filt+'_FLUX']
    x_spec[:,count+1]=spec_cat[filt+'_FLUXERR']

    count+=2

#spec_cat=0
#cat0=0


y_spec=np.array(y_spec)
y_spec=y_spec.reshape(len(y_spec),1)/6
x_spec=np.array(x_spec)

In [None]:
!pip install jaxopt

In [None]:
import math
pi_symb=math.pi
def gaus(x,pi,mu,sig):
  a=-(x-mu)**2/(2*sig**2)
  return pi*1/(sig*(2*pi_symb)**0.5)*np.exp(a)

def pick(pi,mu,sig):
  c=pi.shape[1]
  eval=onp.zeros_like(pi)
  for i in range(c):
    for j in range(c):

      eval[:,i]+=gaus(mu[:,i],pi[:,j],mu[:,j],sig[:,j])
  return eval


In [None]:


from numpy.ma.core import add
from jaxopt import ScipyBoundedMinimize as minimize


x_spec = np.array(x_spec.reshape(n_samples, n_input))
y_spec = np.array(y_spec.reshape(n_samples, 1))

import copy 

alpha_mask=onp.identity(54)
beta_mask=onp.identity(54)

one_col= onp.zeros((len(x_spec),54))
two_col= onp.zeros((len(x_spec),54))
x_spec_flux=onp.zeros((len(x_spec),54))

for i in range(27):
  alpha_mask[2*i+1,2*i+1]=0
  beta_mask[2*i,2*i]=0
  one_col[:,2*i]=1
  two_col[:,2*i+1]=1
  x_spec_flux[:,2*i+1]=copy.deepcopy(x_spec[:,2*i])


alpha_mask=np.array(alpha_mask)
beta_mask=np.array(beta_mask)
one_col=np.array(one_col)
two_col=np.array(two_col)
x_spec_flux=np.array(x_spec_flux)

def add_cal(p0,x_data):
  the_params_1d=p0.reshape(1,54)

  the_params= np.repeat(the_params_1d, repeats=len(x_spec), axis=0)

  params_shift=the_params.at[:,1:].set(the_params[:,:53])
  
  alpha_flux=np.dot(the_params,alpha_mask)+two_col
  beta_err=np.dot(the_params,beta_mask)
  alpha_err=np.dot(params_shift,beta_mask) +one_col


  x_corr=copy.deepcopy(x_spec)
  x_corr= ((x_corr*alpha_err)**2 + (beta_err*alpha_err*x_spec_flux)**2)**0.5

  x_corr=x_corr*alpha_flux


  return x_corr
  

ans=6*y_spec.reshape(len(y_spec),)
def opt_func(opt_params):
  loss_val=loss_fn(params, add_cal(opt_params,x_spec), y_spec)
  return loss_val
def sig_func(opt_params):
  logmix, mu_data, logstd = get_mdn_coef(the_network(params, add_cal(opt_params,x_spec)))
  pi_data = softmax(logmix)
  sigma_data = np.exp(logstd)

  y_pred=mu_data[np.arange(len(mu_data)),np.argmax(pick(pi_data,mu_data*6,sigma_data*6),axis=1)]*6

  sort=np.sort((y_pred-ans)/(1+ans))

  return (sort[int(0.841*len(x_spec))-1]-sort[int(0.159*len(x_spec))-1])*0.5


alpha_bound0=[0.8,1.2]
beta_bound0=[0.0,0.0]

lower_bounds=onp.array([])
upper_bounds=onp.array([])
for i in range(27):

  if i != 4:

    lower_bounds=onp.append(lower_bounds,alpha_bound0[0])
    lower_bounds=onp.append(lower_bounds,beta_bound0[0])

    upper_bounds=onp.append(upper_bounds,alpha_bound0[1])
    upper_bounds=onp.append(upper_bounds,beta_bound0[1])

  else:

    lower_bounds=onp.append(lower_bounds,1.0)
    lower_bounds=onp.append(lower_bounds,beta_bound0[0])

    upper_bounds=onp.append(upper_bounds,1.0)
    upper_bounds=onp.append(upper_bounds,beta_bound0[1])


bounds = (np.array(lower_bounds), np.array(upper_bounds))






best_score=loss_fn(params, x_spec, y_spec)
print(best_score)
for i in range(50):

  p0=np.array([])
  for k in range(27):

    if k!=4:
      p0=np.append(p0,0.8+onp.random.random()*0.4)
    else:
      p0=np.append(p0,1.0)
    p0=np.append(p0,0.)

  
  p0=p0.reshape(54,)

  tol = 1e-4
  method = 'SLSQP' #'L-BFGS-B' 
  options = {'disp':False,'ftol':tol, 'gtol':tol}
  opt = minimize(fun=opt_func,method=method, tol=tol,options=options,maxiter=100000)
  opt_params,_=opt.run(p0,bounds=bounds)
  #score= loss_fn(params, add_cal(opt_params,x_spec), y_spec)
  score=opt_func(opt_params)
  if score<best_score:
    best_params=copy.deepcopy(opt_params)
    best_score=score
    print(score)
    print('sig:', sig_func(best_params))
    
print(best_score)
print(best_params)

p0=onp.array(best_params)

beta_bound0=[0,0.1]

lower_bounds=onp.array([])
upper_bounds=onp.array([])
for i in range(27):

  if i != 4:

    lower_bounds=onp.append(lower_bounds,alpha_bound0[0])
    lower_bounds=onp.append(lower_bounds,beta_bound0[0])

    upper_bounds=onp.append(upper_bounds,alpha_bound0[1])
    upper_bounds=onp.append(upper_bounds,beta_bound0[1])

  else:

    lower_bounds=onp.append(lower_bounds,1.0)
    lower_bounds=onp.append(lower_bounds,beta_bound0[0])

    upper_bounds=onp.append(upper_bounds,1.0)
    upper_bounds=onp.append(upper_bounds,beta_bound0[1])


bounds = (np.array(lower_bounds), np.array(upper_bounds))

for k in range(50):
  for i in range(27):
    p0[2*i+1]=onp.random.random()*0.1
  tol = 1e-4
  method = 'SLSQP' #'L-BFGS-B' 
  options = {'disp':False,'ftol':tol, 'gtol':tol}
  opt = minimize(fun=opt_func,method=method, tol=tol,maxiter=1000000)
  opt_params,_=opt.run(np.array(p0),bounds=bounds)
  score=opt_func(opt_params)
  #score= loss_fn(params, add_cal(best_params,x_spec), y_spec)
  if score<best_score:
    best_params=copy.deepcopy(opt_params)
    best_score=score
    print(score)
    print('sig:', sig_func(best_params))
print(best_score)


In [None]:
best_params=np.load('mdn_best_calib.npy',allow_pickle=True)

In [None]:

logmix, mu_data, logstd = get_mdn_coef(the_network(params, x_spec))
pi_data = softmax(logmix)
sigma_data = np.exp(logstd)
k = gumbel_sample(pi_data)
indices = (onp.arange(len(x_spec)), k)
rn = onp.random.randn(len(x_spec))
#sampled = rn * sigma_data[indices] + mu_data[indices]
sampled=mu_data[np.arange(len(mu_data)),np.argmax(pick(pi_data,mu_data*6,sigma_data*6),axis=1)]*6

y_pred_before=copy.deepcopy(sampled)
y_pit_before=pit(y_spec[:,0]*6,pi_data,mu_data*6,sigma_data*6)

loss=loss_fn(params, x_spec, y_spec)

print('RESULTS BEFORE CALLIBRATION')
print('         ')
print('Average Loss: ',loss)
print(r'Median delta z /(1+z) ',bias(y_pred_before,y_spec[:,0]*6))
print('Overall Sigma_NMAD/(1+z) ',sigma_nmad(y_pred_before,y_spec[:,0]*6))
print('Outlier Fraction ',outlier_frac(y_pred_before,y_spec[:,0]*6))
print('Median PIT: ', np.median(y_pit_before))
print('        ')



logmix, mu_data, logstd = get_mdn_coef(the_network(params, add_cal(best_params,x_spec)))
pi_data = softmax(logmix)
sigma_data = np.exp(logstd)
k = gumbel_sample(pi_data)
indices = (onp.arange(len(x_spec)), k)
rn = onp.random.randn(len(x_spec))
#sampled = rn * sigma_data[indices] + mu_data[indices]
sampled=mu_data[np.arange(len(mu_data)),np.argmax(pick(pi_data,mu_data*6,sigma_data*6),axis=1)]*6

y_pred=copy.deepcopy(sampled)
y_pit=pit(y_spec[:,0]*6,pi_data,mu_data*6,sigma_data*6)

loss=loss_fn(params,add_cal(best_params,x_spec), y_spec)

print('RESULTS AFTER CALLIBRATION')
print('         ')
print('Average Loss: ',loss)
print(r'Median delta z /(1+z) ',bias(y_pred,y_spec[:,0]*6))
print('Overall Sigma_NMAD/(1+z) ',sigma_nmad(y_pred,y_spec[:,0]*6))
print('Outlier Fraction ',outlier_frac(y_pred,y_spec[:,0]*6))
print('Median PIT: ', np.median(y_pit))
print('        ')

bins=np.linspace(0,6,100)
import matplotlib as mpl
plt.figure(figsize=(14,10))
plt.hist2d(y_spec[:,0]*6,y_pred,bins=bins,cmap='hot',norm=mpl.colors.LogNorm())
plt.xlabel('Spectroscopic Redshift',fontsize=18)
plt.ylabel('Photometric Redshift Estimate',fontsize=18)
plt.xticks(fontsize=16)
plt.yticks(fontsize=16)
cbar=plt.colorbar()
cbar.ax.tick_params(labelsize=16)
plt.show()

fig, axs = plt.subplots(2, 2,figsize=(16,10))
plot=plot_func(bias,y_pred_before,y_spec[:,0]*6,20,6,ax=axs[0,0],color='red',label='Before Calibration')
plot=plot_func(bias,y_pred,y_spec[:,0]*6,20,6,ax=axs[0,0],color='blue',label='After Calibration')
axs[0,0].set_ylabel(r'$<\Delta z /(1+z)>$',fontsize=20)
axs[0,0].legend(fontsize=20)

plot=plot_func(sigma_nmad,y_pred_before,y_spec[:,0]*6,20,6,ax=axs[1,0],color='red')
plot=plot_func(sigma_nmad,y_pred,y_spec[:,0]*6,20,6,ax=axs[1,0],color='blue')
axs[1,0].set_ylabel(r'$\sigma_{NMAD}/(1+z)$',fontsize=20)



plot=plot_func(outlier_frac,y_pred_before,y_spec[:,0]*6,20,6,ax=axs[0,1],color='red')
plot=plot_func(outlier_frac,y_pred,y_spec[:,0]*6,20,6,ax=axs[0,1],color='blue')
axs[0,1].set_ylabel(r'$\eta_{0.15}$',fontsize=20)

axs[1,1].hist(y_pit_before,bins=20,color='red',histtype='step',linewidth=2)
axs[1,1].hist(y_pit,bins=20,color='blue',histtype='step',linewidth=2)
axs[1,1].set_ylabel('Galaxies',fontsize=20)
axs[1,1].set_xlabel('PIT',fontsize=20)
axs[1,1].set_ylim(100,10000)
axs[1,1].set_yscale('log')
for tick in axs[1,1].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[1,1].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
fig.tight_layout()
plt.show()


In [None]:

fig, axs = plt.subplots(2, 2,sharex=True,figsize=(16,8))
ax_id=[[0,0],[0,1],[1,0],[1,1]]




id_arr=[6154,5842,1395,8652]


for i in range(4):
  the_ax_id=ax_id[i]

  plot_pdf(pi_data[id_arr[i],:],mu_data[id_arr[i],:]*6,sigma_data[id_arr[i],:]*6,pred=sampled[id_arr[i]],true=y_spec[id_arr[i],0]*6,zrange=6,ax=axs[the_ax_id[0],the_ax_id[1]])


fig.subplots_adjust(hspace=0)


axs[1,0].set_xlabel(r'Redshift $z$',fontsize=20)
axs[0,0].set_ylabel(r'$p(z|x)$',fontsize=20)
axs[1,0].set_ylabel(r'$p(z|x)$',fontsize=20)

axs[1,1].set_xlabel(r'Redshift $z$',fontsize=20)
axs[0,1].set_ylabel(r'$p(z|x)$',fontsize=20)
axs[1,1].set_ylabel(r'$p(z|x)$',fontsize=20)


for tick in axs[1,0].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[1,1].xaxis.get_major_ticks():
    tick.label.set_fontsize(18) 

for tick in axs[0,0].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[1,0].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 

for tick in axs[0,1].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 
for tick in axs[1,1].yaxis.get_major_ticks():
    tick.label.set_fontsize(18) 

axs[0,0].legend(fontsize=20)
plt.show()



In [None]:


cat0 = table.Table.read('COSMOS2020.fits',format='fits',hdu=1)
cat0 = cat0[cat0['lp_zq']>0]
n_samples=len(cat0)
x_data=onp.zeros((n_samples,n_input))
keys=cat0.keys()
filt_names=onp.array([])
count=0
for key_name in keys:
  if key_name[len(key_name)-4:]=='FLUX':
    
    filt=key_name[:len(key_name)-5]
    filt_names=onp.append(filt_names,filt)
    x_data[:,count]=cat0[filt+'_FLUX']
    x_data[:,count+1]=cat0[filt+'_FLUXERR']

    count+=2


In [None]:
best_params=np.load('mdn_best_calib.npy',allow_pickle=True)

import copy 


def add_cal(p0,x_data):

  alpha_mask=onp.identity(54)
  beta_mask=onp.identity(54)

  one_col= onp.zeros((len(x_data),54))
  two_col= onp.zeros((len(x_data),54))
  x_data_flux=onp.zeros((len(x_data),54))

  for i in range(27):
    alpha_mask[2*i+1,2*i+1]=0
    beta_mask[2*i,2*i]=0
    one_col[:,2*i]=1
    two_col[:,2*i+1]=1
    x_data_flux[:,2*i+1]=copy.deepcopy(x_data[:,2*i])




  alpha_mask=np.array(alpha_mask)
  beta_mask=np.array(beta_mask)
  one_col=np.array(one_col)
  two_col=np.array(two_col)
  x_data_flux=np.array(x_data_flux)
  the_params_1d=p0.reshape(1,54)

  the_params= np.repeat(the_params_1d, repeats=len(x_data), axis=0)

  params_shift=the_params.at[:,1:].set(the_params[:,:53])
  
  alpha_flux=np.dot(the_params,alpha_mask)+two_col
  beta_err=np.dot(the_params,beta_mask)
  alpha_err=np.dot(params_shift,beta_mask) +one_col


  x_corr=copy.deepcopy(x_data)
  x_corr= ((x_corr*alpha_err)**2 + (beta_err*alpha_err*x_data_flux)**2)**0.5

  x_corr=x_corr*alpha_flux


  return x_corr

In [None]:
import copy
import math
pi_symb=math.pi

y_pred=onp.zeros(len(x_data))
y_pit=onp.zeros(len(x_data))
sig_arr=onp.zeros((len(x_data),n_mixture))
mu_arr=onp.zeros((len(x_data),n_mixture))
pi_arr=onp.zeros((len(x_data),n_mixture))
def ceildiv(a, b):
    return -(a // -b)

def gaus(x,pi,mu,sig):
  a=-(x-mu)**2/(2*sig**2)
  return pi*1/(sig*(2*pi_symb)**0.5)*np.exp(a)

def pick(pi,mu,sig):
  c=pi.shape[1]
  eval=onp.zeros_like(pi)
  for i in range(c):
    for j in range(c):

      eval[:,i]+=gaus(mu[:,i],pi[:,j],mu[:,j],sig[:,j])
  return eval


loss_arr=np.array([])
batch_size=100000
for b in range(int(ceildiv(len(x_data),batch_size))):

  logmix, mu_data, logstd = get_mdn_coef(the_network(params, add_cal(best_params,x_data[b*batch_size:(b+1)*batch_size,:])))

  pi_data = np.exp(logmix)
  sigma_data = np.exp(logstd)
  k = gumbel_sample(pi_data)
  indices = (onp.arange(batch_size), k)
  rn = onp.random.randn(batch_size)
  sampled=mu_data[np.arange(len(mu_data)),np.argmax(pick(pi_data,mu_data*6,sigma_data*6),axis=1)]
  mu_arr[b*batch_size:(b+1)*batch_size,:]=mu_data
  sig_arr[b*batch_size:(b+1)*batch_size,:]=sigma_data
  pi_arr[b*batch_size:(b+1)*batch_size,:]= pi_data
  y_pred[b*batch_size:(b+1)*batch_size]=sampled*6
  
  if b ==0:
    x=np.linspace(0,6,1200)
    pdf_arr=onp.zeros((50000,1200))
    for i in range(50000):
      pdf=np.zeros(1200)
      for j in range(3):
              pdf+=pi_data[i,j]*stats.norm.pdf(x,6*mu_data[i,j],6*sigma_data[i,j])

      pdf_arr[i,:]=pdf

      if i%1000==0:
        print(i)

    np.save('mdn_phot_pdf.npy',pdf_arr,allow_pickle=True)





