In [2]:
import datasets
dataset = datasets.load_dataset('lecslab/glosslm-split')
dataset = dataset.filter(lambda x: x["transcription"] is not None and x["glosses"] is not None)

In [12]:
from collections import defaultdict
import re
import random
import pandas as pd
from eval import strip_gloss_punctuation

# Segmented
# For each language, learn the conditional distribution of morphemes -> glosses

all_train = datasets.concatenate_datasets([dataset['train'], dataset['train_OOD']])
all_train = all_train.filter(lambda row: row["is_segmented"] == "yes")
all_test = datasets.concatenate_datasets([dataset['test_ID'], dataset['test_OOD']])
all_test = all_test.filter(lambda row: row["is_segmented"] == "yes")

def gloss_with_top_gloss(gloss_dict):
    return max(gloss_dict, key=gloss_dict.get)

def gloss_with_random_gloss(gloss_dict):
    return random.choice(list(gloss_dict.keys()))


def make_predictions(glottocode, method):
    select_gloss = {'top': gloss_with_top_gloss, 'random': gloss_with_random_gloss}[method]

    train_data = all_train.filter(lambda row: row['glottocode'] == glottocode)
    test_data = all_test.filter(lambda row: row['glottocode'] == glottocode)

    morpheme_glosses = defaultdict(lambda: defaultdict(lambda: 0))
    for row in train_data:
        for word, glossed_word in zip(strip_gloss_punctuation(row['transcription']).split(),
                                      strip_gloss_punctuation(row['glosses']).split()):
            for morpheme, gloss in zip(re.split(r"\s|-", word), re.split(r"\s|-", glossed_word)):
                morpheme_glosses[morpheme.lower()][gloss] += 1


    preds = []
    for row in test_data:
        line_predictions = []
        for word in strip_gloss_punctuation(row['transcription']).split():
            word_predictions = []
            for morpheme in re.split(r"\s|-", word):
                if morpheme not in morpheme_glosses:
                    word_predictions.append("???")
                else:
                    word_predictions.append(select_gloss(morpheme_glosses[morpheme.lower()]))
            line_predictions.append('-'.join(word_predictions))
        preds.append(' '.join(line_predictions))

    gold = [strip_gloss_punctuation(g) for g in test_data["glosses"]]

    return pd.DataFrame({
        "id": test_data["id"],
        "glottocode": test_data["glottocode"],
        "is_segmented": test_data["is_segmented"],
        "pred": preds,
        "gold": gold,
    })


splits = {'ID': ['arap1274', 'dido1241', 'uspa1245'],
          'OOD': ['gitx1241', 'lezg1247', 'natu1246', 'nyan1302' ]}
for method in ['top', 'random']:
    for split in ['ID', 'OOD']:
        all_preds = []
        for lang in splits[split]:
            all_preds.append(make_predictions(lang, method))

        combined = pd.concat(all_preds)
        combined.to_csv(f'../preds/naive-{method}/test_{split}-preds.csv', index=False)

In [9]:
dataset['train'][100]

{'transcription': 'ɑʑ-ɑd',
 'glosses': 'front-POSS.2SG.INE/ILL',
 'translation': 'in front of you',
 'glottocode': 'udmu1245',
 'id': 'uratyp_124',
 'source': 'uratyp',
 'metalang_glottocode': 'stan1293',
 'is_segmented': 'yes',
 'language': 'Udmurt',
 'metalang': 'English'}