In [1]:
import os

download_name = "worldcities.csv.bz2"
if not os.path.exists(download_name):
    import requests
    response = requests.get(f"https://raw.githubusercontent.com/bzitko/nlp_repo/main/lectures/p03/{download_name}")
    with open(download_name, "wb") as fp:
        fp.write(response.content)
    response.close()

name = "worldcities.csv"
if not os.path.exists(name):
    import bz2
    with open(download_name, 'rb') as bzf, open(name, 'wb') as fp:
        fp.write(bz2.decompress(bzf.read()))        

In [2]:
import collections
import numpy as np
import pandas as pd
import re

from argparse import Namespace

In [3]:
args = Namespace(
    raw_dataset_csv="worldcities.csv",
    train_proportion=0.7,
    val_proportion=0.15,
    test_proportion=0.15,
    output_munged_csv="worldcities_with_splits.csv",
    seed=1337
)

### Read Dataset

👍 Read raw dataset's csv file into pandas DataFrame. Use only `city` and `country` columns.

In [4]:
# Read raw data
cities = pd.read_csv(args.raw_dataset_csv, usecols=["city", "country"])
cities

Unnamed: 0,city,country
0,Tokyo,Japan
1,Jakarta,Indonesia
2,Delhi,India
3,Manila,Philippines
4,São Paulo,Brazil
...,...,...
42900,Tukchi,Russia
42901,Numto,Russia
42902,Nord,Greenland
42903,Timmiarmiut,Greenland


👍 Count how many datapoints are in each class.

In [5]:
cities.country.value_counts()


United States               7824
Brazil                      3604
Germany                     2643
Italy                       2140
France                      2019
                            ... 
Cook Islands                   1
Grenada                        1
Martinique                     1
Northern Mariana Islands       1
U.S. Virgin Islands            1
Name: country, Length: 239, dtype: int64

### Split to TRAIN, VAL, TEST 

👍 For each country that has equal or more than 7 cities make train, val and test split. Use percentages defined in args do determine how many datapoints will be for train, val and test.
Append dataframe with split column.

In [6]:
# Create dict
by_country = collections.defaultdict(list)
for _, row in cities.iterrows():
    by_country[row.country].append(row.to_dict())

# Removing countries with less than 7 cities
for country in list(by_country):
    n_cities = len(by_country[country])
    if  n_cities < 7:
        by_country.pop(country)
        print(f"removed {country} ({n_cities})")

removed Hong Kong (1)
removed Singapore (1)
removed Kuwait (4)
removed Sierra Leone (6)
removed Djibouti (6)
removed The Bahamas (3)
removed Martinique (1)
removed Gibraltar (1)
removed Reunion (2)
removed Bahrain (4)
removed Mauritius (6)
removed Curaçao (1)
removed French Polynesia (1)
removed Barbados (1)
removed Comoros (5)
removed New Caledonia (3)
removed Saint Lucia (2)
removed Vanuatu (6)
removed Bermuda (1)
removed Monaco (1)
removed Kiribati (2)
removed Aruba (2)
removed Jersey (1)
removed Mayotte (1)
removed Marshall Islands (1)
removed Isle Of Man (2)
removed Cayman Islands (1)
removed Seychelles (1)
removed Saint Vincent And The Grenadines (1)
removed Antigua And Barbuda (1)
removed Tonga (2)
removed Dominica (1)
removed Saint Kitts And Nevis (1)
removed American Samoa (1)
removed Gaza Strip (1)
removed Turks And Caicos Islands (1)
removed Federated States of Micronesia (5)
removed Tuvalu (1)
removed Cook Islands (1)
removed Grenada (1)
removed West Bank (1)
removed Northe

In [7]:
# Create split data
final_list = []
np.random.seed(args.seed)

for _, item_list in sorted(by_country.items()):
    np.random.shuffle(item_list)
    n = len(item_list)
    n_train = int(args.train_proportion*n)
    n_val = int(args.val_proportion*n)
    n_test = int(args.test_proportion*n)
    
    # Give data point a split attribute
    for item in item_list[:n_train]:
        item['split'] = 'train'
    for item in item_list[n_train:n_train+n_val]:
        item['split'] = 'val'
    for item in item_list[n_train+n_val:]:
        item['split'] = 'test'  
    
    # Add to final list
    final_list.extend(item_list)

In [8]:
# Write split data to file
final_cities = pd.DataFrame(final_list)

### Save dataset

👍 Save DataFrame with header into comma separated file defined by arguments.

In [9]:
final_cities.split.value_counts()

train    29867
test      6586
val       6335
Name: split, dtype: int64

In [10]:
final_cities.head()

Unnamed: 0,city,country,split
0,Farāh,Afghanistan,train
1,Kabul,Afghanistan,train
2,Tāluqān,Afghanistan,train
3,Zarghūn Shahr,Afghanistan,train
4,Bāmyān,Afghanistan,train


In [11]:
# Write munged data to CSV
final_cities.to_csv(args.output_munged_csv, index=False)