### General next-best experiments

In [2]:
import numpy as np
import pandas as pd

from typdiv_sampling.distance import get_summed_dist_dict
from typdiv_sampling.evaluation import Evaluator

In [3]:
DATASETS = [
    "flores200.csv",
    "ud_214.csv",
    "tydiqa.csv",
    "xcopa.csv",
    "aya_eval-human.csv",
]

# SELECT PARAMETERS HERE:

# Dataset: int (0, 1, 2, 3, 4) corresponding to which dataset to use (see above)
D_NUM = 4

# Number of languages to add
N = 1

In [4]:
def get_dists(frame, dist_df):
    """Get language distances for the given frame."""
    dists = dist_df[frame].loc[frame]
    id2lang = dists.columns.tolist()
    dists = dists.to_numpy()
    return dists, id2lang


def sample_maxsum(pre_sample, frame, dist_df, k):
    """Edited version of the sample code"""
    dists, id2lang = get_dists(frame, dist_df)
    all_langs = [i for i in range(len(dists))]
    langs = [sorted(frame).index(i) for i in pre_sample]

    while len(langs) <= k - 1:
        summed_dist = get_summed_dist_dict(dists, all_langs, langs)
        next_most_distant = max(summed_dist, key=lambda x: summed_dist[x])
        all_langs.remove(next_most_distant)
        langs.append(next_most_distant)

    return [id2lang[i] for i in langs]

In [5]:
# load results for dataset
df = pd.read_csv(f"lang_codes/{DATASETS[D_NUM]}", sep=";")

In [6]:
# load language distances
dist_df = pd.read_csv("../../data/gb_lang_dists.csv").set_index("Unnamed: 0")
dist_dict = dist_df.to_dict("dict")

# load grambank
gb = pd.read_csv("../../data/gb_processed.csv", index_col="Lang_ID")
gb = gb.drop(["Unnamed: 0", "Unnamed: 0.1"], axis=1)
gb.replace(to_replace="no_cov", value="?", inplace=True)
gb_by_lang = {i: np.array(row) for i, row in gb.iterrows()}

In [7]:
# glottocode to language name mapping
glt = pd.read_csv("../../data/languoid.csv")
gltc_to_name = {r["id"]: r["name"] for _, r in glt.iterrows()}

In [8]:
# define languages
dataset_langs = set(df["glottocode"].to_list())
gb_langs = set(gb_by_lang.keys())

In [9]:
# all languages we have data for
current_sample = dataset_langs.intersection(gb_langs)
print(
    f"We have GB data for {len(current_sample)} out of {len(dataset_langs)} languages."
)

We have GB data for 5 out of 7 languages.


In [10]:
# current diversity metrics
evaluator = Evaluator(gb_by_lang, dist_dict)
res = evaluator.evaluate_sample(current_sample, 1)
res

Result(run=1, ent_score_with=0.6623633424086156, ent_score_without=0.5714527405538357, fvi_score=0.8407960199004975, mpd_score=0.7549133172838374, fvo_score=0.6295528816549381, sample={'mand1415', 'stan1318', 'stan1293', 'nucl1301', 'port1283'})

In [11]:
# sample new
k = len(current_sample) + N
new_sample = sample_maxsum(list(current_sample), sorted(list(gb_langs)), dist_df, k)

In [12]:
# analysis
print("New languages:", [gltc_to_name[g] for g in new_sample[-N:]])
res = evaluator.evaluate_sample(new_sample, 1)
res

New languages: ['Yele']


Result(run=1, ent_score_with=0.7841457371605929, ent_score_without=0.6602670669895019, fvi_score=0.8980099502487562, mpd_score=0.7965600701194073, fvo_score=0.5920069229961086, sample={'yele1255', 'mand1415', 'stan1318', 'stan1293', 'nucl1301', 'port1283'})

### Case study: UD extension languages

In [31]:
# load languages
df = pd.read_csv("lang_codes/ud_214.csv", sep=";")
current_langs = set(df["glottocode"].to_list())

df_e = pd.read_csv("lang_codes/ud_extensions.csv", sep=";")
ext_langs = set(df_e["glottocode"].to_list())

# intersection and not yet in dataset
current_sample = current_langs.intersection(set(dist_df.columns.to_list()))
ext_langs = ext_langs.intersection(set(dist_df.columns.to_list()))
ext_langs = ext_langs - current_sample

In [45]:
old_res = evaluator.evaluate_sample(current_sample, 1)
old_res

Result(run=1, ent_score_with=0.8983371130601235, ent_score_without=0.6810527872902721, fvi_score=0.9850746268656716, mpd_score=0.7245766498834872, fvo_score=0.6790589232469976, sample={'yaku1245', 'sout1528', 'tata1255', 'wels1247', 'stan1290', 'nucl1235', 'bela1254', 'nhen1239', 'iris1253', 'faro1244', 'veps1250', 'komi1269', 'russ1264', 'komi1268', 'russ1263', 'stan1289', 'yuec1235', 'esto1258', 'mace1250', 'nucl1643', 'nort2697', 'livv1243', 'mbya1239', 'kore1280', 'zaca1241', 'anci1242', 'bret1244', 'czec1258', 'akka1240', 'erzy1239', 'mala1464', 'gali1258', 'apur1254', 'tswa1253', 'stan1318', 'dani1285', 'stan1293', 'finn1318', 'ital1282', 'malt1254', 'port1283', 'tami1289', 'bhoj1244', 'mund1330', 'west2354', 'jama1261', 'lati1261', 'poli1260', 'hind1269', 'kich1262', 'mara1378', 'xava1240', 'anci1244', 'west2369', 'copt1239', 'chuk1273', 'gheg1238', 'swed1254', 'mode1248', 'mand1415', 'nort2671', 'skol1241', 'boro1282', 'moks1248', 'hebr1245', 'clas1249', 'slov1268', 'hung1274',

In [32]:
# new sampling frame
langs = sorted(list(set(list(ext_langs) + list(current_sample))))
dist_df = dist_df[langs]
dist_df = dist_df.loc[langs]

In [43]:
# sample with extension language
k = len(current_sample) + N
new_sample = sample_maxsum(list(current_sample), langs, dist_df, k)
print("new language:", gltc_to_name[(set(new_sample) - current_sample).pop()])

new language: Seri


In [44]:
# evaluate diversity
new_res = evaluator.evaluate_sample(new_sample, 1)
new_res

Result(run=1, ent_score_with=0.9025768344198168, ent_score_without=0.6846505746879855, fvi_score=0.9850746268656716, mpd_score=0.7267890875768118, fvo_score=0.6771515080936407, sample={'yaku1245', 'sout1528', 'tata1255', 'stan1290', 'wels1247', 'iris1253', 'veps1250', 'komi1269', 'russ1263', 'yuec1235', 'esto1258', 'mace1250', 'livv1243', 'kore1280', 'anci1242', 'bret1244', 'czec1258', 'erzy1239', 'mala1464', 'apur1254', 'tswa1253', 'stan1318', 'dani1285', 'stan1293', 'ital1282', 'malt1254', 'port1283', 'mund1330', 'west2354', 'jama1261', 'lati1261', 'poli1260', 'kich1262', 'west2369', 'chuk1273', 'gheg1238', 'mand1415', 'boro1282', 'clas1249', 'hung1274', 'akun1241', 'ukra1253', 'latv1249', 'tupi1273', 'nucl1235', 'bela1254', 'nhen1239', 'faro1244', 'russ1264', 'komi1268', 'stan1289', 'nucl1643', 'nort2697', 'mbya1239', 'zaca1241', 'akka1240', 'gali1258', 'finn1318', 'tami1289', 'bhoj1244', 'seri1257', 'hind1269', 'mara1378', 'xava1240', 'anci1244', 'copt1239', 'swed1254', 'mode1248',