In [1]:
import sys
sys.path.append('/mnt/d/ariel2/code/core/')
sys.path.append('d:/ariel2/code/core/')
sys.path.append('/kaggle/input/my-ariel2-library')
import kaggle_support as kgs
import ariel_model
import ariel_numerics
import ariel_gp
import ariel_load
import time
import numpy as np
import multiprocess
import importlib
import copy
import cupy as cp
import matplotlib.pyplot as plt
from tqdm import tqdm

local


In [2]:
import numpy as np

def fit_als_multi(S, C0, C_list, w_init=None, max_iter=100, tol=1e-8, eps=1e-12, cond_thresh=1e8, ridge_scale=1e-8):
    """
    Minimize || S - diag(w0) * (C0 + sum_k w[k] * C_list[k]) ||_F^2
    where:
      S, C0 : (N, M)
      C_list: list/tuple/array of K matrices each (N, M)
      w0    : (N,) rowwise scales (solved internally)
      w     : (K,) global scalars (solved internally)

    Returns
    -------
    w0 : (N,)
    w  : (K,)
    obj: float, final squared Frobenius residual
    it : int, iterations performed
    """

    # --- shape checks & setup
    S = np.asarray(S)
    C0 = np.asarray(C0)
    Cs = np.asarray([np.asarray(C) for C in C_list])  # (K, N, M)
    K, N, M = Cs.shape[0], S.shape[0], S.shape[1]
    assert C0.shape == (N, M) and np.all([C.shape == (N, M) for C in Cs]), "Shape mismatch"

    # --- helpers
    def update_w0(w):
        # T = C0 + sum_k w[k]*Ck
        T = C0.copy()
        if K:
            T += np.tensordot(w, Cs, axes=(0,0))  # (N, M)
        num = np.sum(S * T, axis=1)              # (N,)
        den = np.sum(T * T, axis=1) + eps        # (N,)
        return num / den

    def update_w(w0):
        # B = S - diag(w0) C0
        B = S - (w0[:, None] * C0)               # (N, M)

        if K == 0:
            return np.zeros(0), 0.0, 0.0  # no global scalars

        # X_k = diag(w0) C_k  -> stack into X: (K, N, M)
        X = w0[:, None] * Cs                  # (K, N, M)

        # Build KxK normal matrix A and rhs b via Frobenius inner products
        # A_ij = <X_i, X_j>,  b_i = <X_i, B>
        A = np.einsum('knm,lnm->kl', X, X)    # (K, K)
        b = np.einsum('knm,nm->k', X, B)      # (K,)

        reg = 0.0
        # Regularize if ill‑conditioned
        try:
            cnd = np.linalg.cond(A)
        except np.linalg.LinAlgError:
            cnd = np.inf
        if not np.isfinite(cnd) or cnd > cond_thresh:
            reg = ridge_scale * (np.trace(A) / max(K, 1) + eps)
            A = A + reg * np.eye(K)

        w = np.linalg.solve(A, b)
        return w, reg, cnd

    # --- initialize
    if w_init is None:
        w = np.zeros(K)
    else:
        w = np.asarray(w_init, dtype=float)
        assert w.shape == (K,)
    w0 = update_w0(w)

    # --- iterate
    prev = np.inf
    it = 0
    for it in range(1, max_iter + 1):
        # Step B: solve for global scalars given w0
        w, reg, cnd = update_w(w0)

        # Step A: closed‑form rowwise w0 given globals
        w0 = update_w0(w)

        # Objective
        T = C0.copy()
        if K:
            T += np.tensordot(w, Cs, axes=(0,0))
        R = S - (w0[:, None] * T)
        cur = float(np.sum(R * R))

        # Convergence (relative decrease in objective)
        if it>1 and abs(prev - cur) <= tol * (prev + 1.0):
            break
        prev = cur

    return w0, w, cur, it

In [3]:
N_wl = 15
N_t = 100
N_r = 10
N_comp = 3

#(comp,t,r,wl)

c0 = np.random.default_rng(seed=0).normal(size=(N_r,N_wl))
c = np.random.default_rng(seed=1).normal(size=(N_comp,N_r,N_wl))
w0 = np.random.default_rng(seed=2).normal(size=(N_t, N_wl))+1
w = np.random.default_rng(seed=3).normal(size=(N_comp, N_t))

signal = w0[:,None,:] * c0[None,:,:]
for ii in range(N_comp):
    signal += w0[:,None,:]*w[ii,...][:,None,None]*c[ii,...][None,:,:]
signal.shape

(100, 10, 15)

In [4]:
c0_guess = np.random.default_rng(seed=5).normal(size=(N_r,N_wl))
c_guess = 0*c
w0_guess = np.random.default_rng(seed=8).normal(size=(N_t, N_wl))
w_guess = np.random.default_rng(seed=9).normal(size=(N_comp, N_t))

c0_guess = copy.deepcopy(c0);c_guess = copy.deepcopy(c);w0_guess=copy.deepcopy(w0); w_guess=copy.deepcopy(w);
c0_guess = c0_guess+np.random.default_rng(seed=12).normal(0,0.01,size=c0_guess.shape)
c_guess = c_guess+np.random.default_rng(seed=11).normal(0,0.01,size=c_guess.shape)
w0_guess = w0_guess+np.random.default_rng(seed=10).normal(0,0.01,size=w0_guess.shape)
w_guess = w_guess+np.random.default_rng(seed=9).normal(0,0.01,size=w_guess.shape)

In [12]:
import scipy
for i_iter in tqdm(range(200)):
    # c update
    for i_r in range(N_r):
        for i_wl in range(N_wl):
            w_matrix = np.concatenate([w0_guess[:,i_wl][None,:], w0_guess[:,i_wl]*w_guess])
            rhs = signal[:,i_r,i_wl]
            coeffs = np.linalg.lstsq(w_matrix.T, rhs, rcond=None)[0]
            c0_guess[i_r,i_wl] = coeffs[0]
            c_guess[:,i_r,i_wl] = coeffs[1:]
    
    c0_guess = c0_guess/kgs.rms(c0_guess)
    for i_comp in range(N_comp):
        c_guess[i_comp,...] = c_guess[i_comp,...]/kgs.rms(c_guess[i_comp,...])
        
   # c0_guess = copy.deepcopy(c0)
   # c_guess = copy.deepcopy(c)
    #c_guess[:,:,5]*=20
    
            
    # w0 update
    for i_t in range(N_t):
        for i_wl in range(N_wl):
            pass
            y = signal[i_t,:,i_wl]
            x = copy.deepcopy(c0_guess[:,i_wl])
            for i_c in range(N_comp):
                 x += w_guess[i_c,i_t]*c_guess[i_c,:,i_wl]
            w0_guess[i_t,i_wl] = np.sum(x*y)/np.sum(x*x)
           # print(w0_guess[i_t,i_wl]- w0[i_t,i_wl])
    
    #w0_guess = w0
        
    # w update
    design_matrix = c_guess.reshape(N_comp,-1)
    for i_t in range(N_t):
        rhs = signal[i_t,...]/w0_guess[i_t,None,:] - c0_guess[None,:,:]
        rhs = rhs.reshape(-1)
        coeffs = np.linalg.lstsq(design_matrix.T, rhs, rcond=None)[0]
        w_guess[:,i_t] = coeffs
        
    
        
#     def f(x):
#         w0_guess = x[:N_wl]
#         w_guess = x[N_wl:]
#         signal_guess = w0_guess[None,:] * c0_guess[:,:]
#         for ii in range(N_comp):
#             signal_guess += w0_guess[None,:]*w_guess[ii]*c_guess[ii,...][:,:]
#         return kgs.rms(signal[i_t,...] - signal_guess)
    
    # for i_t in range(N_t):
    #     #print(i_t)
    #     w0_guess[i_t,:], w_guess[:,i_t],a,b = fit_als_multi(copy.deepcopy(signal[i_t,...].T), copy.deepcopy(c0_guess).T, copy.deepcopy([c_guess[0,...].T, c_guess[1,...].T, c_guess[2,...].T]))
    #     #w0_guess[i_t,:] = w0[i_t,:]
    #     #w_guess[:,i_t] = w[:,i_t]
    
#         # x0 = np.concatenate((w0_guess[i_t,...],w_guess[:,i_t]))
#         # res=scipy.optimize.minimize(f,x0,method='COBYLA')
#         # #print(res)
#         # print(f(res.x))
#         # #res=scipy.optimize.minimize(f,res.x)
#         # w0_guess[i_t,:] = res.x[:N_wl]
#         # w_guess[:,i_t] = res.x[N_wl:]
        
    #w_guess = w
    #w0_guess = w0;
    #w_guess = w;
        
    
    signal_guess = w0_guess[:,None,:] * c0_guess[None,:,:]
    for ii in range(N_comp):
        signal_guess += w0_guess[:,None,:]*w_guess[ii,...][:,None,None]*c_guess[ii,...][None,:,:]
    if i_iter%1==0:
        print(kgs.rms(w0_guess-w0), kgs.rms(w_guess-w), kgs.rms(c0_guess-c0), kgs.rms(c_guess-c), kgs.rms(signal-signal_guess))
print(kgs.rms(w0_guess-w0), kgs.rms(w_guess-w), kgs.rms(c0_guess-c0), kgs.rms(c_guess-c), kgs.rms(signal-signal_guess))

  4%|██▉                                                                                | 7/200 [00:00<00:03, 62.80it/s]

0.07000891645496359 0.054385819354872064 0.044007433828409465 0.09203415376467207 0.0002035008371836811
0.07000965735699098 0.05437992384132541 0.044005771365677473 0.09203397961305093 0.000189963168585207
0.07001049141816724 0.054374546982169016 0.04400466365006654 0.09203386916115978 0.00017854292422408944
0.07001135257684166 0.0543695108747627 0.04400380298734997 0.09203378240186659 0.0001681857944867961
0.07001221363486174 0.0543647554627854 0.04400308002194311 0.09203371600448233 0.00015855280941825135
0.07001306250709542 0.0543602495480048 0.044002450038529206 0.09203366478834185 0.00014951084424450113
0.07001389112070049 0.05435597376946295 0.0440018888436724 0.0920336247058994 0.00014099750763788067
0.07001469361419223 0.05435191434036834 0.04400138124561083 0.09203359295744584 0.0001329748308859549
0.07001546602564321 0.054348060230441184 0.04400091702858533 0.09203356761478852 0.00012541336717182796
0.07001620600295783 0.054344401819026875 0.04400048904443438 0.09203354731507

 10%|████████▌                                                                         | 21/200 [00:00<00:03, 57.86it/s]

0.07001822461889451 0.05433451427299307 0.0439993769730293 0.09203350786246259 9.927559906632655e-05
0.07001883162584177 0.05433155393615353 0.043999053416249725 0.09203349986636052 9.365301907289809e-05
0.07001940729023724 0.05432874848422513 0.04399874983044276 0.09203349374098142 8.83530847225392e-05
0.07001995284893749 0.054326090558118593 0.04399846460077052 0.09203348917557902 8.335677343308455e-05
0.07002046960782368 0.054323573046176396 0.043998196341521786 0.09203348591121181 7.864619799432618e-05
0.0700209588939142 0.054321189090685264 0.04399794384120403 0.09203348373038626 7.420456090327051e-05
0.07002142202470818 0.054318932092686865 0.043997706024477216 0.092033482449434 7.001609845749994e-05
0.07002186028911779 0.05431679571437132 0.04399748192523963 0.0920334819127221 6.606602030714488e-05
0.07002227493608114 0.054314773879058316 0.04399727066718488 0.09203348198812449 6.234044823832279e-05
0.07002266716818116 0.05431286076904993 0.043997071449432004 0.09203348256339021

 18%|██████████████▎                                                                   | 35/200 [00:00<00:02, 61.31it/s]

0.0700243306994916 0.054304740013653686 0.04399623202650276 0.09203349006405942 4.4025007692139805e-05
0.07002461090447941 0.05430337094980156 0.043996091343069664 0.09203349207471619 4.154727890026588e-05
0.0700248757295614 0.054302076677096196 0.043995958529137166 0.09203349415796225 3.9209390477284706e-05
0.07002512599482087 0.05430085323030175 0.04399583313615045 0.09203349628506864 3.700338493784171e-05
0.0700253624792228 0.05429969684072553 0.0439957147431284 0.09203349843219664 3.492176842368395e-05
0.07002558592240246 0.054298603928097794 0.0439956029545622 0.0920335005796881 3.2957481911225e-05
0.07002579702641377 0.05429757109257406 0.043995497398548765 0.092033502711452 3.1103874485902206e-05
0.0700259964574226 0.05429659510688776 0.0439953977251212 0.09203350481443628 2.9354678498186936e-05
0.07002618484734958 0.05429567290868825 0.043995303604745024 0.09203350687817365 2.7703986434470954e-05
0.07002636279545225 0.05429480159308082 0.043995214726957985 0.0920335088943897 2.

 24%|████████████████████                                                              | 49/200 [00:00<00:02, 62.86it/s]

0.07002711478391087 0.05429111671191036 0.04399483929798595 0.09203351808789101 1.9577826778332028e-05
0.0700272410217107 0.054290497623895366 0.043994776280211485 0.09203351973105937 1.8477313320377397e-05
0.07002736021918549 0.05428991291449569 0.043994716773563695 0.09203352130713126 1.7438695104498952e-05
0.07002747276463113 0.0542893606977222 0.04399466058302638 0.09203352281652052 1.6458485796820127e-05
0.07002757902547356 0.0542888391884553 0.043994607524437915 0.09203352426006266 1.5533396276724087e-05
0.07002767934932387 0.0542883466973401 0.04399455742387133 0.09203352563893222 1.4660323343379725e-05
0.07002777406498985 0.05428788162591156 0.04399451011705303 0.09203352695457227 1.383633909144575e-05
0.07002786348344132 0.0542874424619346 0.04399446544881727 0.09203352820863546 1.3058680912815362e-05
0.07002794789873239 0.05428702777495941 0.04399442327259487 0.09203352940293329 1.2324742085281068e-05
0.07002802758888078 0.05428663621207965 0.04399438344993271 0.0920335305393

 32%|█████████████████████████▊                                                        | 63/200 [00:01<00:02, 59.22it/s]

0.07002836387295815 0.054284982842179356 0.04399421530739626 0.09203353542580076 8.710751599815898e-06
0.07002842025339369 0.05428470546973559 0.04399418709922669 0.09203353625847192 8.221216539393092e-06
0.07002847347143266 0.05428444360613963 0.04399416046779437 0.09203353704773165 7.759195966005923e-06
0.07002852370384495 0.054284196388610755 0.04399413532529677 0.09203353779556225 7.323142947815809e-06
0.07002857111760451 0.05428396300187497 0.043994111588793794 0.09203353850389569 6.911597597127781e-06
0.07002861587041852 0.054283742675602834 0.04399408917994081 0.09203353917460819 6.52318216213567e-06
0.070028658111235 0.05428353468198252 0.04399406802473443 0.09203353980951655 6.156596395449738e-06
0.0700286979807184 0.05428333833341543 0.043994048053274275 0.09203354041037523 5.810613186704436e-06
0.07002873561170472 0.05428315298033419 0.043994029199537564 0.09203354097887449 5.484074441890624e-06
0.07002877112962905 0.05428297800913426 0.04399401140116472 0.09203354151663984 

 38%|███████████████████████████████▏                                                  | 76/200 [00:01<00:02, 59.79it/s]

0.07002886615677693 0.05428250974974901 0.0439939637654223 0.09203354296080665 4.351406150901156e-06
0.07002889434259571 0.05428237082278421 0.043993949631341005 0.0920335433905882 4.106873732533649e-06
0.07002892094502917 0.05428223968400641 0.04399393628908733 0.09203354379679271 3.876083568130035e-06
0.07002894605293145 0.054282115897860535 0.04399392369441043 0.09203354418066434 3.6582633029535163e-06
0.0700289697501857 0.0542819990530331 0.04399391180552177 0.09203354454338825 3.4526840006292997e-06
0.07002899211597877 0.05428188876111442 0.04399390058295936 0.09203354488609289 3.258657700948086e-06
0.07002901322506379 0.054281784655336784 0.04399388998945888 0.09203354520985113 3.075535114929219e-06
0.07002903314800553 0.05428168638937798 0.04399387998983229 0.09203354551568294 2.9027034502122057e-06
0.07002905195141519 0.054281593636231175 0.04399387055085277 0.09203354580455708 2.7395843587736267e-06
0.07002906969816841 0.05428150608713636 0.04399386164114608 0.0920335460773933

 45%|████████████████████████████████████▉                                             | 90/200 [00:01<00:01, 62.42it/s]

0.0700291312568062 0.0542812023395938 0.0439938307268068 0.09203354702514775 2.0516113151940456e-06
0.07002914454669085 0.054281136750218346 0.043993824050809364 0.09203354723001263 1.936320424100639e-06
0.07002915708965937 0.05428107484259962 0.043993817749370785 0.09203354742343896 1.8275084054799866e-06
0.07002916892768186 0.054281016410291594 0.04399381180150458 0.09203354760605816 1.7248111645006192e-06
0.07002918010037117 0.054280961258398805 0.04399380618739789 0.0920335477784675 1.6278850687281347e-06
0.07002919064511588 0.05428090920293382 0.04399380088834502 0.09203354794123268 1.536405797953854e-06
0.07002920059720509 0.05428086007020835 0.04399379588668646 0.09203354809488822 1.4500672584760091e-06
0.07002920998994511 0.05428081369625977 0.04399379116575014 0.09203354823994009 1.3685805590646858e-06
0.07002921885477094 0.05428076992630694 0.04399378670979735 0.09203354837686648 1.2916730436826078e-06
0.07002922722135375 0.05428072861424015 0.04399378250396947 0.092033548506

 81%|████████████████████████████████████████████████████████████████▊               | 162/200 [00:01<00:00, 241.58it/s]

0.07002936382680931 0.054280053761517885 0.04399371378440343 0.0920335506179438 3.3786255474050156e-08
0.07002936404557882 0.05428005268019308 0.04399371367426583 0.09203355062132382 3.1887644510420234e-08
0.07002936425205468 0.05428005165963171 0.04399371357031724 0.09203355062451402 3.009572558767254e-08
0.07002936444692749 0.054280050696420105 0.043993713472209754 0.09203355062752476 2.8404503104095742e-08
0.07002936463084905 0.05428004978733524 0.04399371337961525 0.09203355063036635 2.6808318507714483e-08
0.07002936480443514 0.05428004892933537 0.04399371329222427 0.09203355063304827 2.5301831276235883e-08
0.07002936496826635 0.05428004811955011 0.043993713209743775 0.09203355063557946 2.3880000636401426e-08
0.07002936512289111 0.05428004735527006 0.04399371313189817 0.09203355063796827 2.2538069474097713e-08
0.07002936526882646 0.05428004663393826 0.04399371305842691 0.09203355064022285 2.1271547825134268e-08
0.07002936540656088 0.05428004595314102 0.04399371298908461 0.092033550

 94%|██████████████████████████████████████████████████████████████████████████▊     | 187/200 [00:02<00:00, 131.65it/s]

0.07002936694894399 0.054280038329345796 0.043993712212559655 0.09203355066617812 6.690317586630245e-09
0.0700293669922634 0.05428003811522039 0.0439937121907497 0.09203355066684733 6.3143567954199646e-09
0.0700293670331489 0.05428003791312758 0.04399371217016547 0.09203355066747886 5.959523032720925e-09
0.07002936707173658 0.05428003772239134 0.04399371215073773 0.09203355066807489 5.624629049682966e-09
0.07002936710815635 0.054280037542373295 0.043993712132401955 0.09203355066863754 5.308554284953576e-09
0.0700293671425291 0.05428003737247126 0.04399371211509645 0.09203355066916853 5.010241338430026e-09
0.07002936717497048 0.05428003721211699 0.04399371209876333 0.09203355066966967 4.728692003722694e-09
0.07002936720558874 0.05428003706077401 0.0439937120833483 0.09203355067014266 4.462964298781019e-09
0.07002936723448618 0.05428003691793524 0.04399371206879931 0.09203355067058908 4.2121690858421405e-09
0.07002936726175985 0.05428003678312368 0.04399371205506782 0.09203355067101032 3

100%|█████████████████████████████████████████████████████████████████████████████████| 200/200 [00:02<00:00, 89.52it/s]

0.07002936737679046 0.054280036214536705 0.043993711997154 0.0920335506727873 2.9771457194993235e-09
0.07002936739606778 0.05428003611925236 0.043993711987448635 0.09203355067308512 2.8098456728382673e-09
0.07002936741426127 0.054280036029322465 0.04399371197828857 0.0920335506733662 2.651947094023378e-09
0.0700293674314323 0.05428003594444576 0.04399371196964338 0.09203355067363139 2.5029215503787316e-09
0.07002936744763913 0.05428003586433901 0.04399371196148413 0.09203355067388172 2.3622704303563043e-09
0.07002936746293496 0.05428003578873333 0.043993711953783265 0.09203355067411809 2.2295232145611277e-09
0.07002936747737132 0.054280035717376725 0.04399371194651517 0.09203355067434105 2.1042356370618604e-09
0.07002936749099609 0.054280035650030097 0.043993711939655615 0.09203355067455152 1.9859886551216427e-09
0.07002936749099609 0.054280035650030097 0.043993711939655615 0.09203355067455152 1.9859886551216427e-09





##### fit_als_multi(signal[i_t,...], c0_guess, [c_guess[0,...], c_guess[1,...], c_guess[2,...]])

In [6]:
a,b

NameError: name 'a' is not defined

In [None]:
if False:
    def correct_AIRS_jitter(data):

        data_inpaint = copy.deepcopy(data)

        # Rescale
        ariel_load.inpaint_vectorized(data_inpaint)
        if kgs.debugging_mode>=1:
            assert not cp.any(cp.isnan(data_inpaint))
        x=cp.mean(data,(0,1))
        data[...] = data/x*base_scaling
        orig_shape = data.shape   
        data = data.reshape(-1,orig_shape[1]*orig_shape[2])

        design_matrix = cp.array(C_combined)[:2,:]
        coeffs = ariel_numerics.lstsq_nanrows_normal_eq_with_pinv_sigma(data.T, design_matrix.T, return_A_pinv_w=False)[0]
        data -= (design_matrix.T@coeffs).T

        # Unscale
        data = data.reshape(orig_shape)
        data[...] = data*x/base_scaling

        return data


    R_row = np.memmap(kgs.temp_dir + 'AIRS_row.memmap',  dtype=np.float32, mode='r', shape=(1125*len(train_data), 32*282), order='C')#[:50000,...]

    data = []
    for ii in tqdm(range(len(train_data))):
        this_data = cp.array(np.mean(R_row[1125*ii:1125*(ii+1),...],0), dtype=cp.float64).reshape(1,32,282)
        this_data = correct_AIRS_jitter(this_data)
        data.append(this_data)
    data = cp.concatenate(data)
    #data = data-cp.mean(data,0)
    data.shape
    C0_combined = cp.zeros( (32,282) )
    for i_wavelength in tqdm(range(282)):
        this_data = data[:,:,i_wavelength]    
        C0_combined[:,i_wavelength]=ariel_numerics.nan_pca(this_data,1)[1][0,:]
        C0_combined[:,i_wavelength] = C0_combined[:,i_wavelength]*np.sign(C0_combined[15,i_wavelength])
        
    kgs.dill_save(kgs.calibration_dir + 'AIRS_C0.pickle', (C0_combined))
    
C0_combined = kgs.dill_load(kgs.calibration_dir + 'AIRS_C0.pickle')
plt.figure(figsize=(12,12))
plt.imshow(cp.log(C0_combined).get(), aspect='auto', interpolation='none')
plt.colorbar()

In [None]:
def get_coeffs(data):
    
    data_inpaint = copy.deepcopy(data)

    # Rescale
    ariel_load.inpaint_vectorized(data_inpaint)
    if kgs.debugging_mode>=1:
        assert not cp.any(cp.isnan(data_inpaint))
    x=cp.mean(data_inpaint,1)
    data[...] = data/x[:,None,:]*base_scaling
    orig_shape = data.shape   
    data = data.reshape(-1,orig_shape[1]*orig_shape[2])
    
    noise_est = ariel_numerics.estimate_noise_cp(data)
    
#     plt.figure()
#     plt.imshow(cp.mean(data,0).reshape(32,282).get(), interpolation='none', aspect='auto')
    
#     plt.figure()
#     plt.imshow(ariel_numerics.estimate_noise_cov_cp(data.reshape(1125,32,282)[:,:,0]).get())
#     plt.colorbar()
    for ii in range(3):
        

#         plt.figure()
#         plt.plot(noise_est.get())

        design_matrix = cp.zeros((2+282,32*282))
        design_matrix[:2,:] = cp.array(C_combined[:2,:])
        for i_wavelength in range(282):
            #design_matrix[i_wavelength+2,32*i_wavelength:32*(i_wavelength+1)] = C0_combined[:,i_wavelength]
            design_matrix[i_wavelength+2,i_wavelength::282] = C0_combined[:,i_wavelength]
            
#         plt.figure()
#         plt.plot(cp.log(cp.sum(design_matrix,0)).get())
        
#         plt.figure()
#         plt.plot(cp.log(cp.sum(design_matrix,1)).get())
        
        assert not cp.any(cp.isnan(design_matrix))
 
        #noise_est = cp.sqrt(cp.abs(C0_combined)).flatten()
        for ii in range(8):
            noise_est[282*ii:282*(ii+1)]*=10000
            #noise_est[31-ii::32]*=100
        for ii in range(24,32):
            noise_est[282*ii:282*(ii+1)]*=10000
        res = ariel_numerics.lstsq_nanrows_normal_eq_with_pinv_sigma(data.T, design_matrix.T, return_A_pinv_w=True, sigma=noise_est)  
        coeffs = res[0]
        
        A_pinv_w = res[1]
        A_pinv_w_full = cp.zeros((284,9024))
        A_pinv_w_full[:,~cp.isnan(data[0,:])] = A_pinv_w
        # plt.figure()
        # plt.plot(A_pinv_w_full[2,::282].get())
        #plt.plot(design_matrix[2,::282].get())
        #plt.plot((noise_est[::282]/kgs.rms(noise_est[::282])).get())
        
        xx = noise_est[::282]/cp.sqrt(design_matrix[2,::282])
        #plt.semilogy((xx/kgs.rms(xx)).get())
        #print(A_pinv_w.shape, data.shape, design_matrix.shape)
        # sens = cp.zeros((32,282))
        # for i_wavelength in range(282):
        #     sens[:,i_wavelength] = A_pinv_w[:]
        #plt.figure()
        #plt.imshow(cp.mean(residual,0).reshape(32,282).get(), aspect='auto', interpolation='none')
        #plt.colorbar()
    
        assert not cp.any(cp.isnan(coeffs))

        residual = (data.T - design_matrix.T@coeffs).T
        
        # plt.figure()
        # plt.imshow(cp.mean(residual,0).reshape(32,282).get(), aspect='auto', interpolation='none')
        # plt.colorbar()
        
        noise_est = ariel_numerics.estimate_noise_cp(residual)
        
        # plt.figure()
        # plt.imshow(ariel_numerics.estimate_noise_cov_cp(residual.reshape(1125,32,282)[:,:,0]).get())
        # plt.colorbar()
        
#         plt.figure()
#         plt.plot(cp.std(coeffs[2:,:],1).get())
        
       # plt.figure()
       # plt.plot(cp.mean(coeffs[2:,:],1).get())
        
        coeffs_wl = []
        for i_wavelength in range(282):
            coeffs_wl.append( coeffs[[2+i_wavelength,0,1],:] )
            coeffs_wl[-1] = coeffs_wl[-1] / base_scaling[i_wavelength] * x[:,i_wavelength]
            
        coeffs_wl = [c.T.get() for c in coeffs_wl]
 
    return coeffs_wl
 
R_row = np.memmap(kgs.temp_dir + 'AIRS_row.memmap',  dtype=np.float32, mode='r', shape=(1125*len(train_data), 32*282), order='C')#[:50000,...]
all_coeffs =[]
for ii in tqdm(range(len(train_data[:100]))):
    data = cp.array(R_row[1125*ii:1125*(ii+1),...], dtype=cp.float64).reshape(1125,32,282)
    all_coeffs.append(get_coeffs(data))
#all_coeffs = cp.concatenate([c.T for c in all_coeffs])
#all_coeffs.shape

In [None]:
32*282

In [None]:
kgs.debugging_mode = 1
model = ariel_gp.PredictionModel()
model.run_in_parallel = False
model.model_options.n_iter=0
model.model_options.use_training_labels = True
model.train(train_data)

In [None]:
loaded_res = kgs.dill_load(kgs.temp_dir + '/prep.pickle')

In [None]:
model.infer(train_data[0:1]);

In [None]:
transits = []
for d,r in tqdm(zip(train_data,loaded_res)):
    mm= copy.deepcopy(model.results['model_mean'])
    mm= copy.deepcopy(model.results['model_mean'])
    mm.m['signal'].m['main'].m['transit'].transit_params[0][0] = d.transit_params
    mm.m['signal'].m['main'].m['transit'].transit_params[0][0].u = [0,0]
    mm.m['signal'].m['main'].m['transit'].transit_params[0][1] = d.transit_params
    mm.m['signal'].m['main'].m['transit'].transit_params[0][1].u = [0,0]
    mm.set_parameters(r[0])
    mm.m['signal'].m['main'].m['transit'].depth_model.offset = -d.diagnostics['training_spectrum']    
    obs_transit = copy.deepcopy(model.results['obs'])
    
    stellar_labels = mm.m['signal'].m['main'].m['spectrum'].get_prediction(obs_transit)
    transit_labels = mm.m['signal'].m['main'].m['transit'].get_prediction(obs_transit)
    obs_transit.labels = stellar_labels*transit_labels
    #obs_transit.labels = model.results['model_mean'].m['signal'].m['main'].m['transit'].get_prediction(obs_transit)
    #obs_transit.labels = mm.m['signal'].m['main'].get_prediction(obs_transit)
    #obs_transit.labels = mm.m['signal'].get_prediction(obs_transit)
    #print(mm.m['signal'].m['main'].m['transit'].get_parameters())
    #d.diagnostics['transit'] = obs_transit
    transits.append(obs_transit.export_matrix(True))
del loaded_res
del mm
del model
transits = np.array(np.concatenate(transits))

In [None]:
rr=[]
for i_wavelength in tqdm(range(282)):
    this_coeffs = np.concatenate([c[i_wavelength] for c in all_coeffs])
    this_transit = transits[:,i_wavelength]
    rr.append(np.linalg.lstsq(this_coeffs, this_transit, rcond=None)[0])

In [None]:
coeffs0.shape, rr[jj].shape

In [None]:
plt.figure(figsize=(12,12))
jj=0
coeffs0 = np.concatenate([c[jj] for c in all_coeffs])
pred = np.sum(coeffs0*rr[jj],1)
ii=0
slic = slice(ii*1125,(ii+1)*1125)
#plt.plot(coeffs0[slic,0]/kgs.rms(coeffs0[slic,0]))
#plt.plot(coeffs0[slic,2]/kgs.rms(coeffs0[slic,2]))
plt.plot(pred[slic]/kgs.rms(pred[slic]))
plt.plot(coeffs0[slic,0]/kgs.rms(coeffs0[slic,0]))
plt.plot(transits[slic,jj]/kgs.rms(transits[slic,jj]))

d = copy.deepcopy(train_data[ii])
loaders = ariel_load.default_loaders()
#loaders[0].cache_steps = []
for ii in range(2):
    loaders[ii].apply_pixel_corrections.mask_hot=False
loaders[1].apply_full_sensor_corrections.remove_background_based_on_rows = False
loaders[1].apply_full_sensor_corrections.remove_background_remove_used_rows = False
d.load_to_step(5, loaders)

xx = d.transits[0].data[1].data[:,jj].get()
plt.plot(xx/kgs.rms(xx))

plt.figure()
plt.scatter(xx, pred[slic])
#plt.xlim([1125,2*1125])
#plt.ylim([0.73,0.79])
print(ariel_numerics.estimate_noise_cp(cp.array(xx/kgs.rms(xx))), ariel_numerics.estimate_noise_cp(cp.array(pred[slic]/kgs.rms(pred[slic]))))
print(kgs.rms(pred), kgs.rms(xx))