In [None]:
import sys
sys.path.append('../')
from data_handling import load_data, Container, Denormalize, PCA_invtransform
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from sklearn.metrics import r2_score
import torch
import matplotlib 
import pickle
from ensemble.Bregressor_ import BaggingRegressor
from ensemble.utils import io
from analysis.scan_model import median,quantile

from plot_style.style_prab import load_preset,figsize,cmap_nicify
from plot_style.style_prab import colors as colors_preset
load_preset(scale=1,font_path='../plot_style/font')
mymap = cmap_nicify(cmap='YlGnBu_r',idx_white=1,size_white=50)

wl = np.load('../data/wavelength_axis.npy')


In [None]:

def load_model(model_path):
    with open(f'{model_path}config.pkl', 'rb') as fp:
        config = pickle.load(fp)
        norm = config['norm']


    if isinstance(config['model'],list):
        model = nn.ModuleList(config['model'])

        ensemble = BaggingRegressor(estimator=model,               
                                    n_estimators=config['estimators'],
                                    cuda=False,)
        ensemble.set_criterion(config['loss_fun'])
        io.load(ensemble, model_path)


    path='../data/dataframe_combined_espec_interpolated_gaia_energy_2022.h5'

    D = load_data(
    config["path"],
    config['inputs'],
    config["outputs"],
    start_ind = 0,
    samples=config['samples'],
    ratio=config['ratio'],
    #norm=norm,
    random=False,
    pca = config['pca'],
    pca_components=config['pca_components'],
    return_index=True,
    )

    index = D[3]
    trainset = D[0]
    testset = D[1]

    trainset = Container(trainset)
    testset = Container(testset)

    testset.x = testset.x.to('cpu')
    testset.y = testset.y.to('cpu')

    # check that the pca and normalization is the same as during training
    for i,n in enumerate(norm[0][0]):
        assert D[2][0][0][i] == n
    
    return testset, ensemble, config, index




In [None]:
model_path = '../models/VD/'
testset,ensemble, config, index = load_model(model_path)
y_pred = median(testset.x,ensemble).detach().numpy()
y_pred_spec = PCA_invtransform(y_pred,config['pca_transform'][1])

In [None]:
model_path = '../models/VD_BPM/'
testset,ensemble, config,_ = load_model(model_path)
y_pred = median(testset.x,ensemble).detach().numpy()
y_pred_spec_BPM = PCA_invtransform(y_pred,config['pca_transform'][1])

In [None]:
model_path = '../models/VD_LP/'
testset,ensemble, config, _ = load_model(model_path)
y_pred = median(testset.x,ensemble).detach().numpy()
y_pred_spec_LP = PCA_invtransform(y_pred,config['pca_transform'][1])

In [None]:
import re

fp = './models/VD/ensamble_net.log-2024_06_05_16_37.log'
loss = []
OD_loss = []
with open(fp) as file:
    for line in file:
        l = re.search('(?<=Loss: )\d+.\d+', line)
        if l:
            loss.append(float(l.group(0)))
plt.plot(np.log(loss),'k',label='LP')



In [None]:
df = pd.read_hdf(config['path'])[config['start_ind']:config['start_ind']+config['samples']]
y_wl = df['xspec_1st_order_wavelangth'][index[1]].values
y_bw = df['xspec_1st_order_width'][index[1]].values
df = df['xspec_spectrum']
y = np.array(df[index[1]].to_list())

In [None]:
from matplotlib.gridspec import GridSpec

fig_w = figsize['inch']['column_width']
fig = plt.figure(figsize=(fig_w,fig_w*1.3),constrained_layout=True)
ymax = 11706 #23337.211733071283
colors = [mymap(i) for i in np.linspace(0.2,0.6,3)]


gs = fig.add_gridspec(4, 1, height_ratios=[1,1,1,1],wspace=0.1,hspace=0.)

ax = [fig.add_subplot(gs[0])]
ax. extend([fig.add_subplot(gs[1]),
           fig.add_subplot(gs[2]),
           fig.add_subplot(gs[3])])
#gs0 = gs[0].subgridspec(1, 3, wspace=0, hspace=0)
#gs1 = gs[1].subgridspec(2, 1, wspace=0)

nr = 50
shot = np.arange(0,nr,1)
im = ax[1].pcolormesh(shot,wl[2**10//2+12:-300],y[:nr,2**10//2+12:-300].T/ymax,vmin=0,vmax=20000/ymax,cmap=mymap)
cax0 = plt.colorbar(im, orientation='vertical',location='right',pad=0,aspect=15)
im = ax[2].pcolormesh(shot,wl[2**10//2+12:-300],y_pred_spec[:nr,2**10//2+12:-300].T/ymax,vmin=0,vmax=20000/ymax,cmap=mymap)
cax1 = plt.colorbar(im, orientation='vertical',location='right',pad=0,aspect=15)
im = ax[3].pcolormesh(shot,wl[2**10//2+12:-300],np.abs(y[:nr,2**10//2+12:-300].T-y_pred_spec[:nr,2**10//2+12:-300].T)/ymax,vmin=0,vmax=2000/ymax,cmap=mymap)
cax2 = plt.colorbar(im, orientation='vertical',location='right',pad=0,aspect=15)

cax = [cax0,cax1,cax2]
for c in cax:
    c.ax.tick_params('y',direction='in')
cax[1].ax.set_ylabel('Intensty (a.u)')
cax[1].ax.yaxis.set_label_position('right')

for i in range(2): cax[i].ax.set_yticks([0.,1.,])
cax[2].ax.set_yticks([0.,0.1,])



i = 40
ax[0].plot(wl[2**10//2+12:-300],y_pred_spec_LP[i][2**10//2+12:-300]/ymax,color=colors[0],label = f'M1 (laser)',lw=0.7)
ax[0].plot(wl[2**10//2+12:-300],y_pred_spec_BPM[i][2**10//2+12:-300]/ymax,color=colors[1],label = f'M2 (laser+BPM)',lw=0.7)
str_ =  'M3 (laser+BPM+'+r'e$^{-}$'+'-spectrum)'
ax[0].plot(wl[2**10//2+12:-300],y_pred_spec[i][2**10//2+12:-300]/ymax,color=colors[2],label = str_,lw=0.7)
ax[0].plot(wl[2**10//2+12:-300],y[i][2**10//2+12:-300]/ymax,'k--',label = r'$measured$',lw=0.7)
ax[0].legend(frameon = False,ncol=1,handlelength=1.2,labelspacing=0.5,loc='upper left',fontsize=8,bbox_to_anchor=(0.115, 1.05))


ax[0].set_xlim(wl[2**10//2+12],20)

for i in range(4):
    ax[i].tick_params('both',direction='in')
    
ax[0].tick_params('y',direction='in',left=True,right=False,labelleft=True)
ax[1].tick_params('y',direction='in',left=True,right=False,labelleft=True)
ax[2].tick_params('y',direction='in',left=True,right=False,labelleft=True)
for i in range(3):
    #ax[i].tick_params('x',direction='in',bottom=False,top=True,labelbottom=False,labeltop=True)
    ax[i].set_xticks([3,10,17])
    ax[i].tick_params('x',direction='in',bottom=True,top=False,labelbottom=False)

ax[0].tick_params('x',direction='in',bottom=True,top=False,labelbottom=True)
ax[0].set_xlabel('Wavelength (nm)')
ax[3].tick_params('x',direction='in',bottom=True,top=False,labelbottom=True)
ax[3].set_xlabel('Shots')

ax[0].set_ylabel('Intensty (a.u)')
ax[2].set_ylabel('Wavelength (nm)')
t = ['(a)','(b)','(c)','(d)','(e)']

for i in range(4):
    ax[i].text(0.96, 0.86,t[i],weight='normal',
                       horizontalalignment='center',
                       verticalalignment='center',
                       transform = ax[i].transAxes)

titles = ['Prediction for models with different input features','Measured spectrum','Predicted spectrum (M3)','Absolute error (M3)']
for i in range(1,4):
    ax[i].text(0.5, 0.88,titles[i],weight='normal',
                       horizontalalignment='center',
                       verticalalignment='center',
                       transform = ax[i].transAxes)

plt.savefig('./xray_prediction.png',bbox_inches='tight')

In [None]:
from matplotlib.gridspec import GridSpec

fig_w = figsize['inch']['column_width']
fig = plt.figure(figsize=(fig_w,fig_w*1.),constrained_layout=False)
ymax = 11706 #23337.211733071283
colors = [mymap(i) for i in np.linspace(0.2,0.6,3)]


gs = fig.add_gridspec(2, 1, height_ratios=[1,2])

gs0 = fig.add_gridspec(nrows=1, ncols=3, bottom=0.6,
                        wspace=0.1,hspace=0)


ax = [fig.add_subplot(gs0[0]),
      fig.add_subplot(gs0[1]),
      fig.add_subplot(gs0[2]),]

gs1 = fig.add_gridspec(nrows=2, ncols=1,
                        wspace=0.,hspace=0.1, top=0.55)

ax.extend([fig.add_subplot(gs1[0]),
           fig.add_subplot(gs1[1])])
#gs0 = gs[0].subgridspec(1, 3, wspace=0, hspace=0)
#gs1 = gs[1].subgridspec(2, 1, wspace=0)

shot = np.arange(0,50,1)
im = ax[0].pcolormesh(wl[2**10//2+12:-300],shot,y[:50,2**10//2+12:-300]/ymax,vmin=0,vmax=20000/ymax,cmap=mymap)
cax0 = plt.colorbar(im, orientation='horizontal',location='top',pad=0)
im = ax[1].pcolormesh(wl[2**10//2+12:-300],shot,y_pred_spec[:50,2**10//2+12:-300]/ymax,vmin=0,vmax=20000/ymax,cmap=mymap)
cax1 = plt.colorbar(im, orientation='horizontal',location='top',pad=0)
im = ax[2].pcolormesh(wl[2**10//2+12:-300],shot,np.abs(y[:50,2**10//2+12:-300]-y_pred_spec[:50,2**10//2+12:-300])/ymax,vmin=0,vmax=2000/ymax,cmap=mymap)
cax2 = plt.colorbar(im, orientation='horizontal',location='top',pad=0)

cax = [cax0,cax1,cax2]
for c in cax:
    c.ax.tick_params('x',direction='in')
cax[1].ax.set_xlabel('Relative intensty (a.u)')
cax[1].ax.xaxis.set_label_position('top')

for i in range(2): cax[i].ax.set_xticks([0.,1.,])
cax[2].ax.set_xticks([0.,0.1,])



i = 40
ax[3].plot(wl[2**10//2+12:-300],y_pred_spec_LP[i][2**10//2+12:-300]/ymax,color=colors[0],label = f'(1)',lw=0.7)
ax[3].plot(wl[2**10//2+12:-300],y_pred_spec_BPM[i][2**10//2+12:-300]/ymax,color=colors[1],label = f'(2)',lw=0.7)
ax[3].plot(wl[2**10//2+12:-300],y_pred_spec[i][2**10//2+12:-300]/ymax,color=colors[2],label = f'(3)',lw=0.7)
ax[3].plot(wl[2**10//2+12:-300],y[i][2**10//2+12:-300]/ymax,'k--',label = r'$y_{true}$',lw=0.7)
ax[3].legend(frameon = False,ncol=4,handlelength=1.2,labelspacing=0.5,loc='upper center')

prec = round((q*2-1)*100)
ax[4].fill_between(wl[2**10//2+12:-300],high_LP[2**10//2+12:-300]/ymax,low_LP[2**10//2+12:-300]/ymax,color=colors[0],label = f'{prec}% QI')
ax[4].fill_between(wl[2**10//2+12:-300],high_BPM[2**10//2+12:-300]/ymax,low_BPM[2**10//2+12:-300]/ymax,color=colors[1],label = f'{prec}% QI')
ax[4].fill_between(wl[2**10//2+12:-300],high[2**10//2+12:-300]/ymax,low[2**10//2+12:-300]/ymax,color=colors[2],label = f'{prec}% QI')

for i in range(5):
    ax[i].set_xlim(wl[2**10//2+12],20)
    ax[i].tick_params('both',direction='in')
    
ax[0].tick_params('y',direction='in',left=True,right=True,labelleft=True)
ax[1].tick_params('y',direction='in',left=True,right=True,labelleft=False)
ax[2].tick_params('y',direction='in',left=True,right=True,labelleft=False)
for i in range(3):
    #ax[i].tick_params('x',direction='in',bottom=False,top=True,labelbottom=False,labeltop=True)
    ax[i].set_xticks([3,10,17])

ax[3].tick_params('x',direction='in',bottom=True,top=False,labelbottom=False)
ax[4].tick_params('x',direction='in',bottom=True,top=True,labelbottom=True)


ax[0].set_ylabel('Shots')
ax[3].set_ylabel('Relative \n intensity (a.u)')
ax[4].set_ylabel('90% QI of relative \n residuals (a.u)')
ax[4].set_xlabel('Wavelength (nm)')

t = ['(a)','(b)','(c)','(d)','(e)']

for i in range(3):
    ax[i].text(0.88, 0.9,t[i],weight='normal',
                       horizontalalignment='center',
                       verticalalignment='center',
                       transform = ax[i].transAxes)
for i in range(3,5):
    ax[i].text(0.96, 0.88,t[i],weight='normal',
                       horizontalalignment='center',
                       verticalalignment='center',
                       transform = ax[i].transAxes)


fig.tight_layout()
plt.savefig('./xray_prediction.png',bbox_inches='tight')