Skip to content

Commit

Permalink
AMD brain labels
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed May 19, 2021
1 parent 01fd78b commit a5c835f
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions eelbrain/plot/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,13 +686,16 @@ def add_column_titles(self, titles, x=None, y=0.25, va='center', ha='center', **
"""
if len(titles) > self._n_columns:
raise ValueError(f"titles={titles}: {len(titles)} titles for {self._n_columns} columns")
# distribute titles over columns
n = self._n_columns / len(titles) # columns per title
column_width = self._layout.axw * n + self._layout.margins['wspace'] * (n - 1)
# center
if x is None:
x = self._layout.axw / 2
x_left = self._layout.margins['left'] + x
x_offset = self._layout.margins['wspace'] + self._layout.axw
x = column_width / 2
first = self._layout.margins['left'] + x
y_ = 1 - y / self._layout.h
for i, label in enumerate(titles):
x_ = (x_left + i * x_offset) / self._layout.w
x_ = (first + i * column_width) / self._layout.w
self.figure.text(x_, y_, label, va=va, ha=ha, **kwargs)
self.draw()

Expand All @@ -713,13 +716,16 @@ def add_row_titles(self, titles, x=0.1, y=None, va='center', ha='center', **kwar
"""
if len(titles) > self._n_rows:
raise ValueError(f"titles={titles}: {len(titles)} titles for {self._n_rows} rows")
# distribute titles over columns
n = self._n_rows / len(titles) # columns per title
row_height = self._layout.axh * n + self._layout.margins['hspace'] * (n - 1)
# center
if y is None:
y = self._layout.axh / 2
y_top = self._layout.margins['top'] + self._layout.axh - y
y_offset = self._layout.margins['hspace'] + self._layout.axh
y = row_height / 2
y_top = self._layout.margins['top'] + y
x_ = x / self._layout.w
for i, label in enumerate(titles):
y_ = 1 - (y_top + i * y_offset) / self._layout.h
y_ = 1 - (y_top + i * row_height) / self._layout.h
self.figure.text(x_, y_, label, va=va, ha=ha, **kwargs)
self.draw()

Expand Down Expand Up @@ -1558,13 +1564,9 @@ def plot_table(
# labels
bin_in_column = 'bin' in columns
bin_axis = columns if bin_in_column else rows
if len(bin_axis) > 1:
components = [labels if key == 'bin' else ('',)*ns[key] for key in bin_axis]
index = bin_axis.index('bin')
labels = [items[index] for items in product(*components)]
for i in range(len(labels) - 1, 0, -1):
if labels[i] == labels[i - 1]:
labels[i] = None
if labels and (bin_pos := bin_axis.index('bin')):
n_tile = reduce(operator.mul, (ns[key] for key in bin_axis[:bin_pos]))
labels = labels * n_tile
if bin_in_column:
column_labels, row_labels = labels, None
else:
Expand Down

0 comments on commit a5c835f

Please sign in to comment.