In [7]:
import itertools
import os
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import glob
import json
import logging
import sys
import torch.nn as nn
import torch.optim as optim
import random
def batch_equalizer_fn(args):
    eeg = args[0]
#     print("eegshape=",eeg.shape)
    num_stimuli = len(args) - 1
    # repeat eeg num_stimuli times
    new_eeg = torch.cat([eeg] * num_stimuli, dim=0)
    all_features = [new_eeg]
#     print("all_features=",all_features[0].shape)

    # create args
    args_to_zip = [args[i::num_stimuli] for i in range(1, num_stimuli + 1)]
#     print("args_to_zip=",args_to_zip[0].shape,args_to_zip[1].shape)

    for stimuli_features in zip(*args_to_zip):
#         print("stimuli_features=",stimuli_features[0].shape,stimuli_features[1].shape)
        for i in range(num_stimuli):
            shift = i
            shifted_tuple = tuple((stimuli_features[(j - shift) % len(stimuli_features)][0],stimuli_features[(j - shift) % len(stimuli_features)][1]) for j in range(len(stimuli_features)))
            stimulus_rolled = torch.stack(tuple(shifted_tuple[i][0] for i in range(len(shifted_tuple))))
            mel_rolled = torch.stack(tuple(shifted_tuple[i][1] for i in range(len(shifted_tuple))))
#             print("stimulus_rolled=", stimulus_rolled.shape)
            # reshape stimulus_rolled to merge the first two dimensions
            stimulus_rolled = stimulus_rolled.view(-1, stimulus_rolled.size(2), stimulus_rolled.size(3))
            mel_rolled = mel_rolled.view(-1, mel_rolled.size(2), mel_rolled.size(3))
#             print("stimulus_rolled1=", stimulus_rolled.shape)

            all_features.append((stimulus_rolled,mel_rolled))
#     print("all_features1=",all_features)
    
    labels_list = [torch.tensor([[1 if ii == i else 0 for ii in range(num_stimuli)]]) for i in range(num_stimuli)]
    labels = torch.cat([label.repeat(eeg.size(0), 1) for label in labels_list], dim=0)
#     print("labels=",labels)
#     print("tuple(all_features)=", tuple(all_features))

    return tuple(all_features), labels

def shuffle_fn(args, number_mismatch):
    # repeat the last argument number_mismatch times
    args = list(args)
    for _ in range(number_mismatch):
        args.append((args[-1][0][torch.randperm(args[-1][0].size(0))],args[-1][1][torch.randperm(args[-1][1].size(0))]))
    return tuple(args)

# Function to create frames from a tensor
def frame_tensor(tensor, window_length, hop_length):
    num_frames = (tensor.size(0) - window_length) // hop_length + 1
    frames = torch.stack(
        [tensor[i * hop_length : i * hop_length + window_length] for i in range(num_frames)]
    )
    return frames

def process_eeg(original_tensors_list):
    reshaped_tensors = [tensor[i].view(320, 64) for tensor in original_tensors_list for i in range(tensor.size(0))]

    lists_of_tensors = [reshaped_tensors[i:i+8] for i in range(0, len(reshaped_tensors), 8)]
    lists_of_tensors = lists_of_tensors[:len(reshaped_tensors)//8]

    # Shuffle the lists
    random.shuffle(lists_of_tensors)

    final_tensors = []
    for chunk_of_lists in zip(*(iter(lists_of_tensors),) * 8):
        concatenated_tensors = torch.cat([torch.unsqueeze(tensor, 0) for sublist in chunk_of_lists for tensor in sublist], dim=0).view(64, 320, 64)
        final_tensors.append(concatenated_tensors)

    return final_tensors
def process_stimuli(original_tensors_list):
    reshaped_tensors = [tensor[i].view(320, 1) for tensor in original_tensors_list for i in range(tensor.size(0))]

    lists_of_tensors = [reshaped_tensors[i:i+8] for i in range(0, len(reshaped_tensors), 8)]
    lists_of_tensors = lists_of_tensors[:len(reshaped_tensors)//8]

    # Shuffle the lists
    random.shuffle(lists_of_tensors)

    final_tensors = []
    for chunk_of_lists in zip(*(iter(lists_of_tensors),) * 8):
        concatenated_tensors = torch.cat([torch.unsqueeze(tensor, 0) for sublist in chunk_of_lists for tensor in sublist], dim=0).view(64, 320, 1)
        final_tensors.append(concatenated_tensors)

    return final_tensors

def process_mel(original_tensors_list):
    reshaped_tensors = [tensor[i].view(320, 10) for tensor in original_tensors_list for i in range(tensor.size(0))]

    lists_of_tensors = [reshaped_tensors[i:i+8] for i in range(0, len(reshaped_tensors), 8)]
    lists_of_tensors = lists_of_tensors[:len(reshaped_tensors)//8]

    # Shuffle the lists
    random.shuffle(lists_of_tensors)

    final_tensors = []
    for chunk_of_lists in zip(*(iter(lists_of_tensors),) * 8):
        concatenated_tensors = torch.cat([torch.unsqueeze(tensor, 0) for sublist in chunk_of_lists for tensor in sublist], dim=0).view(64, 320, 10)
        final_tensors.append(concatenated_tensors)

    return final_tensors
class PyTorchDataGenerator(Dataset):
    def __init__(self, files, window_length):
        self.window_length = window_length
        self.files = self.group_recordings(files)

    def group_recordings(self, files):
        new_files = []
        grouped = itertools.groupby(
            sorted(files), lambda x: "_-_".join(os.path.basename(x).split("_-_")[:3])
        )
        for recording_name, feature_paths in grouped:
            new_files += [sorted(feature_paths, key=lambda x: "0" if x == "eeg" else x)]
#         print("new_files=", new_files[0:4])
        return new_files
    

    def __len__(self):
        return len(self.files)

    def __getitem__(self, recording_index):
        data = []
        for feature in self.files[recording_index]:
            f = np.load(feature).astype(np.float32)
            if f.ndim == 1:
                f = f[:, None]
#                 print("f_after=", f.shape)
            data += [f]
#         print("data_before=", data)
        data = self.prepare_data(data)
#         print(tuple(torch.tensor(x) for x in data))
#         print("data_after=", data)
#         print("tuple(torch.tensor(x) for x in data)=",tuple(torch.tensor(x) for x in data))
        return tuple(torch.tensor(x) for x in data)

    def __call__(self):
        for idx in range(self.__len__()):
            yield self.__getitem__(idx)

            if idx == self.__len__() - 1:
                self.on_epoch_end()

    def on_epoch_end(self):
        np.random.shuffle(self.files)

    def prepare_data(self, data):
        # make sure data has dimensionality of (n_samples, n_features)
        return data

def create_pytorch_dataset(
    data_generator,
    window_length,
    batch_equalizer_fn=None,
    frame_tensor=None,
    process_eeg=None,
    process_stimuli=None,
    process_mel=None,
    hop_length=64,
    batch_size=64,
    number_mismatch=None,
    data_types=(torch.float32, torch.float32),
    feature_dims=(64, 1)
):
    dataset = data_generator
    if frame_tensor is not None:
        dataset = [(frame_tensor(data[0], window_length, hop_length),(frame_tensor(data[1], window_length, hop_length),frame_tensor(data[2], window_length, hop_length))) for data in dataset]


    if number_mismatch is not None:
        # map second argument to shifted version
        dataset = [
        shuffle_fn(data, number_mismatch) for data in dataset
    ]
    
    if process_eeg is not None and process_stimuli is not None and process_mel is not None:
        # map second argument to shifted version
        dataset=[process_eeg([data[0] for data in dataset]),
        process_stimuli([data[1][0] for data in dataset]),
        process_stimuli([data[2][0] for data in dataset]),
        process_stimuli([data[3][0] for data in dataset]),
        process_stimuli([data[4][0] for data in dataset]),
        process_stimuli([data[5][0] for data in dataset]),
        process_mel([data[1][1] for data in dataset]),
        process_mel([data[2][1] for data in dataset]),
        process_mel([data[3][1] for data in dataset]),
        process_mel([data[4][1] for data in dataset]),
        process_mel([data[5][1] for data in dataset])]
        dataset = [tuple([dataset[0][i],(dataset[1][i],dataset[6][i]),(dataset[2][i],dataset[7][i]),(dataset[3][i],dataset[8][i]),(dataset[4][i],dataset[9][i]),(dataset[5][i],dataset[10][i])]) for i in range(len(dataset[0]))]
#         print(dataset[0][0].shape,dataset[0][1].shape,dataset[0][2].shape)

    if batch_equalizer_fn is not None:
        # Create the labels and make sure classes are balanced
        dataset = [
            tuple(batch_equalizer_fn(args)) for args in dataset
        ]

    return tuple(dataset)

window_length_s = 5
fs = 64

window_length = window_length_s * fs  # 5 seconds
# Hop length between two consecutive decision windows
hop_length = 64

epochs = 100
patience = 5
batch_size = 64 #fixed in the code
only_evaluate = True
number_mismatch = 4 # or 4



training_log_filename = "training_log_{}_{}.csv".format(number_mismatch, window_length_s)
data_folder = "split_data/split_data"

# stimulus feature which will be used for training the model. Can be either 'envelope' ( dimension 1) or 'mel' (dimension 28)
stimulus_features = ["envelope","mel"]
stimulus_dimension = 1

# uncomment if you want to train with the mel spectrogram stimulus representation
# stimulus_features = ["mel"]
# stimulus_dimension = 10

features = ["eeg"] + stimulus_features
# print("features=", features)
train_files = [x for x in glob.glob(os.path.join(data_folder, "train_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features]
# Create list of numpy array files
train_generator = PyTorchDataGenerator(train_files, window_length)
import pdb
dataset_train = create_pytorch_dataset(train_generator, window_length, batch_equalizer_fn,frame_tensor,process_eeg,process_stimuli,process_mel,
                                  hop_length, batch_size,
           
                                       number_mismatch=number_mismatch,
                                  data_types=(torch.float32, torch.float32),
                                  feature_dims=(64, stimulus_dimension))

val_files = [x for x in glob.glob(os.path.join(data_folder, "val_-_*")) if os.path.basename(x).split("_-_")[-1].split(".")[0] in features]
val_generator = PyTorchDataGenerator(val_files, window_length)
dataset_val = create_pytorch_dataset(val_generator,  window_length, batch_equalizer_fn,frame_tensor,process_eeg,process_stimuli,process_mel,
                                  hop_length, batch_size,
                                  number_mismatch=number_mismatch,
                                  data_types=(torch.float32, torch.float32),
                                  feature_dims=(64, stimulus_dimension))




KeyboardInterrupt



In [6]:
i=0
for batch in dataset_train:
    print(batch[0][0].shape,batch[0][1][0].shape,batch[0][1][1].shape,batch[1].shape)
    i+=1
    if i == 2:
        break;

torch.Size([320, 320, 64]) torch.Size([320, 320, 1]) torch.Size([320, 320, 10]) torch.Size([320, 5])
torch.Size([320, 320, 64]) torch.Size([320, 320, 1]) torch.Size([320, 320, 10]) torch.Size([320, 5])
