In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import sys, os
sys.path.append('../')
from data_handling import Container,load_data,Normalize,Denormalize
from plot_style.style_prab import load_preset,figsize,cmap_nicify
from plot_style.style_prab import colors as colors_preset
from train_ensemble import train, NN, Loss
from analysis.scan_model import median,quantile
load_preset(scale=1,font_path='../plot_style/font')
import pickle
mymap = cmap_nicify(cmap='YlGnBu_r',idx_white=1,size_white=50)
from scipy.interpolate import griddata
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.patches as patches


def plot(x,y,ax,color,label=None,datalabel='',pos='left',rot=0):
    
    mean = y[0]
    above = y[1]
    below = y[2]
    #c2 = (color[0]+0.1,color[1]+0.1,color[2]+0.1,color[3])
    ax.fill_between(x,above,below,alpha=0.5,color=color,linewidth=0)
    p1 = ax.plot(x,mean,color=color,alpha=1,label=datalabel)
    yl = ax.set_ylabel(label)
    ax.yaxis.set_label_position("right")
    yl.set_rotation(rot)

    return p1

In [None]:
import pickle
from ensemble.Bregressor_ import BaggingRegressor
from ensemble.utils import io

model_path = '../models/'
with open(f'{model_path}config.pkl', 'rb') as fp:
    config = pickle.load(fp)
    norm = config['norm']


if isinstance(config['model'],list):
    model = nn.ModuleList(config['model'])

    ensemble = BaggingRegressor(estimator=model,               
                                n_estimators=config['estimators'],
                                cuda=False,)
    ensemble.set_criterion(config['loss_fun'])
    io.load(ensemble, model_path)
else:
    model = config['model']
    ensemble = BaggingRegressor_te(estimator=model,               
                                n_estimators=config['estimators'],
                                cuda=False,)
    ensemble.set_criterion(config['loss_fun'])
    io_te.load(ensemble, model_path)
    
    
path='../data/dataframe_combined_espec_interpolated_gaia_energy_2022.h5'
trainset, testset, _ = load_data(
config["path"],
config["inputs"],
config["outputs"],
samples=config['samples'],
ratio=config['ratio'],
start_ind=0,
norm=norm,
random=False,
)

trainset = Container(trainset)
testset = Container(testset)

testset.x = testset.x.to('cpu')
testset.y = testset.y.to('cpu')
ydata = Denormalize(testset.y.detach().numpy(),norm[1])


In [None]:
x_ = torch.linspace(-1.5,1.5,100)
x = torch.zeros(100,14)
x[:,8] = x_

z = median(x,ensemble).detach().numpy()
z = Denormalize(z,norm[1])
z = z.reshape(100,5)

c_zfoc = ((162*1e3)/6.9)
x = Denormalize(x,norm[0])
x = x.reshape(100,14)[:,8].detach().numpy()
x = (x-x[50])*c_zfoc

In [None]:
#'../data/waterfall_data_xrays.pkl', 'rb') as handle:
with open('../data/waterfall_data_xray_normalized.pkl','rb') as handle:
    wfp = pickle.load(handle)

In [None]:
import numpy as np
from matplotlib import colors
from matplotlib import pyplot as plt

cs = [colors_preset[-2],'#175381','#710000',colors_preset[0]]
cmap_divergent = colors.LinearSegmentedColormap.from_list("cmap_name", cs)

gradient = np.linspace(-1.0, 1.0, 256)
gradient = np.vstack((gradient, gradient))

plt.imshow(gradient, aspect="auto", cmap=cmap_divergent)
plt.show()

In [None]:
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes


fig_w = figsize['inch']['column_width']
#fig,axs = plt.subplots(2,2, figsize=(fig_w,fig_w))

fig = plt.figure(constrained_layout=True,figsize=(fig_w,fig_w*0.7))


gs = gridspec.GridSpec(
    2, 2, figure=fig,
    height_ratios=[1,1],width_ratios=[1,1.5],hspace=0.0,wspace=0.0) 

ax = [fig.add_subplot(gs[0, 1]),
      #fig.add_subplot(gs[2:4, 0]),
      fig.add_subplot(gs[0, 0]),
      fig.add_subplot(gs[1, 0])]
axs = [fig.add_subplot(gs[1, 1])]


#ax=[axs[1,0],axs[0,0],axs[0,1]]
color = colors_preset[2]

#z[10:-10,0] = z[10:-10,0] + 20

extent = [0,17,235,285]
l1 = ax[0].plot(z[10:-10:2,-2].flatten(),z[10:-10:2,0].flatten(),'--',lw=0.7,mew=0,ms=2,color=color,label='Explanatory model')
im = ax[0].pcolormesh(wfp['wl'],wfp['E'],wfp['wfp'],cmap=mymap,vmin=0,vmax=0.2)



ax[0].set_xlim(extent[0:2])
ax[0].set_ylim(extent[-2:])

axins = inset_axes(ax[0],
                    width="5%",  
                    height="100%",
                    loc='right',
                    borderpad=0,
                   )
cbar = plt.colorbar(im, cax=axins, orientation="vertical")

cbar.ax.tick_params(axis='y',which='both',labelright=True,right=True,direction='in')
cbar.ax.set_ylabel('Normalized \n intensity')
cbar.ax.set_yticks([0.0,0.1,0.2])


extent = [8,15,235,285]
ax[1].plot(z[10:-10,-2].flatten(),z[10:-10,0].flatten(),'--',lw=0.7,mew=0,ms=2,color=color,label='Explanatory model')
ax[1].hexbin(ydata[:,-2].flatten(),ydata[:,0].flatten(),extent=extent,bins=15,cmap = 'bone_r',
             lw=0.05,gridsize=[50,int(50*2./4)]);


ax[1].set_xlim(extent[0:2])
ax[1].set_ylim(extent[-2:])
ax[1].set_ylabel('E (MeV)')
ax[1].set_xlabel(r'$\lambda_{xray}$ (nm)')


extent = [0.005,0.07,0.1,0.25]#[1,20,1.4,3]
ax[2].plot(z[10:-10,1].flatten()/z[10:-10,0].flatten(),z[10:-10,-1].flatten()/z[10:-10,-2].flatten(),'--',lw=0.7,mew=0,ms=2,color=color,label='Infered model')
ax[2].hexbin(ydata[:,1].flatten()/ydata[:,0].flatten(),ydata[:,-1].flatten()/ydata[:,-2].flatten(),extent=extent,bins=15,
             cmap = 'bone_r',lw=0.05,gridsize=[50,int(50*2./4)]);
#ax[1].set_xlim(extent[0:2])
#ax[1].set_ylim(extent[-2:])
ax[2].set_xlabel(r'$\Delta$E/E ')
ax[2].set_ylabel(r'$\Delta\lambda_{xray}/\lambda_{xray}$')
#ax[2].yaxis.tick_right()
#ax[2].yaxis.set_label_position("right")


ax[0].sharey(ax[1])

wl_u = 5e6 # nm
K = 0.29 #0.29 # unitless
mec2 = 0.511 #MeV 

def xray_wl(E):
    return (wl_u*mec2**2/(2*(E)**2))*(1+K**2/2) #+ (np.pi*2e-7)*(E)**2/(mec2)**2)




E = np.linspace(240,280)
xwl = xray_wl(E)
l2 = ax[0].plot(xwl,E,'k',lw=0.5,label='Undulator eq.')


dE = np.linspace(0.005,0.05)
#ax[2].plot(dE,dE*2/1.5+0.15,'k',label='Undulator eq.')
#ax[2].text(0.5, 0.95, f'slope={2/1.2:0.3}', transform=ax[2].transAxes,verticalalignment='top')

ax[0].text(0.06, 0.9, r'(b)', transform=ax[0].transAxes,verticalalignment='top')
ax[0].tick_params(axis='x',which='both',bottom=False,top=True,labeltop=True,labelbottom=False,direction='in')
ax[0].tick_params(axis='y',which='both',left=True,right=False,labelleft=False,labelright=False,direction='in')
ax[0].set_xlabel(r'$\lambda_{xray}$ (nm)')
ax[0].xaxis.set_label_position("top")

ax[1].tick_params(axis='x',which='both',bottom=False,top=True,labeltop=True,labelbottom=False,direction='in')
ax[1].xaxis.set_label_position("top")
ax[1].set_xticks([9,12,14])

for i in range(3):
    ax[i].tick_params(axis='x',which='both',direction='in')
    ax[i].tick_params(axis='y',which='both',direction='in')

#ax[1].text(0.03, 0.95, r'(c)', transform=ax[1].transAxes,verticalalignment='top')
ax[1].text(0.06, 0.9, r'(a)', transform=ax[1].transAxes,verticalalignment='top')
ax[2].text(0.06, 0.9, r'(c)', transform=ax[2].transAxes,verticalalignment='top')


ax[1].annotate("", xy=(10,285), xytext=(10.9,275),
            arrowprops=dict(arrowstyle="->",linewidth=0.5))
ax[1].annotate("", xytext=(13.3,238), xy=(13,250),
            arrowprops=dict(arrowstyle="->",linewidth=0.5))

ax[2].annotate("", xy=(0.04,0.205), xytext=(0.05,0.22),
            arrowprops=dict(arrowstyle="->",linewidth=0.5))
ax[2].annotate("", xytext=(0.06,0.184), xy=(0.07,0.190),
            arrowprops=dict(arrowstyle="->",linewidth=0.5))
print(l1[0])
fig.legend([l1[0],l2[0]], ['Explanatory model', 'Undulator eq'], frameon=False,ncols=2, bbox_to_anchor=(0, 0),loc = 'upper left')

#plt.subplots_adjust(wspace=0.05,hspace=0.05)
#plt.savefig('figures/corr_electron-xrays.png',dpi=300,)#bbox='tight')


from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize



emin = 262.433290328439
elim = (262-20,262+20)
E = np.flip(np.linspace(elim[0],elim[1],3))
mad = []
for e in E:
    ind = np.argwhere(wfp['E']>e)[0]
    d = wfp['wfp'][ind][0][40:]
    m = np.argmax(d)
    d = d/d[m]
    x = wfp['wl'][40:]
    c = (e-elim[0])/(elim[1]-elim[0])
    color = cmap_divergent(c)
    mad.append(wfp['mad'][ind][0])
    #if e < 262:

    ax[0].hlines(e,x[m]-4,x[m],linestyles='--', color=color,lw=0.5)
    ax[0].vlines([x[m]-4,x[m]],e-2,e+2,linestyles='-', color=color,lw=0.5)


    if e == E[0]:
        axs[0].plot(x-x[m],d,color=color,alpha=1,linewidth=0.5)
    elif e == E[-1]:
        axs[0].plot(x-x[m],d,color=color,alpha=1,linewidth=0.5,zorder=100)
    else:
        axs[0].plot(x-x[m],d,color=color,alpha=1,linewidth=0.5,zorder=0)


#cbar.set_ticks(E_ticks)
#cbar.set_ticklabels(ticks)

axs[0].text(0.06, 0.9, r'(d)', transform=axs[0].transAxes,verticalalignment='top')

axs[0].set_xlim(-4,0)
axs[0].set_xticks([-4,-2,0,]) 
axs[0].set_yticks([0.0,0.4,0.8,]) 

#ax[1].set_xlim(-5,7.5)
axs[0].tick_params(axis='x',which='both',bottom=True,top=False,labeltop=False,labelbottom=True,direction='in')
#ax[1].tick_params(axis='x',which='both',bottom=True,top=False,labelbottom=True,direction='in')
#ax[1].tick_params(axis='y',which='both',right=True,left=False,labelright=True,direction='in')
axs[0].tick_params(axis='y',which='both',right=True,left=False,labelright=True,labelleft=False,direction='in')
axs[0].yaxis.set_label_position("right")
#ax[1].yaxis.set_label_position("right")
#ax[1].set_ylabel('Normalized intensity (a.u.)')
axs[0].set_ylabel('Normalized \n intensity')
axs[0].set_xlabel(r'$\lambda_{xray} - \lambda_{xray,max}$ (nm)')


plt.savefig('xray_correlations.png',dpi=300,bbox_inches='tight')


In [None]:
from scipy.odr import *


def fun(B, x):
    return B[0]*x + B[1]

def tls(x,y):
    linear = Model(fun)
    sx = x.std()
    sy = y.std()
    mydata = Data(x, y, wd=1./sx**2, we=1./sy**2)
    myodr = ODR(mydata, linear, beta0=[0., 1.])
    myoutput = myodr.run()

    return myoutput.beta

p = tls(ydata[:,0].flatten(),ydata[:,2].flatten())
p

In [None]:
fig,ax = plt.subplots(1,2,figsize=(fig_w*1.3,fig_w*0.5))

ax[0].hist2d(ydata[:,0],ydata[:,2],cmap='bone_r',bins=100);
p = np.polyfit(ydata[:,0],ydata[:,2],1)
p = tls(ydata[:,0].flatten(),ydata[:,2].flatten())
x = np.linspace(230,290)
ax[0].plot(x,p[0]*x + p[1])
ax[0].text(0.3, 0.85, f'slope={p[0]:0.3}', transform=ax[0].transAxes,verticalalignment='top')
ax[0].text(0.3, 0.75, f'm={p[1]:0.3}', transform=ax[0].transAxes,verticalalignment='top')


f = ydata[:,0]<265
ax[1].hist2d(ydata[:,0],ydata[:,1],cmap='bone_r',bins=100);
p = tls(ydata[f,0].flatten(),ydata[f,1].flatten())
x = np.linspace(230,290)
ax[1].plot(x,p[0]*x + p[1])
ax[1].text(0.3, 0.85, f'slope={p[0]:0.3}', transform=ax[1].transAxes,verticalalignment='top')
ax[1].text(0.3, 0.75, f'm={p[1]:0.3}', transform=ax[1].transAxes,verticalalignment='top')

f = ydata[:,0]>265
ax[1].hist2d(ydata[:,0],ydata[:,1],cmap='bone_r',bins=100);
p = tls(ydata[f,0].flatten(),ydata[f,1].flatten())
x = np.linspace(230,290)
ax[1].plot(x,p[0]*x + p[1])
ax[1].text(0.5, 0.45, f'slope={p[0]:0.3}', transform=ax[1].transAxes,verticalalignment='top')
ax[1].text(0.5, 0.35, f'm={p[1]:0.3}', transform=ax[1].transAxes,verticalalignment='top')

ax[0].set_ylabel('Charge')
ax[1].set_ylabel('Energy spread')
ax[0].set_xlabel('Energy')
ax[1].set_xlabel('Energy')
plt.savefig('correlations.png',dpi=300,)