In [1]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib notebook
from mpl_toolkits.mplot3d import axes3d
from matplotlib import cm

In [2]:
%matplotlib notebook
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')
X,Y,Z = axes3d.get_test_data(0.05)
cset = ax.contour(X,Y,Z, cmap=cm.viridis)


<IPython.core.display.Javascript object>

In [116]:
def get_vertex_pairs(vertices):
    pairs = []
    for i, v1 in enumerate(vertices):
        for j, v2 in enumerate(vertices):
            if j > i:
                pairs.append(v1)
                pairs.append(v2)
    pairs = np.asarray(pairs)
    return pairs

def plot_outline(vertices,ax=None,**kwargs):
    dim = vertices.shape[1]
    pairs = get_vertex_pairs(vertices)
    if ax==None:
        fig = plt.figure()
        if dim==2:
            ax = fig.add_subplot(111)
            p = ax.plot(pairs[:,0],pairs[:,1],**kwargs)
        if dim==3:
            ax = fig.add_subplot(111,projection='3d')
            p = ax.plot(xs=pairs[:,0],ys=pairs[:,1],zs=pairs[:,2],**kwargs)
    else:
        if dim==2:
            p = ax.plot(pairs[:,0],pairs[:,1],**kwargs)
        if dim==3:
            p = ax.plot(xs=pairs[:,0],ys=pairs[:,1],zs=pairs[:,2],**kwargs)
    return p

def make_quat_axes(ax=None,**kwargs):
    if ax==None:
        fig = plt.figure()
        ax = fig.add_subplot(111,projection='3d')
    vertices = np.array([[0,0,0],[1,0,0],[1/2,3**0.5/2,0],[1/2,3**0.5/6,6**0.5/3]])
    plot_outline(vertices,ax=ax,**kwargs)
    return ax
 

def fractions_to_quaternary(points,normalize=True):
    "Convert atomic fractions to quaternary coordinates (3D hexagonal Miller indices)."
    if type(points) == list:
        points = np.array(points)
    
    if normalize==True:
        #normalize all points to 1 (independent of scale since cartesian space is not scaled)
        ptsum = np.sum(points,axis=1)
        points = points/ptsum[:,None]

    xa = points[:,0]
    xb = points[:,1]
    xc = points[:,2]
    xd = points[:,3]
    u = 2*xa + xb + xd - 1
    v = 2*xb + xa + xd - 1
    w = xd
    
    return np.column_stack((u,v,w))

def quaternary_to_cartesian(points):
    "Convert quaternary coordinates to Cartesian coordinates."
    u = points[:,0]
    v = points[:,1]
    w = points[:,2]
    origin = (1/2,3**0.5/6,0)
    dx = -u/2 + v/2 
    dy = -(3**0.5/6)*u - (3**0.5/6)*v 
    dz = (6**0.5/3)*w
    x = origin[0] + dx
    y = origin[1] + dy
    z = origin[2] + dz
    return np.column_stack((x,y,z))

def fractions_to_cartesian(points,normalize=True):
    """Convert atomic fractions directly to Cartesian coordinates. 
    First coordinate corresponds to left corner, second to right, third to back, fourth to c."""
    qp = fractions_to_quaternary(points,normalize=normalize)
    cp = quaternary_to_cartesian(qp)
    return cp

def expand_axeslims(ax,factor):
    xlim = ax.get_xlim()
    xrng = max(xlim) - min(xlim)
    xoffset = xrng*(factor-1)/2
    ax.set_xlim([xlim[0]-xoffset,xlim[1]+xoffset])
    
    ylim = ax.get_ylim()
    yrng = max(ylim) - min(ylim)
    yoffset = yrng*(factor-1)/2
    ax.set_ylim([ylim[0]-yoffset,ylim[1]+yoffset])
    
    zlim = ax.get_zlim()
    zrng = max(zlim) - min(zlim)
    zoffset = zrng*(factor-1)/2
    ax.set_zlim([zlim[0]-zoffset,zlim[1]+zoffset])


    
class quaternary_axes:
    def __init__(self,scale=1,ax=None):
        self.scale = scale
        self.ax = ax
        self.axis_labels = {}
        self.axis_num = {'l':0,'r':1,'b':2,'t':3}
#         self.vertices = np.array([[0,0],[1/2,3**0.5/2],[1,0]])#*self.scale
#         self.corner_coords=dict(zip(['left','top','right'], list(self.vertices)))

        
    def draw_axes(self,**kwargs):
        if self.ax==None:
            self.ax = make_quat_axes(**kwargs)
        else:
            self.ax = make_quat_axes(ax=self.ax,**kwargs)

    
    def plot(self,points,normalize=True,**kwargs): 
        #"not updated"
        cp = fractions_to_cartesian(points,normalize=normalize)
        x = cp[:,0]
        y = cp[:,1]
        z = cp[:,2]
        p = self.ax.plot(x,y,z,**kwargs)
        return p
    
    def axis_ticks(self,axis,tick_format='.1f',multiple=0.2,offset=0.05,shrink=0.6):
        if axis not in ('l','r','b','t'):
            try:
                axis = self.axis_labels[axis]
                if axis not in ('l','r','b','t'):
                    raise Exception('Lookup on axis resulted in unexpected key: {}'.format(axis))
            except KeyError:
                raise Exception('{} not found in axis_labels'.format(axis))
        
        "not updated"
        x1 = np.arange(multiple,self.scale,multiple) #ascending order
        x2 = self.scale - x1 #descending order
        zeros = np.zeros(len(x1))
        if axis=='l': #xa - ticks between C and A, keep xb=xd=0
            pts = np.column_stack((x1,zeros,x2,zeros))
        elif axis=='r': #xb - ticks between A and B, keep xc=xd=0
            pts = np.column_stack((x2,x1,zeros,zeros))
        elif axis=='b': #xc - ticks between B anc C, keep xa=xd=0
            pts = np.column_stack((zeros,x2,x1,zeros))
        elif axis=='t': #xd - ticks between A and D, keep xb=xc=0
            pts = np.column_stack((x2,zeros,zeros,x1))
        
        ps = np.sum(pts,axis=0)
        nonzero_idx = list(np.where(ps!=0)[0])
        zero_idx = list(np.where(ps==0)[0])
        
        txt_pts = np.copy(pts)
        dash_end = np.copy(pts)
        
        #offset labels
        txt_pts[:,nonzero_idx] = txt_pts[:,nonzero_idx] + offset
        dash_end[:,nonzero_idx] = dash_end[:,nonzero_idx] + offset*(1-shrink)
        txt_pts[:,zero_idx] = txt_pts[:,zero_idx] - offset
        dash_end[:,zero_idx] = dash_end[:,zero_idx] - offset*(1-shrink)
        
        
        
        
#         if axis=='l': #xa - ticks between A and C, keep xb=xd=0
#             pts = np.column_stack((x1,zeros,x2,zeros))
#             txt_pts = np.copy(pts)
#             dash_end = np.copy(pts)
#             #text locations
#             txt_pts[:,[0,2]] = txt_pts[:,[0,2]] + offset #offset in A-C direction
#             txt_pts[:,[1,3]] = txt_pts[:,[1,3]] - offset
#             #dash endpoints
#             dash_end[:,[0,2]] = dash_end[:,[0,2]] + offset*(1-shrink) 
#             dash_end[:,[1,3]] = dash_end[:,[1,3]] - offset*(1-shrink)
            
#             print (txt_pts)
#             print(dash_end)
#         elif axis=='r': #xb - ticks between A and B, keep xc=xd=0
#             pts = np.column_stack((x2,x1,zeros,zeros))
#             txt_pts = np.copy(pts)
#             dash_end = np.copy(pts)
#             txt_pts[:,[0,1]] = txt_pts[:,[0,1]] + offset #offset in A-B direction
#             txt_pts[:,[2,3]] = txt_pts[:,[2,3]] - offset
#             #dash endpoints
#             dash_end[:,[0,1]] = dash_end[:,[0,1]] + offset*(1-shrink) 
#             dash_end[:,[2,3]] = dash_end[:,[2,3]] - offset*(1-shrink)
#             print (txt_pts)
#             print(dash_end)
#         elif axis=='b': #xc
#             pts = np.column_stack((x2,x1,np.zeros(len(x1))))
#             txt_pts = np.copy(pts)
#             txt_pts[:,2] = txt_pts[:,2] - offset #offset in negative c direction
        labels = x1
        for p,tp,de, l in zip(pts,txt_pts,dash_end,labels):
            cp = fractions_to_cartesian([p])[0]
            cpt = fractions_to_cartesian([tp])[0]
            cde = fractions_to_cartesian([de])[0]
            self.ax.text(x=cpt[0],y=cpt[1],z=cpt[2],
                         s=('{:>{fmt}}'.format(l,fmt=tick_format)),ha='center',va='center')
                            #xy=cp, xytext=cpt)#, arrowprops=dict(facecolor='black',arrowstyle='-'))
            #print(p,tp)
            #print(np.stack((p,de)))
            self.plot(np.stack((p,de)), color='k') #plot dashes
    

    
    def gridlines(self,gridspace=0.2,add_ticks=True,**kwargs):
        "not updated"
        pos = np.arange(0,self.scale,gridspace)
        points = []
        for x1 in pos:
            x2 = self.scale - x1
            pts1 = [[x1,x2,0],[x1,0,x2]]
            pts2 = [[x2,x1,0],[0,x1,x2]]
            pts3 = [[x2,0,x1],[0,x2,x1]]
            points = [pts1,pts2,pts3]
            for p in points:
                self.plot(p,**kwargs)
#                 self.plot(pts1,**kwargs)
#                 self.plot(pts2,**kwargs)
#                 self.plot(pts3,**kwargs)
#                 if add_ticks==True:
#                     cp = frac_to_cart(p)
#                     self.ax.annotate()
        
    def label_corner(self,corner,label,offset=0.05,**kwargs):
        """
        Label a single corner of the triange.
        
        corner:    corner to label (left, right, back, or top)
        label:  label text
        offset: distance from corner in fractional coordinates. Default 0.05
        kwargs: kwargs to pass to pyploy annotate
        """
        if corner in ['left','l']:
            frac = np.array([[1+offset,-offset/3,-offset/3,-offset/3]])
        elif corner in ['right','r']:
            frac = np.array([[-offset/3,1+offset,-offset/3,-offset/3]])
        elif corner in ['back','b']:
            frac = np.array([[-offset/3,-offset/3,1+offset,-offset/3]])
        elif corner in ['top','t']:
            frac = np.array([[-offset/3,-offset/3,-offset/3,1+offset]])
        else:
            raise Exception('Invalid corner argument: {}. Choose left, right, back, or top (or short name l, r, b, or t)'.format(corner))
        x,y,z = fractions_to_cartesian(frac,normalize=False)[0]
        self.ax.text(x,y,z,s=label,**kwargs)
        #update axis_labels dict
        self.axis_labels[label] = corner
    
    def label_all_corners(self,labels,offset=0.05,order='lrbt',**kwargs):
        """
        Label all corners of the triangle.
        
        labels: list of label strings
        offset: distance from corner in fractional coordinates. Default 0.05
        order: order in which labels are given. Default is lrbt (left, right, back, top)
        kwargs: kwargs to pass to pyplot text
        """
        #corner_dict = {'l':'left','r':'right','b':'back','t':'top'}
        for i, s in enumerate(labels):
            corner = order[i]
            #corner = corner_dict[c]
            self.label_corner(corner,s,offset=offset,**kwargs)
            
   
        
    
    #def draw_gridlines(self, gridspace=0.2,**kwargs):

    
    def draw_normslice(self,slice_axis,slice_start,**kwargs):
        if slice_axis not in ('l','r','b','t'):
            try:
                slice_axis = self.axis_labels[slice_axis]
            except KeyError:
                raise Exception('{} not found in axis_labels'.format(slice_axis))
        axis_num = self.axis_num[slice_axis]
        
        #create array of points to define plane (in fractional coords)
        plane_pts = np.zeros([4,4])
        plane_pts[:,axis_num] = slice_start
        for r in range(len(plane_pts)):
            plane_pts[r,r] = 1 - slice_start
        plane_pts = np.delete(plane_pts,axis_num,axis=0)
        
        #convert to cartesian
        cpts = fractions_to_cartesian(plane_pts)
        x = cpts[:,0]
        y = cpts[:,1]
        z = cpts[:,2]
        
        #print(plane_pts)
        
        self.ax.plot_trisurf(x,y,z,**kwargs)
        
        
            
        

In [117]:
fig = plt.figure(figsize=[8,6])
ax = fig.add_subplot(111,projection='3d')
qf = quaternary_axes(ax=ax)
qf.draw_axes(color='k')
#expand_axeslims(qf.ax,1.1)
qf.label_all_corners(labels=['Zr','Co','Fe','Y'],offset=0.15,ha='center',va='center',size=16)



#qf.ax.plot_trisurf(x,y,z,antialiased=True,cmap=plt.cm.viridis,alpha=0.5)
qf.ax.axis('off')

qf.draw_normslice(slice_axis='Y',slice_start=0.,alpha=0.8)
qf.draw_normslice(slice_axis='Y',slice_start=0.1,alpha=0.8)
qf.draw_normslice(slice_axis='Y',slice_start=0.2,alpha=0.8)
qf.axis_ticks('l',shrink=0.6)
qf.axis_ticks('r')
qf.axis_ticks('b')
qf.axis_ticks('t')

<IPython.core.display.Javascript object>

In [111]:
x1 = np.arange(0.2,1,0.2) #ascending order
x2 = 1 - x1 #descending order
zeros = np.zeros(len(x1))
pts = np.column_stack((x1,zeros,x2,zeros))
t = np.sum(pts,axis=0)
list(np.where(t==0)[0])

[1, 3]

In [219]:
qf.axis_labels

{'b': 'Fe', 'l': 'Zr', 'r': 'Co', 't': 'Y'}

In [154]:
import matplotlib.tri as mtri

In [66]:
fig = plt.figure()
ax = fig.add_subplot(111,projection='3d')

ax.scatter(0,1,1)

ax.text(x=0.02,y=1.02,z=1.02,s='test',withdash=True)
twd.set_position((0,1,1))
twd.draw(fig.renderer)

<IPython.core.display.Javascript object>

AttributeError: 'Figure' object has no attribute 'renderer'

In [189]:
X


array([[0, 1, 1],
       [0, 1, 1],
       [0, 1, 1]])

In [190]:
Y

array([[0, 0, 0],
       [1, 1, 1],
       [0, 0, 0]])

In [192]:
tri.neighbors

array([[-1, -1,  1],
       [-1, -1,  0]], dtype=int32)

In [34]:
from mpl_toolkits.mplot3d.proj3d import proj_transform
from matplotlib.text import Annotation

class Annotation3D(Annotation):
    '''Annotate the point xyz with text s'''

    def __init__(self, s, xyz, *args, **kwargs):
        Annotation.__init__(self,s, xy=(0,0), *args, **kwargs)
        self._verts3d = xyz        

    def draw(self, renderer):
        xs3d, ys3d, zs3d = self._verts3d
        xs, ys, zs = proj_transform(xs3d, ys3d, zs3d, renderer.M)
        self.xy=(xs,ys)
        Annotation.draw(self, renderer)
        
def annotate3D(ax, s, *args, **kwargs):
    '''add anotation text s to to Axes3d ax'''

    tag = Annotation3D(s, *args, **kwargs)
    ax.add_artist(tag)

In [46]:
import matplotlib.pyplot as plt    
from mpl_toolkits.mplot3d import axes3d
from mpl_toolkits.mplot3d.art3d import Line3DCollection

# data: coordinates of nodes and links
xn = [1.1, 1.9, 0.1, 0.3, 1.6, 0.8, 2.3, 1.2, 1.7, 1.0, -0.7, 0.1, 0.1, -0.9, 0.1, -0.1, 2.1, 2.7, 2.6, 2.0]
yn = [-1.2, -2.0, -1.2, -0.7, -0.4, -2.2, -1.0, -1.3, -1.5, -2.1, -0.7, -0.3, 0.7, -0.0, -0.3, 0.7, 0.7, 0.3, 0.8, 1.2]
zn = [-1.6, -1.5, -1.3, -2.0, -2.4, -2.1, -1.8, -2.8, -0.5, -0.8, -0.4, -1.1, -1.8, -1.5, 0.1, -0.6, 0.2, -0.1, -0.8, -0.4]
group = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 2, 2, 3, 3, 3, 3]
edges = [(1, 0), (2, 0), (3, 0), (3, 2), (4, 0), (5, 0), (6, 0), (7, 0), (8, 0), (9, 0), (11, 10), (11, 3), (11, 2), (11, 0), (12, 11), (13, 11), (14, 11), (15, 11), (17, 16), (18, 16), (18, 17), (19, 16), (19, 17), (19, 18)]
xyzn = np.column_stack((xn,yn,zn))#zip(xn, yn, zn)
segments = [(xyzn[s], xyzn[t]) for s, t in edges]                

# create figure        
fig = plt.figure(dpi=60)
ax = fig.gca(projection='3d')
#ax.set_axis_off()

# plot vertices
ax.scatter(xn,yn,zn, marker='o', c = group, s = 64)    
# plot edges
edge_col = Line3DCollection(segments, lw=0.2)
ax.add_collection3d(edge_col)
# add vertices annotation.
for j, xyz_ in enumerate(xyzn): 
    annotate3D(ax, s=str(j), xyz=xyz_, fontsize=10, xytext=(-10,0),
               textcoords='offset points', ha='right',va='bottom')    
plt.show()

<IPython.core.display.Javascript object>

In [39]:
np.column_stack((xn,yn,zn))

array([[ 1.1, -1.2, -1.6],
       [ 1.9, -2. , -1.5],
       [ 0.1, -1.2, -1.3],
       [ 0.3, -0.7, -2. ],
       [ 1.6, -0.4, -2.4],
       [ 0.8, -2.2, -2.1],
       [ 2.3, -1. , -1.8],
       [ 1.2, -1.3, -2.8],
       [ 1.7, -1.5, -0.5],
       [ 1. , -2.1, -0.8],
       [-0.7, -0.7, -0.4],
       [ 0.1, -0.3, -1.1],
       [ 0.1,  0.7, -1.8],
       [-0.9, -0. , -1.5],
       [ 0.1, -0.3,  0.1],
       [-0.1,  0.7, -0.6],
       [ 2.1,  0.7,  0.2],
       [ 2.7,  0.3, -0.1],
       [ 2.6,  0.8, -0.8],
       [ 2. ,  1.2, -0.4]])