Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Viz 2d #21

Merged
merged 4 commits into from Jul 5, 2012
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions dipy/viz/__init__.py
@@ -1,3 +1,4 @@
# Init file for visualization package

from ._show_odfs import show_odfs
from projections import *
110 changes: 110 additions & 0 deletions dipy/viz/projections.py
@@ -0,0 +1,110 @@
"""

Visualization tools for 2D projections of 3D functions on the sphere, such as
ODFs.

"""

import numpy as np
import scipy.interpolate as interp

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.tri as tri

import dipy.core.geometry as geo

def sph_project(vertices, val, ax=None, vmin=None, vmax=None,
cmap=matplotlib.cm.hot, cbar=True, triang=False):

"""Draw a signal on a 2D projection of the sphere.

Parameters
----------

vertices : (N,3) ndarray
unit vector points of the sphere

val: (N) ndarray
Function values.

ax : mpl axis, optional
If specified, draw onto this existing axis instead.

vmin, vmax: floats
Values to cut the z

cmap: mpl colormap

cbar: Whether to add the color-bar to the figure

triang: Whether to display the plot triangulated as a pseudo-color plot.

Returns
-------
fig : figure
Matplotlib figure

Examples
--------
>>> from dipy.data import get_sphere
>>> verts,faces=get_sphere('symmetric724')
>>> sph_project(verts,np.random.rand(len(verts)))

"""
if ax is None:
fig, ax = plt.subplots(1)

x = vertices[:, 0]
y = vertices[:, 1]

my_min = np.nanmin(val)
if vmin is not None:
my_min = vmin

my_max = np.nanmax(val)
if vmax is not None:
my_max = vmax

r = (val - my_min)/float(my_max-my_min)

# Enforce the maximum and minumum boundaries, if there are values
# outside those boundaries:
r[r<0]=0
r[r>1]=1

if triang:
triang = tri.Triangulation(x, y)
plt.tripcolor(triang, r, cmap=cmap)
else:
cmap_data = cmap._segmentdata
red_interp, blue_interp, green_interp = (
interp.interp1d(np.array(cmap_data[gun])[:,0],
np.array(cmap_data[gun])[:,1]) for gun in
['red', 'blue','green'])


for this_x, this_y, this_r in zip(x,y,r):
red = red_interp(this_r)
blue = blue_interp(this_r)
green = green_interp(this_r)
ax.plot(this_x, this_y, 'o',
c=[red.item(), green.item(), blue.item()])


plt.axis('equal')
plt.axis('off')
if cbar:
mappable = matplotlib.cm.ScalarMappable(cmap=cmap)
mappable.set_array([my_min, my_max])
# setup colorbar axes instance.
pos = ax.get_position()
l, b, w, h = pos.bounds
# setup colorbar axes
cax = fig.add_axes([l+w+0.075, b, 0.05, h], frameon=False)
fig.colorbar(mappable, cax=cax) # draw colorbar

ax.set_xlim([-1.1, 1.1])
ax.set_ylim([-1.1, 1.1])

return fig