# Import packages

In [None]:
%matplotlib inline

from tiled.client import from_uri
from tiled.client.cache import Cache
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
from IPython.display import display
from larch.xafs import pre_edge, find_e0, preedge
import palettable as pltt
from collections import defaultdict
from lmfit import Model

import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

# Constants

In [None]:
client = from_uri("https://localhost:8008", verify=False, username="dallan", cache=Cache.in_memory(1e9))
print(client)
BEAMLINE = 'bmm'
client[BEAMLINE]

# Functions

### Merge scans

In [None]:
def merge(counts_list):
    total_counts = 0
    for counts in counts_list:
        total_counts += counts
    return total_counts/len(counts_list)

### Plot  and reture data dictionary

In [None]:
def plot(element, data_dictionary, reference_list=[], transmission_list=[], offset=0, plot_range=[0, 0, 0, 0], norm=True):
    
    fig, ax = plt.subplots(figsize = (6, 7.5))
    palette = pltt.colorbrewer.sequential.YlOrRd_3     # .colorbrewer.diverging.Spectral_4_r, .colorbrewer.sequential.Greens_9     
    cmap = palette.mpl_colormap
    color_idx = np.linspace(0, 1, len(data_dictionary))
    number_of_detector = 4
    plot_increment = 0
    output_dictionary = defaultdict(list)
    
    for sample_name, scan_id_list in data_dictionary.items():
        print(f'Processing {sample_name}...')
        merge_list = []
        linestyle = '-'
        
        if sample_name in reference_list:
                for scan_index, scan_id in enumerate(scan_id_list):
                    it = np.array(client['bmm']['raw'][scan_id]['primary']['data']['It'])
                    ir = np.array(client['bmm']['raw'][scan_id]['primary']['data']['Ir'])
                    reference_counts = np.log(it/ir)
                merge_list.append(reference_counts)
                linestyle = '--'
                
        elif sample_name in transmission_list:
                for scan_index, scan_id in enumerate(scan_id_list):
                    i0 = np.array(client['bmm']['raw'][scan_id]['primary']['data']['I0'])
                    it = np.array(client['bmm']['raw'][scan_id]['primary']['data']['It'])
                    transmission_counts = np.log(i0/it)
                merge_list.append(transmission_counts)
                linestyle = '--'
                
        else:
            for scan_index, scan_id in enumerate(scan_id_list):
                fluorescence_total_counts = 0
                for index in range(1, 1+number_of_detector):
                    fluorescence_detector = str(f'{element}{index}')
                    fluorescence_counts = np.array(client['bmm']['raw'][scan_id]['primary']['data'][fluorescence_detector])
                    i0 = client['bmm']['raw'][scan_id]['primary']['data']['I0']
                    fluorescence_total_counts += fluorescence_counts/i0
                merge_list.append(fluorescence_total_counts)

        merge_counts = merge(merge_list)
        
        x = client['bmm']['raw'][scan_id_list[0]]['primary']['data']['dcm_energy']
        y = merge_counts
        
        output_dictionary[sample_name] = [x, y]
        
        xafs_info_dictionary = preedge(x, y)
        normalized_y = xafs_info_dictionary['norm']
        if norm:
            plt.plot(x, normalized_y + offset * plot_increment, label=f'{sample_name}', linestyle=linestyle, color=cmap(color_idx[plot_increment]))
        else:
            plt.plot(x, y + offset * plot_increment, label=f'{sample_name}', linestyle=linestyle, color=cmap(color_idx[plot_increment]))
        plot_increment += 1
        
    plt.xlabel(r'$\mathregular{Energy\ (eV)}$')
    plt.ylabel(r'$\mathregular{\chi\mu(E)}$')
    
    xmin = plot_range[0]
    xmax = plot_range[1]
    ymin = plot_range[2]
    ymax = plot_range[3]
    
    if plot_range == [0, 0, 0, 0]:
        ...
    else:
        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)
        
    plt.legend(loc='lower right')
    plt.tight_layout()
    plt.show()
    
    return output_dictionary

# Data plot

### Basic plot

Set your K-edge energy(e.g Cu, Nb, Sc).

Build a data dictionary containing sample name as key and scan IDs as value. The plot function will merge all the scans in the same sample name.

If your samples are reference and measured in transmission mode, you may list them in the reference_list or transmission_list.

Because the plot function will return the merged scan energy and counts, please assign a new variable.

Let's do the simple plotting!

In [None]:
element = 'Nb'
data_dictionary = {'Pristine':[21695, 21696, 21697],
                   '400C60M':[21747, 21748, 21749], 
                   '500C60M':[21760, 21761, 21762],
                   '600C60M':[21773, 21774, 21775],
                   '700C60M':[21786, 21787, 21788],
                   'Pure Nb':[21555, 21556]}
Nb_dictionary = plot(element, data_dictionary, reference_list=['Pure Nb'], norm=False)

### Data normalization

Well...your data did not align. You couldn't even see the pure element data! It's hard to compare with different data.

Let's normalize the data!

To focus on the region we're interested, we zoon in the white region by setting the plot range and shift the data curve by adding offset.

In [None]:
plot_range = [18900, 19150, 0, 1.9]
Nb_dictionary = plot(element, data_dictionary, reference_list=['Pure Nb'], plot_range=plot_range, offset=0.15, norm=True)

Great! You make a better plot!

Let's try another case in Sc

In [None]:
element = 'Sc'
data_dictionary = {'Pure Sc':[36495, 36502],
                   'Pristine':[21589, 21590, 21591],
                   '400C60M':[21289, 21290, 21291], 
                   '500C60M':[21654, 21655, 21656],
                   '600C60M':[21667, 21668, 21669],
                   '700C60M':[21680, 21681, 21682],
                   'Sc2O3':[36508, 36509]}
plot_range = [4425, 4625, 0, 2.5]
Sc_dictionary = plot(element, data_dictionary, transmission_list=['Pure Sc', 'Sc2O3'], offset=0.1, plot_range=plot_range, norm=True)

# Linear combination fitting

data_dictionary: {sample_name: [energy, counts]}

In [None]:
len(Sc_dictionary['Pristine'][0])

In [None]:
print('Mystery phase length:', len(Sc_dictionary['Pristine'][0]))
print('Phase 1 length:', len(Sc_dictionary['Pure Sc'][0]))
print('Phase 2 length:', len(Sc_dictionary['Sc2O3'][0]))

def phase1(energy):
    return np.interp(energy, Sc_dictionary['Pure Sc'][0], Sc_dictionary['Pure Sc'][1])
   
def phase2(energy):
    return np.interp(energy, Sc_dictionary['Sc2O3'][0], Sc_dictionary['Sc2O3'][1])

   
def f(energy, fraction):
    return fraction * phase1(energy) + (1 - fraction) * phase2(energy)
   

pristine_energy = Sc_dictionary['Pristine'][0]  # independent variable, "x axis"
pristine_phase = Sc_dictionary['Pristine'][1] # dependent variable, mystery phase to fit

m = Model(f)  # lmfit Model
# Check that these make sense...
print(m.param_names)
print(m.independent_vars)
# lmfit Parameter, with initial guess set here
params = m.make_params(fraction=dict(value=0.5, min=0, max=1))
# Fit model to y, given energy (x axis) and params (with initial guess)
result = m.fit(pristine_phase, energy=pristine_energy, params=params)
result.params

In [None]:
fig, ax = plt.subplots()
ax.plot(Sc_dictionary['Sc2O3'][0], Sc_dictionary['Sc2O3'][1], label="Sc2O3")
ax.plot(Sc_dictionary['Pure Sc'][0], Sc_dictionary['Pure Sc'][1], label="Pure Sc")
ax.plot(Sc_dictionary['Pristine'][0], phase2(Sc_dictionary['Pristine'][0]), ls="dashed", label="Sc2O3 (interp)")
ax.plot(Sc_dictionary['Pristine'][0], phase1(Sc_dictionary['Pristine'][0]), ls="dashed", label="Pure Sc (interp)")
ax.plot(Sc_dictionary['Pristine'][0], Sc_dictionary['Pristine'][1], label="Pristine")
ax.set_yscale("log")
ax.set_xlabel("energy")
ax.set_ylabel("counts")
ax.legend()

In [None]:
print('Mystery phase:', Sc_dictionary['Pristine'][0][20])
print('Phase 1:', Sc_dictionary['Pure Sc'][0][20])
print('Phase 2:', Sc_dictionary['Sc2O3'][0][20])