# CNN model

## Notebook set-up

In [None]:
# Set notebook root to project root
from helper_functions import set_project_root

set_project_root()

# Standard library imports
import math

# Third party imports
import h5py
import numpy as np

from keras.utils import PyDataset

# Local imports
import configuration as config

## 1. PyDataset class definition

In [None]:
class ArielPyDataset(PyDataset):

    def __init__(self, planet_ids, resamples: int = 10, batch_size: int = 32, **kwargs):

        super().__init__(**kwargs)

        self.planet_ids = planet_ids        
        self.resamples = resamples
        self.batch_size = batch_size

    def __len__(self):

        # Return number of batches.
        return math.ceil(
            (self.resamples * len(self.planet_ids)) / self.batch_size
        )

    def __getitem__(self, idx):

        # Return x, y for batch idx.
        low = idx * self.batch_size

        # Cap upper bound at array length; the last batch may be smaller
        # if the total number of items is not a multiple of batch size.
        high = min(low + self.batch_size, len(self.planet_ids))
        batch_x = self.x[low:high]
        batch_y = self.y[low:high]

        return np.array([
            resize(imread(file_name), (200, 200)) for file_name in batch_x]), np.array(batch_y)


In [None]:
# Load the list of planets
with h5py.File(f'{config.PROCESSED_DATA_DIRECTORY}/train.h5', 'r') as hdf:
    planet_ids = list(hdf.keys())