# Init

In [None]:
import numpy as np
import logging
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from random import randint
import pywt
from scipy.interpolate import interp1d

%cd ..
%load_ext autoreload

import plotly.graph_objects as go

In [None]:
%autoreload 2
from src.MapData import MapData
from src.LineFinder import LineFinder

from src.plotting_functions import plot_embedding, plot_single_variable_map

In [None]:
map_data = MapData('./data/2022_03_22_P56B_307x532.libsdata')
map_data.get_metadata()
map_data.load_wavelenths()
map_data.load_all_data()
map_data.trim_spectra(64)
map_data.get_map_dimensions()

# map_data.get_baseline(min_window_size=50, smooth_window_size=100)
# map_data.baseline_correct()

map_data.upsample_spectra()

# Systemic noise
- [ ] must avoid removing parts of the emission lines
  - using std of the difference spectra finds the emission lines but leaving these regions untreated defeats te purpuse of the noise removal
  - the systemic noise is white noise -> it is supposed oscillate around 0 => use this to detect or correct emission line regions

In [None]:
std_diff_spectrum = np.std(diff_spectra,axis=-1,keepdims=True)

In [None]:
diff_spectra = np.diff(map_data.spectra[:,:])
noise_spectrum = np.median(diff_spectra,axis=0,keepdims=True)
std_diff_spectrum = np.std(diff_spectra,axis=-1,keepdims=True)
noise_spectrum /= 2

In [None]:
plot_spectrum = noise_spectrum.copy()
threshold = np.quantile(
    a=std_diff_spectrum,
    q=[.95],
)

plot_spectrum[:,np.where(std_diff_spectrum > threshold)[1]] = 0

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        y=np.squeeze(std_diff_spectrum)
    )
)
fig.add_trace(
    go.Scatter(
        y=np.squeeze(plot_spectrum)
    )
)
fig.add_trace(
    go.Scatter(
        y=np.squeeze(noise_spectrum)
    )
)

In [None]:
denoised_data = np.subtract(
    map_data.spectra[:,1:],
    # noise_spectrum
    plot_spectrum
)

In [None]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=map_data.wvl[1:],
        y=np.squeeze(noise_spectrum),
        name=f'differences'
    )
)
for ndx in [randint(0,map_data.spectra.shape[0]) for _ in range(10)]:
    fig.add_trace(
        go.Scatter(
            x=map_data.wvl,
            y=map_data.spectra[ndx,:],
            name=f'{ndx}'
        )
    )
    fig.add_trace(
        go.Scatter(
            x=map_data.wvl[1:],
            y=denoised_data[ndx,:],
            name=f'dn_{ndx}'
        )
    )

fig.show()

In [None]:
denoised_data = np.subtract(
    map_data.spectra[:,1:],
    noise_spectrum
)

In [None]:
fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=map_data.wvl[1:],
        y=np.squeeze(noise_spectrum),
        name=f'differences'
    )
)
for ndx in [randint(0,map_data.spectra.shape[0]) for _ in range(10)]:
    fig.add_trace(
        go.Scatter(
            x=map_data.wvl,
            y=map_data.spectra[ndx,:],
            name=f'{ndx}'
        )
    )
    fig.add_trace(
        go.Scatter(
            x=map_data.wvl[1:],
            y=denoised_data[ndx,:],
            name=f'dn_{ndx}'
        )
    )

fig.show()

# Wavelet denoising

In [None]:
def get_threshold(data: np.array):
    return np.std(np.abs(data))

def sigma_clip(
    spectrum: np.array, 
    level: int = 3,
    iters: int = 1
):
    spectrum = spectrum.copy()
    threshold = np.max(spectrum)
    for _ in range(iters):
        threshold = get_threshold(spectrum[np.abs(spectrum) <= threshold]) * level

    spectrum[np.abs(spectrum) <= threshold] = 0
    return spectrum

In [None]:
from typing import Callable

In [None]:
def _denoise_spectrum(
    spectrum,
    wavelet,
    threshold
):
    wavelet_docomposition = pywt.swt(
        spectrum, 
        wavelet=wavelet,
        level=11,
        start_level=0,
        trim_approx=False
    )

    if isinstance(threshold,Callable):
        threshold = threshold(spectrum)

    thresholded_decomposition = [
        (
            pywt.threshold(
                data=coefs[0],
                substitute=0,
                value=threshold,
                mode='soft'
            ),
            pywt.threshold(
                data=coefs[1],
                substitute=0,
                value=threshold,
                mode='soft'
            )
        )
        for coefs
        in wavelet_docomposition
    ]

    return pywt.iswt(
        thresholded_decomposition,
        wavelet=wavelet
    )


# x[0],
# x[1]
# sigma_clip(spectrum=x[0],level=3, iters=2),
# sigma_clip(spectrum=x[1],level=3, iters=2)

In [None]:
arr = np.apply_along_axis(
    func1d=_denoise_spectrum,
    axis=1,
    arr=map_data.spectra[:1000,:],
    wavelet=wavelet,
    threshold=35
)

In [225]:
fig = go.Figure()
for _ in range(10):
    ndx = randint(0,1000)
    fig.add_trace(
        go.Scatter(
            y=map_data.spectra[ndx,:],
            name=f'spectrum {ndx}'
        )
    )
    fig.add_trace(
        go.Scatter(
            y=arr[ndx,:],
            name=f'reconstruction {ndx}'
        )
    )
    fig.add_trace(
        go.Scatter(
            y=map_data.spectra[ndx,:] - arr[ndx,:],
            name='removed noise'
        )
    )

fig.show()

In [None]:
b3_wavelet = pywt.Wavelet(    
    'b3',
    filter_bank=(
        np.array([0,0,1,4,6,4,1,0,0])/16, # low-pass h(z)
        np.array([-1,-8,-28,-56,186,-56,-28,-8,-1])/256, # high-pass g(z)
        np.array([0,0,1,4,6,4,1,0,0])/16,
        np.array([0,0,0,0,1,0,0,0,0]),
        ######################################
        # np.array([1,1])/2, # low-pass h(z)
        # np.array([-1,1])/2, # high-pass g(z)
        # np.array([1,1])/2, #
        # np.array([-1,1])/2, #
    )
)

wavelet = [b3_wavelet,pywt.Wavelet('bior6.8'),pywt.Wavelet('rbio6.8')][2]

data = map_data.spectra[
    randint(0,map_data.spectra.shape[0]),
    :
]

wavelet_reconstruction = _denoise_spectrum(
    data,
    wavelet,
    35
)

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        y=data,
        name='spectrum'
    )
)
fig.add_trace(
    go.Scatter(
        y=wavelet_reconstruction,
        name='reconstruction'
    )
)
fig.add_trace(
    go.Scatter(
        y=data - wavelet_reconstruction,
        name='removed noise'
    )
)

In [None]:
plt.hist(data - wavelet_reconstruction)

In [None]:
wavelet_docomposition = pywt.swt(
    interpolated_data, 
    wavelet=wavelet,    
    level=11,
    start_level=0,
    axis=-1, 
    trim_approx=False,
    norm=False
)

manual_reconstruction = np.zeros_like(thresholded_decomposition[0][0])

for scale in thresholded_decomposition:    
    manual_reconstruction += np.convolve(scale[0],wavelet.filter_bank[2],mode='same')
    manual_reconstruction += np.convolve(scale[1],wavelet.filter_bank[3],mode='same')    

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        y=interpolated_data,
        name='spectrum'
    )
)
fig.add_trace(
    go.Scatter(
        y=manual_reconstruction,
        name='reconstruction'
    )
)

In [None]:
fig = go.Figure()

for level,coefs in enumerate(wavelet_docomposition):
    fig.add_trace(
        go.Scatter(
            y=coefs[1],
            name=f'detail {level}'
        )
    )
    fig.add_trace(
        go.Scatter(
            y=coefs[0],
            name=f'appr. {level}'
        )
    )
fig.add_trace(
    go.Scatter(
        y=data,
        name='spectrum'
    )
)

## "Manual" wavelet decomposition

In [None]:
wavelet_docomposition = wavelet_docomposition = pywt.swt(
    interpolated_data, 
    wavelet=wavelet,
    level=9,
    start_level=0,    
    trim_approx=False    
)

In [None]:
def extend_wavelet(wavelet, level):
    level += 1
    extended_wavelet = np.zeros(len(wavelet) * level)
    extended_wavelet[::level] = wavelet
    return extended_wavelet

In [None]:
previous_approx.shape

In [None]:
manual_wavelet_decomposition = []
previous_detail = 0
filtered_data = interpolated_data
for level in range(9):
    manual_wavelet_decomposition.append((
        np.convolve(
            filtered_data,
            extend_wavelet(wavelet=wavelet.filter_bank[0],level=level),
            mode='same'
        ),
        np.convolve(
            filtered_data,
            extend_wavelet(wavelet=wavelet.filter_bank[1],level=level),
            mode='same'
        )
    ))    
    filtered_data = manual_wavelet_decomposition[-1][0][level:]

In [None]:
level = 3

fig = go.Figure()
for ndx,detail in enumerate(['appr.','detail']):
    fig.add_trace(
        go.Scatter(
            y=manual_wavelet_decomposition[level][ndx][(level+1):],
            name=f'manual {detail}'
        )
    )
    fig.add_trace(
        go.Scatter(
            y=wavelet_docomposition[-(1+level)][ndx],
            name=f'built-in {detail}'
        )
    )

fig.show()
