In [1]:
#Import Stuff
import numpy as np
import pandas as pd
from astropy.stats import sigma_clip
import itertools
import matplotlib.pyplot as plt
from scipy import signal
from scipy.ndimage import gaussian_filter1d
from scipy.optimize import curve_fit, minimize
from scipy.signal import detrend, savgol_filter
from scipy.optimize import brute, minimize
from scipy.interpolate import interp1d
from scipy.stats import pearsonr
import gc

In [2]:
def normalize_data(data):
    """
    Normalize the data by subtracting the mean and dividing by the standard deviation for each pixel's time series.
    """
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0)
    normalized_data = (data - mean) / (std + 1e-10)  # Add small value to avoid division by zero
    return normalized_data

In [3]:
def simple_transit_model(params, time):
    t0, per, depth, duration = params
    phase = (time - t0 + 0.5*per) % per - 0.5*per
    transit = np.abs(phase) < 0.5*duration
    model = np.ones_like(time) - depth * transit
    return model

In [4]:
def process_planet(planet_id, data_path, axis_info):
    try:
        # Load data
        airs_signal = pd.read_parquet(f"{data_path}/train/{planet_id}/AIRS-CH0_signal.parquet")
        fgs1_signal = pd.read_parquet(f"{data_path}/train/{planet_id}/FGS1_signal.parquet")
        
        # Load calibration files
        dark = pd.read_parquet(f"{data_path}/train/{planet_id}/AIRS-CH0_calibration/dark.parquet")
        flat = pd.read_parquet(f"{data_path}/train/{planet_id}/AIRS-CH0_calibration/flat.parquet")
        
        # Load wavelengths
        wavelengths = pd.read_csv(f"{data_path}/wavelengths.csv")
        
        # Reshape and calibrate AIRS-CH0 data
        airs_data = airs_signal.values.reshape(-1, 32, 356)
        airs_data = (airs_data - dark.values) / flat.values
        
        # Extract spectral information
        spectral_data = np.mean(airs_data, axis=1)  # Average over spatial dimension
        
        # Time array
        airs_time_step = axis_info['AIRS-CH0-axis0-h'].iloc[1] - axis_info['AIRS-CH0-axis0-h'].iloc[0]
        time = np.arange(len(spectral_data)) * airs_time_step
        
        # Load and preprocess FGS1 data
        fgs1_data = fgs1_signal.values.reshape(-1, 32, 32)
        fgs1_centroids = np.array([measure_centroid(frame) for frame in fgs1_data])

        # Create time arrays
        fgs1_time_step = axis_info['FGS1-axis0-h'].iloc[1] - axis_info['FGS1-axis0-h'].iloc[0]
        fgs1_time = np.arange(len(fgs1_data)) * fgs1_time_step

        # Apply jitter correction
        corrected_spectral_data = correct_jitter(spectral_data, fgs1_centroids, time, fgs1_time)
        
        # Function to model transit for each wavelength
        def transit_model(params, t):
            t0, duration, depth = params
            transit = np.abs(t - t0) < duration / 2
            return 1 - depth * transit
        
        # Function to fit transit for each wavelength
        def fit_transit(flux):
            def residuals(params):
                model = transit_model(params, time)
                return np.sum((flux - model)**2)
            
            initial_guess = [np.median(time), 0.1 * (time[-1] - time[0]), 0.01]
            result = minimize(residuals, initial_guess, method='Nelder-Mead')
            return result.x
        
        # Fit transit for each wavelength using corrected data
        transit_params = np.array([fit_transit(corrected_spectral_data[:, i]) for i in range(corrected_spectral_data.shape[1])])
        
        # Extract spectrum and estimate uncertainties using corrected data
        spectrum = np.mean(corrected_spectral_data, axis=0)
        uncertainties = np.std(corrected_spectral_data, axis=0) / np.sqrt(len(corrected_spectral_data))

    except Exception as e:
        print(f"Error processing planet {planet_id}: {str(e)}")
        return None

    return {
        'wavelengths': wavelengths.iloc[0].values,  # All wavelength values in a single array
        'spectrum': spectrum,
        'uncertainties': uncertainties,
        'transit_params': transit_params
    }

In [5]:
def measure_centroid(image):
    """Calculate the centroid of a 2D image."""
    y, x = np.indices(image.shape)
    total = image.sum()
    x_center = (x * image).sum() / total
    y_center = (y * image).sum() / total
    return x_center, y_center

In [6]:
def process_and_save_batch(planet_ids, data_path, results_path):
    for planet_id in planet_ids:
        try:
            results = process_planet(str(planet_id), data_path)
            
            # Save results for this planet
            np.savez(f"{results_path}/planet_{planet_id}_results.npz", **results)
            
            # Clear memory immediately after saving
            del results
            gc.collect()
        except Exception as e:
            print(f"Error processing planet {planet_id}: {str(e)}")
        
        # Clear any remaining memory
        gc.collect()

In [7]:
def correct_jitter(airs_data, fgs1_data, airs_time, fgs1_time):
    """
    Correct jitter noise in AIRS-CH0 data using FGS1 data.
    
    :param airs_data: AIRS-CH0 spectral data (time, wavelength)
    :param fgs1_data: FGS1 centroid data (time, x, y)
    :param airs_time: Time array for AIRS-CH0 data
    :param fgs1_time: Time array for FGS1 data
    :return: Jitter-corrected AIRS-CH0 data
    """
    # Ensure FGS1 data covers the full AIRS-CH0 time range
    fgs1_start = max(fgs1_time.min(), airs_time.min())
    fgs1_end = min(fgs1_time.max(), airs_time.max())
    mask = (fgs1_time >= fgs1_start) & (fgs1_time <= fgs1_end)
    fgs1_time = fgs1_time[mask]
    fgs1_data = fgs1_data[mask]

    # Interpolate FGS1 centroid positions to AIRS-CH0 timestamps
    interp_x = interp1d(fgs1_time, fgs1_data[:, 0], kind='cubic', fill_value='extrapolate')
    interp_y = interp1d(fgs1_time, fgs1_data[:, 1], kind='cubic', fill_value='extrapolate')
    
    x_pos = interp_x(airs_time)
    y_pos = interp_y(airs_time)

    # Smooth the centroid positions to reduce noise
    x_smooth = savgol_filter(x_pos, window_length=51, polyorder=3)
    y_smooth = savgol_filter(y_pos, window_length=51, polyorder=3)

    # Calculate pixel shifts
    x_shift = x_smooth - np.median(x_smooth)
    y_shift = y_smooth - np.median(y_smooth)

    # Correct AIRS-CH0 data for jitter
    corrected_data = np.zeros_like(airs_data)
    for i in range(airs_data.shape[1]):
        # Create a 2D interpolation function for each wavelength
        interp_func = interp1d(airs_time, airs_data[:, i], kind='cubic', fill_value='extrapolate')
        
        # Apply correction
        corrected_time = airs_time - x_shift * 0.1 - y_shift * 0.1  # Adjust scaling factors as needed
        corrected_data[:, i] = interp_func(corrected_time)

    return corrected_data

In [8]:
options = {
    'DO_MASK': True,
    'DO_THE_NL_CORR': False,
    'DO_DARK': True,
    'DO_FLAT': True,
    'TIME_BINNING': True,
    'BINNING_FACTOR': 30
}

# Set up paths
data_path = "/kaggle/input/ariel-data-challenge-2024"
results_path = "/kaggle/working"
axis_info = pd.read_parquet(f"{data_path}/axis_info.parquet")

# Load planet IDs
train_adc_info = pd.read_csv(f"{data_path}/train_adc_info.csv", index_col='planet_id')

In [9]:
# Main processing loop
total_planets = len(train_adc_info.index)
for i, planet_id in enumerate(train_adc_info.index, 1):
    results = process_planet(str(planet_id), data_path, axis_info)
    if results is not None:
        np.savez(f"{results_path}/planet_{planet_id}_results.npz", **results)
    gc.collect()
    print(f"Processed planet {i}/{total_planets}: {planet_id}")

print("All planets processed and saved.")

Processed planet 1/673: 785834
Processed planet 2/673: 14485303
Processed planet 3/673: 17002355
Processed planet 4/673: 24135240
Processed planet 5/673: 25070640
Processed planet 6/673: 26372015
Processed planet 7/673: 29348276
Processed planet 8/673: 33548644
Processed planet 9/673: 50637799
Processed planet 10/673: 53735941
Processed planet 11/673: 57518461
Processed planet 12/673: 57669658
Processed planet 13/673: 61961538
Processed planet 14/673: 62282335
Processed planet 15/673: 75219649
Processed planet 16/673: 77523557
Processed planet 17/673: 79009616
Processed planet 18/673: 90833891
Processed planet 19/673: 92029723
Processed planet 20/673: 93311500
Processed planet 21/673: 100468857
Processed planet 22/673: 112782545
Processed planet 23/673: 113675418
Processed planet 24/673: 137536026
Processed planet 25/673: 141323216
Processed planet 26/673: 145136829
Processed planet 27/673: 155746318
Processed planet 28/673: 159506197
Processed planet 29/673: 165516591
Processed planet