Skip to content

Commit

Permalink
bump vega versions
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinL committed Oct 26, 2020
1 parent e480d66 commit 3fb056d
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions splink/params.py
Expand Up @@ -200,7 +200,7 @@ def _convert_params_dict_to_bayes_factor_data(self):

data.append(row)
return data

def _convert_params_dict_to_bayes_factor_iteration_history(self):
"""
Get the data needed for a chart that shows which comparison
Expand All @@ -227,9 +227,9 @@ def _convert_params_dict_to_bayes_factor_iteration_history(self):
row["u"] = this_gamma["prob_dist_non_match"][level]["probability"]
try:
row["bayes_factor"] = row["m"] / row["u"]
except ZeroDivisionError:
except ZeroDivisionError:
row["bayes_factor"] = None

if it_num == len(self.param_history)-1:
row["final"]=True
else:
Expand Down Expand Up @@ -371,11 +371,11 @@ def is_converged(self):
if new_change > biggest_change:
biggest_change = new_change
biggest_change_key = key

logger.info(f"The maximum change in parameters was {biggest_change} for key {biggest_change_key}")

return(all(diff))

def get_settings_with_current_params(self):
return get_or_update_settings(self)

Expand All @@ -390,10 +390,10 @@ def field_value_to_probs(fv):
u_probs = []
for key, val in fv['prob_dist_non_match'].items():
u_probs.append(val["probability"])

print(f'"m_probabilities": {m_probs},')
print(f'"u_probabilities": {u_probs}')

for field, value in self.params['π'].items():
print(field)
field_value_to_probs(value)
Expand Down Expand Up @@ -438,7 +438,7 @@ def probability_distribution_chart(self): # pragma: no cover
return alt.Chart.from_dict(probability_distribution_chart)
else:
return probability_distribution_chart

def gamma_distribution_chart(self): # pragma: no cover
"""
If altair is installed, returns the chart
Expand Down Expand Up @@ -466,34 +466,34 @@ def bayes_factor_chart(self): # pragma: no cover
return alt.Chart.from_dict(bayes_factor_chart_def)
else:
return bayes_factor_chart_def

def bayes_factor_history_charts(self):
"""
If altair is installed, returns the chart
Otherwise will return the chart spec as a dictionary
"""
# Empty list of chart definitions
chart_defs = []

# Full iteration history
data = self._convert_params_dict_to_bayes_factor_iteration_history()

# Create charts for each column
for col_dict in self.settings["comparison_columns"]:

# Get column name
if "col_name" in col_dict:
col_name = col_dict["col_name"]
elif "custom_name" in col_dict:
col_name = col_dict["custom_name"]
col_name = col_dict["custom_name"]

chart_def = copy.deepcopy(bayes_factor_history_chart_def)
# Assign iteration history to values of chart_def
chart_def["data"]["values"] = [d for d in data if d['column']==col_name]
chart_def["title"]["text"] = col_name
chart_def["hconcat"][1]["layer"][0]["encoding"]["color"]["legend"]["tickCount"] = col_dict["num_levels"]-1
chart_defs.append(chart_def)

combined_charts = {
"config": {
"view": {"width": 400, "height": 120},
Expand All @@ -503,12 +503,12 @@ def bayes_factor_history_charts(self):
"resolve": {"scale":{"color": "independent"}},
'$schema': 'https://vega.github.io/schema/vega-lite/v4.8.1.json'
}

if altair_installed:
return alt.Chart.from_dict(combined_charts)
else:
return combined_charts


def all_charts_write_html_file(self, filename="splink_charts.html", overwrite=False):

Expand All @@ -530,7 +530,7 @@ def all_charts_write_html_file(self, filename="splink_charts.html", overwrite=Fa

c5 = self.bayes_factor_history_charts().to_json(indent=None)
c6 = self.gamma_distribution_chart().to_json(indent=None)

with open(filename, "w") as f:
f.write(
multi_chart_template.format(
Expand All @@ -554,16 +554,16 @@ def all_charts_write_html_file(self, filename="splink_charts.html", overwrite=Fa
c4 = json.dumps(self.ll_iteration_chart())
else:
c4 = ""

c5 = json.dumps(self.bayes_factor_history_charts())
c6 = json.dumps(self.gamma_distribution_chart())

with open(filename, "w") as f:
f.write(
multi_chart_template.format(
vega_version="5",
vegalite_version="3.3.0",
vegaembed_version="4",
vega_version="5.17.0",
vegalite_version="4.17.0",
vegaembed_version="6",
spec1=c1,
spec2=c6,
spec3=c2,
Expand Down Expand Up @@ -680,34 +680,34 @@ def _flatten_dict(dictionary, accumulator=None, parent_key=None, separator="_"):

@check_types
def get_or_update_settings(params: Params, settings: dict = None):

if not settings:
settings = params.settings
settings["proportion_of_matches"] = params.params['λ']

settings["proportion_of_matches"] = params.params['λ']

for comp in settings["comparison_columns"]:
if "col_name" in comp.keys():
label = "gamma_"+comp["col_name"]
else:
label = "gamma_"+comp["custom_name"]

if "num_levels" in comp.keys():
num_levels = comp["num_levels"]
else:
num_levels = _get_default_value("num_levels", is_column_setting=True)


if label in params.params["π"].keys():
saved = params.params["π"][label]

if num_levels == saved["num_levels"]:
m_probs = [val['probability'] for key, val in saved["prob_dist_match"].items()]
u_probs = [val['probability'] for key, val in saved["prob_dist_non_match"].items()]

comp["m_probabilities"] = m_probs
comp["u_probabilities"] = u_probs
else:
warnings.warn(f"{label}: Saved m and u probabilities do not match the specified number of levels ({num_levels}) - default probabilities will be used")

return(settings)

0 comments on commit 3fb056d

Please sign in to comment.