diff --git a/.coveragerc b/.coveragerc index b2273aa..3fa0c7b 100644 --- a/.coveragerc +++ b/.coveragerc @@ -2,6 +2,7 @@ branch = true source = mplaltair +omit = *tests* [report] exclude_lines = diff --git a/mplaltair/__init__.py b/mplaltair/__init__.py index 721598e..001b3a4 100644 --- a/mplaltair/__init__.py +++ b/mplaltair/__init__.py @@ -1,35 +1,23 @@ import matplotlib import altair +from ._convert import _convert -# TODO rename this? -def convert(encoding, *, figure=None): +def convert(chart): """Convert an altair encoding to a Matplotlib figure Parameters ---------- - encoding - The Altair encoding of the plot. - - figure : matplotib.figure.Figure, optional - # TODO: generalize this to 'thing that supports gridspec slicing? + chart + The Altair chart object generated by Altair Returns ------- - figure : matplotlib.figure.Figure - The Figure with all artists in it (ready to be saved or shown) - mapping : dict Mapping from parts of the encoding to the Matplotlib artists. This is for later customization. """ - if figure is None: - from matplotlib import pyplot as plt - figure = plt.figure() - - mapping = {} - - return figure, mapping + return _convert(chart) diff --git a/mplaltair/_convert.py b/mplaltair/_convert.py index 8b13789..ab87dc0 100644 --- a/mplaltair/_convert.py +++ b/mplaltair/_convert.py @@ -1 +1,129 @@ +import matplotlib.dates as mdates +from ._data import _locate_channel_data, _locate_channel_dtype +def _allowed_ranged_marks(enc_channel, mark): + """TODO: DOCS + """ + return mark in ['area', 'bar', 'rect', 'rule'] if enc_channel in ['x2', 'y2'] else True + +def _process_x(dtype, data): + """Returns the MPL encoding equivalent for Altair x channel + """ + return ('x', data) + + +def _process_y(dtype, data): + """Returns the MPL encoding equivalent for Altair y channel + """ + return ('y', data) + + +def _process_x2(dtype, data): + """Returns the MPL encoding equivalent for Altair x2 channel + """ + raise NotImplementedError + + +def _process_y2(dtype, data): + """Returns the MPL encoding equivalent for Altair y2 channel + """ + raise NotImplementedError + + +def _process_color(dtype, data): + """Returns the MPL encoding equivalent for Altair color channel + """ + if dtype == 'quantitative': + return ('c', data) + elif dtype == 'nominal': + raise NotImplementedError + elif dtype == 'ordinal': + return ('c', data) + else: # temporal + return ('c', data) + + +def _process_fill(dtype, data): + """Returns the MPL encoding equivalent for Altair fill channel + """ + return _process_color(dtype, data) + + +def _process_shape(dtype, data): + """Returns the MPL encoding equivalent for Altair shape channel + """ + raise NotImplementedError + + +def _process_opacity(dtype, data): + """Returns the MPL encoding equivalent for Altair opacity channel + """ + raise NotImplementedError + + +def _process_size(dtype, data): + """Returns the MPL encoding equivalent for Altair size channel + """ + if dtype == 'quantitative': + return ('s', data) + elif dtype == 'nominal': + raise NotImplementedError + elif dtype == 'ordinal': + return ('s', data) + elif dtype == 'temporal': + raise NotImplementedError + + +def _process_stroke(dtype, data): + """Returns the MPL encoding equivalent for Altair stroke channel + """ + raise NotImplementedError + +_mappings = { + 'x': _process_x, + 'y': _process_y, + 'x2': _process_x2, + 'y2': _process_y2, + 'color': _process_color, + 'fill': _process_fill, + 'shape': _process_shape, + 'opacity': _process_opacity, + 'size': _process_size, + 'stroke': _process_stroke, +} + +def _convert(chart): + """Convert an altair encoding to a Matplotlib figure + + + Parameters + ---------- + chart + The Altair chart. + + Returns + ------- + mapping : dict + Mapping from parts of the encoding to the Matplotlib artists. This is + for later customization. + """ + mapping = {} + + if not chart.to_dict().get('encoding'): + raise ValueError("Encoding not provided with the chart specification") + + for enc_channel, enc_spec in chart.to_dict()['encoding'].items(): + if not _allowed_ranged_marks(enc_channel, chart.to_dict()['mark']): + raise ValueError("Ranged encoding channels like x2, y2 not allowed for Mark: {}".format(chart['mark'])) + + for channel in chart.to_dict()['encoding']: + data = _locate_channel_data(chart, channel) + dtype = _locate_channel_dtype(chart, channel) + if dtype == 'temporal': + try: + data = mdates.date2num(data) # Convert dates to Matplotlib dates + except AttributeError: + raise + mapping[_mappings[channel](dtype, data)[0]] = _mappings[channel](dtype, data)[1] + + return mapping diff --git a/mplaltair/_data.py b/mplaltair/_data.py new file mode 100644 index 0000000..41ecfc0 --- /dev/null +++ b/mplaltair/_data.py @@ -0,0 +1,64 @@ +from ._exceptions import ValidationError + +def _locate_channel_dtype(chart, channel): + """Locates dtype used for each channel + Parameters + ---------- + chart + The Altair chart + channel + The Altair channel being examined + + Returns + ------- + A string representing the data type from the Altair chart ('quantitative', 'ordinal', 'numeric', 'temporal') + """ + + channel_val = chart.to_dict()['encoding'][channel] + if channel_val.get('type'): + return channel_val.get('type') + else: + # TODO: find some way to deal with 'value' so that, opacity, for instance, can be plotted with a value defined + if channel_val.get('value'): + raise NotImplementedError + raise NotImplementedError + + +def _locate_channel_data(chart, channel): + """Locates data used for each channel + + Parameters + ---------- + chart + The Altair chart + channel + The Altair channel being examined + + Returns + ------- + A numpy ndarray containing the data used for the channel + + Raises + ------ + ValidationError + Raised when the specification does not contain any data attribute + + """ + + channel_val = chart.to_dict()['encoding'][channel] + if channel_val.get('value'): + return channel_val.get('value') + elif channel_val.get('aggregate'): + return _aggregate_channel() + elif channel_val.get('timeUnit'): + return _handle_timeUnit() + else: # field is required if the above are not present. + return chart.data[channel_val.get('field')].values + + +def _aggregate_channel(): + raise NotImplementedError + + +def _handle_timeUnit(): + raise NotImplementedError \ No newline at end of file diff --git a/mplaltair/_exceptions.py b/mplaltair/_exceptions.py new file mode 100644 index 0000000..c741942 --- /dev/null +++ b/mplaltair/_exceptions.py @@ -0,0 +1,2 @@ +class ValidationError(Exception): + pass \ No newline at end of file diff --git a/mplaltair/tests/test_convert.py b/mplaltair/tests/test_convert.py new file mode 100644 index 0000000..05e022c --- /dev/null +++ b/mplaltair/tests/test_convert.py @@ -0,0 +1,231 @@ +import pytest + +import altair as alt +import pandas as pd +import matplotlib.pyplot as plt +import matplotlib.dates as mdates + +from mplaltair import convert + + +df = pd.DataFrame({ + 'quant': [1, 1.5, 2, 2.5, 3], 'ord': [0, 1, 2, 3, 4], 'nom': ['A', 'B', 'C', 'D', 'E'], + "years": pd.date_range('01/01/2015', periods=5, freq='Y'), "months": pd.date_range('1/1/2015', periods=5, freq='M'), + "days": pd.date_range('1/1/2015', periods=5, freq='D'), "hrs": pd.date_range('1/1/2015', periods=5, freq='H'), + "combination": pd.to_datetime(['1/1/2015', '1/1/2015 10:00:00', '1/2/2015 00:00', '1/4/2016 10:00', '5/1/2016']), + "quantitative": [1.1, 2.1, 3.1, 4.1, 5.1] +}) + +df_quant = pd.DataFrame({ + "a": [1, 2, 3], "b": [1.2, 2.4, 3.8], "c": [7, 5, 3], + "s": [50, 100, 200.0], "alpha": [.1, .5, .8], "shape": [1, 2, 3], "fill": [1, 2, 3] +}) + + +def test_encoding_not_provided(): + chart_spec = alt.Chart(df).mark_point() + with pytest.raises(ValueError): + convert(chart_spec) + +def test_invalid_encodings(): + chart_spec = alt.Chart(df).encode(x2='quant').mark_point() + with pytest.raises(ValueError): + convert(chart_spec) + +@pytest.mark.xfail(raises=AttributeError) +def test_invalid_temporal(): + chart = alt.Chart(df).mark_point().encode(alt.X('quant:T')) + convert(chart) + +@pytest.mark.parametrize('channel', ['quant', 'ord', 'nom']) +def test_convert_x_success(channel): + chart_spec = alt.Chart(df).encode(x=channel).mark_point() + mapping = convert(chart_spec) + assert list(mapping['x']) == list(df[channel].values) + +@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) +def test_convert_x_success_temporal(column): + chart = alt.Chart(df).mark_point().encode(alt.X(column)) + mapping = convert(chart) + assert list(mapping['x']) == list(mdates.date2num(df[column].values)) + +def test_convert_x_fail(): + chart_spec = alt.Chart(df).encode(x='b:N').mark_point() + with pytest.raises(KeyError): + convert(chart_spec) + +@pytest.mark.parametrize('channel', ['quant', 'ord', 'nom']) +def test_convert_y_success(channel): + chart_spec = alt.Chart(df).encode(y=channel).mark_point() + mapping = convert(chart_spec) + assert list(mapping['y']) == list(df[channel].values) + +@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) +def test_convert_y_success_temporal(column): + chart = alt.Chart(df).mark_point().encode(alt.Y(column)) + mapping = convert(chart) + assert list(mapping['y']) == list(mdates.date2num(df[column].values)) + +def test_convert_y_fail(): + chart_spec = alt.Chart(df).encode(y='b:N').mark_point() + with pytest.raises(KeyError): + convert(chart_spec) + +@pytest.mark.xfail(raises=ValueError, reason="It doesn't make sense to have x2 and y2 on scatter plots") +def test_quantitative_x2_y2(): + chart = alt.Chart(df_quant).mark_point().encode(alt.X('a'), alt.Y('b'), alt.X2('c'), alt.Y2('alpha')) + convert(chart) + +@pytest.mark.xfail(raises=ValueError) +@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) +def test_convert_x2_y2_fail_temporal(column): + chart = alt.Chart(df).mark_point().encode(alt.X2(column), alt.Y2(column)) + convert(chart) + +@pytest.mark.parametrize('channel,dtype', [('quant','quantitative'), ('ord','ordinal')]) +def test_convert_color_success(channel, dtype): + chart_spec = alt.Chart(df).encode(color=alt.Color(field=channel, type=dtype)).mark_point() + mapping = convert(chart_spec) + assert list(mapping['c']) == list(df[channel].values) + +def test_convert_color_success_nominal(): + chart_spec = alt.Chart(df).encode(color='nom').mark_point() + with pytest.raises(NotImplementedError): + convert(chart_spec) + +@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) +def test_convert_color_success_temporal(column): + chart = alt.Chart(df).mark_point().encode(alt.Color(column)) + mapping = convert(chart) + assert list(mapping['c']) == list(mdates.date2num(df[column].values)) + +def test_convert_color_fail(): + chart_spec = alt.Chart(df).encode(color='b:N').mark_point() + with pytest.raises(KeyError): + convert(chart_spec) + +@pytest.mark.parametrize('channel,type', [('quant', 'Q'), ('ord', 'O')]) +def test_convert_fill(channel, type): + chart_spec = alt.Chart(df).encode(fill='{}:{}'.format(channel, type)).mark_point() + mapping = convert(chart_spec) + assert list(mapping['c']) == list(df[channel].values) + +def test_convert_fill_success_nominal(): + chart_spec = alt.Chart(df).encode(fill='nom').mark_point() + with pytest.raises(NotImplementedError): + convert(chart_spec) + +@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) +def test_convert_fill_success_temporal(column): + chart = alt.Chart(df).mark_point().encode(alt.Fill(column)) + mapping = convert(chart) + assert list(mapping['c']) == list(mdates.date2num(df[column].values)) + + +def test_convert_fill_fail(): + chart_spec = alt.Chart(df).encode(fill='b:N').mark_point() + with pytest.raises(KeyError): + convert(chart_spec) + +@pytest.mark.xfail(raises=NotImplementedError, reason="The marker argument in scatter() cannot take arrays") +def test_quantitative_shape(): + chart = alt.Chart(df_quant).mark_point().encode(alt.Shape('shape')) + mapping = convert(chart) + assert list(mapping['marker']) == list(df_quant['shape'].values) + +@pytest.mark.xfail(raises=NotImplementedError, reason="The marker argument in scatter() cannot take arrays") +@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) +def test_convert_shape_fail_temporal(column): + chart = alt.Chart(df).mark_point().encode(alt.Shape(column)) + mapping = convert(chart) + assert list(mapping['s']) == list(mdates.date2num(df[column].values)) + +@pytest.mark.xfail(raises=NotImplementedError, reason="Merge: the dtype for opacity isn't assumed to be quantitative") +def test_quantitative_opacity_value(): + chart = alt.Chart(df_quant).mark_point().encode(opacity=alt.value(.5)) + mapping = convert(chart) + assert mapping['alpha'] == 0.5 + +@pytest.mark.xfail(raises=NotImplementedError, reason="The alpha argument in scatter() cannot take arrays") +def test_quantitative_opacity_array(): + chart = alt.Chart(df_quant).mark_point().encode(alt.Opacity('alpha')) + convert(chart) + +@pytest.mark.xfail(raises=NotImplementedError, reason="The alpha argument in scatter() cannot take arrays") +@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) +def test_convert_opacity_fail_temporal(column): + chart = alt.Chart(df).mark_point().encode(alt.Opacity(column)) + convert(chart) + +@pytest.mark.parametrize('channel,type', [('quant', 'Q'), ('ord', 'O')]) +def test_convert_size_success(channel, type): + chart_spec = alt.Chart(df).encode(size='{}:{}'.format(channel, type)).mark_point() + mapping = convert(chart_spec) + assert list(mapping['s']) == list(df[channel].values) + +def test_convert_size_success_nominal(): + chart_spec = alt.Chart(df).encode(size='nom').mark_point() + with pytest.raises(NotImplementedError): + convert(chart_spec) + +def test_convert_size_fail(): + chart_spec = alt.Chart(df).encode(size='b:N').mark_point() + with pytest.raises(KeyError): + convert(chart_spec) + +@pytest.mark.xfail(raises=NotImplementedError, reason="Dates would need to be normalized for the size.") +@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) +def test_convert_size_fail_temporal(column): + chart = alt.Chart(df).mark_point().encode(alt.Size(column)) + convert(chart) + + +@pytest.mark.xfail(raises=NotImplementedError, reason="Stroke is not well supported in Altair") +def test_quantitative_stroke(): + chart = alt.Chart(df_quant).mark_point().encode(alt.Stroke('fill')) + convert(chart) + +@pytest.mark.xfail(raises=NotImplementedError, reason="Stroke is not well defined in Altair") +@pytest.mark.parametrize("column", ["years", "months", "days", "hrs", "combination"]) +def test_convert_stroke_fail_temporal(column): + chart = alt.Chart(df).mark_point().encode(alt.Stroke(column)) + convert(chart) + + +# Aggregations + +@pytest.mark.xfail(raises=NotImplementedError, reason="Aggregate functions are not supported yet") +def test_quantitative_x_count_y(): + df_count = pd.DataFrame({"a": [1, 1, 2, 3, 5], "b": [1.4, 1.4, 2.9, 3.18, 5.3]}) + chart = alt.Chart(df_count).mark_point().encode(alt.X('a'), alt.Y('count()')) + mapping = convert(chart) + assert list(mapping['x']) == list(df_count['a'].values) + assert list(mapping['y']) == list(df_count.groupby(['a']).count().values) + +@pytest.mark.xfail(raises=NotImplementedError, reason="specifying timeUnit is not supported yet") +def test_timeUnit(): + chart = alt.Chart(df).mark_point().encode(alt.X('date(combination)')) + convert(chart) + +# Plots + +chart_quant = alt.Chart(df_quant).mark_point().encode( + alt.X(field='a', type='quantitative'), alt.Y('b'), alt.Color('c:Q'), alt.Size('s') +) +chart_fill_quant = alt.Chart(df_quant).mark_point().encode( + alt.X(field='a', type='quantitative'), alt.Y('b'), alt.Fill('fill:Q') +) + +@pytest.mark.parametrize("chart", [chart_quant, chart_fill_quant]) +def test_quantitative_scatter(chart): + mapping = convert(chart) + plt.scatter(**mapping) + plt.show() + +@pytest.mark.parametrize("channel", [alt.Color("years"), alt.Fill("years")]) +def test_scatter_temporal(channel): + chart = alt.Chart(df).mark_point().encode(alt.X("years"), channel) + mapping = convert(chart) + mapping['y'] = df['quantitative'].values + plt.scatter(**mapping) + plt.show() diff --git a/mplaltair/tests/test_data.py b/mplaltair/tests/test_data.py new file mode 100644 index 0000000..c6dcf35 --- /dev/null +++ b/mplaltair/tests/test_data.py @@ -0,0 +1,82 @@ +import altair as alt +import pandas as pd +import mplaltair._data as _data +import pytest + +df = pd.DataFrame({ + "a": [1, 2, 3, 4, 5], "b": [1.1, 2.2, 3.3, 4.4, 5.5], "c": [1, 2.2, 3, 4.4, 5], + "nom": ['a', 'b', 'c', 'd', 'e'], "ord": [1, 2, 3, 4, 5], + "years": pd.date_range('01/01/2015', periods=5, freq='Y'), "months": pd.date_range('1/1/2015', periods=5, freq='M'), + "days": pd.date_range('1/1/2015', periods=5, freq='D'), "hrs": pd.date_range('1/1/2015', periods=5, freq='H'), + "combination": pd.to_datetime(['1/1/2015', '1/1/2015 10:00:00', '1/2/2015 00:00', '1/4/2016 10:00', '5/1/2016']), + "quantitative": [1.1, 2.1, 3.1, 4.1, 5.1] +}) + + +# _locate_channel_data() tests + +@pytest.mark.parametrize("column, dtype", [ + ('a', 'quantitative'), ('b', 'quantitative'), ('c', 'quantitative'), ('combination', 'temporal') +]) +def test_data_field_quantitative(column, dtype): + chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type=dtype)) + for channel in chart.to_dict()['encoding']: + data = _data._locate_channel_data(chart, channel) + assert list(data) == list(df[column].values) + + +@pytest.mark.parametrize("column", ['a', 'b', 'c', 'combination']) +def test_data_shorthand_quantitative(column): + chart = alt.Chart(df).mark_point().encode(alt.X(column)) + for channel in chart.to_dict()['encoding']: + data = _data._locate_channel_data(chart, channel) + assert list(data) == list(df[column].values) + + +def test_data_value_quantitative(): + chart = alt.Chart(df).mark_point().encode(opacity=alt.value(0.5)) + for channel in chart.to_dict()['encoding']: + data = _data._locate_channel_data(chart, channel) + assert data == 0.5 + + +@pytest.mark.parametrize("column", ['a', 'b', 'c']) +@pytest.mark.xfail(raises=NotImplementedError) +def test_data_aggregate_quantitative(column): + chart = alt.Chart(df).mark_point().encode(alt.X(field=column, type='quantitative', aggregate='average')) + for channel in chart.to_dict()['encoding']: + data = _data._locate_channel_data(chart, channel) + + +@pytest.mark.xfail(raises=NotImplementedError) +def test_data_timeUnit_shorthand_temporal(): + chart = alt.Chart(df).mark_point().encode(alt.X('month(combination):T')) + for channel in chart.to_dict()['encoding']: + data = _data._locate_channel_data(chart, channel) + + +@pytest.mark.xfail(raises=NotImplementedError) +def test_data_timeUnit_field_temporal(): + chart = alt.Chart(df).mark_point().encode(alt.X(field='combination', type='temporal', timeUnit='month')) + for channel in chart.to_dict()['encoding']: + data = _data._locate_channel_data(chart, channel) + + +# _locate_channel_dtype() tests + +@pytest.mark.parametrize('column, expected', [ + ('a:Q', 'quantitative'), ('nom:N', 'nominal'), ('ord:O', 'ordinal'), ('combination:T', 'temporal') +]) +def test_data_dtype(column, expected): + chart = alt.Chart(df).mark_point().encode(alt.X(column)) + for channel in chart.to_dict()['encoding']: + dtype = _data._locate_channel_dtype(chart, channel) + assert dtype == expected + + +@pytest.mark.xfail(raises=NotImplementedError) +def test_data_dtype_fail(): + chart = alt.Chart(df).mark_point().encode(opacity=alt.value(.5)) + for channel in chart.to_dict()['encoding']: + dtype = _data._locate_channel_dtype(chart, channel) + assert dtype == 'quantitative'