In [1]:
import numpy as np
import mne
import matplotlib.pyplot as plt
from scipy.signal import welch, spectrogram
import math
import librosa
import scipy.signal
import os
import warnings
import random
import torch
import torch.nn as nn
import torch.optim as optim
import scipy.signal
from torch.utils.data import Dataset, DataLoader, random_split
from collections import defaultdict
import torchaudio.transforms as T
import torch.nn.functional as F
from scipy.signal import resample
from tqdm import tqdm
from geomloss import SamplesLoss
import os
import time
from tqdm import tqdm
import csv
import geomloss
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
warnings.filterwarnings('ignore')
os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [2]:
base_path = "../features/time_domain_win_cut_pad/"
categories = ["chew", "elpp", "shiv"]
file_limit = 4266  
grouped_files = defaultdict(list)

if not os.path.exists(base_path):
    print(f"Error: Directory '{base_path}' not found.")
    exit()

print(f"Scanning directory: {base_path}\n")

for root, dirs, files in os.walk(base_path):
    folder_name = os.path.basename(root)
    
    if folder_name in categories:
        print(f"Processing folder: {folder_name}, Found {len(files)} files.")

        for file in files:
            if file.endswith(".pt") and len(grouped_files[folder_name]) < file_limit:
                file_path = os.path.join(root, file)
                grouped_files[folder_name].append(file_path)
                print(f"Added {file} to category '{folder_name}'")

Scanning directory: ../features/time_domain_win_cut_pad/

Processing folder: chew, Found 4266 files.
Added C3-CZ_0.0_21.0048.pt to category 'chew'
Added C3-CZ_0.2405_2.5624.pt to category 'chew'
Added C3-CZ_10.4038_12.6573.pt to category 'chew'
Added C3-CZ_100.0431_110.1627.pt to category 'chew'
Added C3-CZ_1002.0443_1005.8908.pt to category 'chew'
Added C3-CZ_1007.3117_1010.1994.pt to category 'chew'
Added C3-CZ_1026.2655_1029.0685.pt to category 'chew'
Added C3-CZ_105.0722_110.9278.pt to category 'chew'
Added C3-CZ_105.7222_115.6.pt to category 'chew'
Added C3-CZ_1067.0648_1079.0259.pt to category 'chew'
Added C3-CZ_1086.1008_1120.6059.pt to category 'chew'
Added C3-CZ_1130.2026_1142.4228.pt to category 'chew'
Added C3-CZ_114.0141_129.7887.pt to category 'chew'
Added C3-CZ_114.0481_119.7789.pt to category 'chew'
Added C3-CZ_115.0957_119.5215.pt to category 'chew'
Added C3-CZ_115.2294_179.9893.pt to category 'chew'
Added C3-CZ_1152.0601_1154.6503.pt to category 'chew'
Added C3-CZ_116.

In [3]:
chew_data = np.stack([torch.load(x, weights_only=False) for x in grouped_files['chew']], axis=0)
elpp_data = np.stack([torch.load(x, weights_only=False) for x in grouped_files['elpp']], axis=0)
shiv_data = np.stack([torch.load(x, weights_only=False) for x in grouped_files['shiv']], axis=0)

In [4]:
print(chew_data.shape)
print(elpp_data.shape)
print(shiv_data.shape)

(4266, 512)
(156, 512)
(2762, 512)


In [5]:
EEG_PATH = 'data'
EOG_ALL_EPOCHS = 'EOG_all_epochs.npy'
EMG_ALL_EPOCHS = 'EMG_all_epochs.npy'
EEG_ALL_EPOCHS = 'EEG_all_epochs.npy'

In [6]:
def show_data_informations(signal, signal_type):
    print(f"Data type {signal_type}:", type(signal))
    print(f"Data shape {signal_type}:", signal.shape)

In [7]:
eog_data = np.load(os.path.join(EEG_PATH, EOG_ALL_EPOCHS))
emg_data = np.load(os.path.join(EEG_PATH, EMG_ALL_EPOCHS))
eeg_data = np.load(os.path.join(EEG_PATH, EEG_ALL_EPOCHS))

show_data_informations(eog_data, 'EOG')
show_data_informations(emg_data, 'EMG')
show_data_informations(eeg_data, 'EEG')

Data type EOG: <class 'numpy.ndarray'>
Data shape EOG: (3400, 512)
Data type EMG: <class 'numpy.ndarray'>
Data shape EMG: (5598, 512)
Data type EEG: <class 'numpy.ndarray'>
Data shape EEG: (4514, 512)


In [8]:
def get_rms(records):
    return math.sqrt(sum([x ** 2 for x in records]) / len(records))

def random_signal(signal, combine_num):
    random_result=[]
    for i in range(combine_num):
        random_num = np.random.permutation(signal.shape[0])
        shuffled_dataset = signal[random_num, :]
        shuffled_dataset = shuffled_dataset.reshape(signal.shape[0],signal.shape[1])
        random_result.append(shuffled_dataset)
    random_result  = np.array(random_result)
    return  random_result

In [9]:
EEG_all_random = np.squeeze(random_signal(signal=eeg_data, combine_num=1))

NOISE_all_random_chew = np.squeeze(random_signal(signal=chew_data, combine_num=1))
NOISE_all_random_elpp = np.squeeze(random_signal(signal=elpp_data, combine_num=1))
NOISE_all_random_shiv = np.squeeze(random_signal(signal=shiv_data, combine_num=1))
NOISE_all_random_eog = np.squeeze(random_signal(signal=eog_data, combine_num=1))
NOISE_all_random_emg = np.squeeze(random_signal(signal=emg_data, combine_num=1))

In [10]:
SNR_dB_chew = np.random.uniform(-7, 2, (eeg_data.shape[0]))
SNR_dB_elpp = np.random.uniform(-7, 2, (eeg_data.shape[0]))
SNR_dB_shiv = np.random.uniform(-7, 2, (eeg_data.shape[0]))
SNR_dB_eog = np.random.uniform(-7, 2, (eeg_data.shape[0]))
SNR_dB_emg = np.random.uniform(-7, 2, (eeg_data.shape[0]))

In [11]:
SNR_chew = 10 ** (0.1 * SNR_dB_chew)
SNR_elpp = 10 ** (0.1 * SNR_dB_elpp)
SNR_shiv = 10 ** (0.1 * SNR_dB_shiv)
SNR_eog = 10 ** (0.1 * SNR_dB_eog)
SNR_emg = 10 ** (0.1 * SNR_dB_emg)

In [12]:
num_eeg_samples = EEG_all_random.shape[0]
num_chew_samples = chew_data.shape[0]
num_elpp_samples = elpp_data.shape[0]
num_shiv_samples = shiv_data.shape[0]
num_eog_samples = eog_data.shape[0]
num_emg_samples = emg_data.shape[0]

In [13]:
needed_repetitions_chew = int(np.ceil(num_eeg_samples / num_chew_samples))
needed_repetitions_elpp = int(np.ceil(num_eeg_samples / num_elpp_samples))
needed_repetitions_shiv = int(np.ceil(num_eeg_samples / num_shiv_samples))
needed_repetitions_eog = int(np.ceil(num_eeg_samples / num_eog_samples))
needed_repetitions_emg = int(np.ceil(num_eeg_samples / num_emg_samples))

In [14]:
NOISE_all_random_chew = random_signal(signal=chew_data, combine_num=needed_repetitions_chew)
NOISE_all_random_elpp = random_signal(signal=elpp_data, combine_num=needed_repetitions_elpp)
NOISE_all_random_shiv = random_signal(signal=shiv_data, combine_num=needed_repetitions_shiv)
NOISE_all_random_eog = random_signal(signal=eog_data, combine_num=needed_repetitions_eog)
NOISE_all_random_emg = random_signal(signal=emg_data, combine_num=needed_repetitions_emg)

NOISE_all_random_chew = NOISE_all_random_chew.reshape(-1, chew_data.shape[1])
NOISE_all_random_elpp = NOISE_all_random_elpp.reshape(-1, elpp_data.shape[1])
NOISE_all_random_shiv = NOISE_all_random_shiv.reshape(-1, shiv_data.shape[1])
NOISE_all_random_eog = NOISE_all_random_eog.reshape(-1, eog_data.shape[1])
NOISE_all_random_emg = NOISE_all_random_emg.reshape(-1, emg_data.shape[1])

In [15]:
if NOISE_all_random_chew.shape[0] < num_eeg_samples:
    extra_needed = num_eeg_samples - NOISE_all_random_chew.shape[0]
    extra_noise = random_signal(signal=chew_data, combine_num=1).reshape(-1, chew_data.shape[1])
    NOISE_all_random_chew = np.concatenate((NOISE_all_random_chew, extra_noise), axis=0)

if NOISE_all_random_elpp.shape[0] < num_eeg_samples:
    extra_needed = num_eeg_samples - NOISE_all_random_elpp.shape[0]
    extra_noise = random_signal(signal=elpp_data, combine_num=1).reshape(-1, elpp_data.shape[1])
    NOISE_all_random_elpp = np.concatenate((NOISE_all_random_elpp, extra_noise), axis=0)
    
if NOISE_all_random_shiv.shape[0] < num_eeg_samples:
    extra_needed = num_eeg_samples - NOISE_all_random_shiv.shape[0]
    extra_noise = random_signal(signal=shiv_data, combine_num=1).reshape(-1, shiv_data.shape[1])
    NOISE_all_random_shiv = np.concatenate((NOISE_all_random_shiv, extra_noise), axis=0)

if NOISE_all_random_eog.shape[0] < num_eeg_samples:
    extra_needed = num_eeg_samples - NOISE_all_random_eog.shape[0]
    extra_noise = random_signal(signal=eog_data, combine_num=1).reshape(-1, eog_data.shape[1])
    NOISE_all_random_eog = np.concatenate((NOISE_all_random_eog, extra_noise), axis=0)

if NOISE_all_random_emg.shape[0] < num_eeg_samples:
    extra_needed = num_eeg_samples - NOISE_all_random_emg.shape[0]
    extra_noise = random_signal(signal=emg_data, combine_num=1).reshape(-1, emg_data.shape[1])
    NOISE_all_random_emg = np.concatenate((NOISE_all_random_emg, extra_noise), axis=0)

In [16]:
NOISE_all_random_chew = NOISE_all_random_chew[:num_eeg_samples]
NOISE_all_random_elpp = NOISE_all_random_elpp[:num_eeg_samples]
NOISE_all_random_shiv = NOISE_all_random_shiv[:num_eeg_samples]
NOISE_all_random_eog = NOISE_all_random_eog[:num_eeg_samples]
NOISE_all_random_emg = NOISE_all_random_emg[:num_eeg_samples]

In [17]:
noiseEEG_CHEW = []
for i in range(num_eeg_samples):
    eeg = EEG_all_random[i]
    noise_chew = NOISE_all_random_chew[i]
    coe_chew = get_rms(eeg) / (get_rms(noise_chew) * SNR_chew[i])
    noise_chew = noise_chew * coe_chew
    noise_eeg_chew = eeg + noise_chew
    noiseEEG_CHEW.append(noise_eeg_chew)

noiseEEG_ELPP = []
for i in range(num_eeg_samples):
    eeg = EEG_all_random[i]
    noise_elpp = NOISE_all_random_elpp[i]
    coe_elpp = get_rms(eeg) / (get_rms(noise_elpp) * SNR_elpp[i])
    noise_elpp = noise_elpp * coe_elpp
    noise_eeg_elpp = eeg + noise_elpp
    noiseEEG_ELPP.append(noise_eeg_elpp)

noiseEEG_SHIV = []
for i in range(num_eeg_samples):
    eeg = EEG_all_random[i]
    noise_shiv = NOISE_all_random_shiv[i]
    coe_shiv = get_rms(eeg) / (get_rms(noise_shiv) * SNR_shiv[i])
    noise_shiv = noise_shiv * coe_shiv
    noise_eeg_shiv = eeg + noise_shiv
    noiseEEG_SHIV.append(noise_eeg_shiv)

noiseEEG_EOG = []
for i in range(num_eeg_samples):
    eeg = EEG_all_random[i]
    noise_eog = NOISE_all_random_eog[i]
    coe_eog = get_rms(eeg) / (get_rms(noise_eog) * SNR_eog[i])
    noise_eog = noise_eog * coe_eog
    noise_eeg_eog = eeg + noise_eog
    noiseEEG_EOG.append(noise_eeg_eog)

noiseEEG_EMG = []
for i in range(num_eeg_samples):
    eeg = EEG_all_random[i]
    noise_emg = NOISE_all_random_emg[i]
    coe_emg = get_rms(eeg) / (get_rms(noise_emg) * SNR_emg[i])
    noise_emg = noise_emg * coe_emg
    noise_eeg_emg = eeg + noise_emg
    noiseEEG_EMG.append(noise_eeg_emg)

In [18]:
noiseEEG_CHEW = np.array(noiseEEG_CHEW)
noiseEEG_ELPP = np.array(noiseEEG_ELPP)
noiseEEG_SHIV = np.array(noiseEEG_SHIV)
noiseEEG_EOG = np.array(noiseEEG_EOG)
noiseEEG_EMG = np.array(noiseEEG_EMG)

In [19]:
EEG_end_standard = EEG_all_random / np.std(EEG_all_random)
noiseEEG_CHEW_standard = noiseEEG_CHEW / np.std(noiseEEG_CHEW)
noiseEEG_ELPP_standard = noiseEEG_ELPP / np.std(noiseEEG_ELPP)
noiseEEG_SHIV_standard = noiseEEG_SHIV / np.std(noiseEEG_SHIV)
noiseEEG_EOG_standard = noiseEEG_EOG / np.std(noiseEEG_EOG)
noiseEEG_EMG_standard = noiseEEG_EMG / np.std(noiseEEG_EMG)

In [20]:
EEG = EEG_end_standard.flatten()
EEG_CHEW = noiseEEG_CHEW_standard.flatten()
EEG_ELPP = noiseEEG_ELPP_standard.flatten()
EEG_SHIV = noiseEEG_SHIV_standard.flatten()
EEG_EOG = noiseEEG_EOG_standard.flatten()
EEG_EMG = noiseEEG_EMG_standard.flatten()

In [21]:
print(f'EEG: {EEG.shape}')
print(f'EEG_CHEW: {EEG_CHEW.shape}')
print(f'EEG_ELPP: {EEG_ELPP.shape}')
print(f'EEG_SHIV: {EEG_SHIV.shape}')
print(f'EEG_EOG: {EEG_EOG.shape}')
print(f'EEG_EMG: {EEG_EMG.shape}')

EEG: (2311168,)
EEG_CHEW: (2311168,)
EEG_ELPP: (2311168,)
EEG_SHIV: (2311168,)
EEG_EOG: (2311168,)
EEG_EMG: (2311168,)


In [22]:
import pandas as pd

EEG_signal = EEG[0:512*20]
EEG_CHEW_signal = EEG_CHEW[0:512*20]
EEG_ELPP_signal = EEG_ELPP[0:512*20]
EEG_SHIV_signal = EEG_SHIV[0:512*20]
EEG_EOG_signal = EEG_EOG[0:512*20]
EEG_EMG_signal = EEG_EMG[0:512*20]

# Lista de señales a alternar
signals_list = [EEG_signal[0:512*20], EEG_CHEW_signal[0:512*20], EEG_ELPP_signal[0:512*20], EEG_SHIV_signal[0:512*20], EEG_EOG_signal[0:512*20], EEG_EMG_signal[0:512*20]]

# Lista de canales
channels = ['FP1', 'FP2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2',
            'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'FZ', 'CZ', 'PZ', 'PZ2']

# Generar el diccionario de señales
signals = {}
for i, channel in enumerate(channels):
    # Seleccionar la señal correspondiente cíclicamente
    signal = signals_list[i % len(signals_list)]
    signals[channel] = signal.copy()  # Copiar la señal para que no se modifique

# Crear un DataFrame
df = pd.DataFrame(signals)

# Guardar en un archivo CSV
df.to_csv('generated_eeg_data.csv', index=False)
print("Data generation complete. Saved as 'generated_eeg_data.csv'.")

Data generation complete. Saved as 'generated_eeg_data.csv'.


In [24]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class AlignedEEGDataset(Dataset):
    def __init__(self, csv_file, channel_name="FP1", window_size=512, n_fft=128, hop_length=64, norm_type="zscore"):
        self.df = pd.read_csv(csv_file)
        self.signal = self.df[channel_name].values.astype(np.float32)
        self.window_size = window_size  # Ventana de 512 para dividir la señal
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.norm_type = norm_type
        self.segments = self.segment_signal()  # Segmentar la señal

    def segment_signal(self):
        """Dividir la señal en ventanas de tamaño 512"""
        segments = []
        for i in range(0, len(self.signal), self.window_size):
            segment = self.signal[i:i + self.window_size]
            if len(segment) == self.window_size:  # Solo usar ventanas completas
                segments.append(segment)
        return segments

    def __len__(self):
        return len(self.segments)

    def apply_stft(self, segment):
        eeg_tensor = torch.tensor(segment, dtype=torch.float32)
        stft_result = torch.stft(eeg_tensor, n_fft=self.n_fft, hop_length=self.hop_length, 
                                 return_complex=True).abs()
        return stft_result

    def normalize_minmax(self, tensor):
        min_val = tensor.min()
        max_val = tensor.max()
        return (tensor - min_val) / (max_val - min_val + 1e-8)

    def normalize_zscore(self, tensor):
        mean = tensor.mean()
        std = tensor.std() + 1e-8
        return (tensor - mean) / std

    def normalize(self, tensor):
        if self.norm_type == "minmax":
            return self.normalize_minmax(tensor)
        elif self.norm_type == "zscore":
            return self.normalize_zscore(tensor)
        else:
            return tensor

    def __getitem__(self, idx):
        segment = self.segments[idx]
        
        x_feat_stft = self.apply_stft(segment)
        
        x_feat_stft = self.normalize(x_feat_stft)
        
        return x_feat_stft.unsqueeze(0)


def get_dataloader(csv_file, channel_name="FP1", batch_size=32, shuffle=False, norm_type="zscore"):
    dataset = AlignedEEGDataset(csv_file, channel_name=channel_name, window_size=512, n_fft=128, hop_length=64, norm_type=norm_type)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader


csv_file = 'generated_eeg_data.csv'
dataloader = get_dataloader(csv_file)

for x_feat in dataloader:
    print(x_feat.shape)
    break


torch.Size([20, 1, 65, 9])


In [25]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms

# Modelo ResNet2D definido anteriormente
class ResidualBlock2D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock2D, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels)
        )
        self.downsample = downsample
        self.relu = nn.ReLU()

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.conv2(out)
        
        if self.downsample:
            residual = self.downsample(x)
        
        out += residual
        return self.relu(out)


class ResNet2D(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet2D, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.layer0 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer1 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer2 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer3 = self._make_layer(block, 512, layers[3], stride=2)
        
        self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride),
                nn.BatchNorm2d(out_channels),
            )
        
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.maxpool(x)

        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.global_avgpool(x)
        x = x.view(x.size(0), -1)

        features = x
        logits = self.fc(x)
        
        return logits, features

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = ResNet2D(ResidualBlock2D, layers=[2, 2, 2, 2], num_classes=6).to(device)
model.load_state_dict(torch.load('best_model_target_stft_da.pt', map_location=device))
model.to(device)
model.eval()

ResNet2D(
  (conv1): Sequential(
    (0): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer0): Sequential(
    (0): ResidualBlock2D(
      (conv1): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (conv2): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (relu): ReLU()
    )
    (1): ResidualBlock2D(
      (conv1): Sequential(
        (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1,

In [36]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn.functional as F

class AlignedEEGDataset(Dataset):
    def __init__(self, csv_file, channel_name, window_size=512, n_fft=128, hop_length=64, norm_type="zscore"):
        self.df = pd.read_csv(csv_file)
        self.signal = self.df[channel_name].values.astype(np.float32)
        self.window_size = window_size
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.norm_type = norm_type
        self.segments = self.segment_signal()

    def segment_signal(self):
        segments = []
        for i in range(0, len(self.signal), self.window_size):
            segment = self.signal[i:i + self.window_size]
            if len(segment) == self.window_size:
                segments.append(segment)
        return segments

    def __len__(self):
        return len(self.segments)

    def apply_stft(self, segment):
        eeg_tensor = torch.tensor(segment, dtype=torch.float32)
        stft_result = torch.stft(eeg_tensor, n_fft=self.n_fft, hop_length=self.hop_length, 
                                 return_complex=True).abs()
        return stft_result

    def normalize(self, tensor):
        mean = tensor.mean()
        std = tensor.std() + 1e-8
        return (tensor - mean) / std

    def __getitem__(self, idx):
        segment = self.segments[idx]
        x_feat_stft = self.apply_stft(segment)
        x_feat_stft = self.normalize(x_feat_stft)
        return x_feat_stft.unsqueeze(0)


def get_dataloader(csv_file, channel_name, batch_size=8, shuffle=False):
    dataset = AlignedEEGDataset(csv_file, channel_name=channel_name)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
    return dataloader


# Cargar modelo preentrenado
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNet2D(ResidualBlock2D, layers=[2, 2, 2, 2], num_classes=6).to(device)
model.load_state_dict(torch.load('best_model_target_stft_da.pt', map_location=device))
model.eval()

# Definir tus canales
channels = ['FP1', 'FP2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2',
            'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'FZ', 'CZ', 'PZ', 'PZ2']

# Cargar el archivo CSV
csv_file = 'generated_eeg_data.csv'

# Inferencia para cada canal por separado
all_predictions = {}

with torch.no_grad():
    for channel in channels:
        print(f"\nProcessing Channel: {channel}")
        
        dataloader = get_dataloader(csv_file, channel_name=channel, batch_size=8)
        
        channel_predictions = []
        
        for x_feat in dataloader:
            x_feat = x_feat.to(device)
            logits, _ = model(x_feat)
            predicted_labels = torch.argmax(logits, dim=1)
            channel_predictions.extend(predicted_labels.cpu().numpy())
        
        # Almacenar predicciones de cada canal
        all_predictions[channel] = channel_predictions
        print(f"Predictions for {channel}: {channel_predictions[:10]}...")  # Imprime las primeras 10 predicciones



Processing Channel: FP1
Predictions for FP1: [np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0)]...

Processing Channel: FP2
Predictions for FP2: [np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(5), np.int64(1), np.int64(1), np.int64(1), np.int64(1), np.int64(1)]...

Processing Channel: F3
Predictions for F3: [np.int64(2), np.int64(2), np.int64(2), np.int64(2), np.int64(2), np.int64(2), np.int64(2), np.int64(5), np.int64(2), np.int64(1)]...

Processing Channel: F4
Predictions for F4: [np.int64(5), np.int64(5), np.int64(5), np.int64(5), np.int64(5), np.int64(5), np.int64(5), np.int64(5), np.int64(5), np.int64(5)]...

Processing Channel: C3
Predictions for C3: [np.int64(3), np.int64(0), np.int64(3), np.int64(3), np.int64(3), np.int64(3), np.int64(3), np.int64(3), np.int64(1), np.int64(3)]...

Processing Channel: C4
Predictions for C4: [np.int64(1), np.int64(4), np.int64(4), np.int64(4), np.int64(