In [109]:
class StaticModel:
    """A class to act as parent class for submodels
    
    It should not be initialized; its methods are static."""
    def __new__(cls, *args, **kwargs):
        raise TypeError('Intended for static use only')

    @classmethod
    def eval(x, params):
        raise NotImplementedError

    @staticmethod
    def guess_params(chunk):
        raise NotImplementedError
    
    @staticmethod
    def name():
        return __class__.__name__ #'MultiGaussian'



In [110]:
import numpy as np

class LSFModel:
    """Abstract base class for the LSF models
    """
    @staticmethod
    def generate_x(osample_factor, conv_width=10.):
        """Generate a pixel vector to sample the LSF over
        
        Args:
            osample_factor (float): The oversample factor of the model.
            conv_width (Optional[float]): The number of pixels on either side.
                Defaults to 10.
        
        Return:
            ndarray[(2*int(conv_width)*int(osample_factor)+1)]: The evaluated
                pixel vector.
        """
        #return np.linspace(-4.0, 4.0, 8 * osample_factor + 1)
        return np.linspace(-conv_width, conv_width, 
                           int(2*conv_width) * int(osample_factor) + 1)

class MultiGaussian(LSFModel, StaticModel):
    """The LSF model of a Multi Gaussian
    
    10 free parameters: Amplitudes of the Satellite Gaussians (5 left, 5 right).
    """
    param_names = [
        'left_5', 'left_4', 'left_3', 'left_2', 'left_1',
        'right1', 'right2', 'right3', 'right4', 'right5',
    ]
    
    param_guess = np.array([
        0.1, 0.2, 0.3, 0.5, 0.7,
        0.7, 0.5, 0.3, 0.2, 0.1
    ])
    
    positions = np.array([
        -2.4, -2.1, -1.6, -1.1, -0.6,
        0.0,
        0.6, 1.1, 1.6, 2.1, 2.4
    ])
    
    sigmas = np.array([
        0.3, 0.3, 0.3, 0.3, 0.3,
        0.4,
        0.3, 0.3, 0.3, 0.3, 0.3
    ])
    
    @classmethod
    def change_positions(cls, positions):
        if isinstance(positions, (list,tuple,np.ndarray)) and len(positions) == 11:
            cls.positions = positions
        else:
            raise ValueError('You need to supply 11 positions as one of (list, tuple, ndarray)!')
    
    @classmethod
    def change_sigmas(cls, sigmas):
        if isinstance(sigmas, (list,tuple,np.ndarray)) and len(sigmas) == 11:
            cls.sigmas = sigmas
        else:
            raise ValueError('You need to supply 11 sigmas as one of (list, tuple, ndarray)!')

    @classmethod
    def eval(cls, x, params):
        # convert input dict to list (nope, this is not pretty..)
        params = np.array([params[k] for k in cls.param_names])

        # Set up parameter vectors, including central gaussian
        a = np.array([
            params[0], params[1], params[2], params[3], params[4],
            1.0,
            params[5], params[6], params[7], params[8], params[9],
        ])
        # In Butler 1996: Gaussians placed at 0.5 pixels apart
        # This is from the cf's
        b = cls.positions
        c = cls.sigmas
        n = 11

        # Multigauss function
        def func(x):
            xarr = np.repeat([x], n, axis=0)
            f = np.sum(a * np.exp(-0.5 * ((np.transpose(xarr) - b) / c)**2.), axis=1)
            #f[np.where(f < 0.0)] = 0.0
            return f

        # Evaluate function and find centroid
        y = func(x)
        
        # added for NaN-debugging
        y_sum = np.sum(y)
        #if y_sum==0:
        #    print('Sum of lsf is 0. Setting to 1e-4.')
        #    print(y)
        #    y_sum = 1e-4
        
        # Calculate centroid and re-center the LSF
        offset = np.sum(x * y) / y_sum #np.sum(y)  # TODO: Is this the correct way of weighting?
        y = func(x + offset)
        
        # added for NaN-debugging
        if any(np.isnan(y)):
            print('NaN value detected in un-normalized lsf function.')
        y_sum = np.sum(y)
        #if y_sum==0:
        #    print('Sum of lsf is 0. Setting to 1e-4.')
        #    y_sum = 1e-4
        
        y = y / y_sum
        if any(np.isnan(y)):
            print('NaN value detected in lsf function.')
            print('Sum of y: ', y_sum)
            print(params)
        
        return y# / np.sum(y)  # FIXME: Normalize to unit area?

    @classmethod
    def guess_params(cls, chunk):
        # These are the median parameters from the cf of rs15.31
        # 0.4    0.0820768    0.0557928     0.167839     0.417223     0.453222
        #        0.400129     0.390609     0.138321    0.0599234    0.0588782
        return ParameterSet(
            {name: guess for name, guess in zip(cls.param_names, cls.param_guess)}
        )  # FIXME: Make a better guess

In [33]:
keys = ['bla', 'bli']
values = [1, 2]
a = {key: value for key, value in zip(keys, values)}
print(a)

{'bla': 1, 'bli': 2}


In [108]:
a = (1, 2, 3)
np.array(a)

array([1, 2, 3])

In [22]:
class ParameterSet(dict):
    """A general set of parameters for a model (a dict with extra methods)"""

    def __getitem__(self, item):
        """The dedicated get-method
        
        Args:
            item (str): A string corresponding to a parameter key or a key
                prefix.
        
        Returns:
            :class:'ParameterSet' or value: Either a set of parameters 
                corresponding to the prefix, or the parameter value
                corresponding to the key name.
        """
        if item in self.keys():
            return super().__getitem__(item)
        else:
            return self.filter(prefix=item)

    def filter(self, prefix): #=None):
        """Return a subset of parameters, defined by prefix
        (or something else in the future)
        
        Args:
            prefix (str): A prefix to filter the parameter keys by (either
                   of 'lsf', 'wave' or 'cont' at the moment).
        
        Returns:
            :class:'ParameterSet': The parameters corresponding to the prefix.
        """
        if prefix is not None:
            new = {k[len(prefix) + 1:]: self[k] for k in self.keys() if k.startswith(prefix + '_')}
            return ParameterSet(new)
        else:
            raise ValueError('No filter keywords set')
    

    def add(self, parameter_set, prefix=''):
        """Add the parameters of another ParameterSet, adding a prefix if set.
        This will override existing items with the same key, if any.
        
        Args:
            parameter_set (:class:'ParameterSet'): A set of parameters to add.
            prefix (Optional[string]): A prefix to set in front of the
                parameter keys.
        """
        if prefix:
            prefix += '_'
        for k in parameter_set:
            self[prefix + k] = parameter_set[k]
    
    
    def __str__(self):
        """Return information about the contained parameters
        
        Returns:
            str: The names and values of the parameters nicely formatted.
        """
        string = "<ParameterSet (values: {})>".format(len(self))
        if len(self) > 0:
            fill = max([len(k) for k in self])
            for k in sorted(self):
                string += "\n    {name:<{fill}}  =  {value}".format(
                    fill=fill,
                    name=k,
                    value=self[k]
                )
        return string

In [35]:
lsf = MultiGaussian
x = lsf.generate_x(6, 6)
params = ParameterSet(left_5=0.1, left_4=0.2, left_3=0.3, left_2=0.4, left_1=0.5,
            right1=0.5, right2=0.4, right3=0.3, right4=0.2, right5=0.1)

In [36]:
y = lsf.eval(x, params)
print(y)

[2.75179453e-34 1.85308013e-31 9.16500439e-29 3.32914538e-26
 8.88171360e-24 1.74031843e-21 2.50458426e-19 2.64746233e-17
 2.05557091e-15 1.17241346e-13 4.91291649e-12 1.51293583e-10
 3.42546020e-09 5.70651874e-08 7.00415472e-07 6.34827022e-06
 4.26493116e-05 2.13691683e-04 8.06115443e-04 2.32089557e-03
 5.19128357e-03 9.21447081e-03 1.33270553e-02 1.63736953e-02
 1.83152083e-02 2.01270954e-02 2.23226861e-02 2.46960123e-02
 2.73414911e-02 3.04702893e-02 3.42026731e-02 3.91065458e-02
 4.50359429e-02 5.04258284e-02 5.43553716e-02 5.70450674e-02
 5.81176440e-02 5.70450674e-02 5.43553716e-02 5.04258284e-02
 4.50359429e-02 3.91065458e-02 3.42026731e-02 3.04702893e-02
 2.73414911e-02 2.46960123e-02 2.23226861e-02 2.01270954e-02
 1.83152083e-02 1.63736953e-02 1.33270553e-02 9.21447081e-03
 5.19128357e-03 2.32089557e-03 8.06115443e-04 2.13691683e-04
 4.26493116e-05 6.34827022e-06 7.00415472e-07 5.70651874e-08
 3.42546020e-09 1.51293583e-10 4.91291649e-12 1.17241346e-13
 2.05557091e-15 2.647462

In [37]:
lsf.change_positions(np.linspace(-3, 3, 11))

In [38]:
print(lsf.positions)

[-3.  -2.4 -1.8 -1.2 -0.6  0.   0.6  1.2  1.8  2.4  3. ]


In [39]:
y = lsf.eval(x, params)
print(y)

[9.86487004e-25 2.18684367e-22 3.56043122e-20 4.25741474e-18
 3.73893419e-16 2.41161627e-14 1.14242360e-12 3.97470483e-11
 1.01564488e-09 1.90607408e-08 2.62726245e-07 2.65979271e-06
 1.97796187e-05 1.08081370e-04 4.34371929e-04 1.28766420e-03
 2.83983381e-03 4.77427129e-03 6.50417601e-03 8.02795126e-03
 9.88627856e-03 1.19620867e-02 1.35887808e-02 1.50515831e-02
 1.70966985e-02 1.91806989e-02 2.05729102e-02 2.21582751e-02
 2.46157947e-02 2.70100798e-02 2.95811484e-02 3.45168540e-02
 4.19315706e-02 4.89845665e-02 5.38873408e-02 5.69352156e-02
 5.80820916e-02 5.69352156e-02 5.38873408e-02 4.89845665e-02
 4.19315706e-02 3.45168540e-02 2.95811484e-02 2.70100798e-02
 2.46157947e-02 2.21582751e-02 2.05729102e-02 1.91806989e-02
 1.70966985e-02 1.50515831e-02 1.35887808e-02 1.19620867e-02
 9.88627856e-03 8.02795126e-03 6.50417601e-03 4.77427129e-03
 2.83983381e-03 1.28766420e-03 4.34371929e-04 1.08081370e-04
 1.97796187e-05 2.65979271e-06 2.62726245e-07 1.90607408e-08
 1.01564488e-09 3.974704

In [111]:
b = MultiGaussian
b.positions

array([-2.4, -2.1, -1.6, -1.1, -0.6,  0. ,  0.6,  1.1,  1.6,  2.1,  2.4])

In [112]:
MultiGaussian.__name__

'MultiGaussian'

In [113]:
b.name()

'StaticModel'

In [115]:
b.positions = np.linspace(-3, 3, 11)

In [116]:
b.positions

array([-3. , -2.4, -1.8, -1.2, -0.6,  0. ,  0.6,  1.2,  1.8,  2.4,  3. ])

In [117]:
lsf.positions

array([-3. , -2.4, -1.8, -1.2, -0.6,  0. ,  0.6,  1.2,  1.8,  2.4,  3. ])