### imports

In [None]:
import pjlsa_gsipro as pjlsa

import random
import time
import json
from datetime import datetime

import psycopg2
import numpy as np
from scipy.interpolate import LinearNDInterpolator
import matplotlib as mpl
import matplotlib.pyplot as plt

%matplotlib inline  

In [None]:
from pyda import SimpleClient
from pyda.data import DataFilter, TimingSelector
from pyda_rda3 import RdaProvider

In [None]:
import jpype
with pjlsa.LSAClientGSI().java_api():
    from cern.lsa.client import ServiceLocator, ContextService, ParameterService, TrimService, SettingService
    from cern.accsoft.commons.value import ValueFactory
    from cern.lsa.domain.settings import TrimRequest, ContextSettingsRequest, Settings, SettingPartEnum
    from cern.lsa.domain.settings.type import  BeamProcessPurposes
    from de.gsi.lsa.domain.settings import GsiBeamProcessPurpose
    from java.util import Set, Collections

In [None]:
ts = ServiceLocator.getService(TrimService)
cs = ServiceLocator.getService(ContextService)
ps = ServiceLocator.getService(ParameterService)
ss = ServiceLocator.getService(SettingService)

In [None]:
TERMINAL_COLORS = [
"\u001b[31m", "\u001b[32m", "\u001b[33m", "\u001b[34m", "\u001b[35m", "\u001b[36m", "\u001b[37m", "\u001b[31;1m", "\u001b[32;1m", "\u001b[33;1m", "\u001b[34;1m", "\u001b[35;1m",
]

TERMINAL_COLOR_RESET = "\u001b[0m"


## fetch pattern and related beam processes

In [None]:
# pattern = cs.findPattern("SCRATCH_RM_SIS18_PYTHON_TEST_20220728_124253")

# pattern = cs.findPattern("SIS18_FAST_HHD_12C6_HOCHSTROM_4H1")
# pattern = cs.findPattern("SIS18_SLOW_HHD_Langsam")
# pattern = cs.findPattern("SIS18_FAST_HHD_20231109_EXP")
# pattern = cs.findPattern("SIS18_SLOW_HADES_20231106_171354")
pattern = cs.findPattern("SIS18_FAST_HHD_20231111_221700")
# pattern = cs.findPattern("SIS18_FAST_HHD_20231112_042341")


In [None]:
beamProcesses = list(pattern.getBeamProcesses())

for i, p in enumerate(beamProcesses):
    print(i, p.getName())

### beam process for tunescan

In [None]:
beamProcess_to_trim_id = 13
beamProcess_to_trim = beamProcesses[beamProcess_to_trim_id]
process_length = beamProcess_to_trim.getLength()  # ns
print("trim process {} has length {} ms".format(beamProcess_to_trim.getName(), process_length / 1e6))

parameterList = [
    ps.findParameterByName(p) for p in ["SIS18BEAM/QH","SIS18BEAM/QV",]
]

###### fetch process length & tune settings

In [None]:
tuneSettings = ss.findContextSettings(
        ContextSettingsRequest.byStandAloneContextAndParameters(beamProcess_to_trim, Set.of(parameterList))
    )

tuneSettings_values = list()
for p in parameterList:
    df = Settings.getFunction(tuneSettings, p) 
    val = np.array((df.toXArray(), df.toYArray(),))
    tuneSettings_values.append(val)

In [None]:
process_length

# Tunescan

#### helper functions

In [None]:
def get_tunePoints(qx_start, qx_end, qy_start, qy_end, tuneSettings_values=tuneSettings_values):
    tunePoints = [
        [tuneSettings_values[0][1,0], qx_start, qx_end, tuneSettings_values[0][1,1],],
        [tuneSettings_values[1][1,0], qy_start, qy_end, tuneSettings_values[1][1,1],],
    ]

    return tunePoints

In [None]:
def set_tunes(timePoints, tunePoints, pattern=pattern, parameterList=parameterList):
    trimRequestBuilder = TrimRequest.builder()
    trimRequestBuilder.setSettingPart(SettingPartEnum.TARGET)
    trimRequestBuilder.setDescription("Tunescan, CC")
    
    patternSettings = ss.findContextSettings(
            ContextSettingsRequest.byStandAloneContextAndParameters(pattern, Set.of(parameterList))
    )

    print(type(patternSettings))
    for parameter, parValues in zip(parameterList, tunePoints):
        parameterSettings = patternSettings.getParameterSettings(parameter)
                
        bpSetting = parameterSettings.getSetting(beamProcess_to_trim)
        discreteFunction = ValueFactory.createFunction(jpype.JArray(float)(timePoints), jpype.JArray(float)(parValues))
        bpSetting.updateValue(discreteFunction, SettingPartEnum.TARGET)
        trimRequestBuilder.addSetting(bpSetting)
    
    trimRequest = trimRequestBuilder.build()
    ts.trimSettings(trimRequest)

In [None]:
def get_tunePoints_tuneScan(qx_min, qx_max, qy_min, qy_max, direction: str,
                      numSweeps: int = 25, tuneSettings_values=tuneSettings_values):
    if not direction in ["horizontallyUpwards", "horizontallyDownwards", "verticallyUpwards", "verticallyDownwards"]:
        raise ValueError("invalid scan direction")

    all_tunePoints = list()
    if direction == "horizontallyUpwards":
        verticalTunes = np.linspace(qy_min, qy_max, numSweeps)   
        for qy in verticalTunes:
            all_tunePoints.append(get_tunePoints(qx_min, qx_max, qy, qy))
        
    elif direction == "horizontallyDownwards":
        verticalTunes = np.linspace(qy_min, qy_max, numSweeps)   
        for qy in verticalTunes:
            all_tunePoints.append(get_tunePoints(qx_max, qx_min, qy, qy))

    elif direction == "verticallyUpwards":
        horizontalTunes = np.linspace(qx_min, qx_max, numSweeps)   
        for qx in horizontalTunes:
            all_tunePoints.append(get_tunePoints(qx, qx, qy_min, qy_max))

    elif direction == "verticallyDownwards":
        horizontalTunes = np.linspace(qx_min, qx_max, numSweeps)   
        for qx in horizontalTunes:
            all_tunePoints.append(get_tunePoints(qx, qx, qy_max, qy_min))

    else:
        raise ValueError("invalid scan direction")

    return all_tunePoints

In [None]:
class TuneScan(dict):
    def __init__(self, timePoints, qx_min, qx_max, qy_min, qy_max, direction, numSweeps):
        self["timePoints"] = timePoints
        self["qx_min"] = qx_min
        self["qx_max"] = qx_max
        self["qy_min"] = qy_min
        self["qy_max"] = qy_max
        self["direction"] = direction
        self["numSweeps"] = numSweeps

        self["all_tunePoints"] = get_tunePoints_tuneScan(qx_min, qx_max, qy_min, qy_max, direction, numSweeps)

        return

    def run(self, shotCount):
        self["startTime"] = int(time.time() * 1e9)
        self["stopTime"] = None
        
        tuneScan_results = dict()
        for tune_idx, tunePoints in enumerate(tuneScan["all_tunePoints"]):
            # send to hardware
            print("perform trim")
            set_tunes(timePoints, tunePoints)
        
            # listen to DCCT
            client = SimpleClient(provider=RdaProvider())
            subscription = client.subscribe(
                PROPERTY_NAME,
                context=[
                    TimingSelector(FAIR_SELECTOR),
                    DATA_FILTER 
                ],
            )
        
            intensities = list()
            print("subscribing")
            for iteration, response in enumerate(subscription):
                if iteration == 0:
                    # old, potentially fake data
                    continue
                
                ts = response.value["acquisitionStamp"] / 1e9
                ts_datetime = datetime.utcfromtimestamp(ts).strftime('%Y-%m-%d %H:%M:%S')
        
                intensities.append(response.value["intensity"].tolist())
                
                print(TERMINAL_COLORS[1] + "DCCT" + 
                      " received S={}:P={}".format(response.value["sequenceIndex"], response.value["processIndex"]) +
                      "    " + ts_datetime +
                      TERMINAL_COLOR_RESET)
        
                if iteration == shotCount:
                    tuneScan_results[str(tune_idx)] = intensities
                    break

        self["tuneScan_results"] = tuneScan_results
        self["stopTime"] = int(time.time() * 1e9)

        return

In [None]:
PROPERTY_NAME = "GS09DT_ML/Acquisition"
DATA_FILTER = DataFilter(requestPartialData=False,frequencyFilter=np.int32(2))

## perform scan

In [None]:
FAIR_SELECTOR = beamProcess_to_trim.getUser()
FAIR_SELECTOR

In [None]:
beamProcess_to_trim.getLength()

In [None]:
scanLength = 1.7  # s

tuneScan_length = scanLength * 1e6 # μs
tuneRamp_length = (process_length - tuneScan_length) / 2
timePoints = [0, tuneRamp_length, process_length - tuneRamp_length, process_length]

assert np.all(np.diff(timePoints) > 0), "tune ramp too long"

In [None]:
tuneScan = TuneScan(timePoints, 4.01, 4.5, 3.01, 3.5, "verticallyUpwards", 35)

In [None]:
tuneScan.run(2)

### dump

###### dump helper tools

In [None]:
DBNAME = "bpm_fesa_dump"
HOST = "pgsql.gsi.de"
PORT = "8646"
USER = "bpm_fesa_dump_slave"
PASSWORD = "kuwLMKTcAap6mKTP"

In [None]:
insertionStatement = """
INSERT INTO
    bpm_fesa_dump.tunescan (scanCompleted, scanStarted, qx_min, qx_max, qy_min, qy_max, direction, tuneScan)
VALUES
    (%s, %s, %s, %s, %s, %s, %s, %s)
;
"""

#### json file

In [None]:
fileName = "/home/bphy/ccaliari/lnx/Tunescan/Tunescan_Results/tunescan_{}_{}_{}_{}.json".format(
    int(time.time() * 1e9), datetime.now().strftime("%Y-%m-%d"), tuneScan["direction"], datetime.now().strftime("%H:%M:%S")
)

with open(fileName, "w") as file:
    json.dump(tuneScan, file)

"wrote file to {}".format(fileName)

##### SQL

In [None]:
try:
    dbcon = psycopg2.connect("dbname="+DBNAME + " user=" + USER + " host=" + HOST + " port=" + PORT + " password=" + PASSWORD)
except Exception as e:
    print("Unable to connect to database")
    print(e)

crsr = dbcon.cursor()

values = [tuneScan["stopTime"], tuneScan["startTime"],
          tuneScan["qx_min"], tuneScan["qx_max"], tuneScan["qy_min"], tuneScan["qy_max"],
          tuneScan["direction"], json.dumps(tuneScan)]
crsr.execute(insertionStatement, values)
dbcon.commit()

dbcon.close()

# plot result

In [None]:
def crop_signal(timePoints, signal):
    time_recording = np.linspace(0, timePoints[3] - timePoints[0], len(signal))

    time_scan = np.linspace(timePoints[1], timePoints[2], round(
        (timePoints[2] - timePoints[1]) / (timePoints[3] - timePoints[0]) * len(signal)
    ),)

    interp = np.interp(time_scan, time_recording, signal)

    return time_scan, interp

In [None]:
def interpolate_tuneScan_results(tuneScan):
    timePoints, tuneScan_results = tuneScan["timePoints"], tuneScan["tuneScan_results"]
    
    samples, labels = list(), list()
    for idx, tunePoints in enumerate(np.array(tuneScan["all_tunePoints"])):
        shots = [crop_signal(timePoints, s)[1] for s in tuneScan_results[str(idx)]]
    
        for i in range(len(shots)):
            # normalize intensity
            shots[i] = shots[i] / shots[i][0]
    
            # differentiate
            shots[i] = np.diff(shots[i])
    
        shots = np.array(shots).mean(axis=0)
        
        qx_range = np.linspace(tunePoints[0][1], tunePoints[0][2], len(shots))
        qy_range = np.linspace(tunePoints[1][1], tunePoints[1][2], len(shots))
    
        for i in range(shots.shape[0]):
            samples.append((qx_range[i], qy_range[i]),)
            labels.append(shots[i])

    return LinearNDInterpolator(samples, labels)

In [None]:
interp = interpolate_tuneScan_results(tuneScan)

In [None]:
qx_range = np.linspace(tuneScan["qx_min"], tuneScan["qx_max"], 20)
qy_range = np.linspace(tuneScan["qy_min"], tuneScan["qy_max"], 20)
qx_range, qy_range = np.meshgrid(qx_range, qy_range)

result_scanInterpolated = interp(qx_range, qy_range)

In [None]:
fig, ax = plt.subplots()

# norm = mpl.colors.Normalize(vmin=0, vmax=0.01)
pcm = ax.pcolormesh(qx_range, qy_range, np.abs(result_scanInterpolated),
                   )

cbar = fig.colorbar(pcm)

# make nice
ax.set_xlabel("hor. tune")
ax.set_ylabel("ver. tune")
cbar.set_label(r"$\frac{1}{I} \cdot \frac{\partial I}{\partial t}$")

# debug

In [None]:
plottable = np.abs(result_scanInterpolated)

In [None]:
plt.hist(plottable)

In [None]:
idx = 5

tuneScan_results = tuneScan["tuneScan_results"]

shots = [crop_signal(timePoints, s)[1] for s in tuneScan_results[str(idx)]]