In [1]:
import random
from typing import TypeVar, List, Tuple

In [2]:
X = TypeVar('X')

In [3]:
def split_data(data: List[X], prob: float) -> Tuple[List[X], List[X]]:
    # 데이터를 [prob, 1 - prob] 비율로 나눈다.
    data = data[:] # 얕은 복사본을 만든다.
    random.shuffle(data) # shuffle이 리스트 내용을 바꾸기 때문
    cut = int(len(data) * prob) # prob을 사용하여 자를 위치를 선택하고
    return data[:cut], data[cut:] # 섞인 리스트를 자른다

In [5]:
data = [n for n in range(1000)]
train, test = split_data(data, 0.75)

In [6]:
assert len(train) == 750

In [10]:
assert len(test) == 250

In [8]:
assert sorted(train + test) == data

In [11]:
Y = TypeVar('Y')

In [12]:
def train_test_split(xs: List[X],
                     ys: List[Y],
                     test_pct: float) -> Tuple[List[X], List[X], List[X], List[Y], List[Y]]:
    # 인덱스를 생성하여 분할
    idxs = [i for i in range(len(xs))]
    train_idxs, test_idxs = split_data(idxs, 1 - test_pct)

    return([xs[i] for i in train_idxs],   # x_train
           [xs[i] for i in test_idxs],    # x_test
           [ys[i] for i in train_idxs],   # y_train
           [ys[i] for i in test_idxs])    # y_test

In [13]:
xs = [x for x in range(1000)]
ys = [2 * x for x in  xs]
x_train, x_test, y_train, y_test = train_test_split(xs, ys, 0.25)

In [14]:
# 확인
assert len(x_train) == 750

In [15]:
assert len(x_test) == 250

In [16]:
assert len(y_train) == 750
assert len(y_test) == 250

In [17]:
assert all(y == 2 * x for x, y in zip(x_train, y_train))
assert all(y == 2 * x for x, y in zip(x_test, y_test))