In [15]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from scipy.signal import welch
import matplotlib.pyplot as plt

In [3]:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.stop_training = False

    def __call__(self, current_loss):
        if self.best_loss is None or current_loss < self.best_loss - self.min_delta:
            self.best_loss = current_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.stop_training = True

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

Using device: cuda


In [5]:
class AdvancedLSTM(nn.Module):
    def __init__(self, input_length):
        super(AdvancedLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=512, num_layers=3, batch_first=True, dropout=0.3)
        self.layer_norm = nn.LayerNorm(512)
        self.fc = nn.Linear(512, input_length)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Reshape to (batch_size, seq_length, channels)
        x, _ = self.lstm(x)  # LSTM output
        x = self.layer_norm(x[:, -1, :])  # Apply layer normalization on the last hidden state
        x = self.fc(x)
        return x

In [6]:
def compute_rrmse_t(clean, denoised):
    per_sample_rrmse = np.sqrt(np.sum((clean - denoised) ** 2, axis=1) / np.sum(clean ** 2, axis=1))
    return np.mean(per_sample_rrmse)

def compute_rrmse_s(clean, denoised, fs):
    f_clean, psd_clean = welch(clean.flatten(), fs=fs, nperseg=512)
    f_denoised, psd_denoised = welch(denoised.flatten(), fs=fs, nperseg=512)
    return np.sqrt(np.sum((psd_clean - psd_denoised) ** 2) / np.sum(psd_clean ** 2))

def compute_cc(clean, denoised):
    return np.corrcoef(clean.flatten(), denoised.flatten())[0, 1]


In [7]:
def normalize_signals(data):
    """Normalize the signals to zero mean and unit variance per sample."""
    return (data - np.mean(data, axis=1, keepdims=True)) / np.std(data, axis=1, keepdims=True)
bands = ["Delta_band", "Theta_band", "Alpha_band", "Beta_band", "Gamma_band", "High_Frequencies_band"]
clean_dir = "/home/tulgaa/Desktop/eeg_last/eeg_last/Extracted_Bands/Clean"
contaminated_dir = "/home/tulgaa/Desktop/eeg_last/eeg_last/Extracted_Bands/Contaminated"

In [8]:
def plot_signals(contaminated, denoised, ground_truth, sample_indices, title):
    num_samples = len(sample_indices)
    plt.figure(figsize=(12, 4 * num_samples))
    for i, idx in enumerate(sample_indices):
        plt.subplot(num_samples, 1, i + 1)
        plt.plot(contaminated[idx], label="Contaminated Signal", alpha=0.6)
        plt.plot(denoised[idx], label="Denoised Signal", alpha=0.8)
        plt.plot(ground_truth[idx], label="Ground Truth Signal", alpha=0.8)
        plt.title(f"{title} - Sample {idx}")
        plt.xlabel("Time Points")
        plt.ylabel("Amplitude")
        plt.legend()
    plt.tight_layout()
    plt.show()

In [9]:
def train_and_denoise_band(clean_band, contaminated_band, band_name):
    split_idx = int(0.8 * len(clean_band))
    train_clean = clean_band[:split_idx]
    test_clean = clean_band[split_idx:]
    train_contaminated = contaminated_band[:split_idx]
    test_contaminated = contaminated_band[split_idx:]

    print(f"\n{band_name} Dataset Shapes:")
    print(f"  Train Clean: {train_clean.shape}, Train Contaminated: {train_contaminated.shape}")
    print(f"  Test Clean: {test_clean.shape}, Test Contaminated: {test_contaminated.shape}")

    train_clean = torch.tensor(train_clean, dtype=torch.float32).unsqueeze(1).to(device)
    test_clean = torch.tensor(test_clean, dtype=torch.float32).unsqueeze(1).to(device)
    train_contaminated = torch.tensor(train_contaminated, dtype=torch.float32).unsqueeze(1).to(device)
    test_contaminated = torch.tensor(test_contaminated, dtype=torch.float32).unsqueeze(1).to(device)

    train_dataset = TensorDataset(train_contaminated, train_clean)
    test_dataset = TensorDataset(test_contaminated, test_clean)
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

    input_length = train_contaminated.shape[2]
    model = AdvancedLSTM(input_length).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    scaler = torch.cuda.amp.GradScaler()

    early_stopping = EarlyStopping(patience=10, min_delta=1e-4)

    num_epochs = 1000
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():  # Mixed precision
                outputs = model(inputs)
                loss = criterion(outputs, targets.squeeze(1))
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)

        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for inputs, targets in test_loader:
                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, targets.squeeze(1))
                test_loss += loss.item()

        avg_test_loss = test_loss / len(test_loader)

        print(f"Epoch {epoch + 1}: Train Loss = {avg_train_loss:.4f}, Test Loss = {avg_test_loss:.4f}")

        early_stopping(avg_test_loss)
        if early_stopping.stop_training:
            print(f"Early stopping triggered for {band_name} at epoch {epoch + 1}")
            break

    # Process test data in smaller batches to reduce memory usage
    with torch.no_grad():
        denoised_batches = []
        batch_size = 256  # Adjust this batch size based on GPU memory
        for i in range(0, len(test_contaminated), batch_size):
            batch = test_contaminated[i:i+batch_size]
            denoised_batch = model(batch).squeeze(1).cpu().numpy()
            denoised_batches.append(denoised_batch)
        denoised_band = np.concatenate(denoised_batches, axis=0)

    return denoised_band, test_clean.cpu().numpy(), test_contaminated.cpu().numpy()


In [10]:
denoised_bands = []
clean_bands = []
contaminated_bands = []

for band in bands:
    print(f"Processing {band}...")
    clean_band_path = os.path.join(clean_dir, f"{band}.npy")
    contaminated_band = []

    for snr_folder in sorted(os.listdir(contaminated_dir)):
        contaminated_band_path = os.path.join(contaminated_dir, snr_folder, f"{band}.npy")
        contaminated_band.append(np.load(contaminated_band_path))
    contaminated_band = np.concatenate(contaminated_band, axis=0)

    clean_band = np.load(clean_band_path)
    clean_band_repeated = np.tile(clean_band, (len(contaminated_band) // len(clean_band) + 1, 1))[:len(contaminated_band)]

    # Normalize signals
    clean_band_repeated = normalize_signals(clean_band_repeated)
    contaminated_band = normalize_signals(contaminated_band)

    denoised_band, clean_band_test, contaminated_band_test = train_and_denoise_band(clean_band_repeated, contaminated_band, band)
    denoised_bands.append(denoised_band)
    clean_bands.append(clean_band_test)
    contaminated_bands.append(contaminated_band_test)


Processing Delta_band...

Delta_band Dataset Shapes:
  Train Clean: (27200, 512), Train Contaminated: (27200, 512)
  Test Clean: (6800, 512), Test Contaminated: (6800, 512)
Epoch 1: Train Loss = 0.9753, Test Loss = 0.9401
Epoch 2: Train Loss = 0.9625, Test Loss = 0.9316
Epoch 3: Train Loss = 0.9609, Test Loss = 0.9362
Epoch 4: Train Loss = 0.9588, Test Loss = 0.9304
Epoch 5: Train Loss = 0.9583, Test Loss = 0.9252
Epoch 6: Train Loss = 0.9560, Test Loss = 0.9178
Epoch 7: Train Loss = 0.9523, Test Loss = 0.9096
Epoch 8: Train Loss = 0.9519, Test Loss = 0.9155
Epoch 9: Train Loss = 0.9514, Test Loss = 0.9136
Epoch 10: Train Loss = 0.9481, Test Loss = 0.9008
Epoch 11: Train Loss = 0.9515, Test Loss = 0.9217
Epoch 12: Train Loss = 0.9512, Test Loss = 0.9138
Epoch 13: Train Loss = 0.9362, Test Loss = 0.8887
Epoch 14: Train Loss = 0.9180, Test Loss = 0.8613
Epoch 15: Train Loss = 0.9009, Test Loss = 0.8234
Epoch 16: Train Loss = 0.8836, Test Loss = 0.8080
Epoch 17: Train Loss = 0.8610, Test 

In [11]:
print("\nReconstructing EEG signals...")
clean_bands = np.array([np.squeeze(band, axis=1) for band in clean_bands])
contaminated_bands = np.array([np.squeeze(band, axis=1) for band in contaminated_bands])
denoised_eeg = np.sum(denoised_bands, axis=0)
clean_eeg = np.sum(clean_bands, axis=0)
contaminated_eeg = np.sum(contaminated_bands, axis=0)

print(f"Reconstructed EEG Shapes - Clean: {clean_eeg.shape}, Denoised: {denoised_eeg.shape}, Contaminated: {contaminated_eeg.shape}")


Reconstructing EEG signals...
Reconstructed EEG Shapes - Clean: (6800, 512), Denoised: (6800, 512), Contaminated: (6800, 512)


In [12]:
print("\nCalculating evaluation metrics...")
rrmse_t = compute_rrmse_t(clean_eeg, denoised_eeg)
rrmse_s = compute_rrmse_s(clean_eeg, denoised_eeg, fs=512)
cc = compute_cc(clean_eeg, denoised_eeg)

print("\nEvaluation Metrics for Full EEG Denoising:")
print(f"  RRMSE_t: {rrmse_t:.4f}")
print(f"  RRMSE_s: {rrmse_s:.4f}")
print(f"  CC: {cc:.4f}")


Calculating evaluation metrics...

Evaluation Metrics for Full EEG Denoising:
  RRMSE_t: 0.4633
  RRMSE_s: 0.3792
  CC: 0.8787
