Skip to content

Commit

Permalink
Truncate group names in the group scatter graph if the group names is…
Browse files Browse the repository at this point in the history
… longer than 20 chars

Added unit test.
  • Loading branch information
cmmorrow committed Dec 21, 2017
1 parent 85818b2 commit 4ac626f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
3 changes: 3 additions & 0 deletions sci_analysis/graphs/vector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import six

# matplotlib imports
from matplotlib.pyplot import show, subplot, yticks, xlabel, ylabel, figure, setp, savefig, close, xticks, \
Expand Down Expand Up @@ -506,6 +507,8 @@ def draw(self):
alpha_trans = 0.2
except TypeError:
alpha_trans = 0.6
if isinstance(grp, six.string_types) and len(grp) > 20:
grp = grp[0:21] + '...'

# Draw the points
if self._points:
Expand Down
15 changes: 14 additions & 1 deletion sci_analysis/test/test_graph_groupscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,20 @@ def test_25_scatter_two_groups_no_ydata(self):
input_array = pd.DataFrame({'a': cs_x, 'b': cs_y, 'c': grp})
self.assertRaises(AttributeError, lambda: GraphGroupScatter(input_array['a'], groups=input_array['c']))

# TODO: Test with long group names
def test_26_scatter_three_groups_long_group_names(self):
np.random.seed(987654321)
input_1_x = st.norm.rvs(size=100)
input_1_y = [x + st.norm.rvs(0, 0.5, size=1)[0] for x in input_1_x]
input_2_x = st.norm.rvs(size=100)
input_2_y = [(x / 2) + st.norm.rvs(0, 0.2, size=1)[0] for x in input_2_x]
input_3_x = st.norm.rvs(size=100)
input_3_y = np.array([(x * 1.5) + st.norm.rvs(size=100)[0] for x in input_3_x]) - 0.5
grp = ['11111111111111111111'] * 100 + ['222222222222222222222'] * 100 + ['3333333333333333333333'] * 100
cs_x = np.concatenate((input_1_x, input_2_x, input_3_x))
cs_y = np.concatenate((input_1_y, input_2_y, input_3_y))
input_array = pd.DataFrame({'a': cs_x, 'b': cs_y, 'c': grp})
self.assertTrue(GraphGroupScatter(input_array['a'], input_array['b'], groups=input_array['c'],
save_to='{}test_group_scatter_26'.format(self.save_path)))


if __name__ == '__main__':
Expand Down

0 comments on commit 4ac626f

Please sign in to comment.