# Strong Lensing Challenge data preparation for Hyrax
## Goals of this work
* Produce a small version of the Strong Lensing Challenge dataset that contains only N samples from each class.
* Produce a `combined_slsim` and `combined_hsc` dataset where each are a single directory containing equal numbers of lens and non-lens samples of either simulated or hsc data.

## Prerequisites
* You have downloaded all the training data available for the SL Challenge https://slchallenge.cbpf.br/ 
  * 4 zip files; 1. hsc_lenses 2. hsc_nonlenses 3. slsim_lenses 4. slsim_nonlenses
* You have extracted all the training data in to subdirectories under a single data directory i.e. `/home/user/sl_data/`
* IMPORTANT - You have moved the corresponding `parameters.fits` files into it's appropriate subdirectory.
  * Note - When unzipped the `parameters.fits` file will be outside the folder containing the .fits image files.

## Notes about methodology
When creating the smaller dataset classes, it's tempting to simply take the first N .fits images from the original dataset.
Unfortunately the `parameters.fits` files don't always contain a row for a given object or lens id.
Therefore, we take the following approach, roughly defined in pseudocode:
0) Loop over .fits files by index `i`
1) Select and open .fits image file `i`
2) Set `obj_id` equal to the object id in the file
3) Find `obj_id` in the `parameters.fits` table
4) If the `obj_id` is in the table
   1) Append `obj_id` to the list `ids`
   2) Copy the set of .fits files to th smaller data directory
   3) Rename the .fits files to have `index = length(ids)`
5) Stop `len(ids)` is equal to the desired size
6) Use `ids` to index into `parameters.fits` table to select a subset of rows
7) Write out the subset of rows as a new `parameter.fits` file in the smaller data directory

## Set up the imports

In [None]:
from tqdm import tqdm
from pathlib import Path
import numpy as np
from astropy.table import Table, vstack
from astropy.io import fits
import shutil
import os

## Edit these parameters as needed

In [None]:
# The number of samples from each of the 4 classes of data. i.e. dataset_size = 10 will result in 40 objects each with 5 bands, meaning 200 .fits files total
smaller_dataset_size = 100

# The parent directory that contains the smaller dataset. i.e. where you want to save the smaller dataset
smaller_data_directory = Path(f"/Users/drew/sl_data_challenge/sl_{smaller_dataset_size}")

# The directory containing the uncompressed original data
original_data_directory = Path("/Users/drew/sl_data_challenge")

### Function to copy subset of original data

In [None]:
# Original datasets each contain this many objects (5 bands each, so 250_000 total .fits files)
original_dataset_size = 50_000  #! Be careful editing this number

def create_smaller_sample(subset, smaller_dataset_size=smaller_dataset_size):
    # directory containing the original data
    root_dir = original_data_directory / subset

    # directory where the smaller dataset will be saved
    output_dir = smaller_data_directory / subset
    output_dir.mkdir(parents=True, exist_ok=True)

    # Read in the original parameters table
    params_table = Table.read(root_dir / "parameters.fits")
    if "Lens ID" in params_table.columns:
        id_column_name = "Lens ID"
    else:
        id_column_name = "Object ID"

    ids = []
    i = 0
    with tqdm(total=smaller_dataset_size) as pbar:

        # Work through the original files until we have accumulated the number we want, or we run out of files
        while len(ids) < smaller_dataset_size and i < original_dataset_size:
            # glob pattern to get the file set for the i_th index
            pattern = f"*_{str(i).zfill(8)}_*.fits"
            files = list(root_dir.glob(pattern))

            # if there are not 5 files, move to the next file set
            if len(files) != 5:
                print(f"\tExpected 5 files, found {len(files)} for pattern {pattern}")
                i += 1
                continue

            # open one of the files
            raw_data = fits.getdata(files[0], 1)

            # get the object id as a string from raw_data
            object_id = str(raw_data[0][0])

            # if the object id is in params table...
            if np.any(np.isin(params_table[id_column_name].astype(str), object_id)):
                # copy the 5 files over, and update their indexes to match.
                for file in files:
                    # replace a portion of the filename with the current index
                    new_filename = file.name.replace(
                        f"_{str(i).zfill(8)}_",
                        f"_{str(len(ids)).zfill(8)}_"
                    )
                    shutil.copy(file, output_dir / new_filename)

                # Add the object id to the list
                ids.append(object_id)

                # update the progress bar
                pbar.update(1)

            # move to the next file set
            i += 1

    # Filter the original parameters table to only include rows for the files that were copied
    small_params_table = params_table[np.isin(params_table[id_column_name].astype(str), ids)]

    # Write out the new smaller parameters file.
    small_params_table.write(output_dir / "parameters.fits", format='fits', overwrite=True)

    return small_params_table, output_dir

In [None]:
"""
The allowed file_prefixes are:
"D1_L" - slsim lenses
"D1_N" - slsim non-lenses
"D2_L" - HSC lenses
"D2_N" - HSC non-lenses
"""
allowed_file_prefixes = ["D1_L", "D1_N", "D2_L", "D2_N"]

def verify_file_and_table_order(output_dir, file_prefix="D2_L"):
    if file_prefix not in allowed_file_prefixes:
        raise ValueError(f"file_prefix must be one of {allowed_file_prefixes}")

    # Open the newly created small parameters table
    small_table = Table.read(output_dir / "parameters.fits")
    if "Lens ID" in small_table.columns:
        id_column_name = "Lens ID"
    else:
        id_column_name = "Object ID"

    # For each row in the table
    for i in tqdm(range(len(small_table))):
        # Open one of the corresponding files
        raw_data = fits.getdata(output_dir / f"{file_prefix}_{str(i).zfill(8)}_g.fits")
        # Compare the object id in the file to the object id in the table
        table_id = small_table[i][id_column_name]
        if not isinstance(table_id, str):
            table_id = str(table_id)
        if not table_id == str(raw_data[0][0]):
            print("Problem")

## HSC Lenses

In [None]:
subset = "hsc_lenses"
smaller_params_table, output_dir = create_smaller_sample(subset)
print("Completed creating smaller sample.")
verify_file_and_table_order(output_dir, file_prefix="D2_L")
print("Finished verification test.")

## HSC Non-Lenses

In [None]:
subset = "hsc_nonlenses"
smaller_params_table, output_dir = create_smaller_sample(subset)
print("Completed creating smaller sample.")
verify_file_and_table_order(output_dir, file_prefix="D2_N")
print("Finished verification test.")

## Simulated Lenses

In [None]:
subset = "slsim_lenses"
smaller_params_table, output_dir = create_smaller_sample(subset)
print("Completed creating smaller sample.")
verify_file_and_table_order(output_dir, file_prefix="D1_L")
print("Finished verification test.")

## Simulated Non-Lenses

In [None]:
subset = "slsim_nonlenses"
smaller_params_table, output_dir = create_smaller_sample(subset)
print("Completed creating smaller sample.")
verify_file_and_table_order(output_dir, file_prefix="D1_N")
print("Finished verification test.")

# Make combined datasets for training on both lens and non-lens images
The method for creating the combined datasets is identical for both slsim and HSC data.
1) Open both the parameters.fits tables, concatenate them together and write that out.
2) Create symlinks for all of the lens .fits files in the new combined directory
3) Create symlinks for all of the non-lens .fits files, but update the index to be `i+N`
   * `i` = the original index. i.e. `00000012`.
   * `N` = the number of objects in the smaller dataset. i.e. `100`
   * This update means that the filename index will match the row in the `parameters.fits` file.

### Create combined SLSIM dataset
This contains N samples of lens and non-lens objects. There will be 5 .fits files per object.
The naming scheme is the same as for the original dataset.
The `parameters.fits` is simply the concatenation of the two original `parameters.fits` files, therefore there is almost no column overlap.

In [None]:
# Create a combined parameters file that just contains object ids

small_lenses_dir = smaller_data_directory / "slsim_lenses"
slsim_lenses_table = Table.read(small_lenses_dir / "parameters.fits")

small_nonlenses_dir = smaller_data_directory / "slsim_nonlenses"
slsim_nonlenses_table = Table.read(small_nonlenses_dir / "parameters.fits")

combined_dir = smaller_data_directory / "slsim_combined"
combined_dir.mkdir(parents=True, exist_ok=True)

# Concatenate the lens and non-lens parameters files together.
combined = vstack([slsim_lenses_table, slsim_nonlenses_table])
combined.write(combined_dir / "parameters.fits", format='fits', overwrite=True)

# Create symlinks in the combined directory for the lens files
for i in range(smaller_dataset_size):
    pattern = f"*L*_{str(i).zfill(8)}_*.fits"
    files = list(small_lenses_dir.glob(pattern))

    for src in files:
        dst = combined_dir / src.name
        if os.path.islink(dst):
            os.unlink(dst)
        os.symlink(src, dst)

# Create symlinks in the combined directory for the non-lens files
for i in range(smaller_dataset_size):
    pattern = f"*N*_{str(i).zfill(8)}_*.fits"
    files = list(small_nonlenses_dir.glob(pattern))

    for src in files:
        # increment file indexes by the total number of lens files
        # e.g. D1_N_00000010_g.fits becomes D1_N_00000110_g.fits
        new_index = i + (1 * smaller_dataset_size)
        new_filename = src.name.replace(
                        f"_{str(i).zfill(8)}_",
                        f"_{str(new_index).zfill(8)}_"
                    )
        dst = combined_dir / new_filename
        if os.path.islink(dst):
            os.unlink(dst)
        os.symlink(src, dst)


### Create combined HSC dataset
This contains N samples of lens and non-lens objects. There will be 5 .fits files per object.
The naming scheme is the same as for the original dataset.
The `parameters.fits` is simply the concatenation of the two original `parameters.fits` files, therefore there is almost no column overlap.

In [None]:
# Create a combined parameters file that just contains object ids

small_lenses_dir = smaller_data_directory / "hsc_lenses"
hsc_lenses_table = Table.read(small_lenses_dir / "parameters.fits")

small_nonlenses_dir = smaller_data_directory / "hsc_nonlenses"
hsc_nonlenses_table = Table.read(small_nonlenses_dir / "parameters.fits")

combined_dir = smaller_data_directory / "hsc_combined"
combined_dir.mkdir(parents=True, exist_ok=True)

# Concatenate the lens and non-lens parameters files together.
combined = vstack([hsc_lenses_table, hsc_nonlenses_table])
combined.write(combined_dir / "parameters.fits", format='fits', overwrite=True)

# Create symlinks in the combined directory for the lens files
for i in range(smaller_dataset_size):
    pattern = f"*L*_{str(i).zfill(8)}_*.fits"
    files = list(small_lenses_dir.glob(pattern))

    for src in files:
        dst = combined_dir / src.name
        if os.path.islink(dst):
            os.unlink(dst)
        os.symlink(src, dst)

# Create symlinks in the combined directory for the non-lens files
for i in range(smaller_dataset_size):
    pattern = f"*N*_{str(i).zfill(8)}_*.fits"
    files = list(small_nonlenses_dir.glob(pattern))

    for src in files:
        # increment file indexes by the total number of lens files
        # e.g. D1_N_00000010_g.fits becomes D1_N_00000110_g.fits
        new_index = i + (1 * smaller_dataset_size)
        new_filename = src.name.replace(
                        f"_{str(i).zfill(8)}_",
                        f"_{str(new_index).zfill(8)}_"
                    )
        dst = combined_dir / new_filename
        if os.path.islink(dst):
            os.unlink(dst)
        os.symlink(src, dst)