# CNN model

## Notebook set-up

In [1]:
# Set notebook root to project root
from helper_functions import set_project_root

set_project_root()

# Standard library imports
import random
from functools import partial

# Third party imports
import h5py
import numpy as np
import tensorflow as tf

# Local imports
from ariel_data_preprocessing.utils import load_masked_frames
import configuration as config

wavelengths = 283
sample_size = 100

Working directory: /mnt/arkk/kaggle/ariel-data-challenge


## 1. Data preparation

### 1.1. Load planet list

In [2]:
# Load corrected/extracted data for a sample planet
with h5py.File(f'{config.PROCESSED_DATA_DIRECTORY}/train.h5', 'r') as hdf:
    planet_ids = list(hdf.keys())

print(f'Found {len(planet_ids)} planets in training data.')

Found 1100 planets in training data.


### 1.2. Split planets into training & validation

In [16]:
random.shuffle(planet_ids)

training_planet_ids = planet_ids[:len(planet_ids) // 2]
validation_planet_ids = planet_ids[len(planet_ids) // 2:]

print(f'Training planets: {len(training_planet_ids)}')
print(f'Validation planets: {len(validation_planet_ids)}')

Training planets: 550
Validation planets: 550


## 2. Data generator

### 2.1. Data loader function

In [17]:
def data_loader(planet_ids: list, data_file: str, sample_size: int = 100):
    '''Generator that yields signal, spectrum pairs for training/validation/testing.

    Args:
        planet_ids (list): List of planet IDs to include in the generator.
        data_file (str): Path to the HDF5 file containing the data.
        sample_size (int, optional): Number of frames to draw from each planet. Defaults to 100.
    '''

    with h5py.File(data_file, 'r') as hdf:

        while True:
            np.random.shuffle(planet_ids)
            
            for planet_id in planet_ids:

                signal = hdf[planet_id]['signal'][:]
                spectrum = hdf[planet_id]['spectrum'][:]

                indices = random.sample(range(signal.shape[0]), sample_size)
                sample = signal[sorted(indices), :]

                yield sample, spectrum


### 2.2. Prefill the arguments to `data_loader()`

In [18]:
training_data_generator = partial(
    data_loader,
    planet_ids=training_planet_ids,
    data_file=f'{config.PROCESSED_DATA_DIRECTORY}/train.h5',
    sample_size=100
)

validation_data_generator = partial(
    data_loader,
    planet_ids=validation_planet_ids,
    data_file=f'{config.PROCESSED_DATA_DIRECTORY}/train.h5',
    sample_size=100
)

### 2.3. Create TF datasets

In [19]:
training_dataset = tf.data.Dataset.from_generator(
    training_data_generator,
    output_signature=(
        tf.TensorSpec(shape=(sample_size, wavelengths), dtype=tf.float64),
        tf.TensorSpec(shape=(wavelengths), dtype=tf.float64)
    )
)

validation_dataset = tf.data.Dataset.from_generator(
    validation_data_generator,
    output_signature=(
        tf.TensorSpec(shape=(sample_size, wavelengths), dtype=tf.float64),
        tf.TensorSpec(shape=(wavelengths), dtype=tf.float64)
    )
)

### 2.4. Manually check batch shape

In [20]:
batch = training_dataset.batch(4)
signals, spectrums = next(iter(batch))

print(f'Signal batch shape: {signals.shape}')
print(f'Spectrum batch shape: {spectrums.shape}')

Signal batch shape: (4, 100, 283)
Spectrum batch shape: (4, 283)


## 3. CNN