In [None]:
import sys
import json
import polars as pl
import matplotlib.pyplot as plt
import blitzbeaver as bb

sys.path.append("..")
pl.Config.set_tbl_rows(100)

In [None]:
from genealogy.processing import (
    load_dataframes,
)
from genealogy.models import GenealogyNode, Element

In [None]:
START_YEAR = 1835
dataframes = load_dataframes(
    folder_path="../data/normalized",
    start_year=START_YEAR,
    end_year=1898,
)
# combine all dataframes into one by concatenating them
combined_dataframe = pl.concat(dataframes, how="vertical")

In [None]:
record_schema = bb.RecordSchema(
    [
        bb.FieldSchema("nom_rue", bb.ElementType.String),
        bb.FieldSchema("chef_prenom", bb.ElementType.String),
        bb.FieldSchema("chef_nom", bb.ElementType.String),
        bb.FieldSchema("chef_origine", bb.ElementType.String),
        bb.FieldSchema("epouse_nom", bb.ElementType.String),
        bb.FieldSchema("chef_vocation", bb.ElementType.String),
        bb.FieldSchema("enfants_chez_parents_prenom", bb.ElementType.MultiStrings),
    ]
)

In [None]:
path_graph = "../graph_35_98.beaver"

graph = bb.read_beaver(path_graph)

In [None]:
public_path = "../data/public"

with open(f"{public_path}/trees.json", "r") as f:
    genealogy_trees = json.load(f)

In [None]:
import Levenshtein

Metric = tuple[str, str, float]


def get_chain(id: bb.ID) -> bb.MaterializedTrackingChain:
    return graph.materialize_tracking_chain(id, dataframes, record_schema)

def get_cluster_values(df: pl.DataFrame, col: str, min_count: int=3) -> list[Element]:
    entries = df.select([pl.col(col).value_counts(sort=True)]).get_column(col)
    return [e[col] for e in entries if e[col] is not None and e["count"] >= min_count]

def lv(v1: str, v2: str) -> float:
    return 1.0 - Levenshtein.distance(v1, v2) / max(len(v1), len(v2))

def compute_tree_metrics(node: GenealogyNode, col: str) -> list[Metric]:
    metrics = []
    chain = get_chain(node["id"])
    parent_values = get_cluster_values(chain.as_dataframe(), col)
    for child in node["children"]:
        child_chain = get_chain(child["id"])
        child_values = get_cluster_values(child_chain.as_dataframe(), col)
        for pv in parent_values:
            for cv in child_values:
                dist = lv(pv, cv)
                metrics.append((pv, cv, dist))    
        metrics += compute_tree_metrics(child, col)
    return metrics

def get_most_common_values(
    df: pl.DataFrame, col: str, top_n: int = 10
) -> list[tuple[str, int]]:
    return [
        e[col]
        for e in df.select([pl.col(col).value_counts(sort=True).head(top_n)])
        .get_column(col)
        .to_list()
        if e[col] is not None
    ]

def compute_match_ratio(
    metrics: list[Metric], threshold: float, parent_value: str | None = None
) -> tuple[int, int]:
    count = 0
    num_match = 0
    for metric in metrics:
        if parent_value is not None and metric[0] != parent_value:
            continue
        count += 1
        if metric[2] > threshold:
            num_match += 1
    return (num_match, count)

In [None]:
col = "nom_rue"
mcvs = get_most_common_values(combined_dataframe, col, top_n=50)
metrics = []
for tree in genealogy_trees:
    metrics += compute_tree_metrics(tree, col)

print(f"Number of metrics: {len(metrics)}")
match, count = compute_match_ratio(metrics, 0.95)
print(f"Overall ratio: {match/count:.4f} ({match}/{count})")

In [None]:
max_entries = 20
cats = []
for mcv in mcvs:
    match, count = compute_match_ratio(metrics, 0.95, parent_value=mcv)
    if count < 10:
        continue
    cats.append((mcv, match, count))

In [None]:
gcats = sorted(cats, key=lambda x: x[2])[-max_entries:]
labels = [cat[0] for cat in gcats]
counts = [cat[2] for cat in gcats]
matches = [cat[1] for cat in gcats]

plt.rcParams.update({'font.size': 14})

plt.figure(figsize=(12, 10))
plt.barh(labels, counts, color="lightblue", label="Chef de famille")
plt.barh(labels, matches, color="orange", label="Enfant")
plt.title(
    "Nombre d'occurrences des adresses chef de famille/enfant"
)
plt.xlabel("Nombre d'occurrences")
plt.legend()
plt.savefig("adr_chef_enfant.png", bbox_inches="tight")

In [None]:
gcats = sorted(cats, key=lambda x: x[1] / x[2])[-max_entries:]
labels = [cat[0] for cat in gcats]
ratios = [cat[1] / cat[2] for cat in gcats]

plt.rcParams.update({'font.size': 14})

plt.figure(figsize=(12, 10))
plt.barh(labels, [1] * len(ratios), color="lightblue")
plt.barh(labels, ratios, color="orange")
plt.xlabel("Pourcentage d'addresses transmises")
plt.xlim(0, 1)
plt.title("Pourcentage d'addresses transmises")
plt.savefig("adr_chef_enfant_ratio.png", bbox_inches="tight")