In [None]:
import os
import numpy as np
import ramanchada2 as rc2
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# Assuming you have a dataset of Raman spectra with x and y values
# Replace this with your actual dataset loading code

def load_raman_spectra(root_data_folder,subset,data=[]):
    for filename in os.listdir(os.path.join(root_data_folder,subset)):
        spe = rc2.spectrum.from_local_file(os.path.join(root_data_folder,subset,filename))
        if len(spe.x)!=1600:
            break
        data.append(spe)
    return data

In [None]:
path = "D:\\nina\OneDrive\\CHARISMA - RUNTIME\\WP 3 - TF3_RoundRobin\\Round Robin 1 internal\\Data Raman RR 1.1\\LBF"
# Load Raman spectra data
raman_data = load_raman_spectra(path,"PST")
ax = None
for spe in raman_data:
    ax = spe.plot(ax=ax)    


In [None]:
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
#from denoising_diffusion_pytorch.losses import gaussian_nll
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt



In [None]:
# Convert NumPy arrays to PyTorch tensors
def convert_to_tensor(data):
    return torch.tensor(data).float()

# Create DataLoader for batch processing
batch_size = 16  # Adjust as needed

# Convert the data to PyTorch tensors
x_data = [torch.tensor(spec.x).float() for spec in raman_data]
y_data = [torch.tensor(spec.y).float() for spec in raman_data]
data_loader = DataLoader(TensorDataset(torch.stack(x_data), torch.stack(y_data)), batch_size=batch_size, shuffle=True)

num_channels = x_data[0].shape[0]
# Initialize the U-Net model
model = Unet(
    dim=num_channels,
    dim_mults=(1, 2, 4, 8),
    flash_attn=True
)

# Initialize the GaussianDiffusion model
image_size = 128
timesteps = 1000
diffusion = GaussianDiffusion(
    model,
    image_size=image_size,
    timesteps=timesteps
)

# Training parameters
num_epochs = 10
lr = 0.001

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Training loop
for epoch in range(num_epochs):
    total_loss = 0.0
    for x, y in data_loader:
        optimizer.zero_grad()

        # Forward pass using the diffusion model for denoising
        y_pred = diffusion(x)

        # Compute loss
        loss = torch.nn.functional.mse_loss(y_pred, y)  # assuming mean squared error loss

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(data_loader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss:.4f}')

# Example: Denoise a single spectrum
example_spectrum = raman_data[0]
x_example, y_example = torch.tensor(example_spectrum.x).float(), torch.tensor(example_spectrum.y).float()

# Add noise to the spectrum
noisy_spectrum = y_example + torch.randn_like(y_example) * 0.1

# Convert to PyTorch tensor and add batch dimension
noisy_spectrum = noisy_spectrum.unsqueeze(0).unsqueeze(1)

# Denoise the spectrum using the trained diffusion model
denoised_spectrum = diffusion.sample(noisy_spectrum)

# Plot the results
plt.plot(x_example.numpy(), y_example.numpy(), label='Original Spectrum')
plt.plot(x_example.numpy(), noisy_spectrum.squeeze().detach().numpy(), label='Noisy Spectrum')
plt.plot(x_example.numpy(), denoised_spectrum.squeeze().detach().numpy(), label='Denoised Spectrum')
plt.legend()
plt.show()

In [None]:
# Example: Denoise a single spectrum
example_spectrum = raman_data[0]
x_example, y_example = example_spectrum

# Add noise to the spectrum
noisy_spectrum = y_example + torch.randn_like(y_example) * 0.1

# Convert to PyTorch tensor and add batch and channel dimensions
noisy_spectrum = noisy_spectrum.unsqueeze(0).unsqueeze(1)

# Denoise the spectrum
denoised_spectrum = model(noisy_spectrum)


In [None]:

# Plot the results
plt.plot(x_example.numpy(), y_example.numpy(), label='Original Spectrum')
plt.plot(x_example.numpy(), noisy_spectrum.squeeze().detach().numpy(), label='Noisy Spectrum')
plt.plot(x_example.numpy(), denoised_spectrum.squeeze().detach().numpy(), label='Denoised Spectrum')
plt.legend()
plt.show()