In [1]:
import torch
import sys
import os
from sklearn.model_selection import train_test_split


sys.path.insert(0, "..")

from utils import *

from customdataset import TestingDataset

CUR_DIR = os.path.abspath(os.curdir)
SAVE_DIR_ONLINE = "D:/storage/odl/cache/online_session/reintel2020/"
SAVE_DIR_OFFLINE = "D:/storage/odl/cache/offline/reintel2020/"
CACHE_DIR = "D:/storage/odl/cache/clean_metadata_with_image/reintel2020/"

for dir in [SAVE_DIR_OFFLINE, SAVE_DIR_ONLINE, CACHE_DIR]:
    if not os.path.isdir(dir):
        os.makedirs(dir)

In [2]:
offline_train = TestingDataset([])
offline_test = TestingDataset([])
offline_dev = TestingDataset([])

for i, cache_file in enumerate(os.listdir(CACHE_DIR)):
    # create dataset from cache
    dataset = TestingDataset(torch.load(os.path.join(CACHE_DIR, cache_file)))

    # get indices and labels
    x = list(range(len(dataset)))
    y = dataset.get_labels()

    # split dataset into train and test+dev
    train_idx, test_idx, train_label, test_label = train_test_split(
        x, y, test_size=0.2, stratify=y
    )

    # split test+dev into test and dev
    test_idx, dev_idx, test_label, dev_label = train_test_split(
        test_idx, test_label, test_size=0.5, stratify=test_label
    )

    train_idx.sort()
    test_idx.sort()
    dev_idx.sort()

    train_dataset = dataset.subset(train_idx)
    test_dataset = dataset.subset(test_idx)
    dev_dataset = dataset.subset(dev_idx)

    offline_train += train_dataset
    offline_test += test_dataset
    offline_dev += dev_dataset

    # save into cache
    file_name = f"dataset_{i+1:02}.pt"
    file_path = os.path.join(SAVE_DIR_ONLINE, file_name)
    torch.save(
        {"train": train_dataset, "test": test_dataset, "dev": dev_dataset},
        file_path,
    )
    print(f"Saved {file_name} with size: {os.path.getsize(file_path)} bytes")

Saved dataset_01.pt with size: 968348722 bytes
Saved dataset_02.pt with size: 944616242 bytes
Saved dataset_03.pt with size: 822588850 bytes
Saved dataset_04.pt with size: 861975730 bytes
Saved dataset_05.pt with size: 857127023 bytes


In [6]:
len(offline_train), len(offline_test), len(offline_dev)

(3497, 435, 440)

In [5]:
# save into cache
file_name = "dataset.pt"
file_path = os.path.join(SAVE_DIR_OFFLINE, file_name)
torch.save(
    {"train": offline_train, "test": offline_test, "dev": offline_dev},
    file_path,
)
print(f"Saved {file_name} with size: {os.path.getsize(file_path)} bytes")

Saved dataset.pt with size: 4454757809 bytes
