<div class="alert alert-warning">
    
<b>Disclaimer:</b> The main objective of the <i>Jupyter</i> notebooks is to show how to use the models of the <i>QENS library</i> by
    
- building a fitting model: composition of models, convolution with a resolution function
- setting and run the fit   
- extracting and displaying information about the results

These steps have a minimizer-dependent syntax. That's one of the reasons why different minimizers have been used in the notebooks provided as examples.  
But, the initial guessed parameters might not be optimal, resulting in a poor fit of the reference data.
</div>

# Example: Jump sites log norm diffusion model fitted with scipy

## Table of Contents

- [Introduction](#Introduction)
- [Importing libraries](#Importing-libraries)
- [Plot of the fitting model](#Plot-of-the-fitting-model)
- [Creating reference data](#Creating-reference-data)
- [Setting and fitting](#Setting-and-fitting)
- [Plotting the results](#Plotting-the-results)

[Top](#Table-of-Contents)

## Introduction

<div class="alert alert-info">
    
The objective of this notebook is to show how to use one of the models of 
the <a href="https://github.com/QENSlibrary/QENSmodels">QENSlibrary</a>, <b>sqwJumpSitesLogNormDist</b>, to perform some fits. 

<a href="https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html">scipy.optimize.curve_fit</a> is used for fitting.
</div>

### Physical units

For information about unit conversion, please refer to the jupyter notebook called `Convert_units.ipynb` in the `tools` folder.

The dictionary of units defined in the cell below specify the units of the refined parameters adapted to the convention used in the experimental datafile.

In [None]:
# Units of parameters for selected QENS model and experimental data

dict_physical_units = {'scale': "unit_of_signal.ps", 'center': "1/ps", 'radius': 'Angstrom', 'resTime': 'ps'}

[Top](#Table-of-Contents)

## Importing libraries

In [None]:
# import python modules for plotting, fitting
from __future__ import print_function
import numpy as np

import matplotlib.pyplot as plt
from scipy.optimize import curve_fit

# for interactivity (plots, buttons...)
from pandas import DataFrame
import panel.widgets as pnw
import panel as pn
pn.extension()

#%matplotlib notebook

In [None]:
# install QENSmodels (if not already installed)
import pkgutil
import sys

if not pkgutil.find_loader("QENSmodels"):
    buttonY = pnw.Button(name='Yes', button_type='success')
    buttonN = pnw.Button(name='No', button_type='danger')
    choice_installation = panel.Column("Do you want to install the QENSmodels' library?", panel.Row(buttonY, buttonN))
    display(choice_installation)

In [None]:
if not pkgutil.find_loader("QENSmodels"):
    if buttonY.clicks>0:
        !{sys.executable} -m pip install git+https://github.com/QENSlibrary/QENSmodels#egg=QENSmodels
    elif buttonN.clicks>0:
        print("You will not be able to run some of the remaining parts of this notebook")

In [None]:
# import model from QENS library
import QENSmodels

[Top](#Table-of-Contents)

## Plot of the fitting model

The widget below shows the peak shape function imported from QENSmodels where the function's parameters can be varied.

In [None]:
ini_values = {'q': 1., 'scale': 5., 'center': 5., 'Nsites': 3, 'radius': 1., 'resTime':1., 'sigma': 1.}

# Define function to plot
def mplplot(df, **kwargs):
    fig = df.plot(legend=False).get_figure()
    plt.grid()
    plt.ylabel('jump sites log normal distribution')
    plt.xlabel('x')
    plt.close(fig)
    return fig

def jump_log_norm(q=1., scale=1., center=1., Nsites=3, radius=1., resTime=1., sigma=1., view_fn=mplplot):
    xs = np.linspace(-10,10,100)
    ys = QENSmodels.sqwJumpSitesLogNormDist(xs, q, scale, center, Nsites, radius, resTime, sigma)
    df = DataFrame(dict(y=ys), index=xs)
    return view_fn(df, q=q, scale=scale, center=center, Nsites=Nsites, radius=radius, resTime=resTime, sigma=sigma)

# Define sliders and actions on plot
slider_q = pnw.FloatSlider(name='q', value=ini_values['q'], start=.1, end=10)
slider_scale  = pnw.FloatSlider(name='scale', value=ini_values['scale'], start=1., end=10)
slider_center = pnw.FloatSlider(name='center', value=ini_values['center'], start=0, end=10)
slider_Nsites = pnw.IntSlider(name='Nsites', value=ini_values['Nsites'], start=2, end=10)
slider_radius = pnw.FloatSlider(name='radius', value=ini_values['radius'], start=1, end=10)
slider_resTime = pnw.FloatSlider(name='resTime', value=ini_values['resTime'], start=1, end=10)
slider_sigma = pnw.FloatSlider(name='sigma', value=ini_values['sigma'], start=0.1, end=10)

def update(event):
    jump_log_norm_panel[0] = jump_log_norm(slider_q.value, 
                                           slider_scale.value, 
                                           slider_center.value, 
                                           slider_Nsites.value, 
                                           slider_radius.value, 
                                           slider_resTime.value, 
                                           slider_sigma.value)
    
slider_q.param.watch(update, 'value')
slider_scale.param.watch(update, 'value')
slider_center.param.watch(update, 'value')
slider_Nsites.param.watch(update, 'value')
slider_radius.param.watch(update, 'value')
slider_resTime.param.watch(update, 'value')
slider_sigma.param.watch(update, 'value')

# Define reset button
reset_button = pnw.Button(name='Reset')

def on_click(event):
    """Reset the interactive plots to inital values."""
    print("reset values")
    slider_q.value = ini_values['q']
    slider_scale.value = ini_values['scale']
    slider_center.value = ini_values['center']
    slider_Nsites.value = ini_values['Nsites']
    slider_radius.value = ini_values['radius']
    slider_resTime.value = ini_values['resTime']
    slider_sigma.value = ini_values['sigma']

reset_button.param.watch(on_click, 'clicks')

# Define layout: title, plot, sliders and reset button
widgets = pn.Column("#### sqwJumpSitesLogNormDist", slider_q, slider_scale, slider_center, slider_Nsites, slider_radius, slider_resTime, slider_sigma, reset_button)
jump_log_norm_panel = pn.Row(jump_log_norm(slider_q.value, slider_scale.value, slider_center.value, slider_Nsites.value, slider_radius.value, slider_resTime.value, slider_sigma.value), widgets)

jump_log_norm_panel

[Top](#Table-of-Contents)

## Creating reference data

**Input:** the reference data for this simple example correspond to sqwJumpSitesLogNormDist with added noise.

The fit is performed using `scipy.optimize.curve_fit`. <br> The example is based on implementations from https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html

In [None]:
# Creation of reference data
xx = np.linspace(-10,10,100)
sqw_jump_sites_noisy = QENSmodels.sqwJumpSitesLogNormDist(xx, 0.89, 1, 0.3, 5, 2, 0.45, 0.25)*(1. + 0.04*np.random.normal(0,1,100)) + 0.02*np.random.normal(0,1,100)

fig0, ax0 = plt.subplots()
ax0.plot(xx, sqw_jump_sites_noisy, label='reference data')
ax0.set_xlabel('x')
ax0.grid()
ax0.legend();

[Top](#Table-of-Contents)

## Setting and fitting

In [None]:
# From https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.curve_fit.html
# perform fit with initial guesses scale=0.95, center=0.2, radius=2, resTime=0.45, sigma=0.25
# Nsites=5 and q =0.89 are fixed

def func_to_fit(xx, scale, center, radius, resTime, sigma):
    return QENSmodels.sqwJumpSitesLogNormDist(xx, 0.89, scale, center, 5, radius, resTime, sigma)

fig0, ax0 = plt.subplots()
ax0.plot(xx, sqw_jump_sites_noisy, 'b-', label='reference data')
ax0.plot(xx, QENSmodels.sqwJumpSitesLogNormDist(xx, 0.89, scale=0.95, center=0.2, Nsites=5, radius=2, resTime=0.45, sigma=0.25), 'r-', label='model with initial guesses')
ax0.set_xlabel('x')
ax0.grid()
ax0.legend(bbox_to_anchor=(0.6, 1), loc=2, borderaxespad=0.);

In [None]:
popt, pcov = curve_fit(func_to_fit, xx, sqw_jump_sites_noisy, p0=[0.95, 0.2, 2, 0.45, 0.25], 
                       bounds=((0.1, -2, 0.1, 0.1, 0.1), (5, 2, 5, 11, 1)))

[Top](#Table-of-Contents)

## Plotting the results

In [None]:
# Calculation of the errors on the refined parameters:
perr = np.sqrt(np.diag(pcov))

print('Values of refined parameters:')
print('scale:'  , popt[0],'+/-', perr[0], dict_physical_units['scale'])
print('center :', popt[1],'+/-', perr[1], dict_physical_units['center'])
print('radius'  , popt[2],'+/-', perr[2], dict_physical_units['radius'])
print('resTime' , popt[3],'+/-', perr[3], dict_physical_units['resTime'])
print('sigma'   , popt[4],'+/-', perr[4])

In [None]:
# Comparison of reference data with fitting result
fig1, ax1 = plt.subplots()
ax1.plot(xx, sqw_jump_sites_noisy, 'b-', label='reference data')
ax1.plot(xx, func_to_fit(xx, *popt), 'g--', label='fit: %5.3f, %5.3f, %5.3f, %5.3f, %5.3f' % tuple(popt))
ax1.legend(bbox_to_anchor=(0., 1.15), loc='upper left', borderaxespad=0.)
ax1.set_xlabel('x')
ax1.grid();