In [None]:
import os
import numpy as np
import pickle
import shutil 
import pImpactR as impact
import matplotlib.pyplot as plt
import time

# Simulation Setting

In [None]:
npt = 2**12
print(npt)
Espread = 1.0e-3
pipe_radius = 0.04

# Lattice

In [None]:
beam,lattice=impact.readInputFile('IOTA.Chris.in')
beam.nCore_y = 1
beam.nCore_z = 1

In [None]:
lattice = [ item for item in lattice if not item.type == 'write_raw_ptcl' ]
lattice = [ item for item in lattice if not item.type == '-8' ]
lattice[0].turns = 1
for i in range(len(lattice)):
    if 'length' in lattice[i].keys():
        lattice[i].n_sckick = int(np.ceil(lattice[i].length*80))
    if 'pipe_radius' in lattice[i].keys() and lattice[i].type!='dipole':
        lattice[i].pipe_radius = pipe_radius

elemWrite = impact.getElem('write_raw_ptcl')
elemWrite.file_id   = 999
elemWrite.format_id = 2
lattice.append(elemWrite)

In [None]:
NL_nu = lattice[1].tune_advance
NL_L  = lattice[1].length
NL_c  = 0.01
alfx = np.tan(np.pi*NL_nu)
betx = NL_L/np.sin(2.0*np.pi*NL_nu)

In [None]:
arc = lattice.copy()
arc.pop(1)

In [None]:
ke = beam.kinetic_energy
freq = beam.frequency
mass = beam.mass
gam0 = 1.0+beam.kinetic_energy/mass
bet0 = np.sqrt(1.0-1.0/gam0**2)
bg0  = np.sqrt(gam0**2-1.0)
q_m  = beam.multi_charge.q_m[0]

In [None]:
def Impact2norm(data_in,bg0,bet0,sign=1):
    data=data_in.copy()
    data[:,5] = -(np.sqrt(1.0-2.0*data[:,5]/mass/(bet0*bg0)+(data[:,5]/mass)**2/bg0**2)-1.0)
    data[:,1] = (data[:,0]*alfx*sign/np.sqrt(betx) + data[:,1]/(1+data[:,5])*np.sqrt(betx))/NL_c
    data[:,3] = (data[:,2]*alfx*sign/np.sqrt(betx) + data[:,3]/(1+data[:,5])*np.sqrt(betx))/NL_c
    data[:,0] = data[:,0]/(np.sqrt(betx)*NL_c)
    data[:,2] = data[:,2]/(np.sqrt(betx)*NL_c)
    return data
    
def norm2Impact(data_in,bg0,bet0,sign=1):
    data=data_in.copy()
    data[:,1] = (-data[:,0]*alfx*sign + data[:,1])*NL_c/np.sqrt(betx)*(1+data[:,5])
    data[:,3] = (-data[:,2]*alfx*sign + data[:,3])*NL_c/np.sqrt(betx)*(1+data[:,5])
    data[:,0] = data[:,0]*np.sqrt(betx)*NL_c
    data[:,2] = data[:,2]*np.sqrt(betx)*NL_c
    data[:,5] = (bg0*np.sqrt(1/bet0**2+2.0*data[:,5]+data[:,5]**2)-bg0/bet0)*mass
    return data

In [None]:
from scipy.stats import truncnorm

def get_truncated_normal(mean=0, sd=1, low=0, upp=10,n=1):
    f = truncnorm(
        (low - mean) / sd, (upp - mean) / sd, loc=mean, scale=sd)
    return f.rvs(n)

In [None]:
SextIndex=[]
SextStrength=[]
for i in range(len(arc)):
    if arc[i]['type']=='multipole_thin':
        SextIndex.append(i)
        SextStrength.append(arc[i]['KL_sext'])

In [None]:
#%%
def objFunc(arg): 
    target = impact.opt.id_generator()  # generage random directory name
    while os.path.exists(target):  
        target = impact.opt.id_generator()
    shutil.copytree('origin', target)
    os.chdir(target) # cd to the randome directory and
    
    arcTmp = arc.copy()
    for i,j in enumerate(SextIndex):
        arcTmp[j]['KL_sext']=arg[2*i]
        arcTmp[j]['KL_oct'] =arg[2*i+1]
        
    x=get_truncated_normal(sd=0.15,low=-0.3,upp=0.3,n=npt*5)
    pData=np.zeros([npt,9])
    pData[:,[0,1,2,3,5]]=x.reshape([npt,5])
    pData[:,5] = pData[:,5]/0.15*Espread
    pData[:,6] = q_m
    pData[:,-1] = np.arange(1,npt+1)
    
    pData2 = norm2Impact(pData,bg0,bet0,-1)
    impact.writeParticleData(pData2,ke,mass,freq)
    impact.writeInputFile(beam,arcTmp)
    impact.run(beam)
    time.sleep(1)
    if npt > impact.readLostAt(-1):
        os.chdir('..')
        shutil.rmtree(target)
        return 1.0e22
    pData2 = impact.readParticleData(999, ke, mass, freq, format_id=2)
    pData2 = Impact2norm(pData2,bg0,bet0,1)

    obj = np.sum((pData[:,:4]-pData2[:,:4])**2)
    os.chdir('..')
    shutil.rmtree(target)
    return obj

In [None]:
objFunc(SextStrength + [0]*len(SextIndex))

In [None]:
#%% run optim
bounds = [(-2.0,2.0)]*len(SextIndex)*2
result=impact.opt.differential_evolution(objFunc, bounds, ncore=32, popsize=128, 
                                        disp=True, polish=False, maxtime=60*60*2) 
                                        # stop running at maximum 1 min
print(result)
with open('result.12sext.12oct','wb') as fp:
    pickle.dump(result,fp)

In [None]:
while True:
    previous_result = result
    result = impact.opt.differential_evolution(objFunc, bounds, ncore=32, 
                                           prev_result=previous_result, 
                                           disp=True, polish=False, maxtime=60*60*2)
    with open('result.12sext.12oct','wb') as fp:
        pickle.dump(result,fp)
        if hasattr(result,'x'): 
            break

In [None]:
print(result)