In [2]:
from speclearn.io.transform.rebin import *
from speclearn.io.data.m3_data import remove_outliers_spectral_line, remove_zeros_and_ones, M3_image
import matplotlib.pyplot as plt
import numpy as np

In [None]:
m3_img = M3_image('/media/freya/data_1/data/m3/CH1M3_0004/DATA/20081118_20090214/200902/L2/M3G20090202T111347_V01_RFL.HDR', level=2)
m3_img.fill()
m3_img.load_data()
rebin_img, rebin_img_bins = load_file(f'/media/freya/rebin/M3/rebin/nlong7200_nlat3600/spec/M3G20090202T111347.npz')

#sns.heatmap(rebin_img[:,:,0])
#plt.show()


def remove_zeros_and_ones(img, return_zeros=False):
    img_rm = img.copy()
    mean = np.nanmean(img_rm, axis=2)
    std = np.nanstd(img_rm, axis=2)

    img_zeros = None
    if return_zeros:
        img_zeros = img_rm.copy()
        img_zeros[(mean > 0.01)] = float('NaN')
        img_zeros[(std > 0.01)] = float('NaN')

    img_rm[(mean < 0.01)] = float('NaN')
    img_rm[(std < 0.01)] = float('NaN')

    # remove spectra where any is larger than 1 (with some wiggle room)
    img_rm[(np.nanmax(img_rm, axis=2) > 1.2)] = float('NaN')
    return img_rm, img_zeros


img_zeros_rm, img_zeros = remove_zeros_and_ones(rebin_img, return_zeros=True)
#sns.heatmap(img_zeros_rm[:,:,0])
#plt.show()

img_out_rm, img_out = remove_outliers_spectral_line(img_zeros_rm, threshold=5., return_outliers=True)
#sns.heatmap(img_out_rm[:,:,0])
#plt.show()

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(9, 5))  # Adjust wspace here
plt.subplots_adjust(wspace=0.2)
i_min = 0
i_max = 1000
j_min = 0
j_max = 200

# Plot data with NaN check
def plot_data(ax, img, color, title):
    for i in range(i_min, i_max):
        for j in range(j_min, j_max):
            data = img[i, j, :]
            if not np.all(np.isnan(data)):  # Check for NaN values
                ax.errorbar(GLOBAL_WAVELENGTH, data, alpha=0.15, c=color)
    ax.set_title(title)
    ax.set_ylim(-0.1, 0.7)

# Plot subplots
plot_data(axes[0, 0], rebin_img, 'C0', 'Image data')
plot_data(axes[0, 1], img_zeros_rm, 'C1', 'Data after removed zeros')
plot_data(axes[0, 2], img_out_rm, 'C2', 'Data after removed zeros & outliers')
plot_data(axes[1, 1], img_zeros, 'C1', 'Removed zeros')
plot_data(axes[1, 2], img_out, 'C2', 'Removed outliers')

# Adjust y limit
axes[1, 1].set_ylim(-0.1, 0.2)
axes[1, 2].set_ylim(-0.1, 0.2)

# Remove the bottom left axes
axes[1, 0].remove()

# Share x-axis for each row
for ax in axes[0, :]:
    ax.sharex(axes[0, 0])
    ax.sharey(axes[0, 0])
for ax in axes[1, 1:]:  # Only two plots left in the second row
    ax.sharex(axes[1, 1])
    ax.sharey(axes[1, 1])

# Show x-axis labels only on the bottom plots
for ax in axes[0, :]:
    ax.set_xlabel('')
axes[1, 1].set_xlabel('Wavelength [nm]')
axes[1, 2].set_xlabel('Wavelength [nm]')

# Show y-axis labels only on the leftmost plots
axes[0, 1].set_ylabel('')
axes[0, 2].set_ylabel('')
axes[1, 2].set_ylabel('')
axes[0, 0].set_ylabel('Reflectance')
axes[1, 1].set_ylabel('Reflectance')

plt.tight_layout()
plt.show()