Skip to content

Commit

Permalink
black format params
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Dec 14, 2020
1 parent d354b38 commit 212d14b
Showing 1 changed file with 58 additions and 37 deletions.
95 changes: 58 additions & 37 deletions splink/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Params:
"""

def __init__(self, settings:dict, spark:SparkSession):
def __init__(self, settings: dict, spark: SparkSession):
"""[summary]
Args:
Expand Down Expand Up @@ -72,8 +72,7 @@ def describe_gammas(self):
return {k: i["desc"] for k, i in self.params["π"].items()}

def _generate_param_dict(self):
"""Uses the splink settings object to generate a parameter dictionary
"""
"""Uses the splink settings object to generate a parameter dictionary"""

for col_dict in self.settings["comparison_columns"]:
if "col_name" in col_dict:
Expand All @@ -91,7 +90,9 @@ def _generate_param_dict(self):

if "custom_name" in col_dict:
self.params["π"][f"gamma_{col_name}"]["custom_comparison"] = True
self.params["π"][f"gamma_{col_name}"]["custom_columns_used"] = col_dict["custom_columns_used"]
self.params["π"][f"gamma_{col_name}"]["custom_columns_used"] = col_dict[
"custom_columns_used"
]
else:
self.params["π"][f"gamma_{col_name}"]["custom_comparison"] = False

Expand Down Expand Up @@ -121,7 +122,9 @@ def _generate_param_dict(self):
}

self.params["π"][f"gamma_{col_name}"]["prob_dist_match"] = prob_dist_match
self.params["π"][f"gamma_{col_name}"]["prob_dist_non_match"] = prob_dist_non_match
self.params["π"][f"gamma_{col_name}"][
"prob_dist_non_match"
] = prob_dist_non_match

def _set_pi_value(self, gamma_str, level_int, match_str, prob_float):
"""
Expand Down Expand Up @@ -179,7 +182,7 @@ def _convert_params_dict_to_bayes_factor_data(self):
"""
data = []
# Want to compare the u and m probabilities
lam = self.params['λ']
lam = self.params["λ"]
pi = gk = self.params["π"]
gk = list(pi.keys())

Expand All @@ -192,7 +195,7 @@ def _convert_params_dict_to_bayes_factor_data(self):
row["column"] = this_gamma["column_name"]
row["m"] = this_gamma["prob_dist_match"][level]["probability"]
row["u"] = this_gamma["prob_dist_non_match"][level]["probability"]
row["level_proportion"] = row["m"]*lam + row["u"]*(1-lam)
row["level_proportion"] = row["m"] * lam + row["u"] * (1 - lam)
try:
row["bayes_factor"] = row["m"] / row["u"]
except ZeroDivisionError:
Expand Down Expand Up @@ -230,10 +233,10 @@ def _convert_params_dict_to_bayes_factor_iteration_history(self):
except ZeroDivisionError:
row["bayes_factor"] = None

if it_num == len(self.param_history)-1:
row["final"]=True
if it_num == len(self.param_history) - 1:
row["final"] = True
else:
row["final"]=False
row["final"] = False

data.append(row)
return data
Expand Down Expand Up @@ -359,42 +362,51 @@ def is_converged(self):
p_previous = self.param_history[-1]
threshold = self.settings["em_convergence"]

p_new = {key:value for key, value in _flatten_dict(p_latest).items() if '_probability' in key.lower()}
p_old = {key:value for key, value in _flatten_dict(p_previous).items() if '_probability' in key.lower()}
p_new = {
key: value
for key, value in _flatten_dict(p_latest).items()
if "_probability" in key.lower()
}
p_old = {
key: value
for key, value in _flatten_dict(p_previous).items()
if "_probability" in key.lower()
}

diff = [abs(p_new[item] - p_old[item]) < threshold for item in p_new]

biggest_change = 0
biggest_change_key = ''
biggest_change_key = ""
for key in p_new.keys():
new_change = abs(p_new[key] - p_old[key])
if new_change > biggest_change:
biggest_change = new_change
biggest_change_key = key

logger.info(f"The maximum change in parameters was {biggest_change} for key {biggest_change_key}")
logger.info(
f"The maximum change in parameters was {biggest_change} for key {biggest_change_key}"
)

return(all(diff))
return all(diff)

def get_settings_with_current_params(self):
return get_or_update_settings(self)

### The rest of this module is just 'presentational' elements - charts, and __repr__ etc.

def _print_m_u_probs(self):

def field_value_to_probs(fv):
m_probs = []
for key, val in fv['prob_dist_match'].items():
for key, val in fv["prob_dist_match"].items():
m_probs.append(val["probability"])
u_probs = []
for key, val in fv['prob_dist_non_match'].items():
for key, val in fv["prob_dist_non_match"].items():
u_probs.append(val["probability"])

print(f'"m_probabilities": {m_probs},')
print(f'"u_probabilities": {u_probs}')

for field, value in self.params['π'].items():
for field, value in self.params["π"].items():
print(field)
field_value_to_probs(value)

Expand Down Expand Up @@ -489,28 +501,31 @@ def bayes_factor_history_charts(self):

chart_def = copy.deepcopy(bayes_factor_history_chart_def)
# Assign iteration history to values of chart_def
chart_def["data"]["values"] = [d for d in data if d['column']==col_name]
chart_def["data"]["values"] = [d for d in data if d["column"] == col_name]
chart_def["title"]["text"] = col_name
chart_def["hconcat"][1]["layer"][0]["encoding"]["color"]["legend"]["tickCount"] = col_dict["num_levels"]-1
chart_def["hconcat"][1]["layer"][0]["encoding"]["color"]["legend"][
"tickCount"
] = (col_dict["num_levels"] - 1)
chart_defs.append(chart_def)

combined_charts = {
"config": {
"view": {"width": 400, "height": 120},
},
"title": {"text":"Bayes factor iteration history", "anchor": "middle"},
"title": {"text": "Bayes factor iteration history", "anchor": "middle"},
"vconcat": chart_defs,
"resolve": {"scale":{"color": "independent"}},
'$schema': 'https://vega.github.io/schema/vega-lite/v4.8.1.json'
"resolve": {"scale": {"color": "independent"}},
"$schema": "https://vega.github.io/schema/vega-lite/v4.8.1.json",
}

if altair_installed:
return alt.Chart.from_dict(combined_charts)
else:
return combined_charts


def all_charts_write_html_file(self, filename="splink_charts.html", overwrite=False):
def all_charts_write_html_file(
self, filename="splink_charts.html", overwrite=False
):

if os.path.isfile(filename):
if not overwrite:
Expand Down Expand Up @@ -542,7 +557,7 @@ def all_charts_write_html_file(self, filename="splink_charts.html", overwrite=Fa
spec3=c2,
spec4=c3,
spec5=c4,
spec6=c5
spec6=c5,
)
)
else:
Expand All @@ -569,7 +584,7 @@ def all_charts_write_html_file(self, filename="splink_charts.html", overwrite=Fa
spec3=c2,
spec4=c3,
spec5=c4,
spec6=c5
spec6=c5,
)
)

Expand Down Expand Up @@ -658,7 +673,6 @@ def load_params_from_dict(param_dict):
if keys == expected_keys:
p = Params(settings=param_dict["settings"], spark=None)


p.params = param_dict["current_params"]
p.param_history = param_dict["historical_params"]
else:
Expand All @@ -678,36 +692,43 @@ def _flatten_dict(dictionary, accumulator=None, parent_key=None, separator="_"):
accumulator[k] = v
return accumulator


@check_types
def get_or_update_settings(params: Params, settings: dict = None):

if not settings:
settings = params.settings

settings["proportion_of_matches"] = params.params['λ']
settings["proportion_of_matches"] = params.params["λ"]

for comp in settings["comparison_columns"]:
if "col_name" in comp.keys():
label = "gamma_"+comp["col_name"]
label = "gamma_" + comp["col_name"]
else:
label = "gamma_"+comp["custom_name"]
label = "gamma_" + comp["custom_name"]

if "num_levels" in comp.keys():
num_levels = comp["num_levels"]
else:
num_levels = _get_default_value("num_levels", is_column_setting=True)


if label in params.params["π"].keys():
saved = params.params["π"][label]

if num_levels == saved["num_levels"]:
m_probs = [val['probability'] for key, val in saved["prob_dist_match"].items()]
u_probs = [val['probability'] for key, val in saved["prob_dist_non_match"].items()]
m_probs = [
val["probability"] for key, val in saved["prob_dist_match"].items()
]
u_probs = [
val["probability"]
for key, val in saved["prob_dist_non_match"].items()
]

comp["m_probabilities"] = m_probs
comp["u_probabilities"] = u_probs
else:
warnings.warn(f"{label}: Saved m and u probabilities do not match the specified number of levels ({num_levels}) - default probabilities will be used")
warnings.warn(
f"{label}: Saved m and u probabilities do not match the specified number of levels ({num_levels}) - default probabilities will be used"
)

return(settings)
return settings

0 comments on commit 212d14b

Please sign in to comment.