# データの分割

In [1]:
import pandas as pd

In [2]:
data = pd.read_csv("input/pn.csv")

## 学習・テストセットに分割

[train_test_split](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html#sklearn.model_selection.train_test_split)
を使います。

In [3]:
from sklearn.model_selection import train_test_split

train, test = train_test_split(data, test_size=0.2, random_state=0)

In [4]:
def check_label_distribution(splits, labels):
    res = []
    for s in splits:
        percs = [
            (s.query('label == @label').size * 100 / s.size).round(2)
            for label in labels
        ]
        res.append(percs)
    return res

In [5]:
check_label_distribution([train, test], ["positive", "neutral", "negative"])

[[61.32, 24.02, 14.66], [61.39, 23.58, 15.03]]

## ラベルの分布を保って学習・テストセットに分割

ラベルの分布を保つには`stratify`引数にラベルを使います。

In [6]:
train, test = train_test_split(data, test_size=0.2, stratify=data["label"], random_state=0)

In [7]:
check_label_distribution([train, test], ["positive", "neutral", "negative"])

[[61.35, 23.93, 14.72], [61.3, 23.94, 14.76]]

## 交差検証用に分割

In [8]:
from sklearn.model_selection import KFold


fold = KFold(n_splits=4, shuffle=True, random_state=0)

for fold_id, (train_idx, val_idx) in enumerate(fold.split(X=data["text"])):
    train_cv = data.iloc[train_idx]
    val_cv = data.iloc[val_idx]
    print(fold_id, train_cv.shape, val_cv.shape, train_cv.query('label == "positive"').shape)

0 (4164, 3) (1389, 3) (2544, 3)
1 (4165, 3) (1388, 3) (2569, 3)
2 (4165, 3) (1388, 3) (2547, 3)
3 (4165, 3) (1388, 3) (2558, 3)


## ラベルの分布を保って交差検証用に分割

[StratifiedKFold](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html)
を使います。
`.split`メソッドにstratifyするラベル`y`を渡します。

In [9]:
from sklearn.model_selection import StratifiedKFold


fold = StratifiedKFold(n_splits=4, shuffle=True, random_state=0)

for fold_id, (train_idx, val_idx) in enumerate(fold.split(X=data, y=data["label"])):
    train_cv = data.iloc[train_idx]
    val_cv = data.iloc[val_idx]
    print(fold_id, train_cv.shape, val_cv.shape, train_cv.query('label == "positive"').shape)

0 (4164, 3) (1389, 3) (2555, 3)
1 (4165, 3) (1388, 3) (2554, 3)
2 (4165, 3) (1388, 3) (2554, 3)
3 (4165, 3) (1388, 3) (2555, 3)
