In [2]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from plotting_utils import (
    plot_sd,
    basic_plotting,
    plot_overlapping_signal
)
from scipy.interpolate import interp1d

In [3]:
# Definir seed global
def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["PYTHONHASHSEED"] = str(seed)

In [4]:
# Processo de Ruído (OUProcess)
class OUProcess:
    def __init__(self, sigma_squared, ell, signal_length):
        self.sigma_squared = sigma_squared
        self.ell = ell
        self.signal_length = signal_length

    def sample(self, shape):
        dt = 1 / self.signal_length
        noise = np.random.normal(
            0, np.sqrt(self.sigma_squared * (1 - np.exp(-2 * self.ell * dt))), size=shape
        )
        return torch.tensor(noise, dtype=torch.float32)

In [5]:
# Modelo de Difusão
class DiffusionModel:
    def __init__(self, num_steps, beta_start, beta_end, noise_sampler):
        self.num_steps = num_steps
        self.betas = torch.linspace(beta_start, beta_end, num_steps)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        self.noise_sampler = noise_sampler

    def forward_diffusion(self, data, t):
        noise = self.noise_sampler.sample(data.shape)
        alpha_t = self.alphas_cumprod[t].view(-1, 1, 1)
        noisy_data = (
            torch.sqrt(alpha_t) * data +
            torch.sqrt(1.0 - alpha_t) * noise
        )
        return noisy_data, noise

In [6]:
# Rede Neural de Denoising (CatConv)
class CatConv(nn.Module):
    def __init__(self, signal_length, signal_channel, hidden_channel, kernel_size):
        super(CatConv, self).__init__()
        self.conv1 = nn.Conv1d(signal_channel, hidden_channel, kernel_size, padding=1)
        self.conv2 = nn.Conv1d(hidden_channel, hidden_channel, kernel_size, padding=1)
        self.conv3 = nn.Conv1d(hidden_channel, signal_channel, kernel_size, padding=1)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = self.conv3(x)
        return x

In [7]:
# Dataset para EEG
class EEGDataset(Dataset):
    def __init__(self, data_dir):
        self.data_files = [
            np.load(os.path.join(data_dir, file)) for file in os.listdir(data_dir) if file.endswith('.npy')
        ]

    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, idx):
        data = self.data_files[idx]
        if data.ndim == 1:  # Adicionar dimensão do canal
            data = np.expand_dims(data, axis=0)
        return torch.tensor(data, dtype=torch.float32)

In [8]:
# Função de Treinamento
def train_model(diffusion, denoiser, dataset, num_epochs, batch_size, lr, seed):
    set_seed(seed)  # Definir seed
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = optim.Adam(denoiser.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(num_epochs):
        epoch_loss = 0.0
        for data in dataloader:
            t = torch.randint(0, diffusion.num_steps, (data.shape[0],)).long()
            noisy_data, noise = diffusion.forward_diffusion(data, t)
            optimizer.zero_grad()
            predicted_noise = denoiser(noisy_data)
            loss = criterion(predicted_noise, noise)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}")

In [9]:
# Geração de Dados Sintéticos
def generate_synthetic_data(diffusion, denoiser, sample_shape, num_samples):
    samples = []
    for _ in range(num_samples):
        noisy_sample = torch.randn(sample_shape)
        for t in reversed(range(diffusion.num_steps)):
            noisy_sample = denoiser(noisy_sample)
        samples.append(noisy_sample)
    return torch.stack(samples)

In [10]:
# Função para Salvar e Plotar
def generate_and_save_plots(real_data, synthetic_data, channels, fs, output_dir="./results"):
    os.makedirs(output_dir, exist_ok=True)

    # Garantir que os sinais estão na forma correta
    real_signal = real_data.mean(axis=0).squeeze()  # Sinais médios reais
    synthetic_signal = synthetic_data.mean(axis=0).squeeze()  # Sinais médios sintéticos

    # Recalcular o eixo do tempo com base no tamanho de real_signal
    time = np.linspace(0, len(real_signal) / fs, len(real_signal))

    # Interpolar o sinal sintético para coincidir com o tempo e real_signal
    if len(synthetic_signal) != len(real_signal):
        interp_func = interp1d(
            np.linspace(0, 1, len(synthetic_signal)),
            synthetic_signal,
            kind="linear",
            fill_value="extrapolate",
        )
        synthetic_signal = interp_func(np.linspace(0, 1, len(real_signal)))

    # Verificação de dimensões
    if len(time) != len(real_signal) or len(time) != len(synthetic_signal):
        raise ValueError(
            f"Dimensões incompatíveis: time ({len(time)}), "
            f"real_signal ({len(real_signal)}), "
            f"synthetic_signal ({len(synthetic_signal)})"
        )

    # Plotagem básica: Real vs Sintético
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.plot(time, real_signal, label="Real Data")
    ax.plot(time, synthetic_signal, label="Synthetic Data")
    basic_plotting(
        fig,
        ax,
        x_label="Time (s)",
        y_label="Amplitude",
    )
    plt.legend(loc="best")
    plt.title("Real vs Synthetic Data")
    plt.savefig(f"{output_dir}/basic_plotting.png")
    plt.close(fig)

    # Plotagem de densidade espectral
    fig, ax = plt.subplots(figsize=(12, 6))
    plot_sd(
        fig,
        ax,
        arr_one=real_data.squeeze(1),
        arr_two=synthetic_data.squeeze(1),
        fs=fs,
        nperseg=256,
        with_quantiles=True
    )
    plt.title("Spectral Density Comparison")
    plt.savefig(f"{output_dir}/spectral_density_comparison.png")
    plt.close(fig)

    # Plotagem de sinais sobrepostos
    fig, ax = plt.subplots(figsize=(12, 6))
    plot_overlapping_signal(fig, ax, sig=synthetic_data[0, :, :])
    plt.title("Overlapping Signals - Synthetic Data")
    plt.savefig(f"{output_dir}/overlapping_signal.png")
    plt.close(fig)

    print(f"Gráficos salvos na pasta: {output_dir}")

In [11]:
# Main Script
if __name__ == "__main__":
    # Parâmetros
    num_steps = 100
    beta_start = 1e-4
    beta_end = 0.02
    signal_length = 256
    signal_channel = 1
    hidden_channel = 64
    kernel_size = 3
    num_epochs = 10
    batch_size = 32
    lr = 1e-3
    num_samples = 5
    seed = 42  # Seed global

    # Configuração de Seed
    set_seed(seed)

    # Processo de Ruído
    noise_sampler = OUProcess(sigma_squared=0.02, ell=0.1, signal_length=signal_length)

    # Modelo de Difusão
    diffusion = DiffusionModel(num_steps, beta_start, beta_end, noise_sampler)

    # Rede de Denoising
    denoiser = CatConv(
        signal_length=signal_length,
        signal_channel=signal_channel,
        hidden_channel=hidden_channel,
        kernel_size=kernel_size,
    )

    # Dataset
    dataset = EEGDataset(data_dir='/Users/analuiza/Documents/codes/templedata/00007656_s010_t000_processed_data')  # Substitua pelo caminho correto

    # Treinamento
    train_model(diffusion, denoiser, dataset, num_epochs=num_epochs, batch_size=batch_size, lr=lr, seed=seed)

    # Geração de Dados Sintéticos
    synthetic_data = generate_synthetic_data(diffusion, denoiser, sample_shape=(signal_channel, signal_length), num_samples=num_samples)

    # Plotagens
    real_data = torch.stack([dataset[i] for i in range(3)]).numpy()
    synthetic_data_np = synthetic_data.detach().numpy()
    generate_and_save_plots(real_data, synthetic_data_np, channels=["Fp1", "Fp2", "C3"], fs=256, output_dir="./results")

Epoch 1/10, Loss: 0.0087
Epoch 2/10, Loss: 0.0037
Epoch 3/10, Loss: 0.0036
Epoch 4/10, Loss: 0.0027
Epoch 5/10, Loss: 0.0012
Epoch 6/10, Loss: 0.0003
Epoch 7/10, Loss: 0.0005
Epoch 8/10, Loss: 0.0009
Epoch 9/10, Loss: 0.0010
Epoch 10/10, Loss: 0.0009
Gráficos salvos na pasta: ./results
