In [1]:
from pnpl.datasets import LibriBrainSpeech
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [62]:
base_path = "/srv/nfs-data/sisko/storage/libribrain"
train_keys = []
val_keys = [("0", "11", "Sherlock1", "2")] 
test_keys = [("0", "12", "Sherlock1", "2")]  

In [63]:
# Sherlock1: solo le prime 10 sessioni run1 → training
for sess_id in range(1,11):
    train_keys.append(("0", str(sess_id), "Sherlock1", "1"))

# Sherlock2–4: 12 sessioni run1 → training
for subject_idx, stim_name in zip(range(1, 2), ["Sherlock2"]):
    for sess_id in range(1,13):
        train_keys.append((0, str(sess_id), stim_name, "1"))

# Sherlock5: 15 sessioni run1 → training
# for sess_id in range(1,16):
#     train_keys.append(("0", str(sess_id), "Sherlock5", "1"))

# # Sherlock6: 14 sessioni run1 → training
# for sess_id in range(1,15):
#     train_keys.append(("0", str(sess_id), "Sherlock6", "1"))

# # Sherlock7: 14 sessioni run1 → training
# for sess_id in range(1,15):
#     train_keys.append(("0", str(sess_id), "Sherlock7", "1"))

In [64]:
train_dataset = LibriBrainSpeech(
    data_path=f"{base_path}/data/",
    include_run_keys=train_keys,
    tmin=0.0,
    tmax=0.8,
    preload_files=True
)

val_dataset = LibriBrainSpeech(
    data_path=f"{base_path}/data/",
    include_run_keys=val_keys,
    tmin=0.0,
    tmax=0.8,
    preload_files=True
)

test_dataset = LibriBrainSpeech(
    data_path=f"{base_path}/data/",
    include_run_keys=test_keys,
    tmin=0.0,
    tmax=0.8,
    preload_files=True
)

In [65]:
num_workers = 4
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=num_workers)

print("Number of training samples:", len(train_dataset))
print("Number of validation samples:", len(val_dataset))
print("Number of test samples:", len(test_dataset))

Number of training samples: 35263
Number of validation samples: 1671
Number of test samples: 1772


In [54]:
first_batch = next(iter(train_loader))
inputs, labels = first_batch
print("Batch input shape:", inputs.shape)
print("Batch label shape:", labels.shape)

Batch input shape: torch.Size([32, 306, 200])
Batch label shape: torch.Size([32, 200])


In [55]:
import random
import torch
from torch.utils.data import DataLoader
import platform


SENSORS_SPEECH_MASK = [18, 20, 22, 23, 45, 120, 138, 140, 142, 143, 145,
                       146, 147, 149, 175, 176, 177, 179, 180, 198, 271, 272, 275]

class FilteredDataset(torch.utils.data.Dataset):
    """
    Parameters:
        dataset: LibriBrain dataset.
        limit_samples (int, optional): If provided, limits the length of the dataset to this
                          number of samples.
        speech_silence_only (bool, optional): If True, only includes segments that are either
                          purely speech or purely silence (with additional balancing).
        apply_sensors_speech_mask (bool, optional): If True, applies a fixed sensor mask to the sensor
                          data in each sample.
    """
    def __init__(self,
                 dataset,
                 limit_samples=None,
                 disable=False,
                 apply_sensors_speech_mask=True):
        self.dataset = dataset
        self.limit_samples = limit_samples
        self.apply_sensors_speech_mask = apply_sensors_speech_mask

        # These are the sensors we identified:
        self.sensors_speech_mask = SENSORS_SPEECH_MASK

        self.balanced_indices = list(range(len(dataset.samples)))
        # Shuffle the indices
        self.balanced_indices = random.sample(self.balanced_indices, len(self.balanced_indices))

    def __len__(self):
        """Returns the number of samples in the filtered dataset."""
        if self.limit_samples is not None:
            return self.limit_samples
        return len(self.balanced_indices)

    def __getitem__(self, index):
        # Map index to the original dataset using balanced indices
        original_idx = self.balanced_indices[index]
        if self.apply_sensors_speech_mask:
            sensors = self.dataset[original_idx][0][self.sensors_speech_mask]
        else:
            sensors = self.dataset[original_idx][0][:]
        label_from_the_middle_idx = self.dataset[original_idx][1].shape[0] // 2
        return [sensors, self.dataset[original_idx][1][label_from_the_middle_idx]]


print("Filtered dataset:")
train_data_filtered = FilteredDataset(train_dataset)
train_loader_filtered = DataLoader(train_data_filtered, batch_size=32, shuffle=True, num_workers=num_workers)
print(f"Train data contains {len(train_data_filtered)} samples")

val_data_filtered = FilteredDataset(val_dataset)
val_loader_filtered = DataLoader(val_data_filtered, batch_size=32, shuffle=False, num_workers=num_workers)
print(f"Validation data contains {len(val_data_filtered)} samples")

test_data_filtered = FilteredDataset(test_dataset)
test_loader_filtered = DataLoader(test_data_filtered, batch_size=32, shuffle=False, num_workers=num_workers)
print(f"Test data contains {len(test_data_filtered)} samples\n")

# Let's look at the first batch:
first_batch = next(iter(train_loader_filtered))
inputs, labels = first_batch
print("Batch input shape:", inputs.shape)
print("Batch label shape:", labels.shape)

first_input = inputs[0]
first_label = labels[0]
print("\nSingle sample input shape:", first_input.shape)
print("Single sample label is just a single value now!")

Filtered dataset:
Train data contains 85516 samples
Validation data contains 1671 samples
Test data contains 1772 samples



Batch input shape: torch.Size([32, 23, 200])
Batch label shape: torch.Size([32])

Single sample input shape: torch.Size([23, 200])
Single sample label is just a single value now!
