# Performing training (Noise2Noise) with L2-normalized spectra

## Loading libraries and configuration parameters

### Loading libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import json

In [None]:
from ddae1d.paths import PROJECT_ROOT

### Importing locally installed denoising autoencoder

In [None]:
from ddae1d.model import DiamondDAE1D

### Importing parameters from config file

In [None]:
with open("config.json") as f:
    config = json.load(f)

model_params = config["model_params"]
training_params = config["training_params"]
model_filename = config["model_filename"]

## Load data

In [None]:
noisy = np.load(PROJECT_ROOT / "data" / "preprocessed" / "trainset" / "noisy.npy")

In [None]:
print("Noisy data shape:", noisy.shape)

## L2 Normalisation of data before training

In order to keep a consistent scale across all spectra of each point, we use the L2-norm of the last repetition of the noisy spectra for each point. 

In [None]:
# Compute L2 norms along the last axis
l2_norms = np.linalg.norm(noisy[:, -1, :], ord=2, axis=1, keepdims=True)[:, np.newaxis, :]
# Avoid division by zero
l2_norms_safe = np.where(l2_norms == 0, 1, l2_norms)

# Normalize
noisy_normed = noisy / l2_norms_safe

## Perform training

In [None]:
X_train = noisy_normed

In [None]:
print("Building model...")
model = DiamondDAE1D(**model_params)
# print("Model built.")

print("Training model...")
history = model.train_noise2noise(
    X_train,
    **training_params
)
print("Model trained.")

In [None]:
denoised_normed = model.predict(noisy_normed)

In [None]:
denoised = denoised_normed * l2_norms_safe

## Plot examples of denoised spectra vs. noisy spectra 

In [None]:
n_rows = 4
n_cols = 4

In [None]:
fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 14), sharex=True, sharey=True)
axes = axes.flatten()
num_points = noisy.shape[0]
num_reps = noisy.shape[1]

pts = np.random.randint(0, num_points, size=n_rows * n_cols)
reps = np.random.randint(0, num_reps, size=n_rows * n_cols)
for i, ax in enumerate(axes):
    point = pts[i]
    rep = reps[i]
    ax.plot(noisy[point, rep], label='Noisy', color='C0', lw=1, alpha=0.4)
    ax.plot(denoised[point, rep], label='Denoised', color='red', lw=1.5, ls='--')
    ax.set_title(f"Point {point}, Repetition {rep}", fontsize=11)
    ax.tick_params(axis='both', labelsize=10)
    if i % n_cols == 0:
        ax.set_ylabel("Normalized Intensity", fontsize=12)
    if i >= (n_rows - 1) * n_cols:
        ax.set_xlabel("Wavelength", fontsize=12)
    if i == 0:
        ax.legend(fontsize=10)

plt.suptitle(f"Denoised vs Clean vs Noisy Spectra (Random {n_rows * n_cols})", fontsize=22, y=1.02)
plt.tight_layout()
plt.show()

## Save model

In [None]:
model.save_model(PROJECT_ROOT / "models" / model_filename)