In [1]:
# Periodic Table 
import mendeleev as mv 
from mendeleev.fetch import fetch_table 

# basic XRF physics 
from fisx import Elements 

# advanced XRF stuff that we do no need now   
#from fisx import Material
#from fisx import Detector
#from fisx import XRF 

# peak finding 
import scipy.signal as ssg
import scipy

# plotting 
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches 
import matplotlib.cm as cm 
import seaborn as sb

In [2]:
def make_ptable(): 
    '''Make numpy array with element attributes for regular part of the periodic table.
    
    The irregular Lanthanides and Actinides series are rare, so we do not plot them'''
    
    #ptable_df = mv.get_table('elements') # mendeleev 0.5 
    ptable_df = fetch_table('elements')
    ptable = ptable_df[['atomic_number', 'symbol', 'name', 'group_id', 'period']].values
    is_regular = ~np.isnan(ptable_df['group_id'].values)
    
    return ptable, is_regular

def _colorize(elem_select, crop=True): 
    '''High contrast standard color palette for list of elements *elem_select*.
    
    '''
    
    DEFAULT_COLOR = [0.4, 0.4, 0.9, 0.3] 
    
    # return single color tuple is argument is string, e.g. 'Fe' 
    return_single_color_tuple = False 
    if type(elem_select) == str:
        return_single_color_tuple = True 
        elem_select = [elem_select]
        
    # create standard palette for Hydrogen(Z=1) to Organesson(Z=117) 
    palette = []
    for tone in ['pastel', 'bright', 'deep', 'dark'] * 3:  
        palette.extend(sb.palettes.color_palette(tone))
    palette = np.array(palette)[0:118]
    # and include unity alpha channel
    palette = np.c_[palette, np.ones(len(palette))]
    
    # return full color palette if argument is empty list 
    if len(elem_select) == 0:
        return palette
 
    else: 
        # find index position (Z-1) of elements   
        ptable_indices = [mv.element(e).atomic_number - 1 for e in elem_select]
        
        # initialize all element color list with light blue grey 
        colors = np.zeros([118, 4])
        colors[:] = DEFAULT_COLOR 
        
        # colorize selected elements with standard palette color 
        colors[ptable_indices] = palette[ptable_indices]
        
        # skip colors for non selected elements 
        if crop: 
            colors = palette[ptable_indices]

            # for single element just return color tuple 
            if return_single_color_tuple: 
                return colors[0]   

            else: 
                return colors, ptable_indices 
            
        return colors, ptable_indices

def _draw_box(ax, element_attrs, edgecolor=None, facecolor=None):
    '''Draw a box for element *element_attrs*  in subplot *ax*. '''
    
    w = 0.9
    h = 0.9
    
    atomic_number, symbol, name, x, y = element_attrs 
    
    if edgecolor is None: 
        edgecolor = [0.8, 0.8, 1.0]
        
    if facecolor is None: 
        facecolor = [0.95, 0.95, 1.0]
    
    rect = mpatches.Rectangle([x - w/2, y - h/2], h, w, edgecolor=edgecolor, facecolor=facecolor)
    ax.add_patch(rect)
    
    ax.annotate(symbol, [x, y], xytext=[0, -4], textcoords='offset points',
                va='center', ha='center', fontsize=17)
    
    ax.annotate(atomic_number, [x, y], xytext=[0, 12], textcoords='offset points', 
                va='center', ha='center', fontsize=10)
    
    ax.annotate(name, [x, y], xytext=[0, -17], textcoords='offset points',
                va='center', ha='center', fontsize=5)

def ptable_plot(elem_select=None, figname=None): 
    '''Create periodic table plot with selected elements colorized.'''
    
    if elem_select is None: 
        elem_select = []
    
    # initialize full periodic table 
    ptable, is_regular = make_ptable() 
    
    # colorize full table 
    colors, _ = _colorize(elem_select, crop=False) 
    
    # continue with regular elements only  
    ptable = ptable[is_regular] 
    colors = colors[is_regular]
    
    # create figure 
    fig, ax = plt.subplots(figsize=[14, 8])
    ax.invert_yaxis()

    for i, element_attrs in enumerate(ptable):
        
        _draw_box(ax, element_attrs, facecolor=colors[i])

    ax.set_xlim(0, 19)
    ax.set_ylim(8, 0)
    ax.axis('off');
    
    if figname is not None: 

        fig.savefig(figname)

def moseley_law(E_K_alpha_keV): 
    '''Square root form of Moseley's law. 
    
    Args: 
        E_K_alfa_keV (float or array of floats): K_alpha peak energy in keV 
        
    Example: 
        moseley_law(6.40)
        
    Returns: 
        Z (float or array of floats): predicted atomic number 
        
    ''' 
    Z = 1 + np.sqrt(1000 * E_K_alpha_keV / 10.2) 
    
    return Z 


def moseley_plot(tube_keV, elem_select=None, weight_list='equal', law=True, figname=None): 
    '''Generate Moseley plot of simulated xrf spectra for periodic table.
    
    Args: 
        elem_select (list of str): List of symbols of elements. 
            All selected elements will be highlighted with bright colors and peak labels. 
        tube_keV (float or list of floats): X-ray tube energy in keV. 
        weight_list (list of floats): X-ray tube intensities. 
            Optional, defaults to 'equal'. 
        
    Example: 
        moseley_plot(tube_keV=40, elem_select=['Fe', 'Ca']);
    
    Returns: 
        matplotlib figure 
            
    '''


    if elem_select is None: 
        elem_select = []

    # initialize full periodic table 
    ptable, is_regular = make_ptable() 

    # colorize full table 
    colors, indices_selected = _colorize(elem_select, crop=False) 

    atomic_numbers, symbols, names = ptable.T[0:3]

    fig, ax = plt.subplots(figsize=[12, 8])
    
    E_keV_list = np.linspace(0, np.max(tube_keV), 2000) 
    Z = moseley_law(E_keV_list) 
    
    ax.plot(E_keV_list, Z, zorder=-10, linestyle=':', color='k', label="Moseley's law")

    Pb_Si = symbols[13:82] [::-1]
    Pb_Si_colors = colors[13:82] [::-1] 

    for s, c in list(zip(Pb_Si, Pb_Si_colors)): 
        if s in elem_select: 
            peak_labels = 'simple' 

            # ideal spectrum 
            XFluo(s, tube_keV, weight_list=weight_list).plot(ax=ax, color=c, up=False, mos=True, 
                                                             peak_labels=peak_labels)

        else: 
            peak_labels = 'none'

        XFluo(s, tube_keV, weight_list=weight_list).plot(ax=ax, color=c, up=True, peak_labels=peak_labels)

    #ax.set_yticks(atomic_numbers)
    #ax.set_yticklabels(tick_labels)

    ax.set_ylabel('Atomic number')
    ax.set_xticks(range(int(tube_keV)))
    ax.set_xlim(0, np.max(tube_keV))
    ax.set_ylim(0, 90)
    ax.grid(False) 
    ax.set_title(f'Moseley plot (x-ray tube at {np.max(tube_keV)} keV)')
    ax.legend()
    plt.tight_layout()
    
    if figname is not None: 
        print(f'Saving plot as: {figname}...')
        fig.savefig(figname)
        
    return fig