# 3D slices

Simple pure matplotlib figure to slice through a 3D volume interactively.

In [1]:
import numpy as np
from matplotlib import pyplot as plt

In [2]:
plt.style.use('ggplot')
%matplotlib notebook

In [3]:
def plot_interactive(data, xslice=None, yslice=None, zslice=None):
    """Plot slices of a 3D volume, interactively.

    data :: dict containing data; needs the following keywords:
        - data : the actual data.
        - xc, yc, zc : center points of cells.
        - xx, yy, zz : edges of cells (have to have one element more then
                       corresponding centers).

    xslice, yslice, zslice: Initial slice locations (in axis units).
            Optional; defaults to the middle of the volume.

    """

    class IndexTracker():
        def __init__(self, data, xslice, yslice, zslice):
            """Initialize interactive figure."""

            # 1. Store relevant data

            # Store data
            self.data = data['data'].copy()

            # Axis
            self.x = data['xx']
            self.y = data['yy']
            self.z = data['zz']
            self.xc = data['xc']
            self.yc = data['yc']
            self.zc = data['zc']

            # Store initial slice indices
            if xslice:
                self.xind = np.argmin(np.abs(self.xc - xslice))
            else:
                self.xind = data['xc'].size // 2
            if xslice:
                self.yind = np.argmin(np.abs(self.yc - yslice))
            else:
                self.yind = data['yc'].size // 2
            if zslice:
                self.zind = np.argmin(np.abs(self.zc - zslice))
            else:
                self.zind = data['zc'].size // 2

            # 2. Start figure

            # Create subplots
            plt.subplots_adjust(wspace=.075, hspace=.1)

            # X-Y
            self.ax1 = plt.subplot2grid((3, 3), (0, 0), colspan=2, rowspan=2)
            plt.ylabel('y-axis (units)')
            self.ax1.xaxis.set_ticks_position('top')
            plt.setp(self.ax1.get_xticklabels(), visible=False)

            # X-Z
            self.ax2 = plt.subplot2grid((3, 3), (2, 0), colspan=2,
                                        sharex=self.ax1)
            self.ax2.yaxis.set_ticks_position('both')
            plt.gca().invert_yaxis()
            plt.xlabel('x-axis (units)')
            plt.ylabel('z-axis (units)')

            # Z-Y
            self.ax3 = plt.subplot2grid((3, 3), (0, 2), rowspan=2,
                                        sharey=self.ax1)
            self.ax3.yaxis.set_ticks_position('right')
            self.ax3.xaxis.set_ticks_position('both')
            plt.setp(self.ax3.get_yticklabels(), visible=False)

            # Title
            plt.suptitle('Scroll with your mouse\nwhile hoovering over the '
                         +'subplot you want to slice through.')

            # Cross-line properties
            self.clprops = {'c': 'w', 'ls': '--', 'lw': 1, 'zorder': 10}

            # Store min and max of all data
            self.pcm_props = {'vmin': np.min(data['data']),
                              'vmax': np.max(data['data'])}

            # Create colorbar
            plt.sca(self.ax3)
            plt.pcolormesh(self.z, self.y, self.data[self.xind, :, :],
                           **self.pcm_props)
            plt.colorbar(label='Colorbar legend (units)', pad=0.15)

            # Initial draw
            self.update_xy()
            self.update_xz()
            self.update_zy()

            # 3. Keep depth in X-Z and Z-Y in sync

            def do_adjust():
                """Return True if z-axis in X-Z and Z-Y are different."""
                one = np.array(self.ax2.get_ylim())
                two = np.array(self.ax3.get_xlim())[::-1]
                return sum(abs(one - two)) > 0.001  # Difference at least 1 m.

            def on_ylims_changed(ax):
                """Adjust Z-Y if X-Z changed."""
                if do_adjust():
                    self.ax3.set_xlim([self.ax2.get_ylim()[1],
                                       self.ax2.get_ylim()[0]])

            def on_xlims_changed(ax):
                """Adjust X-Z if Z-Y changed."""
                if do_adjust():
                    self.ax2.set_ylim([self.ax3.get_xlim()[1],
                                       self.ax3.get_xlim()[0]])

            self.ax3.callbacks.connect('xlim_changed', on_xlims_changed)
            self.ax2.callbacks.connect('ylim_changed', on_ylims_changed)

        def onscroll(self, event):
            """Update index and data when scrolling."""

            # Get scroll direction
            if event.button == 'up':
                pm = 1
            else:
                pm = -1

            # Update slice index depending on subplot over which mouse is
            if event.inaxes == self.ax1:    # X-Y
                self.zind = (self.zind + pm) % (self.zc.size - 1)
                self.update_xy()
            elif event.inaxes == self.ax2:  # X-Z
                self.yind = (self.yind + pm) % (self.yc.size - 1)
                self.update_xz()
            elif event.inaxes == self.ax3:  # Z-Y
                self.xind = (self.xind + pm) % (self.xc.size - 1)
                self.update_zy()

            plt.draw()

        def update_xy(self):
            """Update plot for change in Z-index."""

            # Clean up
            self.clear_element('xy_pc')
            self.clear_element('xz_ah')
            self.clear_element('zy_av')

            # Draw X-Y slice
            plt.sca(self.ax1)
            zdat = self.data[:, :, self.zind].transpose()
            self.xy_pc = plt.pcolormesh(self.x, self.y, zdat, **self.pcm_props)

            # Draw Z-slice intersection in X-Z plot
            plt.sca(self.ax2)
            self.xz_ah = plt.axhline(self.zc[self.zind], self.x[0], self.x[-1],
                                     **self.clprops)

            # Draw Z-slice intersection in Z-Y plot
            plt.sca(self.ax3)
            self.zy_av = plt.axvline(self.zc[self.zind], self.y[0], self.y[-1],
                                     **self.clprops)

        def update_xz(self):
            """Update plot for change in Y-index."""

            # Clean up
            self.clear_element('xz_pc')
            self.clear_element('zy_ah')
            self.clear_element('xy_ah')

            # Draw X-Z slice
            plt.sca(self.ax2)
            ydat = self.data[:, self.yind, :].transpose()
            self.xz_pc = plt.pcolormesh(self.x, self.z, ydat, **self.pcm_props)

            # Draw X-slice intersection in X-Y plot
            plt.sca(self.ax1)
            self.xy_ah = plt.axhline(self.yc[self.yind], self.x[0], self.x[-1],
                                     **self.clprops)

            # Draw X-slice intersection in Z-Y plot
            plt.sca(self.ax3)
            self.zy_ah = plt.axhline(self.yc[self.yind], self.z[0], self.z[-1],
                                     **self.clprops)

        def update_zy(self):
            """Update plot for change in X-index."""

            # Clean up
            self.clear_element('zy_pc')
            self.clear_element('xz_av')
            self.clear_element('xy_av')

            # Draw Z-Y slice
            plt.sca(self.ax3)
            xdat = self.data[self.xind, :, :]
            self.zy_pc = plt.pcolormesh(self.z, self.y, xdat, **self.pcm_props)

            # Draw Y-slice intersection in X-Y plot
            plt.sca(self.ax1)
            self.xy_av = plt.axvline(self.xc[self.xind], self.y[0], self.y[-1],
                                     **self.clprops)

            # Draw Y-slice intersection in X-Z plot
            plt.sca(self.ax2)
            self.xz_av = plt.axvline(self.xc[self.xind], self.z[0], self.z[-1],
                                     **self.clprops)

        def clear_element(self, name):
            """Remove element <name> from plot if it exists."""
            if hasattr(self, name):
                getattr(self, name).remove()

    # Figure
    fig = plt.figure()
    plt.clf()  # Just in case it exists already

    tracker = IndexTracker(data, xslice, yslice, zslice)
    fig.canvas.mpl_connect('scroll_event', tracker.onscroll)

    plt.show()

In [4]:
# Create some dummy data (can be an irregularly spaced grid),
# therefore we have to provide cell centers and edges.
nx, ny, nz = 10, 5, 8  # Number of cells in x, y, z direction

# Random numbers
dat = np.random.random(nx*ny*nz).reshape((nx, ny, nz))
dat[2:-3, 1:-2, 2:-3] = -1  # Add a block in the middle

data = {
    'xc': np.arange(nx),      # cell centers in x-direction
    'yc': np.arange(ny),      # cell centers in y-direction
    'zc': np.arange(nz),      # cell centers in z-direction
    'xx': np.arange(nx)-0.5,  # cell edges in x-direction
    'yy': np.arange(ny)-0.5,  # cell edges in y-direction
    'zz': np.arange(nz)-0.5,  # cell edges in z-direction
    'data': dat
}

# Plot the model
plot_interactive(data)

<IPython.core.display.Javascript object>