In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

def draw_BZ_border(ax, BZ_array, order_dir=None, **kwargs):

    """
    plot the border of the first Brillouin zone.
    """
    if order_dir is not None:
      
        for i in range(BZ_array.shape[0]-1):
         
            # foreground
            zf = 10
            lsf = '-'

            # background
            zb = 0
            lsb = '--'
         
            if order_dir.shape[0] == 3:
                dirvec = order_dir/np.linalg.norm(order_dir)
            elif order_dir.shape[0] == 2:
                azim = order_dir[0]/180.0*np.pi
                elev = order_dir[1]/180.0*np.pi
                r = 1.0

                x = r * np.cos(elev) * np.cos(azim)
                y = r * np.cos(elev) * np.sin(azim)
                z = r * np.sin(elev)

                dirvec = np.array([x,y,z])

            else:
                sys.exit('draw_BZ_border: Wrong order_dir')

            lim = -0.4

            r1 = BZ_array[i,:]
            r2 = BZ_array[i+1,:]

            p1 = np.dot(dirvec,r1)
            p2 = np.dot(dirvec,r2)

            if p1 >= lim and p2 >= lim:
                zorder = zf
                ls = lsf
                c = 'red'
            else:
                zorder = zb
                ls = lsb
                c = 'blue'

            ax.plot( [r1[0],r2[0]], [r1[1],r2[1]], [r1[2],r2[2]], zorder=zorder, ls = ls,**kwargs )

    else:
        ax.plot(BZ_array[:,0],BZ_array[:,1],BZ_array[:,2],**kwargs)



def plot_border(fig, ax, kpts, bg):
    """
    plot the border of the first Brillouin zone.
    """
   
    BZ_array = np.array([[ 0.25, -0.25, -0.5 ],
          [-0.25, -0.5 , -0.75],
          [-0.5 , -0.25, -0.75],
          [-0.25,  0.25, -0.5 ],
          [ 0.25,  0.5 , -0.25],
          [ 0.5 ,  0.25, -0.25],
          [ 0.25, -0.25, -0.5 ],
          [ 0.25, -0.5 , -0.25],
          [-0.25, -0.75, -0.5 ],
          [-0.25, -0.5 , -0.75],
          [-0.25, -0.75, -0.5 ],
          [-0.5 , -0.75, -0.25],
          [-0.25, -0.5 ,  0.25],
          [ 0.25, -0.25,  0.5 ],
          [ 0.5 , -0.25,  0.25],
          [ 0.25, -0.5 , -0.25],
          [ 0.5 , -0.25,  0.25],
          [ 0.75,  0.25,  0.5 ],
          [ 0.75,  0.5 ,  0.25],
          [ 0.5 ,  0.25, -0.25],
          [ 0.75,  0.5 ,  0.25],
          [ 0.5 ,  0.75,  0.25],
          [ 0.25,  0.5 , -0.25],
          [ 0.5 ,  0.75,  0.25],
          [ 0.25,  0.75,  0.5 ],
          [-0.25,  0.5 ,  0.25],
          [-0.5 ,  0.25, -0.25],
          [-0.25,  0.25, -0.5 ],
          [-0.5 ,  0.25, -0.25],
          [-0.75, -0.25, -0.5 ],
          [-0.5 , -0.25, -0.75],
          [-0.75, -0.25, -0.5 ],
          [-0.75, -0.5 , -0.25],
          [-0.5 , -0.75, -0.25],
          [-0.25, -0.5 ,  0.25],
          [-0.5 , -0.25,  0.25],
          [-0.75, -0.5 , -0.25],
          [-0.5 , -0.25,  0.25],
          [-0.25,  0.25,  0.5 ],
          [ 0.25,  0.5 ,  0.75],
          [ 0.5 ,  0.25,  0.75],
          [ 0.25, -0.25,  0.5 ],
          [ 0.5 , -0.25,  0.25],
          [ 0.75,  0.25,  0.5 ],
          [ 0.5 ,  0.25,  0.75],
          [ 0.25,  0.5 ,  0.75],
          [ 0.25,  0.75,  0.5 ],
          [-0.25,  0.5 ,  0.25],
          [-0.25,  0.25,  0.5 ]])
   

    BZ_array = np.matmul(BZ_array, bg)


    plot_type = 'border'
    ax.set_axis_off()
    azim, elev = 45, 19
    ax.view_init(azim=azim,elev=elev) 
    draw_BZ_border(ax, BZ_array, order_dir=np.array([azim,elev]), c='grey', alpha=0.8, lw=1)
   
   
    return fig, ax



### load data
kpts = np.loadtxt('kpts.dat')
bg = np.loadtxt('bg.dat')

#### load data
epop0 = np.loadtxt('electron_pop_0ps.dat')[np.newaxis,:,:]
epop1 = np.loadtxt('electron_pop_1ps.dat')[np.newaxis,:,:]
epop2 = np.loadtxt('electron_pop_20ps.dat')[np.newaxis,:,:]
epop = np.concatenate((epop0, epop1, epop2))

# plotting electron population

colors = ['orangered', 'ivory', 'blue']
alphas = [0.8, 0.3, 0.05] # in the manuscript plot, this list is [0.8, 0.3, 0.01]

fig = plt.figure(figsize=(5,4)) # in the original plot, this number is (10,8)
ax = fig.add_subplot(projection='3d')
fig, ax = plot_border(fig, ax, kpts, bg)
for it in range(3):
    cnorms_factor = 200 # for the manuscript figure, this number is 200
    shift = 0.001 * it # shift the dots by a bit so that they don't lie on top of each other
    for ibnd in range(np.shape(epop)[2]):
        ax.scatter(kpts[:, 0] + shift, kpts[:, 1] + shift, kpts[:, 2] + shift,
                  s=epop[it,:,ibnd], marker='o', c=colors[it], alpha=alphas[it], lw=0)


ax.set_axis_off()
ax.set_xlim([-1,1])
ax.set_ylim([-1,1])
ax.set_zlim([-1,1])


(-1.0, 1.0)