In [22]:
import os
path = "/home/marta/Documenti/eeg-ml-thesis/"
os.chdir(path)

In [23]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
import torch.optim as optim 
from torch.utils.data import Dataset
import numpy as np 
from sklearn.preprocessing import StandardScaler
import r_pca 
import scipy.io


In [4]:
def create_dataset(subject_list, window, overlap, num_columns=16, num_classes=2):

    x_data = np.empty((0, window, num_columns))
    y_data = np.empty((0, 1))  # Labels
    subj_inputs = []  # Tracks number of windows per subject
    
    dataset_dir = '/home/marta/Documenti/eeg_rnn_repo/rnn-eeg-ad/eeg2'
    # print('\n### Creating dataset')
    tot_rows = 0
    
    # for subject_id, category_label in subject_list:
    subject_id = subject_list[0]
    category_label = subject_list[1]
    
    # print(f"aaaaaaaaaaaaaa{subject_id}")
    # print(f"bbbbbbbbbbbbb{category_label}")
    subj_inputs.append(0)  # Initialize window count for this subject
    
    # Load EEG data
    file_path = f"{dataset_dir}/S{subject_id}_{category_label}.npz"
    eeg = np.load(file_path)['eeg'].T  # Transpose if necessary to get [samples, channels]
    
    # Scale EEG data
    scaler = StandardScaler()
    eeg = scaler.fit_transform(eeg)
    
    assert eeg.shape[1] == num_columns, f"Expected {num_columns} channels, got {eeg.shape[1]}"
    
    # Calculate number of windows
    num_windows = 0
    i = 0
    while i + window <= len(eeg):
        i += (window - overlap)
        num_windows += 1
    
    # Preallocate x_data for this subject
    x_data_part = np.empty((num_windows, window, num_columns))
    
    # Extract windows
    i = 0
    for w in range(num_windows):
        x_data_part[w] = eeg[i:i + window]
        i += (window - overlap)
    
    # Update x_data and y_data
    x_data = np.vstack((x_data, x_data_part))
    y_data = np.vstack((y_data, np.full((num_windows, 1), (category_label == 'AD'))))  # Binary label
    subj_inputs[-1] = num_windows
    tot_rows += len(eeg)
    
    # print(f"Total samples: {tot_rows}")
    # print(f"x_data shape: {x_data.shape}")
    # print(f"y_data shape: {y_data.shape}")
    # print(f"Windows per subject: {subj_inputs}")
    # print(f"Class distribution: {[np.sum(y_data == cl) for cl in range(num_classes)]}")
    
    return x_data, y_data, subj_inputs


In [2]:
from torch.utils.data import Dataset, DataLoader

In [24]:
def pca_reduction(A, tol, comp = 0):
  rpca = False
  rpca_mu = 0
  multiscale_pca = False

  assert(len(A.shape) == 2)
  dmin = min(A.shape)
  if rpca:
    r = r_pca.R_pca(A, mu = rpca_mu)
    print('Auto tol:', 1e-7 * r.frobenius_norm(r.D), 'used tol:', tol)
    print('mu', r.mu, 'lambda', r.lmbda)
    L, S = r.fit(tol = tol, max_iter = 10, iter_print = 1)
    global norm_s
    norm_s = np.linalg.norm(S, ord='fro')  # for debug
    print('||A,L,S||:', np.linalg.norm(A, ord='fro'), np.linalg.norm(L, ord='fro'), np.linalg.norm(S, ord='fro'))
    #np.savez_compressed('rpca.npz', pre = A, post = L)
  elif multiscale_pca:
    print('MSPCA...')
    #ms = mspca.MultiscalePCA()
    #L = ms.fit_transform(A, wavelet_func='sym4', threshold=0.1, scale = True )
    print('saving MAT file and calling Matlab...')
    scipy.io.savemat('mspca.mat', {'A': A}, do_compression = True)
    os.system('matlab -batch "mspca(\'mspca.mat\')"')
    L = scipy.io.loadmat('mspca.mat')['L'] 
  else:
    L = A
  U, lam, V = np.linalg.svd(L, full_matrices = False)  # V is transposed
  assert(U.shape == (A.shape[0], dmin) and lam.shape == (dmin,) and V.shape == (dmin, A.shape[1]))
  #np.savetxt('singular_values.csv', lam)
  lam_trunc = lam[lam > 0.015 * lam[0]]  # magic number
  p = comp if comp else len(lam_trunc)
  assert(p <= dmin)
  print('PCA truncation', dmin, '->', p)
  return L, V.T[:,:p]

def reduce_matrix(A, V):
  # (N, w, 16) → (N, 16, w) → ((N*16), w) → compute V
  # (N, 16, w) * V → transpose again last dimensions
  B = np.swapaxes(A, 1, 2)  # (N, 16, w)
  C = B.reshape((-1, B.shape[2]))  # ((N*16), w)
  if V is None:
    L, V = pca_reduction(C, 5e-6, comp = 50)
  B = C @ V  # ((N*16), p)
  B = B.reshape((A.shape[0], A.shape[2], B.shape[1]))  # (N, 16, p)
  return np.swapaxes(B, 1, 2), V  # B = (N, p, 16)

def adjust_size(x, y):
  # when flattening the data matrix on the first dimension, y must be made compatible
  if len(x) == len(y): return y
  factor = len(x) // len(y)
  ynew = np.empty((len(x), 1))
  for i in range(0, len(y)):
    ynew[i * factor : (i + 1) * factor] = y[i]
  return ynew


In [29]:
class EegDataset(Dataset):
    
    def __init__(self, 
                 file_paths, 
                #  labels, 
                 create_dataset_crop, 
                 window, 
                 overlap):
        
        super().__init__()
        self.file_paths = file_paths
        # self.labels = labels
        self.create_dataset_crop = create_dataset_crop
        self.window = window
        self.overlap = overlap
        
        self.crops_index = self._compute_crops_index()
    
    def _compute_crops_index(self):
        crops_index = []
        for file_idx, (file_path) in enumerate(self.file_paths):
            # print(f"file_path: {file_path}")
            crops, _, _ = self.create_dataset_crop(file_path, self.window, self.overlap)
            
            num_crops = len(crops)
            
            crops_index.extend([(file_idx, crop_idx) for crop_idx in range(num_crops)])
            
        return crops_index
    
    def __len__(self):
        return len(self.crops_index)
    
    def __getitem__(self, idx):
        
        file_idx, crop_idx = self.crops_index[idx]
        file_path = self.file_paths[file_idx]
        
        crops, labels, _ = self.create_dataset_crop(file_path, self.window, self.overlap)
        x_data_reduced, Vpca = reduce_matrix(crops, None)
        labels = adjust_size(x_data_reduced, labels)
        # print(np.unique(label[0]))
        # print(label.shape)
        crop = x_data_reduced[crop_idx]
        label = labels[0] 
        
        label = torch.tensor(label).float().squeeze().unsqueeze(0)        
        # label = self.labels[file_idx]
        
        return torch.tensor(crop).float(), label

subj_list = (
      tuple((f'{i:02d}', 'N') for i in range(1, 16)) +  # normal subjects, S01 to S15
      tuple((f'{i:02d}', 'AD') for i in range(1, 21))   # alzheimer's subjects, S01 to S20
  )

subjs_test = (0, 1, 15, 16, 17)  

test_subject_list = [subj_list[i] for i in subjs_test]
train_val_subjects = [subj for i, subj in enumerate(subj_list) if i not in subjs_test]   
    
dataset = EegDataset(file_paths=train_val_subjects,
                     create_dataset_crop=create_dataset,
                     window=128,
                     overlap=25)

data_loader = DataLoader(dataset, batch_size = 32)



In [30]:
dataset

<__main__.EegDataset at 0x72b241ab1c30>

In [28]:
for X, y in data_loader:
    print(f"{X.shape}")
    print(f"{y.shape}")
    break

PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
PCA truncation 128 -> 50
torch.Size([32, 50, 16])
torch.Size([32, 1])


In [31]:
test = ('03', 'N')
prova = list(test)

print(prova)

['03', 'N']


In [27]:
for i, data in enumerate(train_val_subjects):
    print(i, data)

0 ('03', 'N')
1 ('04', 'N')
2 ('05', 'N')
3 ('06', 'N')
4 ('07', 'N')
5 ('08', 'N')
6 ('09', 'N')
7 ('10', 'N')
8 ('11', 'N')
9 ('12', 'N')
10 ('13', 'N')
11 ('14', 'N')
12 ('15', 'N')
13 ('04', 'AD')
14 ('05', 'AD')
15 ('06', 'AD')
16 ('07', 'AD')
17 ('08', 'AD')
18 ('09', 'AD')
19 ('10', 'AD')
20 ('11', 'AD')
21 ('12', 'AD')
22 ('13', 'AD')
23 ('14', 'AD')
24 ('15', 'AD')
25 ('16', 'AD')
26 ('17', 'AD')
27 ('18', 'AD')
28 ('19', 'AD')
29 ('20', 'AD')


In [7]:
import random 

In [16]:


# normal_subjects = [subj for subj in train_val_subjects if subj[1] == 'N']
# ad_subjects = [subj for subj in train_val_subjects if subj[1] == 'AD']

# random.seed(42)  
# random.shuffle(normal_subjects)
# random.shuffle(ad_subjects)

# split_index_normal = int(0.8 * len(normal_subjects))
# split_index_ad = int(0.8 * len(ad_subjects))

# train_normal = normal_subjects[:split_index_normal]
# val_normal = normal_subjects[split_index_normal:]

# train_ad = ad_subjects[:split_index_ad]
# val_ad = ad_subjects[split_index_ad:]

# train_subject_list = train_normal + train_ad
# val_subject_list = val_normal + val_ad

# random.shuffle(train_subject_list)
# random.shuffle(val_subject_list)


### Creating dataset


FileNotFoundError: [Errno 2] No such file or directory: '/home/marta/Documenti/eeg_rnn_repo/rnn-eeg-ad/eeg2/S0_3.npz'