# M3: Targeted to Global Conversion

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'retina'
import seaborn as sns

# import os ; replaced below
from pathlib import Path

In [None]:
# consider simplifying to assume input files and if not found, create

SEARCH_MSG = \
    f"> Select a folder to add it to the path and search inside.\n" \
    f"> Select a file to finish the search process.\n" \
    f"> To select option [#], enter #.\n" \
    f"[E] Exit\n" \
    f"[U] Search up one level"

def findFile(input_path):
    current_path = Path(input_path)
    if not current_path.exists():
        return("Invalid Path")
    print(f"Current path: {current_path}")
    print(SEARCH_MSG)
    
    # list out files with [#] as identifier
    dirList = [f for f in current_path.iterdir()]
    print("> Files in current folder:")
    for c, i in enumerate(dirList):
        print('[' + str(c) + '] ' + str(i.name))
    print("")

    choice = input()
    while choice.lower() not in ["u", "e"] + list(str(i) for i in range(len(dirList))):
        choice = input("Couldn't read input, trying again. Target #: ")

    if choice.lower() == "e": return("Exiting.")
    if choice.lower() == "u": return(findFile(current_path.parent))
    if choice.isnumeric() and int(choice) in range(len(dirList)):
        name = dirList[int(choice)]
        current_path = current_path / name

        if current_path.is_dir():
            print(f"Selected folder to search: {current_path}")
            return(findFile(current_path))
        if current_path.is_file():
            print(f"Selected file: {current_path}")
            return current_path

    return("Error reading choice, exiting.")

In [None]:
data_path = findFile(Path.cwd())
print(data_path)

In [None]:
if data_path != Path.cwd().parent / Path('input files/su25/clean M3 wavelengths targeted global.csv'):
    print("path updated")
    data_path = Path.cwd().parent / Path('input files/su25/clean M3 wavelengths targeted global.csv')
else:
    print("path unchanged")

In [None]:
df = pd.read_csv(data_path).dropna()

In [None]:
df.head()

In [None]:
df.describe()

In [None]:
# process:
# change step size if past a breakpoint
# format:
# current wl: step to transition to
breakpoints = {
    "wl"   : [0, 0.44, 0.68, 0.71, 1.53, 1.56, 1.60], # 9999],
    "step" : [0,    4,    4,    3,    2,    3,    4]  #    4]
}


try:
    wavelengths = np.array(df[df.columns[0]])
except:
    print(f"wavelength column not found")

print(f"wl shape: {wavelengths.shape}")
wl_nans = np.isnan(wavelengths).sum()
if wl_nans > 0:
    print(f"wavelength nans: {np.isnan(wavelengths).sum()}")
    print(f"wavelength nan indices: {np.where(np.isnan(wavelengths))[0]}")


try:
    spectra = np.array(df[df.columns[1]])
except:
    print(f"spectra column not found")

print(f"spectra shape: {spectra.shape}")
spectra_nans = np.isnan(spectra).sum()
if spectra_nans > 0:
    print(f"spectra nans: {np.isnan(spectra).sum()}")
    print(f"spectra nan indices: {np.where(np.isnan(spectra))[0]}")

In [None]:
# used to find breakpoints and compare to empirical process
# empirical = np.array(df[df.columns[-1]].dropna())
# if testing:
#         print(f"\nlengths:\n" + f"output    {len(output)}\n" + f"empirical {len(empirical)}\n")
#         print(f"sum diff:  {sum(output - empirical):f}")

#         print(f"\n" + f"idx: generated vs empirical")
#         for i in range(min(len(output), len(empirical))):
#             if empirical[i] != output[i]:
#                 print(f"{i}: {output[i]:.04f} | {empirical[i]:.04f}")

In [None]:
# average 4 at a time
# set step size to last breakpoint
# increment by step size
# round to 4 decimals

def targeted_to_global(wavelengths, spectra, verbose = False):
    if verbose: print(f"     wl: step, index change")

    i = 0 # input index
    bp_idx = 0
    output = {"wavelengths": [], "spectra": []}
    while i+3 < len(wavelengths):
        # average over next 4
        avg_wl = np.round(np.mean(wavelengths[i:i+4]),4)
        avg_spectra = np.round(np.mean(spectra[i:i+4]),4)

        # if next step exists
        # and current wl > bp_idx wl
        while bp_idx + 1 < len(breakpoints["step"]) and wavelengths[i] > breakpoints["wl"][bp_idx]:
            bp_idx += 1
            if verbose: print(f"{wavelengths[i]:0.5f}: {breakpoints['step'][bp_idx-1]}->{breakpoints['step'][bp_idx]}, {bp_idx-1}->{bp_idx}")

        # save and increment by values consumed
        output["wavelengths"].append(avg_wl)
        output["spectra"].append(avg_spectra)
        i += breakpoints["step"][bp_idx]
    
    output["wavelengths"] = np.array(output["wavelengths"])
    output["spectra"] = np.array(output["spectra"])
    return output

In [None]:
targeted_to_global(wavelengths, spectra, verbose=True)

In [None]:
### inconsistencies
## round up, not truncating
# mean(0.99497, 1.005, 1.0149, 1.0249) = 1.0099425 -> 1.0099, manual is 1.01 > 1.0099
## round down
# mean(2.3624, 2.3724, 2.3823, 2.3923) = 2.37735 -> 2.3774, manual is 2.3773 < 2.3774

In [None]:
plotting = True
if plotting:

    data = {
        "Input Wavelengths": wavelengths, 
        # "Empirical M3": empirical, 
        "Algorithm Output": targeted_to_global(wavelengths, spectra)["wavelengths"]
    }

    print(data['Input Wavelengths'].shape)
    print(data['Algorithm Output'].shape)

    colors = ['#1f77b4', '#2ca02c']
    fig, axes = plt.subplots(
        nrows=2,
        sharex=True, 
        figsize=(10, 4), 
        height_ratios=[1, 1]
    )

    # each series as a rug plot with its own row
    for ax, (label, series), color in zip(axes, data.items(), colors):
        sns.rugplot(x = series, ax = ax, height = 0.5, color = color)
        ax.set_yticks([])  # No y-axis ticks
        ax.set_ylabel("")
        ax.text(
            0.5, # x offset
            0.8, # y offset
            label, # text content
            # **kwargs
            transform=ax.transAxes, 
            ha='center', 
            va='top', 
            fontsize=10, 
            color=color, 
            weight='bold'
        )

    # last plot label acts as legend label
    axes[-1].set_xlabel("Wavelength distribution comparison", weight = 'bold') 
    plt.tight_layout()

    # plt.savefig(
    #     "../output files/graphs/test/m3 wavelength comparison.png", 
    #     dpi=2000, 
    #     bbox_inches='tight',
    #     facecolor='white'
    # )

    plt.show() # calls plt.clf()

In [None]:
### file writing:
# check if output files/data_path exists
# prompt to create
# write file to output_path

def guess_output_path(input_path):
    parts = list(input_path.parts)
    for i in range(len(parts)-1,-1,-1):
        if parts[i] == "input files":
            return Path(*parts[:i]) / "output files"

def prompt_output_path(input_path):
    print("Folder 'output files' not found, create one (Y/N)? ")
    ans = input()
    while ans.lower() not in ["y", "n", "e"]:
        ans = input()
    if ans.lower() != "y":
        return None

    print("Enter # to create 'output files' folder in corresponding parent folder.")
    print("[#] folder name")
    for c, i in enumerate(output_path.parts):
        print(f"[{c}] {i}")
    ans = input()
    return Path(*list(output_path.parts)[:c]) / "output files"

def writeFile(data, filename, output_path, mode): 
    # x - new, w - overwrite
    output = open(output_path / filename, mode)
    output.write(f"header")
    for i in data:
        output.write(f"data formatted")
    output.close()


output_path = ""
filename = data_path.name
saving = False

if input("Guess output path?").lower() == "y":
    output_path = guess_output_path(data_path)
    if output_path.exists():
        print(f"Saving to '{output_path}'.")
        saving = True
    else:
        print(f"Path not found, exiting.")
        saving = False
if not saving:
    output_path = prompt_output_path(data_path)
    if output_path and output_path.exists():
        print(f"Saving to '{output_path}'.")
        saving = True
    else:
        print(f"Path not recognized, exiting.")
        saving = False
if saving:
    try:
        writeFile(converted_data, filename, output_path, mode = "x")
    except FileExistsError:
        print(f"File path already exists, overwrite (Y/N)?")
        ans = input()
        while ans.lower() not in ["y", "n"]:
            ans = input()
        if ans.lower() == "y":
            writeFile(converted_data, filename, output_path, mode = "w")