In [30]:
import os
import numpy as np
import torch

# Check if GPU is available and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"‚úÖ Using device: {device}")

# Set paths for clean and contaminated EEG data
clean_dir = "/home/tulgaa/Desktop/eeg_last/eeg_last/Extracted_Bands/Clean"
contaminated_dir = "/home/tulgaa/Desktop/eeg_last/eeg_last/Extracted_Bands/Contaminated"

# Define EEG frequency bands
bands = ["Delta_band", "Theta_band", "Alpha_band", "Beta_band", "Gamma_band", "High_Frequencies_band"]

print(f"‚úÖ Clean EEG Directory: {clean_dir}")
print(f"‚úÖ Contaminated EEG Directory: {contaminated_dir}")
print(f"‚úÖ EEG Bands: {bands}")


‚úÖ Using device: cuda
‚úÖ Clean EEG Directory: /home/tulgaa/Desktop/eeg_last/eeg_last/Extracted_Bands/Clean
‚úÖ Contaminated EEG Directory: /home/tulgaa/Desktop/eeg_last/eeg_last/Extracted_Bands/Contaminated
‚úÖ EEG Bands: ['Delta_band', 'Theta_band', 'Alpha_band', 'Beta_band', 'Gamma_band', 'High_Frequencies_band']


In [31]:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.stop_training = False

    def __call__(self, current_loss):
        if self.best_loss is None or current_loss < self.best_loss - self.min_delta:
            self.best_loss = current_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.stop_training = True

In [32]:
class AdvancedLSTM(nn.Module):
    def __init__(self, input_length):
        super(AdvancedLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size=1, hidden_size=512, num_layers=3, batch_first=True, dropout=0.3)
        self.layer_norm = nn.LayerNorm(512)
        self.fc = nn.Linear(512, input_length)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Reshape to (batch_size, seq_length, channels)
        x, _ = self.lstm(x)  # LSTM output
        x = self.layer_norm(x[:, -1, :])  # Apply layer normalization on the last hidden state
        x = self.fc(x)
        return x

In [33]:
def normalize_signals(data):
    """Normalize the signals to zero mean and unit variance per sample."""
    return (data - np.mean(data, axis=1, keepdims=True)) / np.std(data, axis=1, keepdims=True)
bands = ["Delta_band", "Theta_band", "Alpha_band", "Beta_band", "Gamma_band", "High_Frequencies_band"]
clean_dir = "/home/tulgaa/Desktop/eeg_last/eeg_last/Extracted_Bands/Clean"
contaminated_dir = "/home/tulgaa/Desktop/eeg_last/eeg_last/Extracted_Bands/Contaminated"

In [40]:
from sklearn.model_selection import train_test_split

def load_and_split_data(band_name):
    """
    Loads EEG data for a specific frequency band, splits into train and test sets based on SNR levels.
    """
    print(f"\nüìå Loading and Splitting Data for {band_name}...")

    # Load clean EEG data
    clean_band_path = os.path.join(clean_dir, f"{band_name}.npy")
    clean_band = np.load(clean_band_path)

    # Load contaminated EEG data across all SNR levels
    contaminated_band = []
    snr_labels = []

    for snr_folder in sorted(os.listdir(contaminated_dir)):  # Iterate through SNR folders
        contaminated_band_path = os.path.join(contaminated_dir, snr_folder, f"{band_name}.npy")
        contaminated_data = np.load(contaminated_band_path)

        contaminated_band.append(contaminated_data)
        snr_labels.extend([snr_folder] * len(contaminated_data))  # Track SNR levels

    # Convert lists to numpy arrays
    contaminated_band = np.concatenate(contaminated_band, axis=0)
    # Convert string labels ("SNR_-7") to integers (-7)
    
    snr_labels = np.array([int(snr.replace("SNR_", "")) for snr in snr_labels])


    # Ensure clean_band is correctly repeated to match contaminated EEG samples
    clean_band_repeated = np.tile(clean_band, (len(contaminated_band) // len(clean_band) + 1, 1))[:len(contaminated_band)]

    # Stratified Train-Test Split
    train_clean, test_clean, train_contaminated, test_contaminated, snr_labels_train, snr_labels_test = train_test_split(
        clean_band_repeated, contaminated_band, snr_labels, test_size=0.2, stratify=snr_labels, random_state=42
    )

    # Inspect data shapes
    print(f"‚úÖ {band_name} Train Clean Shape: {train_clean.shape}, Test Clean Shape: {test_clean.shape}")
    print(f"‚úÖ {band_name} Train Contaminated Shape: {train_contaminated.shape}, Test Contaminated Shape: {test_contaminated.shape}")
    print(f"‚úÖ {band_name} SNR Labels Train: {np.unique(snr_labels_train)}, Test: {np.unique(snr_labels_test)}")

    return train_clean, train_contaminated, test_clean, test_contaminated, snr_labels_train, snr_labels_test

# Load and split data for each EEG band
train_data = {}
test_data = {}

for band in bands:
    train_clean, train_contaminated, test_clean, test_contaminated, snr_labels_train, snr_labels_test = load_and_split_data(band)
    train_data[band] = (train_clean, train_contaminated)
    test_data[band] = (test_clean, test_contaminated, snr_labels_test)

print("\n‚úÖ All EEG bands successfully loaded and split!")



üìå Loading and Splitting Data for Delta_band...
‚úÖ Delta_band Train Clean Shape: (27200, 512), Test Clean Shape: (6800, 512)
‚úÖ Delta_band Train Contaminated Shape: (27200, 512), Test Contaminated Shape: (6800, 512)
‚úÖ Delta_band SNR Labels Train: [-7 -6 -5 -4 -3 -2 -1  0  1  2], Test: [-7 -6 -5 -4 -3 -2 -1  0  1  2]

üìå Loading and Splitting Data for Theta_band...
‚úÖ Theta_band Train Clean Shape: (27200, 512), Test Clean Shape: (6800, 512)
‚úÖ Theta_band Train Contaminated Shape: (27200, 512), Test Contaminated Shape: (6800, 512)
‚úÖ Theta_band SNR Labels Train: [-7 -6 -5 -4 -3 -2 -1  0  1  2], Test: [-7 -6 -5 -4 -3 -2 -1  0  1  2]

üìå Loading and Splitting Data for Alpha_band...
‚úÖ Alpha_band Train Clean Shape: (27200, 512), Test Clean Shape: (6800, 512)
‚úÖ Alpha_band Train Contaminated Shape: (27200, 512), Test Contaminated Shape: (6800, 512)
‚úÖ Alpha_band SNR Labels Train: [-7 -6 -5 -4 -3 -2 -1  0  1  2], Test: [-7 -6 -5 -4 -3 -2 -1  0  1  2]

üìå Loading and Splitti

In [45]:
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import torch
def train_and_denoise_band(train_clean, train_contaminated, test_clean, test_contaminated, band_name):
    print(f"\nüìå {band_name} Dataset Shapes:")
    print(f"  Train Clean: {train_clean.shape}, Train Contaminated: {train_contaminated.shape}")
    print(f"  Test Clean: {test_clean.shape}, Test Contaminated: {test_contaminated.shape}")

    train_clean = torch.tensor(train_clean, dtype=torch.float32).unsqueeze(1).to(device)
    test_clean = torch.tensor(test_clean, dtype=torch.float32).unsqueeze(1).to(device)
    train_contaminated = torch.tensor(train_contaminated, dtype=torch.float32).unsqueeze(1).to(device)
    test_contaminated = torch.tensor(test_contaminated, dtype=torch.float32).unsqueeze(1).to(device)

    train_dataset = TensorDataset(train_contaminated, train_clean)
    test_dataset = TensorDataset(test_contaminated, test_clean)
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

    input_length = train_contaminated.shape[2]
    model = AdvancedLSTM(input_length).to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.MSELoss()
    scaler = torch.cuda.amp.GradScaler()

    early_stopping = EarlyStopping(patience=10, min_delta=1e-4)

    num_epochs = 1000
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            with torch.cuda.amp.autocast():  # Mixed precision
                outputs = model(inputs)
                loss = criterion(outputs, targets.squeeze(1))
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)

        model.eval()
        test_loss = 0.0
        with torch.no_grad():
            for inputs, targets in test_loader:
                with torch.cuda.amp.autocast():
                    outputs = model(inputs)
                    loss = criterion(outputs, targets.squeeze(1))
                test_loss += loss.item()

        avg_test_loss = test_loss / len(test_loader)

        print(f"Epoch {epoch + 1}: Train Loss = {avg_train_loss:.4f}, Test Loss = {avg_test_loss:.4f}")

        early_stopping(avg_test_loss)
        if early_stopping.stop_training:
            print(f"üöÄ Early stopping triggered for {band_name} at epoch {epoch + 1}")
            break

    # Process test data in smaller batches to reduce memory usage
    with torch.no_grad():
        denoised_batches = []
        batch_size = 256  # Adjust this batch size based on GPU memory
        for i in range(0, len(test_contaminated), batch_size):
            batch = test_contaminated[i:i+batch_size]
            denoised_batch = model(batch).squeeze(1).cpu().numpy()
            denoised_batches.append(denoised_batch)
        denoised_band = np.concatenate(denoised_batches, axis=0)

    return denoised_band, test_clean.cpu().numpy(), test_contaminated.cpu().numpy()


In [46]:
denoised_bands = []
clean_bands = []
contaminated_bands = []

for band in bands:
    print(f"üöÄ Processing {band}...")

    # Load and split data
    train_clean, train_contaminated, test_clean, test_contaminated, _, _ = load_and_split_data(band)

    # Normalize signals
    train_clean = normalize_signals(train_clean)
    train_contaminated = normalize_signals(train_contaminated)
    test_clean = normalize_signals(test_clean)
    test_contaminated = normalize_signals(test_contaminated)

    # Train the model and denoise the test set
    denoised_band, clean_band_test, contaminated_band_test = train_and_denoise_band(
        train_clean, train_contaminated, test_clean, test_contaminated, band
    )

    denoised_bands.append(denoised_band)
    clean_bands.append(clean_band_test)
    contaminated_bands.append(contaminated_band_test)


üöÄ Processing Delta_band...

üìå Loading and Splitting Data for Delta_band...
‚úÖ Delta_band Train Clean Shape: (27200, 512), Test Clean Shape: (6800, 512)
‚úÖ Delta_band Train Contaminated Shape: (27200, 512), Test Contaminated Shape: (6800, 512)
‚úÖ Delta_band SNR Labels Train: [-7 -6 -5 -4 -3 -2 -1  0  1  2], Test: [-7 -6 -5 -4 -3 -2 -1  0  1  2]

üìå Delta_band Dataset Shapes:
  Train Clean: (27200, 512), Train Contaminated: (27200, 512)
  Test Clean: (6800, 512), Test Contaminated: (6800, 512)
Epoch 1: Train Loss = 0.9712, Test Loss = 0.9545
Epoch 2: Train Loss = 0.9575, Test Loss = 0.9532
Epoch 3: Train Loss = 0.9533, Test Loss = 0.9434
Epoch 4: Train Loss = 0.9493, Test Loss = 0.9415
Epoch 5: Train Loss = 0.9468, Test Loss = 0.9397
Epoch 6: Train Loss = 0.9401, Test Loss = 0.9432
Epoch 7: Train Loss = 0.9344, Test Loss = 0.9204
Epoch 8: Train Loss = 0.9226, Test Loss = 0.9067
Epoch 9: Train Loss = 0.9068, Test Loss = 0.8901
Epoch 10: Train Loss = 0.8931, Test Loss = 0.8838
E

In [49]:
import numpy as np

def compute_rrmse(true, pred):
    """Compute Relative Root Mean Square Error (RRMSE)."""
    return np.sqrt(np.mean((true - pred) ** 2)) / np.sqrt(np.mean(true ** 2))

def compute_cc(true, pred):
    """Compute Correlation Coefficient (CC)."""
    true_mean = np.mean(true)
    pred_mean = np.mean(pred)
    return np.corrcoef(true.flatten(), pred.flatten())[0, 1]

def compute_ts_metric(rrmse_t, rrmse_s):
    """Compute T&S Metric in dB."""
    return 10 * np.log10(rrmse_t + rrmse_s)


In [53]:
import numpy as np

# Check if the denoised EEG data exists
try:
    print(f"‚úÖ Number of EEG bands processed: {len(denoised_bands)}")
    print(f"‚úÖ Type of first denoised band: {type(denoised_bands[0])}")
    print(f"‚úÖ Shape of first denoised band: {denoised_bands[0].shape}")
except:
    print("‚ùå `denoised_bands` is empty or not defined. You may need to retrain.")


‚úÖ Number of EEG bands processed: 1
‚úÖ Type of first denoised band: <class 'ellipsis'>
‚ùå `denoised_bands` is empty or not defined. You may need to retrain.
