# Access data from an NSLS2 experiment and fit a model to it

This notebook was adapted from work by Yu-chen Karen Chen-Wiegart, Cheng-Chu Chung, and Xiaoyin Zheng

<table>
    <tr>
        <td>
            <img src="https://www.bnl.gov/assets/global/images/render.php?q=0|24919.jpg|500" width="150px" />
        </td>
        <td>
            <img src="https://www.stonybrook.edu/commcms/chen-wiegart/group/_images/2020PhD_Cheng-Chu.jpg" width="150 px" />
        </td>
        <td>
            <img src="https://www.stonybrook.edu/commcms/chen-wiegart/group/_images/Xiaoyin%20Zheng.jpg" width="150 px" />
        </td>
    </tr>
</table>

In [None]:
%matplotlib inline

from tiled.client import from_uri
from tiled.utils import tree
import numpy as np
import matplotlib.pyplot as plt
from pprint import pprint
from IPython.display import display
from larch.xafs import pre_edge, find_e0, preedge
import palettable as pltt
from collections import defaultdict
from lmfit import Model

# Override some verbose logging settings made by a third-party library.
import logging
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

Connect to the demo Tiled data server, which has a copy of some real data from BMM.

In [None]:
client = from_uri("https://tiled-demo.blueskyproject.io")["bmm"]["raw"]
client

In [None]:
client[21695]

In [None]:
client[21696]['primary']

In [None]:
client[21696]['primary']['data']

In [None]:
client[21696]['primary']['data'].read()

In [None]:
client[21696]['baseline']['data'].read()

In [None]:
client[21696]['primary']['data']['It']

In [None]:
client[21696]['primary']['data']['It'][:10]

In [None]:
client[21696]['primary']['data']['It'][:]

In [None]:
client[21696]['primary']['data']['It'].read()

# Functions

### Plot  and reture data dictionary

In [None]:
def plot(
    input_dict,
    reference_list=(),
    transmission_list=(),
    offset=0,
    plot_range=None,
):
    fig, ax = plt.subplots(figsize=(6, 7.5), layout="constrained")
    palette = (
        pltt.colorbrewer.sequential.YlOrRd_3
    )  # .colorbrewer.diverging.Spectral_4_r, .colorbrewer.sequential.Greens_9
    cmap = palette.mpl_colormap
    color_idx = np.linspace(0, 1, len(input_dict))

    for plot_increment, (sample_name, (x, y)) in enumerate(input_dict.items()):
        if sample_name in reference_list or sample_name in transmission_list:
            linestyle = "--"
        else:
            linestyle = "-"

      
        ax.plot(
                x,
                y + offset * plot_increment,
                label=f"{sample_name}",
                linestyle=linestyle,
                color=cmap(color_idx[plot_increment]),
            )


    ax.set_xlabel(r"$\mathregular{Energy\ (eV)}$")
    ax.set_ylabel(r"$\mathregular{\chi\mu(E)}$")
    
    if plot_range is not None:
        xmin, xmax, ymin, ymax = plot_range
        ax.set_xlim(xmin, xmax)
        ax.set_ylim(ymin, ymax)
        
    ax.legend(loc="lower right")
    plt.show()


def extract(
    element, data_dictionary, reference_list=(), transmission_list=(), *, client=client
):

    number_of_detector = 4
    output_dictionary = defaultdict(list)

    for sample_name, scan_id_list in data_dictionary.items():
        print(f"Processing {sample_name}...")
        merge_list = []

        if sample_name in reference_list:
            print('doing pure')
            for scan_index, scan_id in enumerate(scan_id_list):
                data = client[scan_id]["primary"]["data"]
                it = np.array(data["It"])
                ir = np.array(data["Ir"])
                reference_counts = np.log(it / ir)
                merge_list.append(reference_counts)

        elif sample_name in transmission_list:
            for scan_index, scan_id in enumerate(scan_id_list):
                data = client[scan_id]["primary"]["data"]
                i0 = np.array(data["I0"])
                it = np.array(data["It"])
                transmission_counts = np.log(i0 / it)
                merge_list.append(transmission_counts)

        else:
            for scan_index, scan_id in enumerate(scan_id_list):
                fluorescence_total_counts = 0
                data = client[scan_id]["primary"]["data"]
                for index in range(1, 1 + number_of_detector):
                    fluorescence_detector = str(f"{element}{index}")
                    fluorescence_counts = np.array(
                        data[fluorescence_detector]
                    )
                    i0 = data["I0"]
                    fluorescence_total_counts += fluorescence_counts / i0
                merge_list.append(fluorescence_total_counts)

        merge_counts = np.mean(merge_list, axis=0)

        x = client[scan_id_list[0]]["primary"]["data"]["dcm_energy"]
        y = merge_counts

        output_dictionary[sample_name] = (x, y)

    return output_dictionary

def normalize_xafs(input_dict):
    return_dict = {}
    for k, v in input_dict.items():
        x, y = v
        xafs_info = preedge(*v)
        return_dict[k] = (x, xafs_info['norm'])
    return return_dict

# Data plot

### Basic plot

Set your K-edge energy(e.g Cu, Nb, Sc).

Build a data dictionary containing sample name as key and scan IDs as value. The plot function will merge all the scans in the same sample name.

If your samples are reference and measured in transmission mode, you may list them in the reference_list or transmission_list.

Because the plot function will return the merged scan energy and counts, please assign a new variable.

Let's do the simple plotting!

In [None]:
element = 'Nb'
data_dictionary = {'Pristine':[21695, 21696, 21697],
                   '400C60M':[21747, 21748, 21749], 
                   '500C60M':[21760, 21761, 21762],
                   '600C60M':[21773, 21774, 21775],
                   '700C60M':[21786, 21787, 21788],
                   'Pure Nb':[21555, 21556]}
                   

In [None]:
Nb_dictionary = extract(element, data_dictionary, reference_list=('Pure Nb', ), transmission_list=(), client=client)

In [None]:
plot(Nb_dictionary, reference_list=['Pure Nb'])

### Data normalization

Well...your data did not align. You couldn't even see the pure element data! It's hard to compare with different data.

Let's normalize the data!

To focus on the region we're interested, we zoon in the white region by setting the plot range and shift the data curve by adding offset.

In [None]:
plot_range = [18900, 19150, 0, 1.9]
# Nb_dictionary = plot(element, data_dictionary, reference_list=['Pure Nb'], plot_range=plot_range, offset=0.15, norm=True)
Nb_dictionary = normalize_xafs(Nb_dictionary)
plot(Nb_dictionary, reference_list=['Pure Nb'], plot_range=plot_range, offset=0.15)

Great! You make a better plot!

Let's try another case in Sc

In [None]:
element = 'Sc'
data_dictionary = {'Pure Sc':[36495, 36502],
                   'Pristine':[21589, 21590, 21591],
                   '400C60M':[21289, 21290, 21291], 
                   '500C60M':[21654, 21655, 21656],
                   '600C60M':[21667, 21668, 21669],
                   '700C60M':[21680, 21681, 21682],
                   'Sc2O3':[36508, 36509]}
plot_range = [4425, 4625, 0, 2.5]
Sc_dictionary = normalize_xafs(extract(element, data_dictionary, transmission_list=['Pure Sc', 'Sc2O3']))

In [None]:
 plot(Sc_dictionary, transmission_list=['Pure Sc', 'Sc2O3'],  offset=0.1, plot_range=plot_range)

# Linear combination fitting

### data_dictionary: {sample_name: [energy, counts]}

In [None]:
PHASE_1 = Sc_dictionary['Pure Sc']
PHASE_2 = Sc_dictionary['Sc2O3']
MYSTERY_PHASE = Sc_dictionary['Pristine']

print('Phase 1 length:', len(PHASE_1[0]))
print('Phase 2 length:', len(PHASE_2[0]))
print('Mystery phase length:', len(MYSTERY_PHASE[0]))

## Define fitting function

In [None]:
def phase1(energy):
    # Return array of Pure Sc interpolated at the energy points where the mystery phase (Pristine) was measured.
    return np.interp(energy, Sc_dictionary['Pure Sc'][0], Sc_dictionary['Pure Sc'][1])

def phase2(energy):
    # Return array of Sc2O3 interpolated at the energy points where the mystery phase (Pristine) was measured.
    return np.interp(energy, Sc_dictionary['Sc2O3'][0], Sc_dictionary['Sc2O3'][1])

def f(energy, fraction):
    return fraction * phase1(energy) + (1 - fraction) * phase2(energy)

## Fitting parameters

In [None]:
mystery_energy = MYSTERY_PHASE[0]
mystery_phase = MYSTERY_PHASE[1]
fitting_range = 20, 150

m = Model(f)  # lmfit Model

# Check that these make sense...
print(m.param_names)
print(m.independent_vars)

# lmfit Parameter, with initial guess set here
params = m.make_params(fraction=0.5)
# params = m.make_params(fraction=dict(value=0.5, min=0.0, max=1.0))

# Fit model to y, given energy (x axis) and params (with initial guess)
sliced_mystery_energy = mystery_energy[fitting_range[0]:fitting_range[1]]
sliced_mystery_phase = mystery_phase[fitting_range[0]:fitting_range[1]]

# Fit
result = m.fit(sliced_mystery_phase, energy=sliced_mystery_energy, params=params)

# Results
print(result.params)
print(result.fit_report())

In [None]:
mystery_energy = MYSTERY_PHASE[0]
mystery_phase = MYSTERY_PHASE[1]
fitting_range = 20, 150

m = Model(f)  # lmfit Model

# Check that these make sense...
print(m.param_names)
print(m.independent_vars)

# lmfit Parameter, with initial guess set here
params = m.make_params(fraction=0.5)
# params = m.make_params(fraction=dict(value=0.5, min=0.0, max=1.0))

# Fit model to y, given energy (x axis) and params (with initial guess)
sliced_mystery_energy = mystery_energy[fitting_range[0]:fitting_range[1]]
sliced_mystery_phase = mystery_phase[fitting_range[0]:fitting_range[1]]

# Fit
result = m.fit(sliced_mystery_phase, energy=sliced_mystery_energy, params=params)

# Results
print(result.params)
print(result.fit_report())

In [None]:
fig, ax = plt.subplots()
ax.plot(Sc_dictionary['Sc2O3'][0], Sc_dictionary['Sc2O3'][1], "-", label="Sc2O3")
ax.plot(Sc_dictionary['Pure Sc'][0], Sc_dictionary['Pure Sc'][1], "-", label="Pure Sc")
ax.plot(Sc_dictionary['Pristine'][0], Sc_dictionary['Pristine'][1], "-", label="Mystery")

ax.plot(sliced_mystery_energy, result.best_fit, "--", label="Best fit", )

ax.set_xlabel("energy")
ax.set_ylabel("counts")

ax.set_xlim(4425, 4625)
ax.set_ylim(-0.1, 2)
ax.legend()
plt.show()