In [None]:
import numpy as np
import graphinglib as gl
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.transforms import ScaledTranslation
from torch.utils.data import DataLoader
from string import ascii_lowercase

from projet.src.tools.smart_figure import SmartFigure
from projet.src.spectrums.spectrum import Spectrum
from projet.src.data_structures.spectrum_dataset import SpectrumDataset
from projet.src.fitters.cnn_fitter import CNNFitter
from projet.src.fitters.res_net_fitter import ResNetFitter
from projet.src.fitters.score import *
from projet.src.tools.utilities import *

np.random.seed(0)

# All spectra figure

In [None]:
filenames = ["single_gaussian", "distinct_gaussians", "distinct_twin_gaussians", "merged_twin_gaussians", 
             "pointy_gaussians", "contaminated_gaussians", "two_gaussian_components"]

spectra_filenames = []
for filename in filenames:
    for noise_level in ["smooth", "noisy", "very_noisy"]:
            spectra_filenames.append(f"projet/data/spectra/{filename}/{noise_level}.txt")

spectrums = [Spectrum.load(spectrum_file) for spectrum_file in spectra_filenames]

# Apply a small correction to the two_gaussians_components spectra to show the last two gaussians better
spectrums[-3].models[-2].mean = (69, 75)
spectrums[-3].models[-1].mean = (69, 75)
spectrums[-2].models[-2].mean = (69, 75)
spectrums[-2].models[-1].mean = (69, 75)
spectrums[-1].models[-2].mean = (69, 75)
spectrums[-1].models[-1].mean = (69, 75)
# --------------------------------------

figs = []
for i, (spectrum_type, spec_filenames) in enumerate(zip(
    np.array(spectrums, dtype=object).reshape(-1, 3), np.array(spectra_filenames, dtype=object).reshape(-1, 3)
)):
    figs.append(SmartFigure(
        num_rows=1,
        num_cols=3,
        title=f"{ascii_lowercase[i]}) {" ".join(spec_filenames[0].split('/')[-2].split('_'))}",
        remove_x_ticks=True,
        share_y=True,
        elements=[spectrum_type[0].plot, spectrum_type[1].plot, spectrum_type[2].plot],
        reference_labels=False,
        width_padding=0,
        size=(12, 4),
    ))

large_fig = SmartFigure(
    num_rows=7,
    num_cols=1,
    x_label="Numéro du canal [-]",
    y_label="Intensité [u. arb.]",
    size=(14, 17),
    elements=figs,
    height_padding=0.02,
    reference_labels=True
)
large_fig.show()
# large_fig.save("projet/figures/all_spectra.pdf", dpi=600)

# Scores

In [None]:
filenames = ["single_gaussian", "distinct_gaussians", "distinct_twin_gaussians", "merged_twin_gaussians", 
             "pointy_gaussians", "contaminated_gaussians", "two_gaussian_components"]

cnn_fitter_info = pd.read_csv("projet/data/neural_networks/CNNFitter/info.csv")
res_net_fitter_info = pd.read_csv("projet/data/neural_networks/ResNetFitter/info.csv")

# cnn_fitter_info = cnn_fitter_info.sort_values(by=["file"], key=lambda col: col.str.split("/").str[::-1].str.join("/")).reset_index(drop=True)

# Extract MSE values and reshape them into groups of 3 (smooth, noisy, very noisy)
cnn_mse_values = cnn_fitter_info["MSE"].values.reshape(-1, 3)[:,[1,2,0]][[2,3,5,6,0,1,4]]
res_net_mse_values = res_net_fitter_info["MSE"].values.reshape(-1, 3)[:, [1, 2, 0]][[2, 3, 5, 6, 0, 1, 4]]

# Create a bar plot
x = np.arange(len(cnn_mse_values))  # Number of spectrum groups
width = 0.25  # Width of each bar

# Create a bar plot
fig, ax = plt.subplots(figsize=(10, 6), layout="constrained")

# Plot CNNFitter bars
ax.bar(x - width, cnn_mse_values[:, 0], width, label="Spectres non bruités")
ax.bar(x, cnn_mse_values[:, 1],         width, label=r"Spectres légèrement bruités avec $\sigma=0.4$")
ax.bar(x + width, cnn_mse_values[:, 2], width, label=r"Spectres très bruités avec $\sigma=1$")

# Plot ResNetFitter bars on top
ax.bar(x - width, res_net_mse_values[:, 0], width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
ax.bar(x, res_net_mse_values[:, 1],         width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
ax.bar(x + width, res_net_mse_values[:, 2], width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)

# Add labels, title, and legend
ax.set_ylabel("MSE [u. arb.]")
ax.set_xticks(x)
ax.set_xticklabels([" ".join(file.split("_")) for file in filenames], rotation=10)
ax.tick_params(axis="both", direction="in")
ax.legend()

# plt.show()

In [None]:
filenames = ["single_gaussian", "distinct_gaussians", "distinct_twin_gaussians", "merged_twin_gaussians", 
             "pointy_gaussians", "contaminated_gaussians", "two_gaussian_components"]

cnn_fitter_info = pd.read_csv("projet/data/neural_networks/CNNFitter/info.csv")
res_net_fitter_info = pd.read_csv("projet/data/neural_networks/ResNetFitter/info.csv")

# cnn_fitter_info = cnn_fitter_info.sort_values(by=["file"], key=lambda col: col.str.split("/").str[::-1].str.join("/")).reset_index(drop=True)

# Extract MSE values and reshape them into groups of 3 (smooth, noisy, very noisy)
cnn_r2_values = cnn_fitter_info["R^2"].values.reshape(-1, 3)[:,[1,2,0]][[2,3,5,6,0,1,4]]
res_net_r2_values = res_net_fitter_info["R^2"].values.reshape(-1, 3)[:, [1, 2, 0]][[2, 3, 5, 6, 0, 1, 4]]

# Create a bar plot
x = np.arange(len(cnn_r2_values))  # Number of spectrum groups
width = 0.25  # Width of each bar

# Create a bar plot
fig, ax = plt.subplots(figsize=(10, 6), layout="constrained")

# Plot CNNFitter bars
ax.bar(x - width, cnn_r2_values[:, 0], width, label="Spectres non bruités")
ax.bar(x, cnn_r2_values[:, 1],         width, label=r"Spectres légèrement bruités avec $\sigma=0.4$")
ax.bar(x + width, cnn_r2_values[:, 2], width, label=r"Spectres très bruités avec $\sigma=1$")

# Plot ResNetFitter bars on top
ax.bar(x - width, res_net_r2_values[:, 0], width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
ax.bar(x, res_net_r2_values[:, 1],         width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
ax.bar(x + width, res_net_r2_values[:, 2], width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)

# Add labels, title, and legend
ax.set_ylabel("Coefficient de détermination R$^2$ [-]")
ax.set_xticks(x)
ax.set_xticklabels([" ".join(file.split("_")) for file in filenames], rotation=10)
ax.tick_params(axis="both", direction="in")
ax.legend()

# plt.show()

In [None]:
# Create a new figure with two subplots
fig, axes = plt.subplots(2, 1, figsize=(12, 8), layout="constrained")

# Top plot: R²
axes[0].bar(x - width, cnn_r2_values[:, 0], width)
axes[0].bar(x, cnn_r2_values[:, 1],         width)
axes[0].bar(x + width, cnn_r2_values[:, 2], width)
axes[0].bar(x - width, res_net_r2_values[:, 0], width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
axes[0].bar(x, res_net_r2_values[:, 1],         width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
axes[0].bar(x + width, res_net_r2_values[:, 2], width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
axes[0].set_ylabel("Coefficient de détermination R$^2$ [-]")
axes[0].set_xticks(x)
axes[0].set_xticklabels([" ".join(file.split("_")) for file in filenames], rotation=10)
axes[0].tick_params(axis="y", direction="in")

# Bottom plot: MSE
axes[1].bar(x - width, cnn_mse_values[:, 0], width, label="Spectres non bruités")
axes[1].bar(x, cnn_mse_values[:, 1],         width, label=r"Spectres légèrement bruités avec $\sigma=0.4$")
axes[1].bar(x + width, cnn_mse_values[:, 2], width, label=r"Spectres très bruités avec $\sigma=1$")
axes[1].bar(x - width, res_net_mse_values[:, 0], width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
axes[1].bar(x, res_net_mse_values[:, 1],         width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
axes[1].bar(x + width, res_net_mse_values[:, 2], width, hatch="//", facecolor='none', edgecolor='black', linewidth=1.5)
axes[1].set_ylabel("MSE [u. arb.]")
axes[1].set_xticks(x)
axes[1].set_xticklabels([" ".join(file.split("_")) for file in filenames], rotation=10)
axes[1].tick_params(axis="y", direction="in")

for ax, letter in zip(axes, ["a)", "b)"]):
    ax.text(
        0,
        1,
        letter,
        transform=ax.transAxes + ScaledTranslation(-5 / 72, 10 / 72, fig.dpi_scale_trans),
    )

# Create a general legend at the bottom
fig.legend(loc="lower center", ncol=3, bbox_to_anchor=(0.5, -0.04), frameon=False)

plt.savefig("projet/figures/dl_mse_r2_comparison.pdf", dpi=600)
plt.show()

# Fits of merged_twin_gaussians with the two DL methods

In [None]:
np.random.seed(0)

SPEC_FILE = "merged_twin_gaussians/smooth"
N_SAMPLES = 200
cnn_fitter = CNNFitter.load(f"projet/data/neural_networks/CNNFitter/{SPEC_FILE.replace('/', '_')}.pt")
rn_fitter = ResNetFitter.load(f"projet/data/neural_networks/ResNetFitter/{SPEC_FILE.replace('/', '_')}.pt")

spec = Spectrum.load(f"projet/data/spectra/{SPEC_FILE}.txt")
dataset = SpectrumDataset.generate_from_spectrum(spec, N_SAMPLES)
data_loader = DataLoader(dataset, batch_size=1)

cnn_fits = cnn_fitter.predict(data_loader)
rn_fits = rn_fitter.predict(data_loader)

data = dataset.data.squeeze(1)
x_space = np.linspace(1, spec.number_of_channels, 1000)
plots = []
figs = []
for i, (spectrum, cnn_fit, rn_fit, params) in enumerate(zip(data, cnn_fits, rn_fits, dataset.params)):
    if i not in [12, 21, 63]:
        continue
    plottables = [
        gl.Curve(x_space, spec(x_space, params), label="Données réelles" if i == 12 else None, color="black"),
        gl.Curve(x_space, spec.models[0](x_space, *params[0].numpy()), color="black"),
        gl.Curve(x_space, spec.models[1](x_space, *params[1].numpy()), color="black"),
        gl.Curve(x_space, spec(x_space, cnn_fit), label="Prédictions du CNNFitter" if i == 12 else None, 
                 color="red", line_style=":"),
        gl.Curve(x_space, spec.models[0](x_space, *cnn_fit[0].numpy()), color="red", line_style=":"),
        gl.Curve(x_space, spec.models[1](x_space, *cnn_fit[1].numpy()), color="red", line_style=":"),
        gl.Curve(x_space, spec(x_space, rn_fit), label="Prédictions du ResNetFitter" if i == 12 else None, 
                 color="limegreen", line_style=":"),
        gl.Curve(x_space, spec.models[0](x_space, *rn_fit[0].numpy()), color="limegreen", line_style=":"),
        gl.Curve(x_space, spec.models[1](x_space, *rn_fit[1].numpy()), color="limegreen", line_style=":"),
    ]

    # show_plot(*plottables)
    plots.append(plottables)
    fig = gl.Figure()
    fig.add_elements(*plottables)
    figs.append(fig)

figs[1].x_axis_name=  "Numéro du canal [-]"
multifig = gl.MultiFigure.from_row(size=(12, 4), figures=figs)
multifig.y_label = "Intensité [u. arb.]"
# fig = SmartFigure(1, 3, size=(12, 4), x_label="Numéro du canal [-]", y_label="Intensité [u. arb.]", elements=plots)
multifig.show(general_legend=True, legend_cols=3)#, legend_loc="lower center")
multifig.save("projet/figures/merged_twin_gaussians_fits.pdf", dpi=600)