Skip to content

Commit

Permalink
New Data Type + Action: Automatic geographic attribute visualizations (
Browse files Browse the repository at this point in the history
…#253)

* Merge upstream

* Add new default action for geographic data types (longitude and latitude): geoshape

* Add 'country' as a new secondary geographical feature

* Reformat

* remove unneeded file and run black

* Rename  to , remove vega_datasets from tests

* Clean up intents in map.py

* Add support to aggregate quantitative attributes

* Format and test

* Format with black

* Replace vega dependency

* Formatting

* SIGNIFICANT CHANGES: Reuse univariate action, modify symbolmap, modify tests, use PandasExecutor

* Format via black

* Clean up helper functions for detecting geotypes

* Resolve PR comments

* Add exportability for choropleths

* ARemove zero padding from fips and iso codes

* Resolve comments from @caitlynachen

Co-authored-by: Doris Lee <dorisjunglinlee@gmail.com>
  • Loading branch information
micahtyong and dorisjlee committed Mar 5, 2021
1 parent 952b642 commit 47ef480
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 9 deletions.
1 change: 1 addition & 0 deletions lux/action/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def register_default_actions():
lux.config.register_action("distribution", univariate, no_vis, "quantitative")
lux.config.register_action("occurrence", univariate, no_vis, "nominal")
lux.config.register_action("temporal", univariate, no_vis, "temporal")
lux.config.register_action("geographical", univariate, no_vis, "geographical")

lux.config.register_action("Enhance", enhance, one_current_vis)
lux.config.register_action("Filter", add_filter, one_current_vis)
Expand Down
16 changes: 16 additions & 0 deletions lux/action/univariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ def univariate(ldf, *args):
"description": "Show frequency of occurrence for <p class='highlight-descriptor'>categorical</p> attributes.",
"long_description": f"Occurence displays bar charts of counts for all categorical attributes{examples}. Visualizations are ranked from most to least uneven across the bars. ",
}
elif data_type_constraint == "geographical":
possible_attributes = [
c
for c in ldf.columns
if ldf.data_type[c] == "geographical" and ldf.cardinality[c] > 5 and c != "Number of Records"
]
examples = ""
if len(possible_attributes) >= 1:
examples = f" (e.g., {possible_attributes[0]})"
intent = [lux.Clause("?", data_type="geographical"), lux.Clause("?", data_model="measure")]
intent.extend(filter_specs)
recommendation = {
"action": "Geographical",
"description": "Show choropleth maps of <p class='highlight-descriptor'>geographic</p> attributes",
"long_description": f"Occurence displays choropleths of averages for some geographic attribute{examples}. Visualizations are ranked by diversity of the geographic attribute.",
}
elif data_type_constraint == "temporal":
intent = [lux.Clause("?", data_type="temporal")]
intent.extend(filter_specs)
Expand Down
7 changes: 4 additions & 3 deletions lux/executor/Executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def compute_data_type(self):

def mapping(self, rmap):
group_map = {}
for val in ["quantitative", "id", "nominal", "temporal", "geoshape"]:
for val in ["quantitative", "id", "nominal", "temporal", "geographical"]:
group_map[val] = list(filter(lambda x: rmap[x] == val, rmap))
return group_map

Expand All @@ -74,10 +74,11 @@ def invert_data_type(self, data_type):
def compute_data_model(self, data_type):
data_type_inverted = self.invert_data_type(data_type)
data_model = {
"measure": data_type_inverted["quantitative"] + data_type_inverted["geoshape"],
"measure": data_type_inverted["quantitative"],
"dimension": data_type_inverted["nominal"]
+ data_type_inverted["temporal"]
+ data_type_inverted["id"],
+ data_type_inverted["id"]
+ data_type_inverted["geographical"],
}
return data_model

Expand Down
10 changes: 9 additions & 1 deletion lux/executor/PandasExecutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def execute(vislist: VisList, ldf: LuxDataFrame):
# TODO: Add some type of cap size on Nrows ?
vis._vis_data = vis.data[list(attributes)]

if vis.mark == "bar" or vis.mark == "line":
if vis.mark == "bar" or vis.mark == "line" or vis.mark == "geographical":
PandasExecutor.execute_aggregate(vis, isFiltered=filter_executed)
elif vis.mark == "histogram":
PandasExecutor.execute_binning(vis)
Expand Down Expand Up @@ -418,6 +418,8 @@ def compute_data_type(self, ldf: LuxDataFrame):
ldf._data_type[attr] = "temporal"
elif self._is_datetime_number(ldf[attr]):
ldf._data_type[attr] = "temporal"
elif self._is_geographical_attribute(ldf[attr]):
ldf._data_type[attr] = "geographical"
elif pd.api.types.is_float_dtype(ldf.dtypes[attr]):
# int columns gets coerced into floats if contain NaN
convertible2int = pd.api.types.is_integer_dtype(ldf[attr].convert_dtypes())
Expand Down Expand Up @@ -491,6 +493,12 @@ def _is_datetime_string(series):
return True
return False

@staticmethod
def _is_geographical_attribute(series):
# run detection algorithm
name = str(series.name).lower()
return utils.like_geo(name)

@staticmethod
def _is_datetime_number(series):
if series.dtype == int:
Expand Down
28 changes: 28 additions & 0 deletions lux/interestingness/interestingness.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def interestingness(vis: Vis, ldf: LuxDataFrame) -> int:
if v_size < 2:
return -1

if vis.mark == "geographical":
return n_distinct(vis, dimension_lst, measure_lst)
if n_filter == 0:
return unevenness(vis, ldf, measure_lst, dimension_lst)
elif n_filter == 1:
Expand Down Expand Up @@ -363,3 +365,29 @@ def monotonicity(vis: Vis, attr_specs: list, ignore_identity: bool = True) -> in
return -1
else:
return score


def n_distinct(vis: Vis, dimension_lst: list, measure_lst: list) -> int:
"""
Computes how many unique values there are for a dimensional data type.
Ignores attributes that are latitude or longitude coordinates.
For example, if a dataset displayed earthquake magnitudes across 48 states and
3 countries, return 48 and 3 respectively.
Parameters
----------
vis : Vis
dimension_lst: list
List of dimension Clause objects.
measure_lst: list
List of measure Clause objects.
Returns
-------
int
Score describing the number of unique values in the dimension.
"""
if measure_lst[0].get_attr() in {"longitude", "latitude"}:
return -1
return vis.data[dimension_lst[0].get_attr()].nunique()
10 changes: 7 additions & 3 deletions lux/processor/Compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def populate_data_type_model(ldf, vlist):
clause.data_type = ldf.data_type[clause.attribute]
if clause.data_type == "id":
clause.data_type = "nominal"
if clause.data_type == "geographical":
clause.data_type = "nominal"
if clause.data_model == "":
clause.data_model = data_model_lookup[clause.attribute]
if clause.value != "":
Expand Down Expand Up @@ -261,7 +263,7 @@ def determine_encoding(ldf: LuxDataFrame, vis: Vis):
filters = utils.get_filter_specs(vis._inferred_intent)

# Helper function (TODO: Move this into utils)
def line_or_bar(ldf, dimension: Clause, measure: Clause):
def line_or_bar_or_geo(ldf, dimension: Clause, measure: Clause):
dim_type = dimension.data_type
# If no aggregation function is specified, then default as average
if measure.aggregation == "":
Expand All @@ -272,6 +274,8 @@ def line_or_bar(ldf, dimension: Clause, measure: Clause):
# if cardinality large than 5 then sort bars
if ldf.cardinality[dimension.attribute] > 5:
dimension.sort = "ascending"
if utils.like_geo(dimension.get_attr()):
return "geographical", {"x": dimension, "y": measure}
return "bar", {"x": measure, "y": dimension}

# ShowMe logic + additional heuristics
Expand Down Expand Up @@ -299,7 +303,7 @@ def line_or_bar(ldf, dimension: Clause, measure: Clause):
vis._inferred_intent.append(count_col)
dimension = vis.get_attr_by_data_model("dimension")[0]
measure = vis.get_attr_by_data_model("measure")[0]
vis._mark, auto_channel = line_or_bar(ldf, dimension, measure)
vis._mark, auto_channel = line_or_bar_or_geo(ldf, dimension, measure)
elif ndim == 2 and (nmsr == 0 or nmsr == 1):
# Line or Bar chart broken down by the dimension
dimensions = vis.get_attr_by_data_model("dimension")
Expand All @@ -323,7 +327,7 @@ def line_or_bar(ldf, dimension: Clause, measure: Clause):
if nmsr == 0 and not ldf.pre_aggregated:
vis._inferred_intent.append(count_col)
measure = vis.get_attr_by_data_model("measure")[0]
vis._mark, auto_channel = line_or_bar(ldf, dimension, measure)
vis._mark, auto_channel = line_or_bar_or_geo(ldf, dimension, measure)
auto_channel["color"] = color_attr
elif ndim == 0 and nmsr == 2:
# Scatterplot
Expand Down
4 changes: 4 additions & 0 deletions lux/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def like_nan(val):
return math.isnan(val)


def like_geo(val):
return isinstance(val, str) and val.lower() in {"state", "country"}


def matplotlib_setup(w, h):
plt.ioff()
fig, ax = plt.subplots(figsize=(w, h))
Expand Down
3 changes: 3 additions & 0 deletions lux/vislib/altair/AltairRenderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lux.vislib.altair.LineChart import LineChart
from lux.vislib.altair.Histogram import Histogram
from lux.vislib.altair.Heatmap import Heatmap
from lux.vislib.altair.Choropleth import Choropleth


class AltairRenderer:
Expand Down Expand Up @@ -82,6 +83,8 @@ def create_vis(self, vis, standalone=True):
chart = LineChart(vis)
elif vis.mark == "heatmap":
chart = Heatmap(vis)
elif vis.mark == "geographical":
chart = Choropleth(vis)
else:
chart = None

Expand Down
180 changes: 180 additions & 0 deletions lux/vislib/altair/Choropleth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright 2019-2020 The Lux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from lux.vislib.altair.AltairChart import AltairChart
import altair as alt
import us
from iso3166 import countries

alt.data_transformers.disable_max_rows()


class Choropleth(AltairChart):
"""
Choropleth is a subclass of AltairChart that renders choropleth maps.
All rendering properties for proportional symbol maps are set here.
See Also
--------
altair-viz.github.io
"""

us_url = "https://cdn.jsdelivr.net/npm/vega-datasets@v1.29.0/data/us-10m.json"
world_url = "https://cdn.jsdelivr.net/npm/vega-datasets@v1.29.0/data/world-110m.json"

def __init__(self, dobj):
super().__init__(dobj)

def __repr__(self):
return f"Proportional Symbol Map <{str(self.vis)}>"

def initialize_chart(self):
x_attr = self.vis.get_attr_by_channel("x")[0]
y_attr = self.vis.get_attr_by_channel("y")[0]

x_attr_abv = str(x_attr.attribute)
y_attr_abv = str(y_attr.attribute)

background, background_str = self.get_background(x_attr_abv.lower())
geographical_name = self.get_geographical_name(x_attr_abv.lower())
geo_map, geo_map_str, map_type, map_translation = self.get_geomap(x_attr_abv.lower())
self.data[x_attr_abv] = self.data[x_attr_abv].apply(map_translation)

if len(x_attr_abv) > 25:
x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:]
if len(y_attr_abv) > 25:
y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:]

if isinstance(x_attr.attribute, str):
x_attr.attribute = x_attr.attribute.replace(".", "")
if isinstance(y_attr.attribute, str):
y_attr.attribute = y_attr.attribute.replace(".", "")

self.data = AltairChart.sanitize_dataframe(self.data)
height = 175
width = int(height * (5 / 3))

points = (
alt.Chart(geo_map)
.mark_geoshape()
.encode(
color=f"{y_attr_abv}:Q",
)
.transform_lookup(lookup="id", from_=alt.LookupData(self.data, x_attr_abv, [y_attr_abv]))
.project(type=map_type)
.properties(
width=width, height=height, title=f"Mean of {y_attr_abv} across {geographical_name}"
)
)

chart = background + points

######################################
## Constructing Altair Code String ##
#####################################

self.code += "import altair as alt\n"
dfname = "placeholder_variable"
self.code += f"""nan=float('nan')
df = pd.DataFrame({str(self.data.to_dict())})
background = {background_str}
points = alt.Chart({geo_map_str}).mark_geoshape().encode(
color='{y_attr_abv}:Q',
).transform_lookup(
lookup='id',
from_=alt.LookupData({dfname}, "{x_attr_abv}", ["{y_attr_abv}"])
).project(
type="{map_type}"
).properties(
width={width},
height={height},
title="Mean of {y_attr_abv} across {geographical_name}"
)
chart = background + points
"""
return chart

def get_background(self, feature):
"""Returns background projection based on geographic feature."""
maps = {
"state": (
alt.topo_feature(Choropleth.us_url, feature="states"),
"albersUsa",
f"alt.topo_feature('{Choropleth.us_url}', feature='states')",
),
"country": (
alt.topo_feature(Choropleth.world_url, feature="countries"),
"equirectangular",
f"alt.topo_feature('{Choropleth.world_url}', feature='countries')",
),
}
assert feature in maps
height = 175
background = (
alt.Chart(maps[feature][0])
.mark_geoshape(fill="lightgray", stroke="white")
.properties(width=int(height * (5 / 3)), height=height)
.project(maps[feature][1])
)
background_str = f"(alt.Chart({maps[feature][2]}).mark_geoshape(fill='lightgray', stroke='white').properties(width=int({height} * (5 / 3)), height={height}).project('{maps[feature][1]}'))"
return background, background_str

def get_geomap(self, feature):
"""Returns topological encoding, topological style,
and translation function based on geographic feature"""
maps = {
"state": (
alt.topo_feature(Choropleth.us_url, feature="states"),
f"alt.topo_feature('{Choropleth.us_url}', feature='states')",
"albersUsa",
self.get_us_fips_code,
),
"country": (
alt.topo_feature(Choropleth.world_url, feature="countries"),
f"alt.topo_feature('{Choropleth.world_url}', feature='countries')",
"equirectangular",
self.get_country_iso_code,
),
}
assert feature in maps
return maps[feature]

def get_us_fips_code(self, attribute):
"""Returns FIPS code given a US state"""
if not isinstance(attribute, str):
return attribute
try:
return int(us.states.lookup(attribute).fips)
except:
return attribute

def get_country_iso_code(self, attribute):
"""Returns country ISO code given a country"""
if not isinstance(attribute, str):
return attribute
try:
return int(countries.get(attribute).numeric)
except:
return attribute

def get_geographical_name(self, feature):
"""Returns geographical location label based on secondary feature."""
maps = {"state": "United States", "country": "World"}
return maps[feature]

def encode_color(self):
# Setting tooltip as non-null
self.chart = self.chart.configure_mark(tooltip=alt.TooltipContent("encoding"))
self.code += f"""chart = chart.configure_mark(tooltip=alt.TooltipContent('encoding'))"""
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ matplotlib>=3.0.0
# psycopg2>=2.8.5
# psycopg2-binary>=2.8.5
lux-widget>=0.1.4
us
iso3166

0 comments on commit 47ef480

Please sign in to comment.