# Split Dataset

split dataset into { train, validation, test } sets

- train: 70%
- validation: 10%
- test: 20%

In [1]:
import os
import json
import torch

data_dir = os.path.abspath("../../IEEE")

torch.manual_seed(42)

<torch._C.Generator at 0x29a1ab9c830>

In [2]:
def split_dataset(dataset: torch.Tensor, train_ratio=0.7, val_ratio=0.1):
    """Split dataset into train, val, and test sets.

    :return: Tuple of (train_dataset, val_dataset, test_dataset)
    """
    n_dataset = dataset.shape[0]
    train_size = int(n_dataset * train_ratio)
    val_size = int(n_dataset * val_ratio)
    # test_size = n_dataset - train_size - val_size

    indices = torch.randperm(n_dataset)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]

    return dataset[train_indices], dataset[val_indices], dataset[test_indices]

## Split ADHD Dataset

In [3]:
adhd_dataset = torch.load(
    os.path.join(data_dir, "eeg_adhd.pt"), weights_only=True
).permute(0, 2, 1)
train_set_adhd, val_set_adhd, test_set_adhd = split_dataset(adhd_dataset)

print("Train set:", train_set_adhd.size())
print("Val set:", val_set_adhd.size())
print("Test set:", test_set_adhd.size())

Train set: torch.Size([71, 19, 9250])
Val set: torch.Size([10, 19, 9250])
Test set: torch.Size([21, 19, 9250])


## Split Contorl Dataset

In [4]:
control_dataset = torch.load(
    os.path.join(data_dir, "eeg_control.pt"), weights_only=True
).permute(0, 2, 1)
train_set_control, val_set_control, test_set_control = split_dataset(control_dataset)

print("Train set:", train_set_control.size())
print("Val set:", val_set_control.size())
print("Test set:", test_set_control.size())

Train set: torch.Size([50, 19, 9250])
Val set: torch.Size([7, 19, 9250])
Test set: torch.Size([15, 19, 9250])


## Save with labels

- 1: ADHD
- 0: Contorl (non-ADHD)

In [5]:
def save_dataset(dataset, labels, filename):
    labeled_dataset = {"data": dataset, "label": labels}
    save_path = os.path.join(data_dir, filename)
    torch.save(labeled_dataset, save_path)

In [6]:
train_set = torch.cat((train_set_adhd, train_set_control), dim=0)
train_labels = torch.cat(
    (
        torch.ones((train_set_adhd.shape[0]), dtype=torch.int8),
        torch.zeros((train_set_control.shape[0]), dtype=torch.int8),
    ),
    dim=0,
)
save_dataset(train_set, train_labels, "ieee_train.pt")

In [7]:
val_set = torch.cat((val_set_adhd, val_set_control), dim=0)
val_labels = torch.cat(
    (
        torch.ones((val_set_adhd.shape[0]), dtype=torch.int8),
        torch.zeros((val_set_control.shape[0]), dtype=torch.int8),
    ),
    dim=0,
)
save_dataset(val_set, val_labels, "ieee_val.pt")

In [8]:
test_set = torch.cat((test_set_adhd, test_set_control), dim=0)
test_labels = torch.cat(
    (
        torch.ones((test_set_adhd.shape[0]), dtype=torch.int8),
        torch.zeros((test_set_control.shape[0]), dtype=torch.int8),
    ),
    dim=0,
)
save_dataset(test_set, test_labels, "ieee_test.pt")

## Save metadata

In [9]:
metadata = {
    "name": "EEG data for ADHD / Control children",
    "description": "EEG dataset for ADHD classification.",
    "license": "CC BY 4.0",
    "train_size": list(train_set.size()),
    "val_size": list(val_set.size()),
    "test_size": list(test_set.size()),
    "data_length": train_set.size()[2],
    "channel": train_set.size()[1],
    "label": {1: "ADHD", 0: "Control"},
}

for k, v in metadata.items():
    print(f"{k}: {v}")

metadata_path = os.path.join(data_dir, "metadata.json")
with open(metadata_path, "w") as f:
    json.dump(metadata, f, indent=2)

name: EEG data for ADHD / Control children
description: EEG dataset for ADHD classification.
license: CC BY 4.0
train_size: [121, 19, 9250]
val_size: [17, 19, 9250]
test_size: [36, 19, 9250]
data_length: 9250
channel: 19
label: {1: 'ADHD', 0: 'Control'}
