# Split Dataset

split dataset into { train, validation, test } sets

- train: 70%
- validation: 10%
- test: 20%

In [4]:
import os
import json
import torch

data_dir = os.path.abspath("...")

torch.manual_seed(42)

<torch._C.Generator at 0x2b9ed2bc830>

In [5]:
def split_dataset(dataset: torch.Tensor, train_ratio=0.7, val_ratio=0.1):
    """Split dataset into train, val, and test sets.

    :return: Tuple of (train_dataset, val_dataset, test_dataset)
    """
    n_dataset = dataset.shape[0]
    train_size = int(n_dataset * train_ratio)
    val_size = int(n_dataset * val_ratio)
    # test_size = n_dataset - train_size - val_size

    indices = torch.randperm(n_dataset)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]

    return dataset[train_indices], dataset[val_indices], dataset[test_indices]

## Split ADHD Dataset

In [6]:
adhd_dataset = torch.load(
    os.path.join(data_dir, "eeg_adhd.pt"), weights_only=True
).permute(0, 2, 1)
train_set_adhd, val_set_adhd, test_set_adhd = split_dataset(adhd_dataset)

print("Train set:", train_set_adhd.size())
print("Val set:", val_set_adhd.size())
print("Test set:", test_set_adhd.size())

Train set: torch.Size([71, 19, 9250])
Val set: torch.Size([10, 19, 9250])
Test set: torch.Size([21, 19, 9250])


## Split Contorl Dataset

In [7]:
control_dataset = torch.load(
    os.path.join(data_dir, "eeg_control.pt"), weights_only=True
).permute(0, 2, 1)
train_set_control, val_set_control, test_set_control = split_dataset(control_dataset)

print("Train set:", train_set_control.size())
print("Val set:", val_set_control.size())
print("Test set:", test_set_control.size())

Train set: torch.Size([50, 19, 9250])
Val set: torch.Size([7, 19, 9250])
Test set: torch.Size([15, 19, 9250])


## Save with labels

- 1: ADHD
- 0: Contorl (non-ADHD)

In [8]:
def save_dataset(dataset, labels, filename):
    labeled_dataset = {"data": dataset, "label": labels}
    save_path = os.path.join(data_dir, filename)
    torch.save(labeled_dataset, save_path)

In [9]:
train_set = torch.cat((train_set_adhd, train_set_control), dim=0)
train_labels = torch.cat(
    (
        torch.ones((train_set_adhd.size(0)), dtype=torch.int8),
        torch.zeros((train_set_control.size(0)), dtype=torch.int8),
    ),
    dim=0,
)

# Shuffle data for training.
assert train_set.size(0) == train_labels.size(0), "Data and label size mismatch"
random_indices = torch.randperm(train_set.size(0))
print(f"Random indices: {random_indices}")

save_dataset(train_set[random_indices], train_labels[random_indices], "train.pt")

Random indices: tensor([ 11,  12,  76,  54,   9,  43,  64,  17, 104, 103,  16,  71,  46,  29,
         39, 117,   7,  94,  40,  57,  98,  91,  53,  42,  61,  84,  37, 101,
         86,  68, 111,  99,  90,  22,  58,  82,  50,  93,   4, 107,   1,  38,
        118,  24, 110,  33,  78,  97,  21,  14,   5,  92,  25,  23,  18, 102,
        116, 109,  35,  52,  31, 100,  85, 106,  28,  55,  81,  36,  77,  60,
         10,  47,  89,  45,  15,  83,  26,  30,   3,  48,   2,  80,  20,  27,
         44,  32,  74,  79, 108,  73,  62, 113,  41,  72, 120,  96,  95, 115,
         59,  88,  67,  69,  19,  63,  66,   8, 105,  75,  49,  70,  65,  13,
         51,   0,  56, 119, 112,  34,  87, 114,   6])


In [10]:
val_set = torch.cat((val_set_adhd, val_set_control), dim=0)
val_labels = torch.cat(
    (
        torch.ones((val_set_adhd.size(0)), dtype=torch.int8),
        torch.zeros((val_set_control.size(0)), dtype=torch.int8),
    ),
    dim=0,
)
save_dataset(val_set, val_labels, "val.pt")

In [11]:
test_set = torch.cat((test_set_adhd, test_set_control), dim=0)
test_labels = torch.cat(
    (
        torch.ones((test_set_adhd.size(0)), dtype=torch.int8),
        torch.zeros((test_set_control.size(0)), dtype=torch.int8),
    ),
    dim=0,
)
save_dataset(test_set, test_labels, "test.pt")

## Save metadata

In [12]:
metadata = {
    "name": "EEG data for ADHD / Control children",
    "description": "EEG dataset for ADHD classification.",
    "license": "CC BY 4.0",
    "train_size": list(train_set.size()),
    "val_size": list(val_set.size()),
    "test_size": list(test_set.size()),
    "data_length": train_set.size()[2],
    "channel": train_set.size()[1],
    "label": {1: "ADHD", 0: "Control"},
}

for k, v in metadata.items():
    print(f"{k}: {v}")

metadata_path = os.path.join(data_dir, "metadata.json")
with open(metadata_path, "w") as f:
    json.dump(metadata, f, indent=2)

name: EEG data for ADHD / Control children
description: EEG dataset for ADHD classification.
license: CC BY 4.0
train_size: [121, 19, 9250]
val_size: [17, 19, 9250]
test_size: [36, 19, 9250]
data_length: 9250
channel: 19
label: {1: 'ADHD', 0: 'Control'}
