# imports

In [1]:
import numpy as np
import pandas as pd
import pickle
from scipy.stats import bernoulli, uniform, norm, beta, gamma, binom, poisson
from scipy.special import expit
import sys
import os
from utilities.plotUtils import *
from utilities.utilityFunctions import unpickle_object, pickle_object
from pathlib import Path
# from scipy.ndimage.filters import gaussian_filter1d
from MutualInfo.library import *
import copy
import palettable
from scipy.signal import convolve


ModuleNotFoundError: No module named 'utilities'

# prior class and functions

In [None]:
FONT_SIZE = 15

class Prior1D:
    def __init__(self, xs):
        self.xs = {}
        for ii, x in enumerate(xs):
            self.xs[ii] = x
            if np.any(x[:,1]<0):
                print(x[x[:,1]<0])
                assert 0==1
        
    def evaluate_function_on_grid(self, fcn, grid):
        val = np.zeros((len(grid),2))
        val[:,0] = grid
        val[:,1] = [fcn(x) for x in grid]
        return val
    
    def gaussian_blur_kernel(self, blur_sd, grid):
        self.blur_sd = blur_sd
        gaussian_kernel = np.exp(-np.power(grid,2)/(self.blur_sd**2))
        normalization_constant = np.sum(gaussian_kernel)
        assert normalization_constant!=0, 'normalization constant is zero!'
        self.blur_kernel = gaussian_kernel/normalization_constant
        
#     def blur_xs(self, grid):
    def blur_xs(self):
        self.xs_blurred = {}
#         self.xs_discretized = {}
        for key, val in self.xs.items():
#             x = self.evaluate_function_on_grid(val, grid)
#             self.xs_discretized[key] = x
#             self.xs_blurred[key] = convolve(
#                 val, self.blur_kernel,
#                 mode='same',
#             )
            self.xs_blurred[key] = np.array(list(zip(
                val[:,0],
                convolve(
                    val[:,1], self.blur_kernel,
                    mode='same',
                )
            )))
    
    def subset_blurred_xs(self, lower_bound, upper_bound):
        for key, val in self.xs_blurred.items():
            tmp = self.xs_blurred[key]
            tmp = tmp[(tmp[:,0]>=lower_bound) & (tmp[:,0]<=upper_bound)]
            self.xs_blurred[key] = tmp
            
            tmp = self.xs[key]
            tmp = tmp[(tmp[:,0]>=lower_bound) & (tmp[:,0]<=upper_bound)]
            self.xs[key] = tmp
            
    
    def plot_blur_kernel(self, grid):
        fig, axs = fig_setup(1,1)
        axs[0].plot(
            grid, self.blur_kernel,
            label='blur kernel'
        )
        axs[0].set_title(f'blur sd = {self.blur_sd}')
        finalize(
            axs,
            fontsize=FONT_SIZE,
        )
    
    def plot(self, noisy=True):
        ncolms = 2
#         fig, axs = fig_setup(len(self.xs.keys()), ncolms)
        nrows = int(np.ceil(len(self.xs.keys())/ncolms))
        fig, axs = fig_setup(nrows, ncolms)
        for ii in range(len(self.xs.keys())):
#             print(ii)
            axs[ii].set_prop_cycle(
                'color', 
                palettable.cartocolors.qualitative.Bold_3.mpl_colors
#                 palettable.colorbrewer.qualitative.Dark2_3.mpl_colors
            )

            # plot discretized x
            axs[ii].plot(
#                 self.xs_discretized[ii],
                self.xs[ii][:,0], self.xs[ii][:,1],
                marker='x',
                label=f'x{ii}',
            )
            
            # plot blurred x
            axs[ii].plot(
                self.xs_blurred[ii][:,0], self.xs_blurred[ii][:,1],
                marker='.',
                label=f'blurred x{ii}',
            )
            if noisy==True:
                # plot example blurred+noisy x
                axs[ii].plot(
                    np.random.poisson(self.xs_blurred[ii]),
                    label=f'blurred+noisy x{ii}',
                )
            
        finalize(
            axs,
            fontsize=FONT_SIZE,
        )
        
    def create_prior_dataframe(self):
        prior_df = {
            'prob': [],
#                 'unblurred': [],
            'blurred': [],
        #     'discretized': [],
        }
        for val in self.xs_blurred.values():
            prior_df['prob'].append(1/len(self.xs_blurred.keys()))
            prior_df['blurred'].append(val[:,1])
        self.prior_df = pd.DataFrame(prior_df)

    def plot_mcmc_trace(self):
        fig, axs = fig_setup(1, 1)
#             cumulative_mean = np.zeros(len(self.mutual_infos))
        cumulative_mean = np.divide(
            np.cumsum(self.mutual_infos), np.arange(1,len(self.mutual_infos)+1)
        )
        axs[0].plot(
            cumulative_mean,
#             label='mutual information'
        )
        axs[0].set_ylabel('mutual information\ncumulative mean')
        axs[0].set_xlabel('iteration')
        finalize(
            axs,
            fontsize=FONT_SIZE,
        )

    def mcmc_mutual_information(self, num_mcmc_draws):
        # draw from prior
        xs_blurred = np.random.choice(
            # np.arange(0, prior.shape[0]), 
            self.prior_df['blurred'],
            size=num_mcmc_draws, 
            p=self.prior_df['prob'],
        )
        self.mutual_infos = np.zeros(num_mcmc_draws)
        for ii, x_blurred in enumerate(xs_blurred):
            x_blurred[x_blurred<0] = 0
            y = np.random.poisson(
                x_blurred
            )
            mutual_info = pointwise_mutual_information(y, x_blurred, self.prior_df)
            self.mutual_infos[ii] = mutual_info
        self.mutual_info = np.mean(self.mutual_infos)

        
def mcmc(xs, kernel_grid, num_mcmc_draws, blur_sd, subset_bounds=False):
    prior = Prior1D(xs)
    prior.gaussian_blur_kernel(blur_sd, kernel_grid)
    prior.blur_xs()
    if subset_bounds!=False:
        prior.subset_blurred_xs(subset_bounds[0], subset_bounds[1])
    prior.create_prior_dataframe()
#     print(prior.prior_df.head(2))
    prior.mcmc_mutual_information(num_mcmc_draws)
    return [blur_sd, prior]

# test

In [None]:
x_size = 51
x_grid = np.linspace(-np.pi, np.pi, x_size)
kernel_grid = np.linspace(-np.pi, np.pi, 2*x_size)
subset_bounds = [-np.pi/2, np.pi/2]

x0 = np.array(list(zip(
    x_grid, scale+scale*np.cos(x_grid)
)))
x1 = copy.deepcopy(x0)
idx = np.argwhere((x1[:,0]>-np.pi/2))[0][0]
x1[idx+3, 1] = scale+scale*0.8
xs = [x0, x1]

x2 = np.array(list(zip(x_grid, np.repeat(scale, len(x_grid)))))

x3 = np.array(list(zip(
    x_grid, scale+scale*np.sin(2*x_grid)
)))
xs = [x0, x1, x2, x3]

prior = Prior1D(xs)
prior.gaussian_blur_kernel(0.01, kernel_grid)
prior.blur_xs()
prior.subset_blurred_xs(subset_bounds[0], subset_bounds[1])
prior.plot(noisy=False)

# effect of scale and blur

In [None]:
num_mcmc_draws = 100000
blur_sds = [10**(ii) for ii in np.linspace(-2, 1, 31)]
x_size = 51
x_grid = np.linspace(-np.pi, np.pi, x_size)
kernel_grid = np.linspace(-np.pi, np.pi, 2*x_size)
# scales = [1, 5, 10, 100, 1000]
scales = [5, 10, 100, 1000]
subset_bounds = [-np.pi/2, np.pi/2]

for scale in scales: 
    print(f'scale={scale}')
    x0 = np.array(list(zip(x_grid, scale+scale*np.cos(x_grid))))
    x1 = copy.deepcopy(x0)
    idx = np.argwhere((x1[:,0]>-np.pi/2))[0][0]
    x1[idx+3, 1] = scale+scale*0.8
    xs = [x0, x1]
    results = []
    for blur_sd in blur_sds:
        results.append(mcmc(xs, kernel_grid, num_mcmc_draws, blur_sd, subset_bounds))

    pickle_object(
        f'../results/cosine_alteredCosine_scale{scale}.pkl',
        results
    )

In [None]:
fig, axs = fig_setup(1,1)
axs[0].set_prop_cycle(
                'color', 
                palettable.mycarta.Cube1_4.mpl_colors
#                 palettable.colorbrewer.qualitative.Dark2_3.mpl_colors
            )

for scale in scales:
# for scale in [1000]:
    results = unpickle_object(
        f'../results/cosine_alteredCosine_scale{scale}.pkl'
    )
#     results[3][1].plot()
    mis = np.array([[tmp[1].blur_sd, tmp[1].mutual_info] for tmp in results])
    _=axs[0].loglog(
        mis[:,0], mis[:,1],
        marker='o',
        markersize=5,
        label=f'scale={scale}'
    )

axs[0].set_xlabel('blur sigma')
axs[0].set_ylabel('mutual information')
finalize(
    axs,
    fontsize=FONT_SIZE,
)

# effect of number of things in prior

In [None]:
num_mcmc_draws = 100000
blur_sds = [10**(ii) for ii in np.linspace(-2, 1, 31)]
x_size = 51
x_grid = np.linspace(-np.pi, np.pi, x_size)
kernel_grid = np.linspace(-np.pi, np.pi, 2*x_size)
subset_bounds = [-np.pi/2, np.pi/2]
scales = [5, 10, 100, 1000]

for scale in scales: 
    print(f'scale={scale}')
    x0 = np.array(list(zip(x_grid, scale+scale*np.cos(x_grid))))
    x1 = copy.deepcopy(x0)
    idx = np.argwhere((x1[:,0]>-np.pi/2))[0][0]
    x1[idx+3, 1] = scale+scale*0.8
    x2 = np.array(list(zip(x_grid, np.repeat(scale, len(x_grid)))))
    x3 = np.array(list(zip(
        x_grid, scale+scale*np.sin(2*x_grid)
    )))
    xs = [x0, x1, x2, x3]

    results = []
    for blur_sd in blur_sds:
        results.append(mcmc(xs, kernel_grid, num_mcmc_draws, blur_sd, subset_bounds=[-np.pi/2, np.pi/2]))

    pickle_object(
        f'../results/cosine_alteredCosine_constant_sine_scale{scale}.pkl',
        results
    )

In [None]:
fig, axs = fig_setup(1,2)
ax = axs[0]
_=ax.set_prop_cycle(
                'color', 
                palettable.mycarta.Cube1_4.mpl_colors
#                 palettable.colorbrewer.qualitative.Dark2_3.mpl_colors
            )

for scale in scales:
    if scale==1: continue
# for scale in [1000]:
    results = unpickle_object(
        f'../results/4_things_in_prior_scale{scale}.pkl'
    )
#     results[3][1].plot()
    mis = np.array([[tmp[1].blur_sd, tmp[1].mutual_info] for tmp in results])
    _=ax.loglog(
        mis[:,0], mis[:,1],
        marker='o',
        markersize=5,
        label=f'scale={scale}'
    )

_=ax.set_xlabel('blur sigma')
_=ax.set_ylabel('mutual information')

ax = axs[1]
_=ax.set_prop_cycle(
                'color', 
                palettable.mycarta.Cube1_4.mpl_colors
#                 palettable.colorbrewer.qualitative.Dark2_3.mpl_colors
            )

for scale in scales:
    if scale==1: continue
# for scale in [1000]:
    results = unpickle_object(
        f'../results/cosine_altered_cosine_scale{scale}.pkl'
    )
#     results[3][1].plot()
    mis = np.array([[tmp[1].blur_sd, tmp[1].mutual_info] for tmp in results])
    _=ax.loglog(
        mis[:,0], mis[:,1],
        marker='o',
        markersize=5,
        label=f'scale={scale}'
    )

_=ax.set_xlabel('blur sigma')
_=ax.set_ylabel('mutual information')
finalize(
    axs,
    fontsize=FONT_SIZE,
)