In [1]:
from pathlib import Path

import numpy as np
import polars as pl
from tqdm.auto import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = "../data"
N_ROWS = 500000000
files = list(Path(data_dir).glob("raw_train*"))
print(f"# of train files: {len(files)}")

# of train files: 100


In [3]:
weight_df = pl.read_csv(Path(data_dir, "sample_submission.csv"), n_rows=1)
label_cols = weight_df.columns[1:]
tmp_df = pl.read_csv(Path(data_dir, "train.csv"), n_rows=1)
feat_cols = tmp_df.select(pl.exclude(label_cols)).columns[1:]
len(feat_cols), len(label_cols)

(556, 368)

In [4]:
def calculate_statistics(cols, files, chunk_size=50):
    stats = []
    gs = (len(cols) - 1) // chunk_size + 1
    for i in tqdm(range(gs)):
        start = i * chunk_size
        end = (i + 1) * chunk_size
        sub_cols = cols[start: end]
        dfs = []
        num_data = 0
        for filename in files:
            dfs.append(pl.read_parquet(filename, columns=sub_cols))
            num_data += len(dfs[-1])
            if num_data > N_ROWS:
                break
        df = pl.concat(dfs)
        for col in sub_cols:
            m = df.filter(pl.col(col) > 1e-7)[col].mean()
            m = m if m is not None else 1e-50
            lambda_val = 1 / m
            stats.append({
                "mean": df[col].mean(),
                "std": df[col].std(),
                "q1_4": df[col].quantile(0.25),
                "q2_4": df[col].quantile(0.5),
                "q3_4": df[col].quantile(0.75),
                "min": df[col].min(),
                "max": df[col].max(),
                "lambda": lambda_val,
                "std_y": np.sqrt((df[col]**2).mean()),
            })
    stats_df = (
        pl.from_dicts(stats)
        .transpose(include_header=True, header_name="stats", column_names=cols)
    )
    return stats_df


feat_stats_df = calculate_statistics(feat_cols, files)
label_stats_df = calculate_statistics(label_cols, files)

100% 12/12 [09:40<00:00, 48.41s/it]
100% 8/8 [06:00<00:00, 45.08s/it]


In [5]:
feat_stats_df.shape, label_stats_df.shape

((9, 557), (9, 369))

In [6]:
stats_df = pl.concat([feat_stats_df, label_stats_df[:,1:]], how="horizontal")
stats_df

stats,state_t_0,state_t_1,state_t_2,state_t_3,state_t_4,state_t_5,state_t_6,state_t_7,state_t_8,state_t_9,state_t_10,state_t_11,state_t_12,state_t_13,state_t_14,state_t_15,state_t_16,state_t_17,state_t_18,state_t_19,state_t_20,state_t_21,state_t_22,state_t_23,state_t_24,state_t_25,state_t_26,state_t_27,state_t_28,state_t_29,state_t_30,state_t_31,state_t_32,state_t_33,state_t_34,state_t_35,…,ptend_v_31,ptend_v_32,ptend_v_33,ptend_v_34,ptend_v_35,ptend_v_36,ptend_v_37,ptend_v_38,ptend_v_39,ptend_v_40,ptend_v_41,ptend_v_42,ptend_v_43,ptend_v_44,ptend_v_45,ptend_v_46,ptend_v_47,ptend_v_48,ptend_v_49,ptend_v_50,ptend_v_51,ptend_v_52,ptend_v_53,ptend_v_54,ptend_v_55,ptend_v_56,ptend_v_57,ptend_v_58,ptend_v_59,cam_out_NETSW,cam_out_FLWDS,cam_out_PRECSC,cam_out_PRECC,cam_out_SOLS,cam_out_SOLL,cam_out_SOLSD,cam_out_SOLLD
str,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""mean""",215.612015,227.878262,237.309279,247.912285,256.227916,259.451296,255.271369,246.698109,236.980836,230.274286,225.054863,220.928353,217.108561,213.894105,210.537094,207.052169,202.847109,200.041338,199.271441,201.226921,203.712135,206.900332,210.434994,214.236472,218.053071,221.900889,225.667783,229.347179,232.912757,236.379712,239.742185,243.019442,246.204743,249.292946,252.269269,255.123318,…,-2.868e-09,2.0688e-08,3.4469e-08,3.1557e-08,7.1815e-09,-6.3038e-08,-1.5727e-07,-2.4531e-07,-3.0017e-07,-3.0094e-07,-2.2172e-07,-7.6118e-08,1.1783e-07,3.2319e-07,4.9466e-07,6.1731e-07,6.3501e-07,5.3735e-07,3.3996e-07,8.3744e-08,-2.3336e-07,-6.8066e-07,-1e-06,-1e-06,-1e-06,-5.333e-07,6.2734e-07,1e-06,6.5831e-07,158.311244,351.273302,2.7343e-09,2.9194e-08,61.803931,67.324685,33.451249,17.676368
"""std""",6.654266,8.653689,8.248617,6.786184,6.242818,8.249606,10.157319,10.112921,9.218096,8.81164,8.137254,7.43026,6.723265,6.321632,6.513009,7.6074,9.966565,11.792149,10.03546,8.244552,5.475349,4.073948,4.320332,5.478911,6.961327,8.309486,9.521203,10.517859,11.344936,12.00057,12.514359,12.893314,13.160995,13.32794,13.405421,13.407873,…,8e-06,8e-06,7e-06,7e-06,7e-06,7e-06,8e-06,9e-06,9e-06,1e-05,1.1e-05,1.1e-05,1.2e-05,1.3e-05,1.4e-05,1.5e-05,1.6e-05,1.7e-05,1.8e-05,2e-05,2.2e-05,2.5e-05,2.8e-05,2.9e-05,3e-05,2.9e-05,2.8e-05,2.4e-05,3.4e-05,246.661342,71.979277,7.39e-09,8.1856e-08,110.255081,116.424189,46.397329,29.70672
"""q1_4""",211.795845,222.243162,231.759324,243.668977,252.799997,256.235228,251.635947,243.449935,234.437349,228.097416,223.197921,219.205218,215.187527,211.551624,207.162932,201.871982,194.857171,189.825833,190.678675,194.104774,199.497332,204.247093,208.499765,211.28707,213.509142,215.931538,218.53578,221.356252,224.305857,227.357786,230.46927,233.648802,236.847711,240.04982,243.221606,246.322576,…,-5.5398e-07,-5.2138e-07,-4.9227e-07,-4.6982e-07,-4.4817e-07,-4.3488e-07,-4.3559e-07,-4.4466e-07,-4.6692e-07,-5.0584e-07,-5.6217e-07,-6.3387e-07,-7.3155e-07,-8.551e-07,-1e-06,-1e-06,-1e-06,-2e-06,-2e-06,-3e-06,-3e-06,-4e-06,-5e-06,-6e-06,-7e-06,-8e-06,-7e-06,-6e-06,-9e-06,0.0,311.199335,0.0,0.0,0.0,0.0,0.0,0.0
"""q2_4""",215.558519,228.043463,236.547149,247.060054,255.854526,260.29541,257.143055,248.743841,239.033878,232.473792,227.008668,222.471784,217.942912,213.953813,209.621341,205.616373,201.811222,199.379132,198.640697,200.417245,203.053164,206.646559,210.628659,215.049439,219.470938,223.699178,227.787865,231.76676,235.647794,239.425994,243.096914,246.65641,250.090561,253.374147,256.474984,259.427914,…,1.6653e-18,4.4409e-18,5.181e-18,4.811e-18,2.9606e-18,1.1102e-18,-2.313e-18,-8.1416e-18,-1.2212e-17,-1.1842e-17,-3.7007e-18,3.7007e-19,5.9212e-18,5.9614e-16,1.5969e-14,2.1081e-13,4.4083e-13,1.6354e-13,8.4976e-14,4.3469e-15,5.4956e-17,-1.3878e-17,-3.6602e-15,-6.1598e-12,-5.9597e-10,-1.2688e-08,-6.2603e-10,5.262e-11,2.409e-07,0.016224,363.88266,0.0,3.4721e-09,4.9415e-18,1.3405e-15,0.890961,0.735684
"""q3_4""",219.287002,233.795692,242.07833,251.166304,259.263806,263.889139,261.154271,252.387389,242.001918,235.063319,229.312532,224.574035,220.025695,216.310257,213.453864,211.56516,210.132128,209.441833,206.956568,207.596944,207.551485,209.507191,212.932085,217.680374,222.951198,228.15249,233.138784,237.828987,242.221519,246.335236,250.199286,253.830744,257.231715,260.441533,263.444719,266.252868,…,5.2116e-07,5.0744e-07,4.8461e-07,4.5758e-07,4.3468e-07,4.1226e-07,3.985e-07,4.0108e-07,4.2253e-07,4.6338e-07,5.3128e-07,6.346e-07,7.7652e-07,9.7377e-07,1e-06,2e-06,2e-06,2e-06,2e-06,3e-06,3e-06,4e-06,4e-06,5e-06,6e-06,7e-06,7e-06,8e-06,1.2e-05,247.601915,409.656305,3.3229e-10,1.9853e-08,77.337254,91.911309,57.72635,24.646142
"""min""",142.895886,139.066952,168.246148,203.463584,210.054672,203.917759,196.501913,188.514183,181.753242,177.027108,174.911598,174.306802,173.959746,174.677959,175.853027,177.496814,168.125578,168.136816,175.502996,182.705558,184.31821,185.494009,186.734346,188.217112,189.4428,191.103577,192.623814,194.461734,195.771591,197.105422,197.228953,199.455183,201.346508,202.799973,204.643092,206.046347,…,-0.000666,-0.00048,-0.000454,-0.000623,-0.000574,-0.000645,-0.000565,-0.000522,-0.000551,-0.000554,-0.000539,-0.000619,-0.000504,-0.000572,-0.000558,-0.000435,-0.000453,-0.000457,-0.000512,-0.000545,-0.000751,-0.000943,-0.00111,-0.001148,-0.00068,-0.000568,-0.000486,-0.000551,-0.000771,0.0,57.217559,0.0,0.0,0.0,0.0,0.0,0.0
"""max""",427.707429,313.808117,292.269848,310.633125,309.992211,298.913839,291.821315,288.738876,273.351677,266.611876,259.086748,251.285942,244.165575,239.915048,236.040182,234.870412,233.052619,233.088381,230.410183,229.080932,227.625933,226.132173,229.046415,231.763924,237.234112,241.730175,247.452713,252.747082,255.713055,259.626455,263.605989,266.503223,270.11649,273.81052,277.026922,279.045703,…,0.000723,0.000526,0.000431,0.000438,0.000429,0.000559,0.000528,0.000531,0.000479,0.000437,0.000519,0.000613,0.000631,0.000755,0.000557,0.000746,0.000629,0.000672,0.000759,0.000664,0.000759,0.000495,0.000667,0.000607,0.000614,0.000578,0.000486,0.000502,0.000731,1106.681756,528.098066,2.9351e-07,2e-06,518.902766,575.257308,425.175775,271.069632
"""lambda""",0.004638,0.004388,0.004214,0.004034,0.003903,0.003854,0.003917,0.004054,0.00422,0.004343,0.004443,0.004526,0.004606,0.004675,0.00475,0.00483,0.00493,0.004999,0.005018,0.00497,0.004909,0.004833,0.004752,0.004668,0.004586,0.004507,0.004431,0.00436,0.004293,0.00423,0.004171,0.004115,0.004062,0.004011,0.003964,0.00392,…,253567.14686,270095.982533,285395.839227,298594.249955,308016.870304,313925.365697,312038.447775,301743.890388,285239.88992,262553.756033,233856.871288,205428.505703,179065.358406,156572.340253,138147.423086,123897.372151,113937.42772,106547.136689,100244.65864,93433.200825,85890.184672,77612.698486,69987.88213,63609.390158,59110.291081,56469.007977,56956.200814,61523.05607,50151.003141,0.003166,0.002847,8356800.0,3803500.0,0.007911,0.007292,0.015347,0.029044
"""std_y""",215.714673,228.042515,237.452592,248.005148,256.303956,259.582416,255.473371,246.905302,237.160051,230.442817,225.201924,221.053264,217.212637,213.987502,210.63781,207.191875,203.091807,200.388601,199.523978,201.395745,203.785705,206.940437,210.479338,214.30652,218.164163,222.056416,225.868549,229.588227,233.188894,236.68414,240.068583,243.361227,246.556256,249.648968,252.625196,255.475397,…,8e-06,8e-06,7e-06,7e-06,7e-06,7e-06,8e-06,9e-06,9e-06,1e-05,1.1e-05,1.1e-05,1.2e-05,1.3e-05,1.4e-05,1.5e-05,1.6e-05,1.7e-05,1.8e-05,2e-05,2.2e-05,2.5e-05,2.8e-05,2.9e-05,3e-05,2.9e-05,2.8e-05,2.4e-05,3.4e-05,293.094288,358.572096,7.8796e-09,8.6907e-08,126.395836,134.488675,57.198759,34.567949


In [7]:
feat_stats_df.write_parquet(Path(data_dir, "feat_stats.parquet"))
label_stats_df.write_parquet(Path(data_dir, "label_stats.parquet"))
stats_df.write_parquet(Path(data_dir, "stats.parquet"))