In [None]:
import argparse
import configparser
import numpy as np
import sys
import pickle
import copy

import os
from os import path
import subprocess

from saltshaker.util import snana,readutils
from saltshaker.util.estimate_tpk_bazin import estimate_tpk_bazin
from saltshaker.util.txtobj import txtobj

from saltshaker.training.init_hsiao import init_hsiao, init_kaepora, init_errs,init_errs_percent,init_custom,init_salt2

from saltshaker.training.TrainSALT import TrainSALT,RunTraining

from saltshaker.training.saltfit import fitting
from saltshaker.training import saltfit as saltfit

from saltshaker.data import data_rootdir
from saltshaker.initfiles import init_rootdir
from saltshaker.config import config_rootdir,loggerconfig

import astropy.units as u

from astropy.table import Table
from saltshaker.initfiles import init_rootdir as salt2dir
_flatnu=f'{init_rootdir}/flatnu.dat'

# validation utils
import logging
log=logging.getLogger(__name__)
from matplotlib import pyplot as plt

from scipy.sparse import linalg as sprslinalg

%matplotlib inline


In [None]:
from importlib import reload

from saltshaker.training import TrainSALT as ts
ts=reload(ts)
RunTraining=ts.RunTraining

In [None]:
salt = TrainSALT()

parser = argparse.ArgumentParser(usage='', conflict_handler="resolve",add_help=False)
parser.add_argument('configpositional',nargs='?',default=None,type=str,help='configuration file')
parser.add_argument('-c','--configfile', default=None, type=str,
                    help='configuration file')

options, args = parser.parse_known_args(['testing.conf'])

RunTraining().get_config_options(salt,options.configfile,options.configpositional,['testing.conf'] )


In [None]:
salt.options

In [None]:
salt.options.resume_from_outputdir=True

In [None]:
import time

stage='initialization'
if not len(salt.surveylist):
    raise RuntimeError('surveys are not defined - see documentation')
tkstart = time.time()
salt.kcordict=readutils.rdkcor(salt.surveylist,salt.options)
log.info(f'took {time.time()-tkstart:.3f} to read in kcor files')
# TODO: ASCII filter files

if not os.path.exists(salt.options.outputdir):
    os.makedirs(salt.options.outputdir)
if salt.options.binspec:
    binspecres = salt.options.binspecres
else:
    binspecres = None

tdstart = time.time()
datadict = readutils.rdAllData(salt.options.snlists,salt.options.estimate_tpk,
                               dospec=salt.options.dospec,
                               peakmjdlist=salt.options.tmaxlist,
                               binspecres=binspecres,snparlist=salt.options.snparlist,maxsn=salt.options.maxsn)
log.info(f'took {time.time()-tdstart:.3f} to read in data files')
tcstart = time.time()

datadict = salt.mkcuts(datadict)[0]
log.info(f'took {time.time()-tcstart:.3f} to apply cuts')


phasebins=np.linspace(*salt.options.phaserange,int((salt.options.phaserange[1]-salt.options.phaserange[0])/salt.options.phasesplineres)+1,True)
wavebins=np.linspace(*salt.options.waverange,int((salt.options.waverange[1]-salt.options.waverange[0])/salt.options.wavesplineres)+1,True)


In [None]:
salt.options.snlists+

In [None]:
parlist,x_modelpars,phaseknotloc,waveknotloc,errphaseknotloc,errwaveknotloc = salt.initialParameters(datadict)

saltfitkwargs = salt.get_saltkw(phaseknotloc,waveknotloc,errphaseknotloc,errwaveknotloc)
n_phaseknots,n_waveknots = len(phaseknotloc)-4,len(waveknotloc)-4
n_errphaseknots,n_errwaveknots = len(errphaseknotloc)-4,len(errwaveknotloc)-4


In [None]:
X=x_modelpars

In [None]:
saltfit=reload(saltfit)

saltfitkwargs['regularize'] = salt.options.regularize
saltfitkwargs['fitting_sequence'] = salt.options.fitting_sequence
sf = saltfit.GaussNewton(x_modelpars,{x:datadict[x] for x in list(datadict.keys())},parlist,**saltfitkwargs)


In [None]:
Xmodded=X.copy()
Xmodded[sf.imodelerr]/=4


#Xfit=sf.iterativelyfiterrmodel(Xmodded)
sf.getChi2Contributions(Xfit)

In [None]:
Xmodded=Xfit.copy()
Xmodded[sf.imodelerr]/=1.4

sf.getChi2Contributions(Xmodded)

In [None]:
result=sf.iterativelyfiterrmodel(Xmodded)
sf.getChi2Contributions(result)

In [None]:
lc=sf.datadict['2004ef'].photdata['CSP-B/u']

In [None]:
lc.modelloglikelihood(Xmodded),lc.modelresidual(Xmodded)

In [None]:
photresids=sf.batchedphotresiduals(X)

(photresids==0).sum()

In [None]:
sources=np.array(sf.lsqwrap_sources(X,uncertainties))
sources=np.array([x.split('_')[0] for x in  sources])

In [None]:
(sources=='phot').sum()

In [None]:
sf.getChi2Contributions(X)

In [None]:
sf.num_spec

In [None]:
photresids.size

In [None]:
residuals=lc.modelresidual(Xfit)['residuals']

(residuals**2).sum()

In [None]:
(residuals!=0).sum()

In [None]:
sf.maxlikefit(Xmodded),sf.maxlikefit(Xfit)

In [None]:
uncertainties=sf.calculatecachedvals(x_modelpars,target='variances')



In [None]:
resids=sf.lsqwrap(X,uncertainties)

In [None]:
result=sf.process_fit(X,sf.iModelParam,uncertainties)

In [None]:
includepars=np.zeros(sf.npar,dtype=bool)

includepars[[sf.imodelerr0[33],sf.imodelerr1[33],sf.imodelcorr01[33]]]=True

In [None]:
includepars.sum()

In [None]:
from saltshaker.training import saltfit

saltfit=reload(saltfit)

sf.minuitoptimize=lambda *args,**kwargs: saltfit.GaussNewton.minuitoptimize(sf,*args,**kwargs)
result=sf.minuitoptimize(X,includepars)

In [None]:
iterate=sf.iterativelyfiterrmodel(X)

In [None]:
seciterate=sf.iterativelyfiterrmodel(iterate)


In [None]:
sf.getChi2Contributions(seciterate)

In [None]:
X[includepars],result[0][includepars],result[1]

In [None]:
result[0]

In [None]:
sf.maxlikefit(X),sf.maxlikefit(iterate)

In [None]:
with open('output/gaussnewtonhistory.pickle','rb') as file:
    history=pickle.load(file)

In [None]:
X=np.array(history[-1][0])

In [None]:
Xfit=sf.iterativelyfiterrmodel(X)
sf.getChi2Contributions(Xfit)

In [None]:
uncertainties=sf.calculatecachedvals(X,target='variances')

In [None]:
[[np.isnan(y).any() for y in x]for x in uncertainties]

In [None]:
X[sf.iclscat[-1]]=-np.inf

In [None]:
resids=sf.batchedphotlikelihood(X)

In [None]:
resids

In [None]:
sf.maxlikefit(Xmodded),sf.maxlikefit(X)

In [None]:
sf.getChi2Contributions(X)

In [None]:
def gradient_descent(
    gradient,  init, learn_rate=0.1, n_iter=50, tolerance=1e-06,
    dtype="float64"):

    # Checking if the gradient is callable
    if not callable(gradient):
        raise TypeError("'gradient' must be callable")

    # Setting up the data type for NumPy arrays
    dtype_ = np.dtype(dtype)

    # Initializing the values of the variables
    vector = np.array(init, dtype=dtype_)

    # Setting up and checking the learning rate
    learn_rate = np.array(learn_rate, dtype=dtype_)
    if np.any(learn_rate <= 0):
        raise ValueError("'learn_rate' must be greater than zero")

    # Setting up and checking the maximal number of iterations
    n_iter = int(n_iter)
    if n_iter <= 0:
        raise ValueError("'n_iter' must be greater than zero")

    # Setting up and checking the tolerance
    tolerance = np.array(tolerance, dtype=dtype_)
    if np.any(tolerance <= 0):
        raise ValueError("'tolerance' must be greater than zero")

    # Performing the gradient descent loop
    for _ in range(n_iter):
        # Recalculating the difference
        diff = -learn_rate * np.array(gradient( vector), dtype_)

        # Checking if the absolute difference is small enough
        if np.all(np.abs(diff) <= tolerance):
            break

        # Updating the values of the variables
        vector += diff
    return vector

In [None]:
gradient_descent(lambda x: sf.maxlikefit(x,cachedresults=uncertainties,fixuncertainties=True,diff='grad'),X)

In [None]:
sf.minuitoptimize(result[0],sf.fitOptions['2005bo'][1])

In [None]:
jshape=resids.size,sf.npar
args,kwargs=[x_modelpars,uncertainties],{}
jaclinop=sprslinalg.LinearOperator(matvec = lambda x: (sf.lsqwrap(*args,**kwargs,diff='jvp',jit=True)( (x) )) ,

                                 rmatvec= lambda x: (sf.lsqwrap(*args,**kwargs,diff='vjp',jit=True)(x)) ,shape=(jshape))


In [None]:
# sf.rngkey=jax.random.PRNGKey(18327534917853348)
# sf.randomvjpevalfuns={}`

saltfit=reload(saltfit)
# sf.evalrandomvjp= lambda *args,**kwargs: saltfit.GaussNewton.evalrandomvjp(sf,*args,**kwargs)
sf.iteratedampings= lambda *args,**kwargs: saltfit.GaussNewton.iteratedampings(sf,*args,**kwargs)



In [None]:
sf.pre

In [None]:
sf.vectorizedstochasticbinormpreconditioning(sf.iModelParam,X,uncertainties)

In [None]:
%timeit sf.maxlikefit(X,uncertainties,fixuncertainties=True,diff='grad')

In [None]:
sf.im0

In [None]:
sf.minuitoptimize(X,sf.im0, cachedresults=uncertainties,fixuncertainties=True)

In [None]:
        if staticargs in self.randomvjpevalfuns: 
            preconevalfun= self.randomvjpevalfuns[staticargs]
            
        else:
            preconevalfun = jax.jit(jax.vmap(lambda parindex,x,y: self.evalrandomvjp(parindex,x,y,*staticargs),
            
                    in_axes=(0,None,[[None]*len(self.batchedphotdata),[None]*len(self.batchedspecdata)])))
                    
            self.randomvjpevalfuns[staticargs]=preconevalfun


In [None]:
key=jax.random.PRNGKey(18327534917853348)

nextkey,veckey=jax.random.split(key,2)

veckey=jax.random.split(veckey,100)

In [None]:
args,kwargs=[X,uncertainties],{}

In [None]:
calcrandomvjp=jax.jit(jax.vmap(lambda key: ))




In [None]:
%timeit calcrandomvjp(veckey)


In [None]:
binormprecon=stochasticbinormpreconditioning(sf,sf.iModelParam,X,uncertainties)
plt.hist((np.log10(binormprecon[1])-np.log10(preconditioning)[sf.iModelParam]) )


In [None]:
plt.plot(convergencelog)
plt.yscale('log')

In [None]:
sf.parlist[sf.iModelParam][(binormprecon[1])>1e10].size

In [None]:
np.median(binormprecon[1])

In [None]:
sf.parlist[np.where(preconditioning==np.median(preconditioning))]

In [None]:
jacobian,preconinv=constructoperator(sf,binormprecon,sf.iModelParam,X,uncertainties)

In [None]:
result=gaussNewtonFit(sf,X,jacobian,preconinv,resids ,0.18,[[uncertainties],{}])


In [None]:
result

In [None]:
(sf.lsqwrap(X-preconinv(result.lsmrresult.precondstep),uncertainties))

In [None]:
sf.lsqwrap(X,uncertainties)

In [None]:
from tqdm.notebook import tqdm,trange

In [None]:
(jaclinop.T @ r).size

In [None]:
for k in trange(nmv):
    omega=2**(-max(min(np.floor(np.log2(k+1))-1,4),1))
    
    s=np.random.normal(size=jshape[1])/np.sqrt(c)
    y= jaclinop @ s
    r= (1-omega)*r/r.sum() + omega* y**2 / (y**2).sum()
    
    s= np.random.normal(size=(jshape[0])) / np.sqrt(r)
    y= jaclinop.T @ s
    c=(1-omega)*c/c.sum() + omega*y**2 /(y**2).sum()

x=1/np.sqrt(r)
y=1/np.sqrt(c)

In [None]:
x

In [None]:
x.size,y.size

In [None]:
%matplotlib inline
plt.hist(np.log10(preconditioning)[preconditioning<1e8])

In [None]:
from jax import numpy as jnp

hvp= jax.jit( lambda *args,**kwargs: wrapjvpmultipleargs(sf.maxlikefit,[0])( *args,**kwargs,diff='grad',jit=False),static_argnames=['fixuncertainties'])




In [None]:
preconditioning[100]

In [None]:

np.sqrt(-hvp((np.arange(sf.npar)==i)*1.,x_modelpars,uncertainties,fixuncertainties=True)[i])*preconditioning[i]

In [None]:
np.where(hvp)

In [None]:
193*1e4/1e3/60

In [None]:
gradprecon=sf.maxlikefit(x_modelpars,uncertainties,fixuncertainties=True,diff='grad',jit=True)

In [None]:
%matplotlib inline 
plt.hist(np.log((np.abs(gradprecon) *preconditioning)[np.nonzero(gradprecon)]))

In [None]:
np.percentile((np.abs(gradprecon) *preconditioning)[np.nonzero(gradprecon)],[1,16,50,84,99])

In [None]:
gradpreconjacobian,gradpreconinv= sf.constructoperator(np.nan_to_num(1/np.abs(gradprecon)), np.ones(sf.npar,dtype=bool), X,uncertainties)




In [None]:
jacobian,preconinv= sf.constructoperator(preconditioning, np.ones(sf.npar,dtype=bool), X,uncertainties)


In [None]:
sf.damping['all']=.1

In [None]:
sf.iteratedampings('all',X,jacobian, preconinv,resids ,lsqwrapargs=[[uncertainties],{}],)

In [None]:
sf.gaussNewtonFit(X,gradpreconjacobian,gradpreconinv,resids ,0.18,[[uncertainties],{}])

In [None]:
np.isnan(gradpreconjacobian @ np.random.normal(size=sf.npar)).any()

In [None]:
sf.damping

In [None]:
sf.iteratedampings('all',X,gradpreconjacobian,gradpreconinv,resids ,lsqwrapargs=[[uncertainties],{}],)

In [None]:
(gradprecon**2)*preco

In [None]:
preconditioning=1/jac**2

In [None]:
hvpresults= hvp(np.ones(sf.npar),x_modelpars,uncertainties,fixuncertainties=True)

In [None]:
np.sqrt(np.abs(hvpresults))*preconditioning

In [None]:
(gradprecon**2)/ preconditioning

In [None]:
# gradprecon=graddot(x_modelpars,uncertainties,fixuncertainties=True)

In [None]:
preconditioning/gradprecon

In [None]:
X=x_modelpars

In [None]:
import jax
def wrapjvpmultipleargs(fun,argnums):
    if len(argnums)>1: raise NotImplementedError('Wrapped jacobian-vector products differentiated w.r.t. multiple arguments have not been implemented')
    diffargidx=argnums[0]
    return lambda vec, *args,**kwargs: jax.jvp(
        lambda x: fun(*args[:diffargidx],x,*args[diffargidx+1:],**kwargs)

                     ,[args[diffargidx]], [vec])[1]



In [None]:
sf.maxlikefit(x_modelpars,uncertainties,fixuncertainties=True,diff='grad')

In [None]:
saltfit=reload(saltfit)
# sf.evalpreconditioningscales= lambda *args,**kwargs: saltfit.GaussNewton.evalpreconditioningscales(sf,*args,**kwargs)
sf.preconditioningscales= lambda *args,**kwargs: saltfit.GaussNewton.preconditioningscales(sf,*args,**kwargs)
sf.constructoperator= lambda *args,**kwargs: saltfit.GaussNewton.constructoperator(sf,*args,**kwargs)
sf.iteratedampings= lambda *args,**kwargs: saltfit.GaussNewton.iteratedampings(sf,*args,**kwargs)
sf.gaussNewtonFit= lambda *args,**kwargs: saltfit.GaussNewton.gaussNewtonFit(sf,*args,**kwargs)
sf.process_fit= lambda *args,**kwargs: saltfit.GaussNewton.process_fit(sf,*args,**kwargs)

In [None]:
preconditioning= sf.preconditioningscales( np.arange(sf.npar)<sf.npar,X,uncertainties)


In [None]:
{ ('as',3):10}

In [None]:
sf.preconditioningchunksize=10

In [None]:
sf.cachedpreconevalfuns={}

In [None]:
sf.preconditioningscales( np.arange(sf.npar)<103,X,uncertainties)

# jacobian,preconinv= sf.constructoperator(preconditioning, varyingParams, X,uncertainties)


In [None]:
uncertainties=sf.calculatecachedvals(x_modelpars,target='variances')

result=sf.process_fit(X,sf.iModelParam,uncertainties)

In [None]:
sf.damping 


In [None]:
sf.iteratedampings('all',X,jacobian,preconinv,residuals,([uncertainties],{}))

In [None]:
result=sf.iteratedampings(X,jacobian,preconinv,residuals)

In [None]:
uncertainties=sf.calculatecachedvals(x_modelpars,target='variances')

In [None]:
X=x_modelpars.copy()
varyingParams=sf.iModelParam

residuals=sf.lsqwrap(X,uncertainties)
oldChi=(residuals**2).sum()





In [None]:
oldChi

In [None]:
tol=1e-8

result=sprslinalg.lsmr(jacobian,residuals,damp=.5,atol=tol,btol=tol)

gaussNewtonStep= preconinv(result[0])


In [None]:
postGN=(sf.lsqwrap(X-gaussNewtonStep,uncertainties)**2).sum() #
oldChi,postGN,oldChi-postGN

In [None]:
kwargs={}

prevresult=result
currentresids=sf.lsqwrap(X-gaussNewtonStep,uncertainties,**kwargs)
prevresids=residuals
prevstep=gaussNewtonStep.copy()
currentchi2=(currentresids**2).sum()


In [None]:
jacobian,_= sf.constructoperator(preconditioning, varyingParams, X-prevstep,uncertainties, **kwargs)

currentresult=sprslinalg.lsmr(jacobian,currentresids,damp=.5,atol=tol,btol=tol)
precondstep,stopsignal,itn,normr,normar,norma,conda,normx=currentresult

currentstep=prevstep + preconinv(precondstep)


In [None]:
nextresids=sf.lsqwrap(X-currentstep,uncertainties,**kwargs)
nextchi2=(nextresids**2).sum()
chi2improvement=currentchi2-nextchi2
nextchi2,chi2improvement


In [None]:
((currentresids-(jacobian@precondstep))**2).sum(),(currentresids**2).sum()

In [None]:
fit='all'
scale=1.5
result=iteratedampings(jacobian,preconinv,residuals,.5)

In [None]:
result.gaussNewtonStep

In [None]:
np.where(~np.isclose(sf.lsqwrap(X-prevstep- preconinv(precondstep)*delta,uncertainties),currentresids+(-sf.lsqwrap(X-prevstep,uncertainties,diff='jvp')(preconinv(precondstep)* delta ))
          
          
          ))

In [None]:
sf.lsqwrap(X,uncertainties,jit=False,diff='jvp')(np.random.normal(size=sf.npar))

In [None]:
import jax
from jax import numpy as jnp
preconscales=(jax.vmap(lambda i: 1/jnp.sqrt((sf.lsqwrap(X,uncertainties,jit=True,diff='jvp')(jnp.zeros(sf.npar).at[i].set(1.))**2).sum())))




In [None]:
%timeit preconscales( np.arange(10))

In [None]:
preconscales=(jax.vmap(lambda i: 1/jnp.sqrt((sf.lsqwrap(X,uncertainties,usespec=False,jit=True,diff='jvp')(jnp.zeros(sf.npar).at[i].set(1.))**2).sum())))
%timeit preconscales( np.arange(10))

In [None]:
%timeit [((sf.lsqwrap(X,uncertainties,jit=True,diff='jvp'))((np.arange(sf.npar)==i)*1.)**2).sum() for i in range(10)]

In [None]:
np.arange(preconscales())

In [None]:

delta=.2
(sf.lsqwrap(X-prevstep- preconinv(precondstep)*delta,uncertainties)**2).sum()

In [None]:
((currentresids+(-sf.lsqwrap(X-prevstep,uncertainties,diff='jvp')(preconinv(precondstep)* delta )))**2).sum()

In [None]:
%timeit sf.linesearch(X,gaussNewtonStep,uncertainties)  

In [None]:
np.std(resids)

In [1]:
import os
from os import path


In [2]:
dir='output/'
for file in os.listdir(dir):
    os.rename(path.join(dir,file), path.join( dir,file.replace('M','').replace('10','01')))


In [4]:
# with open('output/salt3_lc_covariance_01.dat','r') as file:
#     text=file.read()

In [5]:
from tqdm import tqdm
out=''
count=-1
for char in tqdm(text):
    if count> 0: count-=1
    if char=='e': count=4
    if count==0: 
        count=-1
        out+='\n'
    out+=char


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2213975/2213975 [00:00<00:00, 2515420.34it/s]


In [6]:
with open('output/salt3_lc_covariance_01.dat','w') as file:
    file.write(out)