In [3]:
import datasets

data = datasets.load_dataset("lecslab/glosslm")

In [None]:
# Map glottocodes to languages

from tqdm import tqdm
from pyglottolog import Glottolog
glottolog = Glottolog('../../glottolog')

all_glottocodes = set(data['train'].unique('glottocode'))
glottocode_mapping = dict()
for code in tqdm(all_glottocodes):
    languoid = glottolog.languoid(code) if code != '' and code is not None else None
    glottocode_mapping[code] = languoid.name if languoid else 'Unknown language'

glottocode_mapping

In [16]:
def map_glottocodes(row):
    row['language'] = glottocode_mapping[row['glottocode']]
    row['metalang'] = glottocode_mapping[row['metalang_glottocode']]
    if row['metalang'] == 'Unknown language':
        row['metalang'] = ''
    return row

data = data.map(map_glottocodes)

Map:   0%|          | 0/425020 [00:00<?, ? examples/s]

In [19]:
data.push_to_hub('lecslab/glosslm')

In [21]:
data = data['train'].to_pandas()

in_dist = ['arap1274', 'uspa1245', 'dido1241']
oo_dist = ['nyan1302', 'natu1246', 'lezg1247']

mask_in_dist_train = (data['source'] == 'sigmorphon_st') & data['glottocode'].isin(in_dist) & (data['ID'].str.contains('train'))
mask_in_dist_eval = (data['source'] == 'sigmorphon_st') & data['glottocode'].isin(in_dist) & (data['ID'].str.contains('dev'))
mask_in_dist_test = (data['source'] == 'sigmorphon_st') & data['glottocode'].isin(in_dist) & (data['ID'].str.contains('test'))

mask_ood_dist_train = (data['glottocode'].isin(oo_dist) & ((data['ID'].str.contains('train')) | (data['source'] != 'sigmorphon_st')))
mask_ood_dist_eval = data['glottocode'].isin(oo_dist) & (data['ID'].str.contains('dev'))
mask_ood_dist_test = data['glottocode'].isin(oo_dist) & (data['ID'].str.contains('test'))

mask_other_train = ~(mask_in_dist_train | mask_in_dist_eval | mask_in_dist_test | mask_ood_dist_train | mask_ood_dist_eval | mask_ood_dist_test)

split_dataset = datasets.DatasetDict()

split_dataset['train']     = datasets.Dataset.from_pandas(data[mask_in_dist_train | mask_other_train])
split_dataset['train_ID']  = datasets.Dataset.from_pandas(data[mask_in_dist_train])
split_dataset['eval_ID']   = datasets.Dataset.from_pandas(data[mask_in_dist_eval])
split_dataset['test_ID']   = datasets.Dataset.from_pandas(data[mask_in_dist_test])
split_dataset['train_OOD'] = datasets.Dataset.from_pandas(data[mask_ood_dist_train])
split_dataset['eval_OOD']  = datasets.Dataset.from_pandas(data[mask_ood_dist_eval])
split_dataset['test_OOD']  = datasets.Dataset.from_pandas(data[mask_ood_dist_test])

split_dataset = split_dataset.remove_columns(['type', '__index_level_0__'])
split_dataset

DatasetDict({
    train: Dataset({
        features: ['ID', 'glottocode', 'transcription', 'glosses', 'translation', 'metalang_glottocode', 'is_segmented', 'source', 'language', 'metalang'],
        num_rows: 392951
    })
    train_ID: Dataset({
        features: ['ID', 'glottocode', 'transcription', 'glosses', 'translation', 'metalang_glottocode', 'is_segmented', 'source', 'language', 'metalang'],
        num_rows: 104928
    })
    eval_ID: Dataset({
        features: ['ID', 'glottocode', 'transcription', 'glosses', 'translation', 'metalang_glottocode', 'is_segmented', 'source', 'language', 'metalang'],
        num_rows: 11138
    })
    test_ID: Dataset({
        features: ['ID', 'glottocode', 'transcription', 'glosses', 'translation', 'metalang_glottocode', 'is_segmented', 'source', 'language', 'metalang'],
        num_rows: 11940
    })
    train_OOD: Dataset({
        features: ['ID', 'glottocode', 'transcription', 'glosses', 'translation', 'metalang_glottocode', 'is_segmented',

In [22]:
split_dataset.push_to_hub('lecslab/glosslm-split', commit_message='Add language names')

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/393 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/105 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/12 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/12 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

In [23]:
# Sanity check

split_dataset_total_rows = sum(len(dataset) for dataset in split_dataset.values()) - len(split_dataset['train_ID'])

if (split_dataset_total_rows != len(data)):
    print(f"Mismatch! {split_dataset_total_rows} in split and {len(data)} total")
else:
    print("Looks good :)")

Looks good :)
