# Init

In [None]:
import numpy as np
import logging
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from random import randint
%cd ..
%load_ext autoreload

import plotly.graph_objects as go

In [None]:
%autoreload 2
from src.MapData import MapData
from src.LineFinder import LineFinder

from src.plotting_functions import plot_embedding, plot_single_variable_map

In [None]:
map_data = MapData('./data/2022_03_22_P56B_307x532.libsdata')
map_data.get_metadata()
map_data.load_wavelenths()
map_data.load_all_data()
map_data.trim_spectra(64)
map_data.get_map_dimensions()

# Systemic noise
- [ ] must avoid removing parts of the emission lines
  - using std of the difference spectra finds the emission lines but leaving these regions untreated defeats te purpuse of the noise removal
  - the systemic noise is white noise -> it is supposed oscillate around 0 => use this to detect or correct emission line regions

In [165]:
diff_spectra = np.diff(map_data.spectra[:,:])
noise_spectrum = np.median(diff_spectra,axis=0,keepdims=True)
std_diff_spectrum = np.std(diff_spectra,axis=0,keepdims=True)
noise_spectrum /= 2

plot_spectrum = noise_spectrum.copy()
threshold = np.quantile(
    a=std_diff_spectrum,
    q=[.95],
)

plot_spectrum[:,np.where(std_diff_spectrum > threshold)[1]] = 0

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        y=np.squeeze(std_diff_spectrum)
    )
)
fig.add_trace(
    go.Scatter(
        y=np.squeeze(plot_spectrum)
    )
)
fig.add_trace(
    go.Scatter(
        y=np.squeeze(noise_spectrum)
    )
)

In [166]:
denoised_data = np.subtract(
    map_data.spectra[:,1:],
    # noise_spectrum
    plot_spectrum
)

In [167]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=map_data.wvl[1:],
        y=np.squeeze(noise_spectrum),
        name=f'differences'
    )
)
for ndx in [randint(0,map_data.spectra.shape[0]) for _ in range(10)]:
    fig.add_trace(
        go.Scatter(
            x=map_data.wvl,
            y=map_data.spectra[ndx,:],
            name=f'{ndx}'
        )
    )
    fig.add_trace(
        go.Scatter(
            x=map_data.wvl[1:],
            y=denoised_data[ndx,:],
            name=f'dn_{ndx}'
        )
    )

fig.show()

## !baseline correction

In [None]:
map_data.get_baseline(min_window_size=50, smooth_window_size=100)
map_data.baseline_correct()

In [None]:
denoised_data = np.subtract(
    map_data.spectra[:,1:],
    noise_spectrum
)

In [None]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=map_data.wvl[1:],
        y=np.squeeze(noise_spectrum),
        name=f'differences'
    )
)
for ndx in [randint(0,map_data.spectra.shape[0]) for _ in range(10)]:
    fig.add_trace(
        go.Scatter(
            x=map_data.wvl,
            y=map_data.spectra[ndx,:],
            name=f'{ndx}'
        )
    )
    fig.add_trace(
        go.Scatter(
            x=map_data.wvl[1:],
            y=denoised_data[ndx,:],
            name=f'dn_{ndx}'
        )
    )

fig.show()

# Wavelet denoising

In [None]:
from skimage.restoration import denoise_wave
import pywt
from scipy.interpolate import interp1d

In [170]:
data = map_data.spectra[1521,:]
# data = denoised_data[1521,:]

In [171]:
interpolator = interp1d(
    map_data.wvl,
    data
)

new_wvl = np.linspace(
    start=map_data.wvl[0],
    stop=map_data.wvl[-1],
    num=int(2 ** np.ceil(np.log2(map_data.spectra.shape[1])))
)

interpolated_data = interpolator(new_wvl)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=map_data.wvl,
        y=data,
        name='initial'
    )
)
fig.add_trace(
    go.Scatter(
        x=new_wvl,
        y=interpolated_data,
        name='interpolated'
    )
)

In [142]:
def get_threshold(data: np.array):
    return np.std(np.abs(data))

def sigma_clip(
    spectrum: np.array, 
    level: int = 3,
    iters: int = 1
):
    spectrum = spectrum.copy()
    threshold = np.max(spectrum)
    for _ in range(iters):
        threshold = get_threshold(spectrum[np.abs(spectrum) <= threshold]) * level

    spectrum[np.abs(spectrum) <= threshold] = 0
    return spectrum

In [179]:
def upscale_wvl(wvl):
    return np.linspace(
        start=wvl[0],
        stop=wvl[-1],
        num=int(2 ** np.ceil(np.log2(len(wvl))))
    )

def upscale_spectrum(
    spectrum: np.array,
    wvl: np.array,
    new_wvl: np.array = None
):
    if new_wvl is None:
        new_wvl = upscale_wvl(wvl)

    interpolator = interp1d(
        wvl,
        spectrum
    )

    return interpolator(new_wvl)

In [181]:
np.apply_along_axis(
    arr=map_data.spectra,
    axis=1,
    func1d=upscale_spectrum,
    wvl=map_data.wvl,
    new_wvl=upscale_wvl(map_data.wvl)
)


array([[15.2291779 , 25.85169534, 18.4449409 , ..., 35.61127087,
        17.10118788, 50.66686268],
       [13.44127819, 20.1104539 ,  1.92579483, ..., 21.16789068,
        -3.72945542, 37.07516882],
       [25.1948179 , 31.701413  , 15.30906603, ..., 38.95791386,
        19.36426323, 50.86193654],
       ...,
       [46.24481766, 69.7902253 , 58.98000519, ..., 27.59865435,
        15.76054001, 51.54904458],
       [22.53051127, 31.70083271,  5.54555942, ..., 23.64800103,
        10.14217013, 60.66219649],
       [28.09228903, 37.18009677, 29.69574157, ..., 40.73890359,
        30.77121006, 58.97030241]])

In [176]:
b3_wavelet = pywt.Wavelet(    
    'b3',
    filter_bank=(
        np.array([0,0,1,4,6,4,1,0,0])/16, # low-pass h(z)
        np.array([-1,-8,-28,-56,186,-56,-28,-8,-1])/256, # high-pass g(z)
        np.array([0,0,1,4,6,4,1,0,0])/16,
        np.array([0,0,0,0,1,0,0,0,0]),
        ######################################
        # np.array([1,1])/2, # low-pass h(z)
        # np.array([-1,1])/2, # high-pass g(z)
        # np.array([1,1])/2, #
        # np.array([-1,1])/2, #
    )
)

wavelet = [b3_wavelet,pywt.Wavelet('bior6.8'),pywt.Wavelet('rbio6.8')][1]

# fig = go.Figure()
# for ndx,filter_coefs in enumerate(wavelet.filter_bank):
# # for ndx,filter_coefs in enumerate(pywt.Wavelet('db8').filter_bank):
#     fig.add_trace(
#         go.Scatter(
#             y=filter_coefs,
#             name=f'{ndx}'
#         )
#     )

# fig.show()

wavelet_docomposition = pywt.swt(
    interpolated_data, 
    wavelet=wavelet,
    level=11,
    start_level=0,
    # axis=-1,
    trim_approx=False,
    # norm=True
)

VALUE = 35

thresholded_decomposition = [
    (
        pywt.threshold(
            data=x[0],
            substitute=0,
            value=np.std(noise_spectrum),
            mode='soft'
        ),
        pywt.threshold(
            data=x[1],
            substitute=0,
            value=np.std(noise_spectrum),
            mode='soft'
        )
        # x[0],
        # x[1],
        # sigma_clip(spectrum=x[0],level=3, iters=3),
        # sigma_clip(spectrum=x[1],level=3, iters=3)
    )
    for x 
    in wavelet_docomposition
]

wavelet_reconstruction = pywt.iswt(
    thresholded_decomposition,
    wavelet=wavelet,
    # norm=True
)

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        y=interpolated_data,
        name='spectrum'
    )
)
fig.add_trace(
    go.Scatter(
        y=wavelet_reconstruction,
        name='reconstruction'
    )
)

In [None]:
wavelet_reconstruction

In [156]:
wavelet_docomposition = pywt.swt(
    interpolated_data, 
    wavelet=wavelet,    
    level=11,
    start_level=0,
    axis=-1, 
    trim_approx=False,
    norm=False
)

manual_reconstruction = np.zeros_like(thresholded_decomposition[0][0])

for scale in thresholded_decomposition:    
    manual_reconstruction += np.convolve(scale[0],wavelet.filter_bank[2],mode='same')
    manual_reconstruction += np.convolve(scale[1],wavelet.filter_bank[3],mode='same')    

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        y=interpolated_data,
        name='spectrum'
    )
)
fig.add_trace(
    go.Scatter(
        y=manual_reconstruction,
        name='reconstruction'
    )
)

In [None]:
fig = go.Figure()

for level,coefs in enumerate(wavelet_docomposition):
    fig.add_trace(
        go.Scatter(
            y=coefs[1],
            name=f'detail {level}'
        )
    )
    fig.add_trace(
        go.Scatter(
            y=coefs[0],
            name=f'appr. {level}'
        )
    )
fig.add_trace(
    go.Scatter(
        y=data,
        name='spectrum'
    )
)