In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import sys, os
sys.path.append('../')
from data_handling import Container,load_data,Normalize,Denormalize
from plot_style.style_prab import load_preset,figsize,cmap_nicify
from plot_style.style_prab import colors as colors_preset
from train_ensemble import train, NN, Loss
from analysis.scan_model import median,quantile
load_preset(scale=1,font_path='../plot_style/font')
import pickle
mymap = cmap_nicify(cmap='YlGnBu_r',idx_white=1,size_white=50)
from scipy.interpolate import griddata
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.patches as patches
import matplotlib


from mpl_toolkits.axes_grid1.inset_locator import inset_axes

def plot(x,y,low,high,ax,color,label=None,pos='left',rot=90):
    

    ax.fill_between(x,low,high,alpha=0.5,color=color)
    p1 = ax.plot(x,y,color=color,alpha=1.)
    ax.yaxis.label.set_color(color)
    ax.tick_params(axis='y', colors=color)
    return p1

def coma(rho,phi,even=True):
    rho, phi = np.meshgrid(rho, phi)
    if even:
        return np.sqrt(8)*(3*rho**3-2*rho)*np.cos(phi)
    else:
        return np.sqrt(8)*(3*rho**3-2*rho)*np.sin(phi)
    
def tilt(rho,phi,even=True):
    rho, phi = np.meshgrid(rho, phi)
    if even:
        return rho*np.cos(phi)
    else:
        return rho*np.sin(phi) 
    
def astigmatism(rho,phi,even=True):
    rho, phi = np.meshgrid(rho, phi)
    if even:
        return np.sqrt(6)*(rho**2)*np.cos(2*phi) 
    else:
        return np.sqrt(6)*(rho**2)*np.sin(2*phi)    

    
def primary_spherical(rho,phi,even=None):
    rho, phi = np.meshgrid(rho, phi)
    return np.sqrt(5)*(6*rho**4 - 6*rho**2+1)


In [None]:
import pickle
from ensemble.Bregressor_ import BaggingRegressor
from ensemble.utils import io

model_path = '../models/' #'../models/rand_pretraining/'
with open(f'{model_path}config.pkl', 'rb') as fp:
    config = pickle.load(fp)
    norm = config['norm']


if isinstance(config['model'],list):
    model = nn.ModuleList(config['model'])

    ensemble = BaggingRegressor(estimator=model,               # estimator is your pytorch model
                                n_estimators=config['estimators'],
                                cuda=False,)
    ensemble.set_criterion(config['loss_fun'])
    io.load(ensemble, model_path)
else:
    model = config['model']
    ensemble = BaggingRegressor_te(estimator=model,               # estimator is your pytorch model
                                n_estimators=config['estimators'],
                                cuda=False,)
    ensemble.set_criterion(config['loss_fun'])
    io_te.load(ensemble, model_path)
    
    
path='../data/dataframe_combined_espec_interpolated_gaia_energy_2022.h5'
trainset, testset, _ = load_data(
config["path"],
config["inputs"],
config["outputs"],
samples=config['samples'],
ratio=config['ratio'],
start_ind=0,
norm=norm,
random=False,
)

trainset = Container(trainset)
testset = Container(testset)

testset.x = testset.x.to('cpu')
testset.y = testset.y.to('cpu')

In [None]:
labels = ['4','6','8']

i = [config['inputs'].index(l) for l in labels]
indices = [[i[0],i[1]],[i[0],i[2]],[i[1],i[2]]]

D = []
D_ = []
x_axes = []
y_axes = []
for i in indices:
    x_ = torch.linspace(-4,4,100)
    y_ = torch.linspace(-4,4,100)
    grid_x, grid_y = torch.meshgrid(x_, y_, indexing='ij')
    x = torch.zeros(100**2,14)
    x[:,i[0]] = grid_y.flatten()
    x[:,i[1]] = grid_x.flatten()
    
    y_pred = median(x,ensemble).detach().numpy()
    y_pred_q_high,y_pred_q_low = quantile(x,ensemble)
    y_pred = Denormalize(y_pred,norm[1])
    y_pred_q_low = Denormalize(y_pred_q_low,norm[1])
    y_pred_q_high = Denormalize(y_pred_q_high,norm[1])
    
    y_pred = y_pred.reshape(100,100,5)
    y_pred_q_low = y_pred_q_low.reshape(100,100,5)
    y_pred_q_high = y_pred_q_high.reshape(100,100,5)
    
    x = Denormalize(x,norm[0])
    x = x.reshape(100,100,14)
    
    D.append(y_pred)
    D_.append(y_pred_q_high-y_pred_q_low)

    x_axes.append((x[0,:,i[0]]-x[0,:,i[0]].mean()).detach().numpy())
    y_axes.append((x[:,0,i[1]]-x[:,0,i[1]].mean()).detach().numpy())

In [None]:

def make_polar_plot(rho,phi,ax,zp=coma,even=True,anchor=(0,0.15,1,1),label=''):
    inset_ax = inset_axes(ax,
                          axes_class = matplotlib.projections.get_projection_class('polar'),                   
                          width="20%", # width = 30% of parent_bbox
                          height=1., # height : 1 inch
                          loc='upper right',
                          bbox_to_anchor=anchor,
                          bbox_transform=ax.transAxes)
    ax.text(0.4, anchor[1]+0.52, label, horizontalalignment='left',
     verticalalignment='center', transform=ax.transAxes)
    
    z = zp(rho,phi,even=even)
    inset_ax.pcolormesh(phi,rho,z.T,edgecolors='face',cmap=mymap)
    inset_ax.set_xticks([])
    inset_ax.set_yticks([]) 



    
fig_w = figsize['inch']['column_width']
fig,ax = plt.subplots(2,2,figsize=(fig_w,fig_w*0.8))


rho = np.linspace(0,1,100)
phi = np.linspace(0,2*np.pi,100)
make_polar_plot(rho,phi,ax[0,1],anchor = (-0.6,0.04-0.05,1,1), even=True,label='Horizontal \ncoma')
make_polar_plot(rho,phi,ax[0,1],anchor = (-0.6,0.35-0.05,1,1), zp=astigmatism,label='Vertical \nastigmatism',even=True)
make_polar_plot(rho,phi,ax[0,1],anchor = (-0.6,-0.27-0.05,1,1), zp=primary_spherical,label='Primary \nspherical')

vlim=[2,18]
#vlim=[0,4]
#vlim = [240,280]

im = ax[0,0].pcolormesh(x_axes[0]*1e3,y_axes[0]*1e3,D[0][:,:,1],cmap=mymap,vmin=vlim[0],vmax=vlim[1])
ax[1,0].pcolormesh(x_axes[1]*1e3,y_axes[1]*1e3,D[1][:,:,1],cmap=mymap,vmin=vlim[0],vmax=vlim[1])
ax[1,1].pcolormesh(x_axes[2]*1e3,y_axes[2]*1e3,D[2][:,:,1],cmap=mymap,vmin=vlim[0],vmax=vlim[1])
ax[0,1].axis('off')

divider = make_axes_locatable(ax[0,1])
cax = divider.append_axes('left', size='7%', pad=0.05)
cb = fig.colorbar(im, cax=cax, orientation='vertical')
cb.ax.tick_params('y', direction='in',length=4)
#cb.set_ticks(ticks)
cb.ax.set_xlabel(r'$\Delta$E (MeV)')
cb.ax.xaxis.set_label_position('top') 

ax[0,0].tick_params('x',
              top=False,      # ticks along the bottom edge are off
              labeltop=False,
              bottom=True,      # ticks along the bottom edge are off
              labelbottom=False,
              direction='in')  
ax[1,1].tick_params('y',
              right=False,      # ticks along the bottom edge are off
              labelright=False,
              left=True,      # ticks along the bottom edge are off
              labelleft=False,
              direction='in')  
for a in ax.flatten():
    a.tick_params('both',direction='in')

    
ax[0,0].set_ylabel(r'Horizontal coma$\,$($\mu$m)')
ax[1,0].set_ylabel(r'Primary spherical$\,$($\mu$m)')
ax[1,0].set_xlabel(r'Vertical astigmatism$\,$($\mu$m)')
ax[1,1].set_xlabel(r'Horizontal coma$\,$($\mu$m)')

#plt.title('OD 100 pretraining 500 epochs, uncertainty')
plt.savefig('zernike_corr.png',bbox_inches='tight')

In [None]:
import pandas as pd


nice_labels = [r'Horizontal tilt',
               r'Vertical tilt',
               'Focus position',
               r'Vertical astigmatism',
               r'Oblique astigmatism',
               r'Horizontal coma',
               r'Vertical coma',
               'Primary spherical']

ts = trainset.x[:,6:].cpu().detach().numpy()
df = pd.DataFrame(ts,columns=nice_labels)
corr = df.corr().round(2)
corr.style.background_gradient(cmap='seismic',vmin=-1,vmax=1).format(precision=2)