# First Model: CNN

## Import packages

In [1]:
import os
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import itertools
import glob 
import seaborn as sns
import tensorflow as tf
import multiprocessing as mp
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Lambda
from tensorflow.keras.models import Model
from astropy.stats import sigma_clip
from tqdm import tqdm
from multiprocessing import Pool
from concurrent.futures import ThreadPoolExecutor, as_completed

sns.set_theme(style='dark')
palette = sns.color_palette('muted')
pd.set_option('display.max_columns', None)

2025-09-30 03:58:16.582752: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759204696.790062      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759204696.849907      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Load and calibrate the data

In [2]:
path_folder = '/kaggle/input/ariel-data-challenge-2025/' 
path_out = '/kaggle/tmp/data_light_raw/'

if not os.path.exists(path_out):
    os.makedirs(path_out)
    print(f"Directory {path_out} created.")
else:
    print(f"Directory {path_out} already exists.")

CHUNKS_SIZE = 4

Directory /kaggle/tmp/data_light_raw/ created.


In [3]:
def ADC_convert(signal, gain=0.4369, offset=-1000):
    """The Analog-to-Digital Conversion (adc) is performed by the detector to convert
    the pixel voltage into an integer number. Since we are using the same conversion number 
    this year, we have simply hard-coded it inside."""
    signal = signal.astype(np.float64)
    signal /= gain
    signal += offset
    return signal

def mask_hot_dead(signal, dead, dark):
    hot = sigma_clip(
        dark, sigma=5, maxiters=5
    ).mask
    hot = np.tile(hot, (signal.shape[0], 1, 1))
    dead = np.tile(dead, (signal.shape[0], 1, 1))
    signal = np.ma.masked_where(dead, signal)
    signal = np.ma.masked_where(hot, signal)
    return signal

def apply_linear_corr(linear_corr, clean_signal):
    linear_corr = np.flip(linear_corr, axis=0)
    for x, y in itertools.product(
                range(clean_signal.shape[1]), range(clean_signal.shape[2])
            ):
        poli = np.poly1d(linear_corr[:, x, y])
        clean_signal[:, x, y] = poli(clean_signal[:, x, y])
    return clean_signal

def clean_dark(signal, dead, dark, dt):
    dark = np.ma.masked_where(dead, dark)
    dark = np.tile(dark, (signal.shape[0], 1, 1))
    signal -= dark * dt[:, np.newaxis, np.newaxis]
    return signal

def get_cds(signal):
    cds = signal[:,1::2,:,:] - signal[:,::2,:,:]
    return cds

def bin_obs(cds_signal, binning):
    cds_transposed = cds_signal.transpose(0,1,3,2)
    cds_binned = np.zeros((cds_transposed.shape[0], cds_transposed.shape[1] // binning, cds_transposed.shape[2], cds_transposed.shape[3]))
    for i in range(cds_transposed.shape[1] // binning):
        cds_binned[:,i,:,:] = np.sum(cds_transposed[:,i*binning:(i+1)*binning,:,:], axis=1)
    return cds_binned

def correct_flat_field(flat, dead, signal):
    flat = flat.transpose(1, 0)
    dead = dead.transpose(1, 0)
    flat = np.ma.masked_where(dead, flat)
    flat = np.tile(flat, (signal.shape[0], 1, 1))
    signal /= flat
    return signal

def get_index(files, CHUNKS_SIZE):
    index = []
    for file in files:
        file_name = file.split('/')[-1]
        if file_name.split('_')[0] == 'AIRS-CH0' and file_name.split('_')[1] == 'signal' and file_name.split('_')[2] == '0.parquet':
            file_index = os.path.basename(os.path.dirname(file))
            index.append(int(file_index))
    index = np.array(index)
    index = np.sort(index) 
    # credit to DennisSakva
    index = np.array_split(index, len(index) // CHUNKS_SIZE)
    
    return index

In [4]:
files = glob.glob(os.path.join(path_folder + 'train/', '*/*'))
index = get_index(files, CHUNKS_SIZE)

In [5]:
axis_info = pd.read_parquet(os.path.join(path_folder,'axis_info.parquet'))
DO_MASK = False
DO_THE_NL_CORR = False
DO_DARK = False
DO_FLAT = False
TIME_BINNING = True

cut_inf, cut_sup = 39, 321
l = cut_sup - cut_inf

In [6]:
def load_calibration_data_batch(path_folder, index_chunk, cut_inf, cut_sup, dataset):
    """Load all calibration data for the chunk at once"""
    calibration_data = {}
    
    for idx in index_chunk:
        calibration_data[idx] = {}
        
        # AIRS calibration data
        airs_flat = pd.read_parquet(os.path.join(path_folder, f'{dataset}/{idx}/AIRS-CH0_calibration_0/flat.parquet'))
        calibration_data[idx]['airs_flat'] = airs_flat.values.astype(np.float32).reshape((32, 356))[:, cut_inf:cut_sup]
        
        airs_dark = pd.read_parquet(os.path.join(path_folder, f'{dataset}/{idx}/AIRS-CH0_calibration_0/dark.parquet'))
        calibration_data[idx]['airs_dark'] = airs_dark.values.astype(np.float32).reshape((32, 356))[:, cut_inf:cut_sup]
        
        airs_dead = pd.read_parquet(os.path.join(path_folder, f'{dataset}/{idx}/AIRS-CH0_calibration_0/dead.parquet'))
        calibration_data[idx]['airs_dead'] = airs_dead.values.astype(np.float32).reshape((32, 356))[:, cut_inf:cut_sup]
        
        airs_linear = pd.read_parquet(os.path.join(path_folder, f'{dataset}/{idx}/AIRS-CH0_calibration_0/linear_corr.parquet'))
        calibration_data[idx]['airs_linear'] = airs_linear.values.astype(np.float32).reshape((6, 32, 356))[:, :, cut_inf:cut_sup]
    
    return calibration_data

In [7]:
def process_single_observation(args):
    """Process a single observation with all AIRS and FGS1 cleaning steps"""
    # Unpack the arguments
    (i, index_chunk, path_folder, cut_inf, cut_sup, l, axis_info, calibration_data, DO_MASK, DO_THE_NL_CORR, DO_DARK) = args
    
    idx = index_chunk[i]
    
    # AIRS Processing
    # Load AIRS signal data
    df = pd.read_parquet(os.path.join(path_folder, f'train/{idx}/AIRS-CH0_signal_0.parquet'))
    signal = df.values.astype(np.float32).reshape((df.shape[0], 32, 356))

    # 1. ADC Conversion
    signal = ADC_convert(signal)
    dt_airs = axis_info['AIRS-CH0-integration_time'].dropna().values
    dt_airs[1::2] += 0.1
    chopped_signal = signal[:, :, cut_inf:cut_sup]
    del signal, df
    
    # Get pre-loaded calibration data for AIRS
    flat = calibration_data[idx]['airs_flat']
    dark = calibration_data[idx]['airs_dark'] 
    dead_airs = calibration_data[idx]['airs_dead']
    linear_corr = calibration_data[idx]['airs_linear']

    # 2. Mask Hot/Dead Pixels
    if DO_MASK:
        chopped_signal = mask_hot_dead(chopped_signal, dead_airs, dark)

    # 3. Linearity Correction
    if DO_THE_NL_CORR: 
        linear_corr_signal = apply_linear_corr(linear_corr, chopped_signal)
        chopped_signal = linear_corr_signal

    # 4. Dark Current Subtraction
    if DO_DARK: 
        cleaned_signal = clean_dark(chopped_signal, dead_airs, dark, dt_airs)
        chopped_signal = cleaned_signal

    # Store AIRS result
    airs_result = chopped_signal
    
    # Return the processed results
    return i, airs_result

In [8]:
def process_single_observation_test(args):
    """Process a single observation with all AIRS and FGS1 cleaning steps"""
    # Unpack the arguments
    (i, index_chunk, path_folder, cut_inf, cut_sup, l, axis_info, calibration_data, DO_MASK, DO_THE_NL_CORR, DO_DARK) = args
    
    idx = index_chunk[i]
    
    # AIRS Processing
    # Load AIRS signal data
    df = pd.read_parquet(os.path.join(path_folder, f'test/{idx}/AIRS-CH0_signal_0.parquet'))
    signal = df.values.astype(np.float32).reshape((df.shape[0], 32, 356))

    # 1. ADC Conversion
    signal = ADC_convert(signal)
    dt_airs = axis_info['AIRS-CH0-integration_time'].dropna().values
    dt_airs[1::2] += 0.1
    chopped_signal = signal[:, :, cut_inf:cut_sup]
    del signal, df
    
    # Get pre-loaded calibration data for AIRS
    flat = calibration_data[idx]['airs_flat']
    dark = calibration_data[idx]['airs_dark'] 
    dead_airs = calibration_data[idx]['airs_dead']
    linear_corr = calibration_data[idx]['airs_linear']

    # 2. Mask Hot/Dead Pixels
    if DO_MASK:
        chopped_signal = mask_hot_dead(chopped_signal, dead_airs, dark)

    # 3. Linearity Correction
    if DO_THE_NL_CORR: 
        linear_corr_signal = apply_linear_corr(linear_corr, chopped_signal)
        chopped_signal = linear_corr_signal

    # 4. Dark Current Subtraction
    if DO_DARK: 
        cleaned_signal = clean_dark(chopped_signal, dead_airs, dark, dt_airs)
        chopped_signal = cleaned_signal

    # Store AIRS result
    airs_result = chopped_signal
    
    # Return the processed results
    return i, airs_result

In [9]:
for n, index_chunk in enumerate(tqdm(index)):
    # Load all calibration data once at the beginning 
    calibration_data = load_calibration_data_batch(path_folder, index_chunk, cut_inf, cut_sup, 'train')
    
    # Pre-allocate output arrays
    AIRS_CH0_clean = np.zeros((CHUNKS_SIZE, 11250, 32, l), dtype=np.float32)
    
    # Parallel Processing
    # Determine number of workers 
    num_workers = min(2, CHUNKS_SIZE)
    
    # Prepare arguments for each observation
    args_list = []
    for i in range(CHUNKS_SIZE):
        args = (i, index_chunk, path_folder, cut_inf, cut_sup, l, axis_info, calibration_data, DO_MASK, DO_THE_NL_CORR, DO_DARK)
        args_list.append(args)
        
    # Process observations in parallel
    results = []
    with ThreadPoolExecutor(max_workers=num_workers) as executor:        
        # Submit all tasks to the thread pool
        future_to_index = {executor.submit(process_single_observation, args): args[0] for args in args_list}
                
        # Collect results as they complete
        for future in tqdm(as_completed(future_to_index), total=CHUNKS_SIZE, desc=f"Processing observations"):
            i, airs_result = future.result()
            results.append((i, airs_result))
        
    # Sort results by observation index (i) to maintain order
    results.sort(key=lambda x: x[0])
    
    # Store results in existing arrays
    for result in results:
        i, airs_result= result
        AIRS_CH0_clean[i] = airs_result
        
    # 5. Get Correlated Double Sampling
    AIRS_cds = get_cds(AIRS_CH0_clean)
    
    del AIRS_CH0_clean
    
    # 6. (Optional) Time Binning (to reduce space)
    if TIME_BINNING:
        AIRS_cds_binned = bin_obs(AIRS_cds, binning=30)
    else:
        AIRS_cds = AIRS_cds.transpose(0,1,3,2)
        AIRS_cds_binned = AIRS_cds
    
    del AIRS_cds

    # 7. Flat Field Correction - use pre-loaded calibration data
    for i in range(CHUNKS_SIZE):
        if DO_FLAT:
            flat_airs = calibration_data[index_chunk[i]]['airs_flat']
            dead_airs = calibration_data[index_chunk[i]]['airs_dead']
            
            corrected_AIRS_cds_binned = correct_flat_field(flat_airs, dead_airs, AIRS_cds_binned[i])
            AIRS_cds_binned[i] = corrected_AIRS_cds_binned

    # Save data
    np.save(os.path.join(path_out, 'AIRS_clean_train_{}.npy'.format(n)), AIRS_cds_binned)
    
    del AIRS_cds_binned, calibration_data
    gc.collect()

  0%|          | 0/275 [00:00<?, ?it/s]
Processing observations:   0%|          | 0/4 [00:00<?, ?it/s][A
Processing observations:  25%|██▌       | 1/4 [00:03<00:09,  3.13s/it][A
Processing observations:  75%|███████▌  | 3/4 [00:05<00:01,  1.86s/it][A
Processing observations: 100%|██████████| 4/4 [00:06<00:00,  1.52s/it][A
  0%|          | 1/275 [00:09<44:58,  9.85s/it]
Processing observations:   0%|          | 0/4 [00:00<?, ?it/s][A
Processing observations:  25%|██▌       | 1/4 [00:03<00:09,  3.00s/it][A
Processing observations:  50%|█████     | 2/4 [00:03<00:02,  1.35s/it][A
Processing observations:  75%|███████▌  | 3/4 [00:05<00:01,  1.91s/it][A
Processing observations: 100%|██████████| 4/4 [00:05<00:00,  1.49s/it][A
  1%|          | 2/275 [00:19<44:27,  9.77s/it]
Processing observations:   0%|          | 0/4 [00:00<?, ?it/s][A
Processing observations:  25%|██▌       | 1/4 [00:02<00:08,  2.89s/it][A
Processing observations:  50%|█████     | 2/4 [00:02<00:02,  1.25s/it][A


In [10]:
TRAIN_CHUNKS_SIZE = CHUNKS_SIZE
CHUNKS_SIZE = 1
test_files = glob.glob(os.path.join(path_folder + 'test/', '*/*'))
test_index = get_index(test_files, CHUNKS_SIZE)

for n, index_chunk in enumerate(tqdm(test_index)):
    # Load all calibration data once at the beginning 
    calibration_data = load_calibration_data_batch(path_folder, index_chunk, cut_inf, cut_sup, 'test')
    
    # Pre-allocate output arrays
    AIRS_CH0_clean = np.zeros((CHUNKS_SIZE, 11250, 32, l), dtype=np.float32)
    
    # Parallel Processing
    # Determine number of workers (start with 2 to be safe)
    num_workers = min(2, CHUNKS_SIZE)
    
    # Prepare arguments for each observation
    args_list = []
    for i in range(CHUNKS_SIZE):
        args = (i, index_chunk, path_folder, cut_inf, cut_sup, l, axis_info, calibration_data, DO_MASK, DO_THE_NL_CORR, DO_DARK)
        args_list.append(args)
        
    # Process observations in parallel
    results = []
    with ThreadPoolExecutor(max_workers=num_workers) as executor:        
        # Submit all tasks to the thread pool
        future_to_index = {executor.submit(process_single_observation_test, args): args[0] for args in args_list}
                
        # Collect results as they complete
        for future in tqdm(as_completed(future_to_index), total=CHUNKS_SIZE, desc=f"Processing observations"):
            i, airs_result = future.result()
            results.append((i, airs_result))
        
    # Sort results by observation index (i) to maintain order
    results.sort(key=lambda x: x[0])
    
    # Store results in existing arrays
    for result in results:
        i, airs_result = result
        AIRS_CH0_clean[i] = airs_result
        
    # 5. Get Correlated Double Sampling
    AIRS_cds = get_cds(AIRS_CH0_clean)
    
    del AIRS_CH0_clean
    
    # 6. (Optional) Time Binning (to reduce space)
    if TIME_BINNING:
        AIRS_cds_binned = bin_obs(AIRS_cds, binning=30)
    else:
        AIRS_cds = AIRS_cds.transpose(0,1,3,2)
        AIRS_cds_binned = AIRS_cds
    
    del AIRS_cds

    # 7. Flat Field Correction - use pre-loaded calibration data
    for i in range(CHUNKS_SIZE):
        if DO_FLAT:
            flat_airs = calibration_data[index_chunk[i]]['airs_flat']  
            dead_airs = calibration_data[index_chunk[i]]['airs_dead']
            
            corrected_AIRS_cds_binned = correct_flat_field(flat_airs, dead_airs, AIRS_cds_binned[i])
            AIRS_cds_binned[i] = corrected_AIRS_cds_binned

    # Save data
    np.save(os.path.join(path_out, 'AIRS_clean_test_{}.npy'.format(n)), AIRS_cds_binned)
    del AIRS_cds_binned, calibration_data

  0%|          | 0/1 [00:00<?, ?it/s]
Processing observations:   0%|          | 0/1 [00:00<?, ?it/s][A
Processing observations: 100%|██████████| 1/1 [00:02<00:00,  2.88s/it][A
100%|██████████| 1/1 [00:03<00:00,  3.82s/it]


In [11]:
def load_data(file, chunk_size, nb_files): 
    data0 = np.load(file + '_0.npy')
    data_all = np.zeros((nb_files * chunk_size, data0.shape[1], data0.shape[2], data0.shape[3]))
    data_all[:chunk_size] = data0
    for i in range(1, nb_files): 
        data_all[i * chunk_size : (i+1) * chunk_size] = np.load(file + '_{}.npy'.format(i))
    return data_all 

data_train_AIRS = load_data(path_out + 'AIRS_clean_train', TRAIN_CHUNKS_SIZE, len(index)) 

print(data_train_AIRS.shape)


(1100, 187, 282, 32)


In [12]:
data_test_AIRS = load_data(path_out + 'AIRS_clean_test', CHUNKS_SIZE, 1) 

print(data_test_AIRS.shape)

(1, 187, 282, 32)


In [13]:
df_train = pd.read_csv(path_folder + 'train.csv')
df_train.set_index('planet_id', inplace=True)

planet_ids = np.concatenate(index)
df_train = df_train[df_train.index.isin(planet_ids)]

print(df_train.shape)
df_train.head()

(1100, 283)


Unnamed: 0_level_0,wl_1,wl_2,wl_3,wl_4,wl_5,wl_6,wl_7,wl_8,wl_9,wl_10,wl_11,wl_12,wl_13,wl_14,wl_15,wl_16,wl_17,wl_18,wl_19,wl_20,wl_21,wl_22,wl_23,wl_24,wl_25,wl_26,wl_27,wl_28,wl_29,wl_30,wl_31,wl_32,wl_33,wl_34,wl_35,wl_36,wl_37,wl_38,wl_39,wl_40,wl_41,wl_42,wl_43,wl_44,wl_45,wl_46,wl_47,wl_48,wl_49,wl_50,wl_51,wl_52,wl_53,wl_54,wl_55,wl_56,wl_57,wl_58,wl_59,wl_60,wl_61,wl_62,wl_63,wl_64,wl_65,wl_66,wl_67,wl_68,wl_69,wl_70,wl_71,wl_72,wl_73,wl_74,wl_75,wl_76,wl_77,wl_78,wl_79,wl_80,wl_81,wl_82,wl_83,wl_84,wl_85,wl_86,wl_87,wl_88,wl_89,wl_90,wl_91,wl_92,wl_93,wl_94,wl_95,wl_96,wl_97,wl_98,wl_99,wl_100,wl_101,wl_102,wl_103,wl_104,wl_105,wl_106,wl_107,wl_108,wl_109,wl_110,wl_111,wl_112,wl_113,wl_114,wl_115,wl_116,wl_117,wl_118,wl_119,wl_120,wl_121,wl_122,wl_123,wl_124,wl_125,wl_126,wl_127,wl_128,wl_129,wl_130,wl_131,wl_132,wl_133,wl_134,wl_135,wl_136,wl_137,wl_138,wl_139,wl_140,wl_141,wl_142,wl_143,wl_144,wl_145,wl_146,wl_147,wl_148,wl_149,wl_150,wl_151,wl_152,wl_153,wl_154,wl_155,wl_156,wl_157,wl_158,wl_159,wl_160,wl_161,wl_162,wl_163,wl_164,wl_165,wl_166,wl_167,wl_168,wl_169,wl_170,wl_171,wl_172,wl_173,wl_174,wl_175,wl_176,wl_177,wl_178,wl_179,wl_180,wl_181,wl_182,wl_183,wl_184,wl_185,wl_186,wl_187,wl_188,wl_189,wl_190,wl_191,wl_192,wl_193,wl_194,wl_195,wl_196,wl_197,wl_198,wl_199,wl_200,wl_201,wl_202,wl_203,wl_204,wl_205,wl_206,wl_207,wl_208,wl_209,wl_210,wl_211,wl_212,wl_213,wl_214,wl_215,wl_216,wl_217,wl_218,wl_219,wl_220,wl_221,wl_222,wl_223,wl_224,wl_225,wl_226,wl_227,wl_228,wl_229,wl_230,wl_231,wl_232,wl_233,wl_234,wl_235,wl_236,wl_237,wl_238,wl_239,wl_240,wl_241,wl_242,wl_243,wl_244,wl_245,wl_246,wl_247,wl_248,wl_249,wl_250,wl_251,wl_252,wl_253,wl_254,wl_255,wl_256,wl_257,wl_258,wl_259,wl_260,wl_261,wl_262,wl_263,wl_264,wl_265,wl_266,wl_267,wl_268,wl_269,wl_270,wl_271,wl_272,wl_273,wl_274,wl_275,wl_276,wl_277,wl_278,wl_279,wl_280,wl_281,wl_282,wl_283
planet_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1,Unnamed: 50_level_1,Unnamed: 51_level_1,Unnamed: 52_level_1,Unnamed: 53_level_1,Unnamed: 54_level_1,Unnamed: 55_level_1,Unnamed: 56_level_1,Unnamed: 57_level_1,Unnamed: 58_level_1,Unnamed: 59_level_1,Unnamed: 60_level_1,Unnamed: 61_level_1,Unnamed: 62_level_1,Unnamed: 63_level_1,Unnamed: 64_level_1,Unnamed: 65_level_1,Unnamed: 66_level_1,Unnamed: 67_level_1,Unnamed: 68_level_1,Unnamed: 69_level_1,Unnamed: 70_level_1,Unnamed: 71_level_1,Unnamed: 72_level_1,Unnamed: 73_level_1,Unnamed: 74_level_1,Unnamed: 75_level_1,Unnamed: 76_level_1,Unnamed: 77_level_1,Unnamed: 78_level_1,Unnamed: 79_level_1,Unnamed: 80_level_1,Unnamed: 81_level_1,Unnamed: 82_level_1,Unnamed: 83_level_1,Unnamed: 84_level_1,Unnamed: 85_level_1,Unnamed: 86_level_1,Unnamed: 87_level_1,Unnamed: 88_level_1,Unnamed: 89_level_1,Unnamed: 90_level_1,Unnamed: 91_level_1,Unnamed: 92_level_1,Unnamed: 93_level_1,Unnamed: 94_level_1,Unnamed: 95_level_1,Unnamed: 96_level_1,Unnamed: 97_level_1,Unnamed: 98_level_1,Unnamed: 99_level_1,Unnamed: 100_level_1,Unnamed: 101_level_1,Unnamed: 102_level_1,Unnamed: 103_level_1,Unnamed: 104_level_1,Unnamed: 105_level_1,Unnamed: 106_level_1,Unnamed: 107_level_1,Unnamed: 108_level_1,Unnamed: 109_level_1,Unnamed: 110_level_1,Unnamed: 111_level_1,Unnamed: 112_level_1,Unnamed: 113_level_1,Unnamed: 114_level_1,Unnamed: 115_level_1,Unnamed: 116_level_1,Unnamed: 117_level_1,Unnamed: 118_level_1,Unnamed: 119_level_1,Unnamed: 120_level_1,Unnamed: 121_level_1,Unnamed: 122_level_1,Unnamed: 123_level_1,Unnamed: 124_level_1,Unnamed: 125_level_1,Unnamed: 126_level_1,Unnamed: 127_level_1,Unnamed: 128_level_1,Unnamed: 129_level_1,Unnamed: 130_level_1,Unnamed: 131_level_1,Unnamed: 132_level_1,Unnamed: 133_level_1,Unnamed: 134_level_1,Unnamed: 135_level_1,Unnamed: 136_level_1,Unnamed: 137_level_1,Unnamed: 138_level_1,Unnamed: 139_level_1,Unnamed: 140_level_1,Unnamed: 141_level_1,Unnamed: 142_level_1,Unnamed: 143_level_1,Unnamed: 144_level_1,Unnamed: 145_level_1,Unnamed: 146_level_1,Unnamed: 147_level_1,Unnamed: 148_level_1,Unnamed: 149_level_1,Unnamed: 150_level_1,Unnamed: 151_level_1,Unnamed: 152_level_1,Unnamed: 153_level_1,Unnamed: 154_level_1,Unnamed: 155_level_1,Unnamed: 156_level_1,Unnamed: 157_level_1,Unnamed: 158_level_1,Unnamed: 159_level_1,Unnamed: 160_level_1,Unnamed: 161_level_1,Unnamed: 162_level_1,Unnamed: 163_level_1,Unnamed: 164_level_1,Unnamed: 165_level_1,Unnamed: 166_level_1,Unnamed: 167_level_1,Unnamed: 168_level_1,Unnamed: 169_level_1,Unnamed: 170_level_1,Unnamed: 171_level_1,Unnamed: 172_level_1,Unnamed: 173_level_1,Unnamed: 174_level_1,Unnamed: 175_level_1,Unnamed: 176_level_1,Unnamed: 177_level_1,Unnamed: 178_level_1,Unnamed: 179_level_1,Unnamed: 180_level_1,Unnamed: 181_level_1,Unnamed: 182_level_1,Unnamed: 183_level_1,Unnamed: 184_level_1,Unnamed: 185_level_1,Unnamed: 186_level_1,Unnamed: 187_level_1,Unnamed: 188_level_1,Unnamed: 189_level_1,Unnamed: 190_level_1,Unnamed: 191_level_1,Unnamed: 192_level_1,Unnamed: 193_level_1,Unnamed: 194_level_1,Unnamed: 195_level_1,Unnamed: 196_level_1,Unnamed: 197_level_1,Unnamed: 198_level_1,Unnamed: 199_level_1,Unnamed: 200_level_1,Unnamed: 201_level_1,Unnamed: 202_level_1,Unnamed: 203_level_1,Unnamed: 204_level_1,Unnamed: 205_level_1,Unnamed: 206_level_1,Unnamed: 207_level_1,Unnamed: 208_level_1,Unnamed: 209_level_1,Unnamed: 210_level_1,Unnamed: 211_level_1,Unnamed: 212_level_1,Unnamed: 213_level_1,Unnamed: 214_level_1,Unnamed: 215_level_1,Unnamed: 216_level_1,Unnamed: 217_level_1,Unnamed: 218_level_1,Unnamed: 219_level_1,Unnamed: 220_level_1,Unnamed: 221_level_1,Unnamed: 222_level_1,Unnamed: 223_level_1,Unnamed: 224_level_1,Unnamed: 225_level_1,Unnamed: 226_level_1,Unnamed: 227_level_1,Unnamed: 228_level_1,Unnamed: 229_level_1,Unnamed: 230_level_1,Unnamed: 231_level_1,Unnamed: 232_level_1,Unnamed: 233_level_1,Unnamed: 234_level_1,Unnamed: 235_level_1,Unnamed: 236_level_1,Unnamed: 237_level_1,Unnamed: 238_level_1,Unnamed: 239_level_1,Unnamed: 240_level_1,Unnamed: 241_level_1,Unnamed: 242_level_1,Unnamed: 243_level_1,Unnamed: 244_level_1,Unnamed: 245_level_1,Unnamed: 246_level_1,Unnamed: 247_level_1,Unnamed: 248_level_1,Unnamed: 249_level_1,Unnamed: 250_level_1,Unnamed: 251_level_1,Unnamed: 252_level_1,Unnamed: 253_level_1,Unnamed: 254_level_1,Unnamed: 255_level_1,Unnamed: 256_level_1,Unnamed: 257_level_1,Unnamed: 258_level_1,Unnamed: 259_level_1,Unnamed: 260_level_1,Unnamed: 261_level_1,Unnamed: 262_level_1,Unnamed: 263_level_1,Unnamed: 264_level_1,Unnamed: 265_level_1,Unnamed: 266_level_1,Unnamed: 267_level_1,Unnamed: 268_level_1,Unnamed: 269_level_1,Unnamed: 270_level_1,Unnamed: 271_level_1,Unnamed: 272_level_1,Unnamed: 273_level_1,Unnamed: 274_level_1,Unnamed: 275_level_1,Unnamed: 276_level_1,Unnamed: 277_level_1,Unnamed: 278_level_1,Unnamed: 279_level_1,Unnamed: 280_level_1,Unnamed: 281_level_1,Unnamed: 282_level_1,Unnamed: 283_level_1
34983,0.018291,0.018088,0.018087,0.018085,0.018084,0.018084,0.018084,0.018084,0.018085,0.018084,0.018083,0.01808,0.018076,0.018072,0.018068,0.018063,0.01806,0.01806,0.018062,0.018069,0.018079,0.018091,0.018103,0.018111,0.018115,0.018117,0.018119,0.018123,0.018129,0.018136,0.018139,0.018137,0.018133,0.018131,0.018131,0.018132,0.018133,0.018134,0.018135,0.018136,0.018137,0.018136,0.018134,0.018133,0.018132,0.018132,0.018131,0.018129,0.018127,0.018126,0.018126,0.018126,0.018124,0.018119,0.018115,0.018111,0.018108,0.018106,0.018104,0.018102,0.0181,0.018098,0.018096,0.018095,0.018095,0.018095,0.018095,0.018094,0.018092,0.01809,0.018089,0.018088,0.018087,0.018084,0.018082,0.01808,0.018079,0.018079,0.01808,0.01808,0.018079,0.018076,0.018074,0.018073,0.018076,0.018082,0.018093,0.018105,0.018116,0.018124,0.01813,0.018134,0.018136,0.018136,0.018136,0.018137,0.018137,0.018137,0.018138,0.01814,0.018144,0.018147,0.018148,0.018148,0.01815,0.018152,0.018154,0.018156,0.018157,0.018157,0.018158,0.018158,0.018158,0.018157,0.018158,0.018159,0.01816,0.018161,0.018162,0.018163,0.018164,0.018167,0.018169,0.018171,0.018171,0.018171,0.01817,0.01817,0.018169,0.018168,0.018167,0.018167,0.018169,0.018173,0.018174,0.018171,0.018167,0.018165,0.018164,0.018164,0.018164,0.018163,0.018162,0.018162,0.018161,0.01816,0.018159,0.01816,0.018161,0.018163,0.018165,0.018166,0.018168,0.018171,0.018173,0.018174,0.018176,0.018178,0.018178,0.018175,0.018173,0.018171,0.01817,0.018169,0.018168,0.018166,0.018165,0.018165,0.018163,0.018159,0.018155,0.018154,0.018152,0.018148,0.018144,0.018142,0.018142,0.018146,0.018155,0.018172,0.018192,0.018207,0.018211,0.018208,0.0182,0.01819,0.018179,0.018169,0.018161,0.018156,0.018152,0.01815,0.018147,0.018147,0.01815,0.018154,0.018156,0.018157,0.018157,0.018157,0.018156,0.018155,0.018155,0.018154,0.018153,0.018153,0.018154,0.018155,0.018155,0.018153,0.01815,0.018148,0.018148,0.018148,0.018148,0.018148,0.018148,0.018146,0.018145,0.018145,0.018144,0.018144,0.018143,0.018139,0.018136,0.018135,0.018133,0.018129,0.018125,0.018123,0.018121,0.01812,0.01812,0.018119,0.018117,0.018115,0.018113,0.018112,0.01811,0.018108,0.018106,0.018104,0.018103,0.018101,0.0181,0.018101,0.0181,0.018098,0.018098,0.018097,0.018095,0.018094,0.018094,0.018094,0.018095,0.018096,0.018097,0.018098,0.018097,0.018097,0.018096,0.018096,0.018096,0.018095,0.018096,0.018096,0.018096,0.018097,0.018097,0.018098,0.018099,0.018101,0.018106,0.018109,0.018112,0.018118,0.018123,0.018125,0.018127,0.01813,0.018134,0.018138,0.018142
1873185,0.006347,0.006343,0.006343,0.006343,0.006343,0.006343,0.006343,0.006342,0.006342,0.006341,0.006341,0.00634,0.00634,0.00634,0.00634,0.006341,0.006342,0.006342,0.006342,0.006341,0.006341,0.00634,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006342,0.006342,0.006342,0.006342,0.006342,0.006342,0.006342,0.006341,0.006341,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.006339,0.006339,0.006339,0.006339,0.00634,0.006341,0.006342,0.006343,0.006343,0.006342,0.006342,0.006342,0.006341,0.006341,0.006341,0.006342,0.006342,0.006342,0.006342,0.006342,0.006342,0.006341,0.006341,0.00634,0.00634,0.00634,0.006341,0.006341,0.00634,0.00634,0.00634,0.006339,0.00634,0.00634,0.00634,0.00634,0.006341,0.006341,0.006341,0.006341,0.00634,0.00634,0.00634,0.006339,0.006339,0.006339,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.006339,0.006339,0.00634,0.00634,0.00634,0.00634,0.00634,0.006341,0.006342,0.006342,0.006342,0.006342,0.006342,0.006342,0.006342,0.006342,0.006342,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.006339,0.00634,0.00634,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.006341,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.006341,0.00634,0.00634,0.00634,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.00634,0.00634,0.00634,0.00634,0.006339,0.006339,0.006339,0.006338,0.006338,0.006338,0.006338,0.006339,0.006338,0.006338,0.006339,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006337,0.006337,0.006337,0.006337,0.006337,0.006338,0.006338,0.006338,0.006338,0.006338,0.006338,0.006337,0.006337,0.006338,0.006338,0.006337,0.006337,0.006338,0.006338,0.006338,0.006338,0.006338,0.006337,0.006337,0.006337,0.006337,0.006337,0.006337,0.006337,0.006338,0.006339,0.006341,0.006341,0.006341,0.006341,0.006341,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.00634,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006338,0.006338,0.006338,0.006338,0.006339,0.006341,0.006341,0.006341,0.006341,0.006341,0.00634,0.00634,0.00634,0.00634,0.00634,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339,0.006339
3849793,0.046061,0.046139,0.04613,0.046117,0.046107,0.046105,0.046109,0.046112,0.046111,0.046104,0.046095,0.046088,0.046082,0.046078,0.046075,0.046076,0.046078,0.046079,0.04608,0.046081,0.046083,0.046086,0.04609,0.046095,0.046103,0.046111,0.046121,0.046137,0.046164,0.046198,0.046225,0.046234,0.046228,0.046225,0.046235,0.046254,0.046269,0.046278,0.046288,0.0463,0.04631,0.04631,0.046302,0.046296,0.046302,0.046321,0.046341,0.046352,0.046353,0.046354,0.046357,0.046353,0.046329,0.046294,0.046267,0.046255,0.046251,0.046247,0.046244,0.046241,0.046234,0.04622,0.046202,0.046187,0.046182,0.046188,0.046198,0.0462,0.046191,0.046174,0.04616,0.046154,0.04615,0.046145,0.046146,0.046155,0.046164,0.046163,0.046156,0.046157,0.046168,0.046177,0.046176,0.046178,0.046194,0.046219,0.046241,0.046252,0.046261,0.046289,0.046331,0.046362,0.046359,0.046333,0.046312,0.046302,0.04629,0.046273,0.046268,0.046289,0.046324,0.046339,0.046318,0.04628,0.046257,0.046266,0.046299,0.046327,0.046321,0.046293,0.046277,0.046278,0.046271,0.046247,0.046222,0.04621,0.04621,0.046211,0.046201,0.04618,0.046156,0.046136,0.046124,0.04612,0.04612,0.046117,0.046116,0.046127,0.04614,0.04614,0.046131,0.046129,0.046136,0.046141,0.046136,0.046127,0.046122,0.046127,0.046138,0.046151,0.046166,0.046183,0.046199,0.046211,0.046216,0.046221,0.046235,0.046262,0.046291,0.046315,0.046339,0.046373,0.046418,0.046461,0.046486,0.046492,0.046496,0.046506,0.046509,0.046497,0.046483,0.046481,0.046497,0.046515,0.046515,0.046505,0.046498,0.046487,0.046468,0.046455,0.046456,0.046455,0.046439,0.046421,0.046406,0.046392,0.04638,0.046392,0.046447,0.046545,0.046651,0.04672,0.046737,0.046717,0.046675,0.046622,0.046565,0.046517,0.046488,0.046471,0.046453,0.046434,0.046419,0.046413,0.046423,0.046442,0.046454,0.046461,0.046465,0.046463,0.046458,0.046452,0.046449,0.046442,0.046432,0.046437,0.046451,0.046456,0.046456,0.046445,0.046427,0.046417,0.046411,0.04641,0.046415,0.046423,0.046432,0.04643,0.04641,0.046394,0.046383,0.046364,0.046355,0.046357,0.046356,0.046351,0.046341,0.046326,0.046316,0.046309,0.0463,0.046291,0.046287,0.046282,0.046273,0.04626,0.046251,0.046252,0.046253,0.04624,0.046224,0.046225,0.046232,0.046225,0.046219,0.046217,0.046203,0.046193,0.046196,0.046189,0.046175,0.046169,0.046168,0.046169,0.046162,0.046156,0.046166,0.046179,0.046182,0.046171,0.046156,0.046153,0.04616,0.046158,0.04615,0.046148,0.046147,0.046142,0.046136,0.046131,0.046125,0.04613,0.046143,0.046144,0.046133,0.046131,0.046138,0.046141,0.046147,0.046147,0.046139,0.046134,0.046133
8456603,0.015363,0.015387,0.015385,0.015385,0.015385,0.015385,0.015384,0.015383,0.015383,0.015384,0.015385,0.015385,0.015386,0.015387,0.015388,0.01539,0.015393,0.015396,0.015399,0.015403,0.015408,0.015412,0.015417,0.015421,0.015426,0.015433,0.01544,0.01545,0.015462,0.015475,0.015485,0.01549,0.015491,0.015492,0.015496,0.015502,0.015508,0.015511,0.015512,0.015514,0.015515,0.015516,0.015516,0.015517,0.015519,0.015521,0.015525,0.015527,0.015528,0.015528,0.015526,0.015525,0.015522,0.015518,0.015515,0.015511,0.015508,0.015504,0.0155,0.015496,0.015493,0.01549,0.015487,0.015484,0.01548,0.015477,0.015473,0.015468,0.015463,0.015459,0.015457,0.015455,0.015453,0.015451,0.015449,0.015449,0.01545,0.01545,0.015448,0.015445,0.015444,0.015443,0.015443,0.015442,0.015441,0.01544,0.015437,0.015434,0.015431,0.015428,0.015425,0.015422,0.015418,0.015414,0.015412,0.01541,0.015408,0.015406,0.015404,0.015404,0.015407,0.01541,0.015412,0.015411,0.015409,0.015408,0.015409,0.015411,0.01541,0.015407,0.015403,0.015399,0.015397,0.015397,0.015397,0.015398,0.015398,0.015398,0.015399,0.015402,0.015405,0.015407,0.015406,0.015404,0.015402,0.015402,0.015405,0.015409,0.015413,0.015416,0.015418,0.01542,0.015425,0.015431,0.015435,0.015439,0.015444,0.015451,0.015459,0.015467,0.015477,0.01549,0.015502,0.015512,0.015517,0.015521,0.015526,0.015534,0.015545,0.015554,0.015558,0.015559,0.015562,0.015567,0.015571,0.015572,0.015575,0.015579,0.015583,0.015584,0.015585,0.015586,0.015588,0.01559,0.015589,0.015585,0.015581,0.015578,0.015576,0.015575,0.015575,0.015573,0.01557,0.015565,0.01556,0.015556,0.015554,0.015557,0.015569,0.015592,0.015615,0.015627,0.015631,0.01563,0.015627,0.015619,0.015609,0.0156,0.015592,0.015585,0.015579,0.015574,0.01557,0.015567,0.015567,0.015569,0.01557,0.015572,0.015574,0.015574,0.015573,0.015573,0.015574,0.015571,0.015569,0.01557,0.015573,0.015574,0.015576,0.015576,0.015576,0.015574,0.015574,0.015574,0.015572,0.015567,0.015565,0.015566,0.015566,0.015565,0.015562,0.015559,0.015555,0.015551,0.01555,0.01555,0.01555,0.015549,0.015547,0.015545,0.015543,0.015541,0.015539,0.015537,0.015535,0.015534,0.015532,0.01553,0.015529,0.015526,0.015521,0.015518,0.015516,0.015513,0.01551,0.015508,0.015508,0.015506,0.015505,0.015505,0.015504,0.015501,0.015497,0.015496,0.015494,0.015493,0.015493,0.015494,0.015491,0.015486,0.015484,0.015484,0.015482,0.01548,0.015482,0.015481,0.015478,0.015477,0.015475,0.015472,0.01547,0.015471,0.01547,0.015471,0.015471,0.015467,0.015465,0.015465,0.015464,0.015461,0.01546,0.01546,0.01546
23615382,0.014474,0.014636,0.014628,0.014635,0.014643,0.014642,0.014637,0.014635,0.014639,0.014644,0.014643,0.014635,0.01462,0.014597,0.014572,0.014549,0.014534,0.014523,0.014506,0.014482,0.014456,0.014436,0.01442,0.014407,0.014392,0.014374,0.014353,0.014335,0.014329,0.014335,0.014341,0.014338,0.014331,0.014326,0.014324,0.014316,0.014298,0.014281,0.014284,0.014313,0.014356,0.014387,0.014402,0.01441,0.014417,0.014415,0.014398,0.014381,0.014383,0.014399,0.014416,0.014437,0.014474,0.01452,0.014558,0.014578,0.014591,0.014607,0.014622,0.014625,0.014618,0.014618,0.014637,0.014672,0.014714,0.014755,0.014785,0.014793,0.014787,0.014791,0.014817,0.014839,0.014836,0.014823,0.014827,0.014848,0.014864,0.014864,0.014857,0.014853,0.014854,0.01485,0.014839,0.014832,0.014844,0.014872,0.014902,0.014917,0.014908,0.014883,0.014865,0.014863,0.014863,0.014852,0.014846,0.014864,0.014893,0.014908,0.014902,0.014898,0.014919,0.014961,0.014993,0.014995,0.014978,0.014967,0.014976,0.014993,0.014999,0.014981,0.014948,0.014917,0.014896,0.014882,0.014875,0.014877,0.014881,0.01488,0.014887,0.01492,0.014965,0.014984,0.01496,0.014909,0.014868,0.014859,0.014876,0.014901,0.014917,0.014917,0.014907,0.01491,0.014939,0.014961,0.014946,0.014914,0.014883,0.01484,0.014798,0.014783,0.014783,0.014767,0.014755,0.014771,0.014794,0.014794,0.014778,0.014756,0.014727,0.014699,0.014693,0.014728,0.014781,0.014803,0.014788,0.014774,0.014778,0.014777,0.01476,0.014735,0.014719,0.014724,0.014735,0.014725,0.014705,0.014699,0.01469,0.01467,0.014663,0.014663,0.014647,0.014635,0.014646,0.014661,0.014654,0.014634,0.014614,0.014602,0.0146,0.014603,0.014599,0.014593,0.014582,0.014556,0.014533,0.014538,0.01456,0.014572,0.014562,0.01454,0.014521,0.014516,0.01453,0.014558,0.014588,0.014624,0.014649,0.014635,0.014599,0.01458,0.014578,0.014585,0.014601,0.01461,0.014594,0.014558,0.014547,0.014574,0.014581,0.014557,0.014543,0.014549,0.014554,0.01457,0.014594,0.014597,0.014593,0.01458,0.014536,0.014508,0.014526,0.014538,0.01453,0.014542,0.01456,0.014542,0.014512,0.014492,0.014469,0.014466,0.014482,0.014485,0.01449,0.014498,0.014492,0.014494,0.014504,0.0145,0.014505,0.01454,0.014566,0.014553,0.014532,0.014514,0.0145,0.014511,0.01452,0.014493,0.014471,0.014477,0.014469,0.014454,0.014471,0.014499,0.014491,0.014466,0.014465,0.014473,0.014473,0.014464,0.014455,0.014458,0.014446,0.014426,0.014442,0.014472,0.014485,0.014476,0.014473,0.014457,0.014413,0.014397,0.014428,0.014473,0.014467,0.014433,0.014426,0.014435,0.014422,0.014399,0.014429,0.014444,0.014418


## Split into train and validation sets

In [14]:
# Simple 80/20 split to start - no shuffling
n = round(.8 * len(df_train))

train_AIRS = data_train_AIRS[:n]
val_AIRS = data_train_AIRS[n:]
print(len(train_AIRS), len(val_AIRS))

train_labels = df_train.iloc[:n,:] 
val_labels = df_train.iloc[n:,:]
print(train_labels.shape, val_labels.shape)

del data_train_AIRS

880 220
(880, 283) (220, 283)


## Define the model

In [15]:
inputs = Input(shape=(375, 282, 32), name='inputs')
x = Conv2D(32, (3, 3), activation='relu')(inputs)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = MaxPooling2D((2, 2))(x)
x = Conv2D(64, (3, 3), activation='relu')(x)
x = Flatten()(x)
x = Dense(64, activation='relu')(x)

# Two output heads
mean_output = Dense(283, activation='linear', name='mean')(x)  
log_std_output = Dense(283, activation='linear', name='log_std')(x) 
std_output = Lambda(lambda x: tf.exp(0.5 * x), name='std')(log_std_output)

# Concatenate outputs for submission
outputs = Concatenate(name='outputs')([mean_output, std_output])

model = Model(inputs=inputs, outputs=outputs)
model.summary()

I0000 00:00:1759207665.923259      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13942 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:04.0, compute capability: 7.5
I0000 00:00:1759207665.926191      36 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13942 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5


## Compile and train the model

In [16]:
def nll_loss(y_true, y_pred):
    mu, std = y_pred[:, :283], y_pred[:, 283:]
    return tf.reduce_mean(0.5 * tf.math.log(2 * np.pi * std**2) + 0.5 * ((y_true - mu)**2 / std**2))

model.compile(optimizer = 'adam', loss = nll_loss)

In [None]:
model.fit(train_AIRS, train_labels.values, 
          validation_data=(val_AIRS, val_labels.values),
          epochs=10, batch_size=8, verbose=0)

del train_AIRS, val_AIRS
del train_labels, val_labels
gc.collect()

## Generate predictions

In [None]:
gc.collect()

In [None]:
# Generate predictions
predictions = model.predict(data_test_AIRS)  
means = predictions[:, :283]  
stds = predictions[:, 283:] 
stds = np.abs(stds)

# Create submission DataFrame
# Fix test_index if it's nested from chunking
if isinstance(test_index, list) and len(test_index) > 0:
    test_index_flat = np.concatenate(test_index)
else:
    test_index_flat = test_index
    
# Load sample submission 
df_sample = pd.read_csv(path_folder + 'sample_submission.csv')

# Create submission DataFrame matching sample format exactly
df_submission = df_sample.copy()
df_submission['planet_id'] = test_index_flat

df_submission.iloc[:, 1:284] = means.astype(np.float32)  
df_submission.iloc[:, 284:567] = stds.astype(np.float32)   

# Replace inf and NaN values
df_submission = df_submission.replace([np.inf, -np.inf], np.nan)
df_submission = df_submission.fillna(0.0)

# Ensure all columns have correct data types (skip planet_id)
for col in df_submission.columns[1:]: 
    df_submission[col] = df_submission[col].astype(np.float32)

In [None]:
df_submission.to_csv('/kaggle/working/submission.csv', index=False)