# Split Dataset

split dataset into { train, validation, test } sets

- train: 70%
- validation: 10%
- test: 20%

In [1]:
import os
import torch

data_dir = os.path.abspath("../../IEEE")

torch.manual_seed(42)

<torch._C.Generator at 0x21654f6c870>

In [2]:
def split_dataset(dataset: torch.Tensor, train_ratio=0.7, val_ratio=0.1):
    """Split dataset into train, val, and test sets.

    :return: Tuple of (train_dataset, val_dataset, test_dataset)
    """
    n_dataset = dataset.shape[0]
    train_size = int(n_dataset * train_ratio)
    val_size = int(n_dataset * val_ratio)
    # test_size = n_dataset - train_size - val_size

    indices = torch.randperm(n_dataset)
    train_indices = indices[:train_size]
    val_indices = indices[train_size : train_size + val_size]
    test_indices = indices[train_size + val_size :]

    return dataset[train_indices], dataset[val_indices], dataset[test_indices]


def save_dataset(dataset, labels, filename):
    labeled_dataset = {"data": dataset, "label": labels}
    save_path = os.path.join(data_dir, filename)
    torch.save(labeled_dataset, save_path)

## Split ADHD Dataset

In [3]:
adhd_dataset = torch.load(os.path.join(data_dir, "ieee_eeg_adhd.pt"))
train_set_adhd, val_set_adhd, test_set_adhd = split_dataset(adhd_dataset)

print("Train set:", train_set_adhd.size())
print("Val set:", val_set_adhd.size())
print("Test set:", test_set_adhd.size())

Train set: torch.Size([306, 2560, 19])
Val set: torch.Size([43, 2560, 19])
Test set: torch.Size([89, 2560, 19])


## Split Contorl Dataset

In [4]:
control_dataset = torch.load(os.path.join(data_dir, "ieee_eeg_control.pt"))
train_set_control, val_set_control, test_set_control = split_dataset(control_dataset)

print("Train set:", train_set_control.size())
print("Val set:", val_set_control.size())
print("Test set:", test_set_control.size())

Train set: torch.Size([241, 2560, 19])
Val set: torch.Size([34, 2560, 19])
Test set: torch.Size([70, 2560, 19])


## Save with labels

- 1: ADHD
- 0: Contorl (non-ADHD)

In [5]:
train_set = torch.cat((train_set_adhd, train_set_control), dim=0)
train_labels = torch.cat(
    (
        torch.ones((train_set_adhd.shape[0], 1), dtype=torch.int8),
        torch.zeros((train_set_control.shape[0], 1), dtype=torch.int8),
    ),
    dim=0,
)
save_dataset(train_set, train_labels, "train.pt")

In [6]:
val_set = torch.cat((val_set_adhd, val_set_control), dim=0)
val_labels = torch.cat(
    (
        torch.ones((val_set_adhd.shape[0], 1), dtype=torch.int8),
        torch.zeros((val_set_control.shape[0], 1), dtype=torch.int8),
    ),
    dim=0,
)
save_dataset(val_set, val_labels, "val.pt")

In [7]:
test_set = torch.cat((test_set_adhd, test_set_control), dim=0)
test_labels = torch.cat(
    (
        torch.ones((test_set_adhd.shape[0], 1), dtype=torch.int8),
        torch.zeros((test_set_control.shape[0], 1), dtype=torch.int8),
    ),
    dim=0,
)
save_dataset(test_set, test_labels, "test.pt")