In [23]:
import os
os.getcwd()

# Change directory to main repository directory in order to install modules
os.chdir(f"{os.environ['HOME']}/Code/ExoTiC-NEAT-training/")

In [24]:
import h5py
import numpy as np

from astropy.io import fits
from matplotlib import pyplot as plt


In [25]:
simulations = "/data/typhon2/hattie/jwst/soss_simulations/10000_soss_sims_randomised_target.h5"

simulation_list = []
clean_simulation_list = []
contaminant_simulation_list = []
with h5py.File(simulations, "r") as f:
    for key in f.keys():
        if key.startswith("data"):
            simulation_list.append(f[key][:])
        if key.startswith("clean"):
            clean_simulation_list.append(f[key][:])
        if key.startswith("contaminant"):
            contaminant_simulation_list.append(f[key][:])


In [26]:
simulation_array = np.array(simulation_list)
clean_simulation_array = np.array(clean_simulation_list)
contaminant_simulation_array = np.array(contaminant_simulation_list)
print(f"Simulation array shape: {simulation_array.shape}")
print(f"Clean simulation array shape: {clean_simulation_array.shape}")
print(f"Contaminant simulation array shape: {contaminant_simulation_array.shape}")

Simulation array shape: (10000, 256, 2048)
Clean simulation array shape: (10000, 256, 2048)
Contaminant simulation array shape: (10000, 256, 2048)


In [27]:
def get_class_indeces(data_array, label_array, clean_array):
    overlap_index = []
    zero_index = []
    no_overlap_index = []

    for index, (data_slice, label_slice, clean_slice) in enumerate(zip(data_array, label_array, clean_array)):
        max_data = data_slice.max()
        max_label = label_slice.max()
        max_clean = clean_slice.max()

        if max_label != 0:
            if max_data > max_clean:
                overlap_index.append(index)
            else:
                no_overlap_index.append(index)
        else:
            zero_index.append(index)

    return overlap_index, zero_index, no_overlap_index

overlap, zero, no_overlap = get_class_indeces(simulation_array, contaminant_simulation_array, clean_simulation_array)

print(f"Overlap: {(len(overlap) / len(simulation_array)) * 100}%")
print(f"Zero: {(len(zero) / len(simulation_array)) * 100}%")
print(f"No overlap: {(len(no_overlap) / len(simulation_array)) * 100}%")

Overlap: 35.02%
Zero: 10.45%
No overlap: 54.53%


In [28]:
len(overlap) + len(zero) + len(no_overlap) == len(simulation_array)

True

In [29]:
# split simulations by class using index lists
overlap_simulations = simulation_array[overlap]
zero_simulations = simulation_array[zero]
no_overlap_simulations = simulation_array[no_overlap]

overlap_labels = contaminant_simulation_array[overlap]
zero_labels = contaminant_simulation_array[zero]
no_overlap_labels = contaminant_simulation_array[no_overlap]

overlap_clean = clean_simulation_array[overlap]
zero_clean = clean_simulation_array[zero]
no_overlap_clean = clean_simulation_array[no_overlap]



In [30]:
print(f"Overlap simulations shape: {overlap_simulations.shape}")
print(f"Zero simulations shape: {zero_simulations.shape}")
print(f"No overlap simulations shape: {no_overlap_simulations.shape}")


Overlap simulations shape: (3502, 256, 2048)
Zero simulations shape: (1045, 256, 2048)
No overlap simulations shape: (5453, 256, 2048)


In [34]:
# create balanced simulation dataset using 500 of each
slice_index = 1000
balanced_simulations = np.concatenate((overlap_simulations[:slice_index], zero_simulations[:slice_index], no_overlap_simulations[:slice_index]))
balanced_labels = np.concatenate((overlap_labels[:slice_index], zero_labels[:slice_index], no_overlap_labels[:slice_index]))
balanced_clean = np.concatenate((overlap_clean[:slice_index], zero_clean[:slice_index], no_overlap_clean[:slice_index]))

print(f"Balanced simulations shape: {balanced_simulations.shape}")
print(f"Balanced labels shape: {balanced_labels.shape}")
print(f"Balanced clean shape: {balanced_clean.shape}")

Balanced simulations shape: (3000, 256, 2048)
Balanced labels shape: (3000, 256, 2048)
Balanced clean shape: (3000, 256, 2048)


In [35]:
# verify that the balanced dataset is balanced

def calculate_class_percentage(data_array, label_array, clean_array):
    trace_overlap_count = 0
    zero_count = 0
    for data_slice, label_slice, clean_slice in zip(data_array, label_array, clean_array):
        max_data = data_slice.max()
        max_label = label_slice.max()
        max_clean = clean_slice.max()

        if max_label != 0:
            if max_data > max_clean:
                trace_overlap_count += 1
        else:
            zero_count += 1
        
        no_overlap = len(data_array) - trace_overlap_count - zero_count
    return trace_overlap_count / len(data_array), zero_count / len(data_array), no_overlap / len(data_array)

overlap, zero, no_overlap = calculate_class_percentage(balanced_simulations, balanced_labels, balanced_clean)

print(f"Overlap: {overlap * 100}%")
print(f"Zero: {zero * 100}%")
print(f"No overlap: {no_overlap * 100}%")

Overlap: 33.33333333333333%
Zero: 33.33333333333333%
No overlap: 33.33333333333333%


In [37]:
# split data into 80:10:10 % train:val:test
number_data_points = balanced_simulations.shape[0]
number_train_points = int(0.8 * number_data_points)
number_validation_points = int(0.1 * number_data_points)
number_test_points = number_data_points - number_train_points - number_validation_points

train_simulations = balanced_simulations[:number_train_points]
validation_simulations = balanced_simulations[number_train_points:number_train_points + number_validation_points]
test_simulations = balanced_simulations[number_train_points + number_validation_points:]

train_clean_simulations = balanced_clean[:number_train_points]
validation_clean_simulations = balanced_clean[number_train_points:number_train_points + number_validation_points]
test_clean_simulations = balanced_clean[number_train_points + number_validation_points:]

train_contaminant_simulations = balanced_labels[:number_train_points]
validation_contaminant_simulations = balanced_labels[number_train_points:number_train_points + number_validation_points]
test_contaminant_simulations = balanced_labels[number_train_points + number_validation_points:]

print(f"Number of training data points: {number_train_points}")
print(f"Number of validation data points: {number_validation_points}")
print(f"Number of test data points: {number_test_points}")


Number of training data points: 2400
Number of validation data points: 300
Number of test data points: 300


In [9]:
# training_data_dir = "/data/typhon2/hattie/jwst/soss_simulations/training_data/"
# for data_set, name in zip([train_simulations, validation_simulations, test_simulations], ["train", "validation", "test"]):
#     for index, spectra in enumerate(data_set):
#         plt.imsave(f"{training_data_dir}/{name}/data/spectra_{index}.png", spectra)

In [10]:
# for data_set, name in zip([train_clean_simulations, validation_clean_simulations, test_clean_simulations], ["train", "validation", "test"]):
#     for index, spectra in enumerate(data_set):
#         plt.imsave(f"{training_data_dir}/{name}/clean/spectra_{index}.png", spectra)

In [11]:
# for data_set, name in zip([train_contaminant_simulations, validation_contaminant_simulations, test_contaminant_simulations], ["train", "validation", "test"]):
#     for index, spectra in enumerate(data_set):
#         plt.imsave(f"{training_data_dir}/{name}/contaminant/spectra_{index}.png", spectra)

In [38]:
training_data = np.asarray([train_simulations, train_clean_simulations, train_contaminant_simulations])
validation_data = np.asarray([validation_simulations, validation_clean_simulations, validation_contaminant_simulations])
test_data = np.asarray([test_simulations, test_clean_simulations, test_contaminant_simulations])

print(f"Training data shape: {training_data.shape}")
print(f"Validation data shape: {validation_data.shape}")
print(f"Test data shape: {test_data.shape}")

Training data shape: (3, 2400, 256, 2048)
Validation data shape: (3, 300, 256, 2048)
Test data shape: (3, 300, 256, 2048)


In [39]:
np.save("unet/data/training_data_3000.npy", training_data)
np.save("unet/data/validation_data_3000.npy", validation_data)
np.save("unet/data/test_data_3000.npy", test_data)