In [1]:
import pynverse
import numpy as np
import sys
import random

sys.path.append("..")

from rascal import models
from rascal import util

import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm

import numpy as np

from rascal.calibrator import Calibrator
from rascal.util import refine_peaks


  from tqdm.autonotebook import tqdm


In [2]:
class SyntheticSpectrum:
    def __init__(self, coefficients, model_type='cubic', degree=None):
        """
        Creates a synthetic spectrum generator which, given a suitable model,
        outputs the expected pixel locations of input wavelengths.

        Parameters
        ----------
        coefficients:
          list, coefficients for the model
        model_type:
           str, model type (linear, quadratic, cubic or poly)
        degree:
           int, if using a general poly model, its degree, default None

        It is
        expected that this will be used mainly for model testing, but
        you can alsus
        """
        self.model = None

        # Default is approx. range of Silicon
        # Angstrom
        self.min_wavelength = 2000
        self.max_wavelength = 12000
        
        self.coefficients = coefficients

        # Model to fit
        if model_type == 'quadratic':
            self.model = models.quadratic(self.coefficients)
        elif model_type == 'cubic':
            self.model = models.cubic(self.coefficients)
        elif model_type == 'poly':

            if degree is None:
                raise ValueError("You should specify a polynomial degree.")

            self.model = models.polynomial(self.coefficients, degree)
        else:
            raise NotImplementedError

    def set_wavelength_limit(self, min_w, max_w):
        """
        Set a wavelength filter for the `get_pixels` function.
        """
        self.min_wavelength = min_w
        self.max_wavelength = max_w

    def get_pixels(self, wavelengths):
        """
        Returns a list of pixel locations for the wavelengths provided
        """

        if self.model is None:
            raise ValueError("Model not initiated")

        wavelengths = np.array(wavelengths)
        wavelengths = wavelengths[wavelengths > self.min_wavelength]
        wavelengths = wavelengths[wavelengths < self.max_wavelength]

        return pynverse.inversefunc(self.model, wavelengths)

class RandomSyntheticSpectrum(SyntheticSpectrum):

    def __init__(self, num_pixels=1024, min_wavelength=4000, max_wavelength=8000, endpoint_jitter=100, model_type='poly', degree=1):

        # This is the maximum and minimum wavelength of the sensor
        min_wavelength = min_wavelength +endpoint_jitter*(-1 + 2*random.random())
        max_wavelength += endpoint_jitter*(-1 + 2*random.random())
        self.num_pixels = 1024
        
        x0 = min_wavelength
        x1 = (max_wavelength - x0)/num_pixels
        
        coefficients = [x0, x1]
        
        if degree >= 2:
            coefficients.append(0.002*random.random())
            
        if degree >= 3:
            coefficients.append(-0.001+0.002*random.random())

        coefficients.reverse()

        super().__init__(coefficients, model_type, degree)
        self.set_wavelength_limit(x0, max_wavelength)

    def add_atlas(self, elements, n_lines=30, min_intensity=1000, min_distance=10):
        lines = util.load_calibration_lines(elements,
                           min_atlas_wavelength=self.min_wavelength,
                           max_atlas_wavelength=self.max_wavelength,
                           min_intensity=min_intensity,
                           min_distance=min_distance,
                           vacuum=False)
    
        idx = random.sample(range(len(lines[0])), min(len(lines[0]), n_lines))
        self.elements = lines[0][idx]
        self.wavelengths = lines[1][idx]
        self.intensities = lines[2][idx]

In [41]:
def test_fit(spectrum, uncertainty = 0, range_tolerance=500):
    c = Calibrator(pixs)
               
    c.set_calibrator_properties(num_pix=spectrum.num_pixels, plotting_library='matplotlib',
                                log_level='info') 

    c.set_hough_properties(num_slopes=5000,
                           xbins=100,
                           ybins=100,
                           min_wavelength=spectrum.min_wavelength,
                           max_wavelength=spectrum.max_wavelength,
                           range_tolerance=range_tolerance,
                           linearity_tolerance=200)

    c.add_atlas(['Hg', 'Ar'], min_atlas_wavelength=spectrum.min_wavelength - uncertainty,
                            max_atlas_wavelength=spectrum.max_wavelength + uncertainty)

    c.set_ransac_properties(sample_size=5,
                            top_n_candidate=5,
                            linear=True,
                            filter_close=True,
                            ransac_thresh=5,
                            candidate_weighted=True,
                            hough_weight=1.0)

    c.do_hough_transform()
    
    return c

In [57]:
"""
Spectrum setup
"""
min_wavelength = 4000
max_wavelength = 8000
endpoint_jitter = 0
elements = ['Hg', 'Ar']
poly_degree = 1
num_pixels = 1024

"""
Generate spectrum + ground truth
"""
spectrum = RandomSyntheticSpectrum(num_pixels,
                                   min_wavelength,
                                   max_wavelength,
                                   endpoint_jitter,
                                   'poly',
                                   poly_degree)
spectrum.add_atlas(elements)

wavs = spectrum.wavelengths
pixs = spectrum.get_pixels(spectrum.wavelengths)

ground_truth = zip(wavs, pixs)

calibrator = test_fit(spectrum, uncertainty=0, range_tolerance=400)
polyfit_coeff, rms, residual, peak_utilisation = calibrator.fit(polydeg=1, max_tries=500)

print(res, residual, spectrum.coefficients)

"""
Check fit
"""
fit_model = models.polynomial(res[::-1], 1)
np.array([fit_model(px) for px in pixs]) - wavs

INFO:rascal.calibrator:num_pix is set to 1024.
INFO:rascal.calibrator:pixel_list is set to None.
INFO:rascal.calibrator:Plotting with matplotlib.


HBox(children=(FloatProgress(value=0.0, max=500.0), HTML(value='')))


[4.00127609e+03 3.85675385e+00] [4.05816536e-09 4.05816536e-09 4.05816536e-09 4.09727363e-09
 4.09727363e-09 4.09727363e-09 4.29372449e-09 4.29372449e-09
 4.29372449e-09 4.29372449e-09 6.56382326e-09 8.87393981e-09
 9.82254278e-09 1.02609192e-08 1.02609192e-08] [3.90625, 4000.0]


array([-37.58760718, -21.14590238, -36.29787985, -40.19420428,
       -17.23241517, -38.59903622, -41.6012251 , -37.77365082,
         0.68633513,  -3.26400816, -43.1203052 , -44.78319017,
       -35.56340032, -48.75002804,   0.2900968 , -43.25693932])

In [44]:
candidates = np.array(list(zip(calibrator.candidate_peak, calibrator.candidate_arc)))

for gt in list(zip(pixs, wavs)):
    peak, arc = gt
    
    idxs = np.argwhere(candidates[:,1] == arc)
    
    print("-- {} --".format(arc))
    
    if len(idxs) == 0:
        print("Could not find {} in candidates".format(arc))
    else:
        for candidate in candidates[idxs]:
            print(candidate[:,0] - peak)

    

-- 7067.13427734375 --
Could not find 7067.13427734375 in candidates
-- 4358.30615234375 --
[0.]
-- 4046.5439453125 --
[0.]
[0.]
[8.0054375]
[79.811125]
-- 7503.7763671875 --
[-30.690875]
[2.7605]
[33.59625]
[33.59625]
-- 7272.84765625 --
Could not find 7272.84765625 in candidates
-- 5460.69677734375 --
Could not find 5460.69677734375 in candidates
-- 7948.07568359375 --
Could not find 7948.07568359375 in candidates
-- 6965.3486328125 --
[-14.839125]
[46.491625]
[78.71975]
-- 4077.815185546875 --
[0.]
[0.]
[0.]
[0.]
-- 7081.81689453125 --
Could not find 7081.81689453125 in candidates
-- 5769.55029296875 --
Could not find 5769.55029296875 in candidates
-- 6907.38330078125 --
Could not find 6907.38330078125 in candidates
-- 7383.89013671875 --
Could not find 7383.89013671875 in candidates
-- 7146.95654296875 --
[0.]
[32.228125]
[94.106375]
-- 7514.5595703125 --
[-33.451375]
[0.]
-- 7635.01171875 --
[-33.59625]
[0.]


array([[  11.91525   , 4044.39770508],
       [  11.91525   , 4046.54394531],
       [  11.91525   , 4046.54394531],
       [  11.91525   , 4046.54394531],
       [  11.91525   , 4044.39770508],
       [  19.9206875 , 4044.39770508],
       [  19.9206875 , 4046.54394531],
       [  19.9206875 , 4077.81518555],
       [  19.9206875 , 4077.81518555],
       [  19.9206875 , 4077.81518555],
       [  91.726375  , 4300.07421875],
       [  91.726375  , 4358.30615234],
       [  91.726375  , 4044.39770508],
       [  91.726375  , 4046.54394531],
       [  91.726375  , 4702.27929688],
       [ 373.938375  , 5365.44677734],
       [ 373.938375  , 5120.59082031],
       [ 373.938375  , 5650.64648438],
       [ 373.938375  , 5572.48486328],
       [ 373.938375  , 5353.98291016],
       [ 453.004875  , 5572.48486328],
       [ 453.004875  , 5495.81982422],
       [ 453.004875  , 5890.13574219],
       [ 453.004875  , 5803.72167969],
       [ 453.004875  , 5890.13574219],
       [ 744.290125  , 71