In [11]:
import collections
from argparse import Namespace

import numpy as np
import pandas as pd

In [2]:
args = Namespace(
    raw_dataset_csv="./data/surnames/surnames.csv",
    train_proportion=0.7,
    val_proportion=0.15,
    test_proportion=0.15,
    output_munged_csv="./data/surnames/surnames_with_splits2.csv",
    seed=1337
)

In [3]:
# Read raw data
surnames = pd.read_csv(args.raw_dataset_csv, header=0)
surnames.head()

Unnamed: 0,surname,nationality
0,Woodford,English
1,Coté,French
2,Kore,English
3,Koury,Arabic
4,Lebzak,Russian


In [5]:
# Unique classes
set(surnames.nationality)

{'Arabic',
 'Chinese',
 'Czech',
 'Dutch',
 'English',
 'French',
 'German',
 'Greek',
 'Irish',
 'Italian',
 'Japanese',
 'Korean',
 'Polish',
 'Portuguese',
 'Russian',
 'Scottish',
 'Spanish',
 'Vietnamese'}

In [9]:
# Splitting train by nationality
by_nationality = collections.defaultdict(list)
for _, row in surnames.iterrows():
    by_nationality[row.nationality].append(row.to_dict())

In [13]:
# Create split data
final_list = []
np.random.seed(args.seed)
for _, item_list in sorted(by_nationality.items()):
    np.random.shuffle(item_list)
    n = len(item_list)
    n_train = int(args.train_proportion * n)
    n_val = int(args.val_proportion * n)
    n_test = int(args.test_proportion * n)

    # Give data point a split attribute
    for item in item_list[:n_train]:
        item['split'] = 'train'
    for item in item_list[n_train:n_train + n_val]:
        item['split'] = 'val'
    for item in item_list[n_train + n_val:]:
        item['split'] = 'test'

    # Add to final list
    final_list.extend(item_list)

In [15]:
# Write split data to file
final_surnames = pd.DataFrame(final_list)
final_surnames.split.value_counts()

train    7680
test     1660
val      1640
Name: split, dtype: int64

In [16]:
final_surnames.head()

Unnamed: 0,surname,nationality,split
0,Totah,Arabic,train
1,Abboud,Arabic,train
2,Fakhoury,Arabic,train
3,Srour,Arabic,train
4,Sayegh,Arabic,train


In [17]:
# Write munged data to CSV
final_surnames.to_csv(args.output_munged_csv, index=False)