In [None]:
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt


In [None]:
class FermiEvents:
    
    def __init__(self, data):
        self.data = data

    @property 
    def theta(self): 
        return self.data[...,0]
    @property 
    def phi(self): 
        return self.data[...,1]
    @property 
    def energy(self): 
        return self.data[...,2]
    @property 
    def data_size(self):
        return self.data.shape[0]
    
    def selection_cuts(self, feature, cut=None):
        if cut is None: cut=[-np.inf, np.inf]
        dic={'theta':0, 'phi':1, 'energy':2}
        mask = (self.data[..., dic[feature]] >= cut[0]) & (self.data[..., dic[feature]] <= cut[1])
        self.data = self.data[mask]
        return self

    def standardize(self, inverse=False, sigma=1, verbose=True):
        if not inverse: 
            if verbose: print('INFO: standardizing data to zero-mean and std={}'.format(sigma)) 
            self.mean = torch.mean(self.data, dim=0)
            self.std = torch.std(self.data, dim=0)
            self.data = (self.data - self.mean) / (self.std / sigma)
        else:
            if verbose: print('INFO: inverting data standardization')
            self.data = self.data * (self.std / sigma) + self.mean

    def normalize(self, inverse=False, verbose=True):
        if not inverse:
            self.max, _ = torch.max(torch.reshape(self.data, (-1,3)), dim=0)
            self.min, _ = torch.min(torch.reshape(self.data, (-1,3)), dim=0)
            if verbose: print('INFO: normalizing data') 
            self.data = (self.data - self.min) / (self.max - self.min) 
        else:
            if verbose: print('INFO: inverting data normalization')
            self.data = self.data * (self.max - self.min) + self.min 

    def logit_transform(self, alpha=1e-6, inverse=False, verbose=True):
        if not inverse:
            if verbose: print('INFO: applying logit transform')
            self.data = logit(self.data, alpha=alpha)
        else:
            if verbose: print('INFO: applying expit transform')
            self.data = expit(self.data, alpha=alpha)
    
    def preprocess(self, reverse=False, sigma=1.0,  alpha=1e-6, verbose=True):  
        if not reverse: 
            self.normalize(verbose=verbose)
            self.logit_transform(alpha=alpha, verbose=verbose)
            self.standardize(sigma=sigma, verbose=verbose)
        else: 
            self.standardize(sigma=sigma, inverse=True, verbose=verbose)
            self.logit_transform(inverse=True, alpha=alpha, verbose=verbose)
            self.normalize(inverse=True, verbose=verbose)

    def plot(self, feature, plot_dir, target=None, color='k', bins=100, log_scale=(False, False)):
        dic={'theta':0, 'phi':1, 'energy':2}
        sns.histplot(x=self.data[...,dic[feature]], 
                     color=color, 
                     bins=bins,
                     log_scale=log_scale,
                     element="step", lw=0.75, fill=False, alpha=0.1) 
        if target is not None:
            feat = target[...,dic[feature]]
            sns.histplot(x=feat, 
                 color=color, 
                 bins=bins,
                 log_scale=log_scale,
                 element="step", lw=0., fill=True, alpha=0.1) 
        plt.savefig(plot_dir+'/{}.pdf'.format(feature))
        plt.close()

    def plot_milky_way(self, plot_dir, cmap='plasma'):
        fig = plt.figure(1,figsize=(6,6))
        ax = fig.add_subplot(111)
        ax.hexbin(self.phi, self.theta, cmap=cmap, gridsize=300, bins='log')
        plt.savefig(plot_dir+'/milky_way.pdf')

    