In [1]:
import sys
sys.path.append('/mnt/d/ariel2/code/core/')
import kaggle_support as kgs
import importlib
import matplotlib.pyplot as plt
import numpy as np
import cupy as cp
import ariel_load
import ariel_simple
import tqdm

local


In [2]:
%%time
train_data = kgs.load_all_train_data()
test_data = kgs.load_all_test_data()
len(train_data), len(test_data)

CPU times: user 964 ms, sys: 183 ms, total: 1.15 s
Wall time: 3.29 s


(1100, 1)

In [3]:
def robust_linear_fit_rms(X, Y, drop_frac=0.05, do_print=False):
    X = np.asarray(X)
    Y = np.asarray(Y)

    # Fit Y = aX + b using least squares
    A = np.vstack([X, np.ones_like(X)]).T
    a, b = np.linalg.lstsq(A, Y, rcond=None)[0]

    # Predicted Y and residuals
    Y_pred = a * X + b
    residuals = Y - Y_pred
    abs_resid = np.abs(residuals)

    # Remove worst 5% (highest absolute residuals)
    if drop_frac>0:
        n = len(residuals)
        keep = int(n * (1 - drop_frac))
        indices = np.argpartition(abs_resid, keep)[:keep]
    #print(keep, len(residuals))
    else:
        indices = np.arange(len(residuals))

    # Compute RMS of the remaining residuals
    rms = np.sqrt(np.mean(residuals[indices] ** 2))
    if do_print:
        print(a,b)
    return rms,residuals

In [None]:
df = 0.05
kgs.sanity_checks_without_errors = True
for jj in range(13):
    model = ariel_simple.SimpleModel()
    model.run_in_parallel = (jj<6)
    match jj:
        case 0:
            name = 'Default'
        case 1:
            name = 'Supersample'
            model.supersample_factor = 5
        case 2:
            name = 'Fit eccentricity'
            model.fit_ecc = True
        case 3:
            name = 'Correction factor'
            model.use_correction_factor = True
        case 4:
            name = '4th order'
            model.order_list = [0,1,2,3,4]
        case 5:
            name = '6th order'
            model.order_list = [0,1,2,3,4,5,6]
        case 6:
            name = 'Use more AIRS rows'
            model.loaders[1].apply_full_sensor_corrections.remove_background_n_rows = 4            
        case 7:
            name = 'Don''t mask hot'
            for ii in range(2):
                model.loaders[ii].apply_pixel_corrections.mask_hot = False
            model.run_in_parallel = False
        case 8:
            name = 'No linear correction'
            for ii in range(2):
                model.loaders[ii].apply_pixel_corrections.linear_correction = False
        case 9:
            name = 'Dark current sign'
            for ii in range(2):
                model.loaders[ii].apply_pixel_corrections.dark_current_sign *= -1
        case 10:
            name = 'ADC sign'
            for ii in range(2):
                model.loaders[ii].apply_pixel_corrections.adc_offset_sign *= -1
        case 11:
            name = 'No flat field'
            for ii in range(2):
                model.loaders[ii].apply_pixel_corrections.flat_field = False
        case 12:
            name = 'Time binning x2'
            for ii in range(2):
                model.loaders[ii].apply_time_binning.time_binning*=2                
    model.train(train_data)
    data = train_data
    inferred_data = model.infer(data)
    print(name)
    solution = kgs.make_submission_dataframe(data, include_sigma=False)
    submission = kgs.make_submission_dataframe(inferred_data, False)
    print(1e6*robust_linear_fit_rms(solution.iloc[:,1].to_numpy(), solution.iloc[:,1].to_numpy()-submission.iloc[:,1].to_numpy(), drop_frac=df)[0], 1e6*robust_linear_fit_rms(solution.iloc[:,1].to_numpy(), solution.iloc[:,1].to_numpy()-submission.iloc[:,1].to_numpy(), drop_frac=0)[0])
    print(1e6*robust_linear_fit_rms(np.mean(solution.iloc[:,2:284].to_numpy(),1), np.mean(solution.iloc[:,2:284].to_numpy()-submission.iloc[:,2:284].to_numpy(),1),drop_frac=df)[0],
         1e6*robust_linear_fit_rms(np.mean(solution.iloc[:,2:284].to_numpy(),1), np.mean(solution.iloc[:,2:284].to_numpy()-submission.iloc[:,2:284].to_numpy(),1),drop_frac=0)[0])
    kgs.sanity_checks['simple_residual_diff_FGS'] = kgs.SanityCheckValue('simple_residual_diff_FGS', 12, [-1,1])
    kgs.sanity_checks['simple_residual_diff_FGS'].seen_all = [d.diagnostics['simple_residual_diff_FGS'] for d in inferred_data]
    kgs.sanity_checks['simple_residual_diff_AIRS'] = kgs.SanityCheckValue('simple_residual_diff_AIRS', 12, [-1,1])
    kgs.sanity_checks['simple_residual_diff_AIRS'].seen_all = [d.diagnostics['simple_residual_diff_AIRS'] for d in inferred_data]
    kgs.dill_save(kgs.temp_dir + '/compare_simple'+str(jj)+'.pickle', (data,inferred_data,kgs.sanity_checks,name))

Processing in parallel:   0%|                                                                  | 0/1100 [00:00<?, ?it/s]

local
local
local
local
local
local
local


Processing in parallel: 100%|███████████████████████████████████████████████████████| 1100/1100 [07:11<00:00,  2.55it/s]


Default
178.4929154418088 296.4701367272985
78.93757942054788 203.07053082721023


Processing in parallel:   0%|                                                                  | 0/1100 [00:00<?, ?it/s]

local
local
local
local
local
local
local


Processing in parallel: 100%|███████████████████████████████████████████████████████| 1100/1100 [15:32<00:00,  1.18it/s]


Supersample
182.77557466948127 309.8766964519231
76.99328969241745 211.6311795625154


Processing in parallel:   0%|                                                                  | 0/1100 [00:00<?, ?it/s]

local
local
local
local
local
local
local


Processing in parallel: 100%|███████████████████████████████████████████████████████| 1100/1100 [07:24<00:00,  2.48it/s]


Fit eccentricity
177.07325706231543 283.0428304261293
76.83788458882636 216.8438874295818


Processing in parallel:   0%|                                                                  | 0/1100 [00:00<?, ?it/s]

local
local
local
local
local
local
local


Processing in parallel: 100%|███████████████████████████████████████████████████████| 1100/1100 [06:27<00:00,  2.84it/s]


Correction factor
220.18984721947587 334.3936587047069
87.1514205781326 239.72748535899663


Processing in parallel:   0%|                                                                  | 0/1100 [00:00<?, ?it/s]

local
local
local
local
local
local
local


Processing in parallel: 100%|███████████████████████████████████████████████████████| 1100/1100 [05:10<00:00,  3.54it/s]


4th order
153.60367000646355 264.9329797359857
69.57088140922706 196.40767404769832


Processing in parallel:   0%|                                                                  | 0/1100 [00:00<?, ?it/s]

local
local
local
local
local
local
local


Processing in parallel: 100%|███████████████████████████████████████████████████████| 1100/1100 [06:29<00:00,  2.82it/s]


6th order
207.91568872624228 316.84165785211724
96.67052743259785 224.5989298636515


Inferring:  30%|████████████████████▌                                                | 328/1100 [12:12<41:13,  3.20s/it]

In [None]:
def process_solution(dat):
    solution = kgs.make_submission_dataframe(dat[0], include_sigma=False)
    submission = kgs.make_submission_dataframe(dat[1], False)
    fgs_err = robust_linear_fit_rms(solution.iloc[:,1].to_numpy(), solution.iloc[:,1].to_numpy()-submission.iloc[:,1].to_numpy(), do_print=False)[1]
    airs_err = robust_linear_fit_rms(np.mean(solution.iloc[:,2:284].to_numpy(),1), np.mean(solution.iloc[:,2:284].to_numpy()-submission.iloc[:,2:284].to_numpy(),1), do_print=False)[1]
    return fgs_err,airs_err,dat[3]
dat_base = process_solution(kgs.dill_load(kgs.temp_dir + '/compare_simple'+str(0)+'.pickle'))
for jj in range(1,13):
    dat_this = process_solution(kgs.dill_load(kgs.temp_dir + '/compare_simple'+str(jj)+'.pickle'))
    _,ax = plt.subplots(1,2,figsize=(8,4))
    for ii in range(2):
        plt.sca(ax[ii])
        plt.grid(True)
        plt.scatter(dat_base[ii], dat_this[ii])
        plt.axline((0,0), slope=1, color='black')
        plt.title(f'{1e6*kgs.rms(dat_base[ii]):.4f}->{1e6*kgs.rms(dat_this[ii]):.4f}')
        plt.xlabel(dat_base[2])
        plt.ylabel(dat_this[2])