## Create k-fold Dataset

In [140]:
from datasets import load_dataset, load_from_disk
import random
import numpy as np
import datasets

import os

seed = 42
random.seed(seed)
np.random.seed(seed)

### 데이터셋 로드

In [141]:
train_dataset = load_from_disk("../../data/train_dataset")

In [142]:
# 데이터셋들을 리스트에 추가해주세요
# 예시 : [train_dataset, qg_data_v1, qg_data_v1_2]
ori_dataset = [train_dataset]

# 데이터셋의 validation 데이터가 겹치는 경우 하나만 True로 설정하고 나머지는 False로 설정해주세요
add_validation = [True]

ori_dataset

[DatasetDict({
     train: Dataset({
         features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
         num_rows: 3952
     })
     validation: Dataset({
         features: ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title'],
         num_rows: 240
     })
 })]

### train, validation 데이터 split

In [143]:
def createDataDict(k_split, column_list):
    if k_split == 0: return {column:[] for column in column_list}
    return [{column:[] for column in column_list} for _ in range(k_split)]

def addSplitData(split_dataset, dataset, k_split):
    index_list = list(range(len(dataset)))
    random.shuffle(index_list)
    
    split_index = np.array_split(index_list, k_split)
    for i in range(k_split):
        for column in column_list:
            split_dataset[i][column].extend(dataset[split_index[i]][column])

In [144]:
# 몇 개로 나눌건지 설정해주세요
k_split = 5

column_list = ['__index_level_0__', 'answers', 'context', 'document_id', 'id', 'question', 'title']

split_dataset = createDataDict(k_split, column_list)

for dataset, vali in zip(ori_dataset, add_validation):
    addSplitData(
        split_dataset=split_dataset,
        dataset=dataset['train'],
        k_split=k_split
    )
    if not vali: continue
    addSplitData(
        split_dataset=split_dataset,
        dataset=dataset['validation'],
        k_split=k_split
    )

In [145]:
len(split_dataset[0]['title'])

839

### Folding

In [146]:
fold_dataset = [
    {
        "train":createDataDict(0, column_list),
        "validation":createDataDict(0, column_list),
    } for _ in range(k_split)
]

for i in range(k_split):
    # validation 추가
    for column in column_list:
        fold_dataset[i]["validation"][column].extend(split_dataset[i][column])
        
    # validation 제외 나머지 train으로 추가
    for j in range(k_split):
        if i == j: continue
        for column in column_list:
            fold_dataset[i]["train"][column].extend(split_dataset[j][column])

In [147]:
len(fold_dataset[0]['train']['title']), len(fold_dataset[0]['validation']['title'])

(3353, 839)

### dataset 저장

In [148]:
for i in range(k_split):
    dataset_dict = datasets.DatasetDict({
        "train": datasets.arrow_dataset.Dataset.from_dict(fold_dataset[i]['train']),
        "validation": datasets.arrow_dataset.Dataset.from_dict(fold_dataset[i]['validation']),
    })

    dataset_dict.save_to_disk(os.path.join("../../data", "fold"+str(i)))