In [None]:
import scipy

In [None]:
class Statistician:
    """
    Container for methods of combining satellite measurements and simulations and measuring implausibility.
    """

    def __init__(
        self,
        emulator,
        clusterer,
        observer,
        pixelwise=True
    ):
        """
        Arguments

        emulator : Emulator
            Container for methods of training Gaussian Process emulators of various flexibilities and for evaluating the
            quality of these emulators.
        clusterer : Clusterer
            Container for methods of selecting physically homogeneous regions of the map to emulate as though they observe
            the same physical processes.
        observer : Observer
            Container for methods of collecting satellite measurements and regridding them to match simulation grid.


        Value

        None
        """

        self.emulator = emulator
        self.clusterer = clusterer

        # Put observations all in the same data frame
        self.observations = pd.concat(
            observer.rgdatasets.values(), axis=0
        ).reset_index(
            drop=True
        )

        self.threshold = scipy.stats.chi2.ppf(
            0.95,
            len(clusterer.labels)
        )

        self.pixelwise = pixelwise
        
        return



    """
    Evaluate implausibility
    """



    def compute_statistics(
        self,
        nVariants,
        regionalEmulators,
        variantsFromTraining=True,
        whichParameter='acure_sea_spray',
        whichPoint=0
    ):

        if variantsFromTraining:
            model_variants = self.__get_variants_from_training_point__(
                nVariants=nVariants,
                whichParameter=whichParameter,
                whichPoint=whichPoint
            )
        else:
            model_variants = self.__get_random_model_variants__(nVariants)

        metrics = []

        for variant in model_variants:

            self.get_regional_emulations(
                regionalEmulators,
                variant
            )
            
            y = np.array(self.emulated.meanResponse)
            z = np.array(self.observations.meanResponse)
            S = np.diag(np.array(self.emulated.sdResponse))
            E = np.diag(np.array(self.observations.sdResponse))
            
            metrics.append(
                self.__mahalanobis_metric__(
                    y, z, S, E
                )
            )

        results = pd.DataFrame(data=model_variants, columns=self.emulator.inputs)
        results['metric'] = metrics
        return results


    def plot_statistics(
        self,
        results,
        whichParameter='acure_sea_spray'
    ):
        import matplotlib.pyplot as plt

        plt.scatter(results[whichParameter], results.metric)
        plt.axhline(d.threshold)
        plt.ylabel('implausibility')
        plt.xlabel(whichParameter)
        plt.show()
        
        return


    def __get_random_model_variants__(
        self,
        nVariants=10
    ):
        nFeatures = len(self.emulator.inputs)
        return np.random.rand(nVariants, nFeatures).tolist()


    def __get_variants_from_training_point__(
        self,
        nVariants=10,
        whichParameter='acure_sea_spray',
        whichPoint=0,
        customVariant=None
    ):

        with open("training_points", "r") as f:
            training_points = pd.read_csv(f, index_col=0).loc[:, self.emulator.inputs]

        emulation_points = pd.concat([training_points.iloc[0, :]]*nVariants, axis=1).transpose().reset_index(drop=True)
        emulation_points.loc[:, whichParameter] = np.linspace(0, 1, nVariants)

        return emulation_points.to_numpy().tolist()
            



    """
    Collect data
    """



    def get_regional_emulations(
        self,
        regionalEmulators,
        queryVariant
    ):

        self.emulated = self.emulator.emulate_variant(
            regionalEmulators,
            pd.DataFrame(queryVariant).transpose(),
            pixelwise=self.pixelwise,
            lengthScales=self.clusterer.lengthscales,
            labels=self.clusterer.labels
        )

        return


    def get_regional_observations(
        self,
        lengthScales=None,
        labels=None
    ):
        regridded_points = list(map(
            tuple,
            self.observations.loc[:, ["latitude", "longitude", "time"]].round(decimals=2).to_numpy()
        ))

        self.z = {}
        self.E = {}

        cluster_idxs = np.unique(labels)
        
        for idx in cluster_idxs:
            
            predict_points = list(map(
                tuple,
                lengthScales.loc[labels==idx, ["latitude", "longitude", "time"]].round(decimals=2).to_numpy().tolist()
            ))
            
            regional_means = np.array(self.observations.loc[
                [row for row in range(len(regridded_points)) if regridded_points[row] in predict_points], 
                "meanResponse"
            ])
            regional_sds = np.diag(np.array(self.observations.loc[
                [row for row in range(len(regridded_points)) if regridded_points[row] in predict_points], 
                "sdResponse"
            ])**2)

            self.z[idx] = regional_means
            self.E[idx] = regional_sds

        return



    """
    Implausibility metrics
    """



    def __history_matching_metric__(
    ):
        return


    def __mahalanobis_metric__(
        self, y, z, S, E
    ):
        xf = np.nan_to_num(y - z)
        Mf = S + np.nan_to_num(E)

        return np.transpose(xf)@np.linalg.inv(Mf)@(xf)



    """
    Helper functions
    """
    
    
    