diff --git a/splink/diagnostics.py b/splink/diagnostics.py index 720e1da7b5..bb4a78ad9e 100644 --- a/splink/diagnostics.py +++ b/splink/diagnostics.py @@ -150,3 +150,60 @@ def splink_score_histogram( ) return _create_probability_density_plot(rows) + + +def comparison_vector_distribution( + df_gammas: DataFrame, + sort_by_colname=None, +): + + spark = df_gammas.sql_ctx.sparkSession + + g_cols = [c for c in df_gammas.columns if c.startswith("gamma_")] + sel_cols = g_cols + if sort_by_colname: + sel_cols = g_cols + [sort_by_colname] + df_gammas = df_gammas.select(sel_cols) + + cols_expr = ", ".join([f'"{c}"' for c in g_cols]) + cols_expr = ", ".join(g_cols) + + df_gammas.createOrReplaceTempView("df_gammas") + + case_tem = "(case when {g} = -1 then 0 when {g} = 0 then -1 else {g} end)" + sum_gams = " + ".join([case_tem.format(g=c) for c in g_cols]) + + sort_col_expr = "" + if sort_by_colname: + sort_col_expr = f", avg({sort_by_colname}) as {sort_by_colname}" + + sql = f""" + select {cols_expr}, concat_ws(',', {cols_expr}) as cc, {sum_gams} as sum_gam, count(*) as count {sort_col_expr} + from df_gammas + group by {cols_expr} + order by {cols_expr} + """ + + gammas_counts = spark.sql(sql).toPandas() + + hist_def_dict = load_chart_definition("gamma_histogram.json") + hist_def_dict["data"]["values"] = gammas_counts.to_dict(orient="records") + + tt = [{"field": "count", "type": "quantitative"}] + + if sort_by_colname: + score_tt = {"field": sort_by_colname, "type": "quantitative"} + else: + score_tt = {"field": "sum_gam", "type": "quantitative"} + + tt.append(score_tt) + + g_tts = [{"field": c, "type": "nominal"} for c in g_cols] + tt.extend(g_tts) + + hist_def_dict["encoding"]["tooltip"] = tt + + if sort_by_colname: + hist_def_dict["encoding"]["x"]["sort"]["field"] = sort_by_colname + + return altair_if_installed_else_json(hist_def_dict) \ No newline at end of file diff --git a/splink/files/chart_defs/gamma_histogram.json b/splink/files/chart_defs/gamma_histogram.json new file mode 100644 index 0000000000..2ef9e90ab5 --- /dev/null +++ b/splink/files/chart_defs/gamma_histogram.json @@ -0,0 +1,35 @@ +{ + "config": { + "view": { + "width": 400, + "height": 300 + } + }, + "data": { + "values": null + }, + "mark": "bar", + "encoding": { + "x": { + "type": "nominal", + "field": "cc", + "sort": { + "field": "sum_gam", + "op": "sum", + "order": "ascending" + }, + "title": "gammas" + }, + "y": { + "type": "quantitative", + "field": "count", + "scale": { + "constant": 10, + "type": "symlog" + }, + "title": "count" + } + }, + "width": 1000, + "$schema": "https://vega.github.io/schema/vega-lite/v4.8.1.json" +}