In [2]:
import step_size_adapt_NUTS_b_prime_transform
import NUTSOrbitDiagnostic as nod
import numpy as np
import Fixed_step_size_NUTS_simulation as fv
import matplotlib.pyplot as plt
seed = 12909067
model = fv.create_model_stan_and_json("funnel", "funnel")
rng = np.random.default_rng(seed)

theta_log_sigma_initial = -2.0
velocity_log_sigma_component = -2.0
velocity_x_component = 1.0

theta_0 = np.empty(model.param_unc_num())
theta_0[0] = theta_log_sigma_initial
theta_0[1:] = 0

rho_0  = np.empty(model.param_unc_num())
rho_0[0] = velocity_log_sigma_component
rho_0[1:] = velocity_x_component

rho_0 = (rho_0 /np.linalg.norm(rho_0))*np.sqrt(11)

bernoulli_sequence = (1,)*10

sampler = nod.NUTSBprimeTransformDiagnostic(model,
                                             rng,
                                             theta_0,
                                             rho_0,
                                             0.7,
                                             1/4,
                                             10,
                                             10,
                                             nod.NUTSOrbitDiagnostic)
sampler.set_bernoulli_sequence(bernoulli_sequence)
theta, rho = sampler.draw()



In [3]:
unrefined_orbit = nod.NUTSOrbitDiagnostic(sampler,
                 rng,
                 theta_0,
                 rho_0,
                 max_stepsize,
                 1,
                 bernoulli_sequence,
                 tree_node_class=nod.NUTSTreeNodeDiagnostic)

In [25]:
class OrbitPlotter:
    def __init__(self, contour_data=None):

        #self.ax.grid(True)
        self._leaf_counter = 0
        self._intermediate_counter = 0
        self._model = fv.create_model_stan_and_json("funnel", "funnel")
        self.x_lim = 9
        self.y_lim = 5
        if contour_data is None:
            self._contour_data = self.get_contour_data()
        else:
            self._contour_data = contour_data

    def reset_fig_ax(self):
        self.fig, self.ax = plt.subplots()
        self.ax.set_xlabel("$\log(\sigma)$")
        self.ax.set_ylabel("First $x$ coordinate")
        return 
        
    def set_energy_max_min(self, energy_max, energy_min):
        self._energy_max = energy_max
        self._energy_min = energy_min
    
    def log_joint(self, theta, rho):
        return self._model.log_density(theta) - .5 * np.dot(rho, rho)

    def get_contour_data(self):
        one_dim_model = fv.create_model_stan_and_json("funnel", "one_dimensional_funnel")
        x = np.linspace(-self.x_lim, self.x_lim, 1000)
        y = np.linspace(-self.y_lim, self.y_lim, 1000)
        X, Y = np.meshgrid(x, y)
        Z = np.zeros(X.shape)
        for i in range(X.shape[0]):
            for j in range(X.shape[1]):
                Z[i, j] = one_dim_model.log_density(np.array([X[i, j], Y[i, j]]))
        contour_levels = np.percentile(Z, np.arange(0, 101, 18))
        return X, Y, Z, contour_levels


    
    def plot_contour(self):
        X, Y, Z, contour_levels = self._contour_data
        print(f"Contour levels: {contour_levels}")
        self.ax.contour(X, Y, Z,
                        levels=contour_levels,
                        colors='black',
                        linestyles='dotted',
                        linewidths=0.7,
                        alpha=1,
                        zorder=1)
        #self.ax.set_xlim(self.x_lim)
        #self.ax.set_ylim(self.y_lim)

    def calculate_size(self, energy):
        size = (1 + 40*(-energy + self._energy_max)/(self._energy_max - self._energy_min))
        return size
        
        #return 2000/size

    def plot_coarse(self, orbit_root, index = ""):
        if orbit_root._height == 0:
            theta = orbit_root._left_theta
            rho = orbit_root._left_rho
            energy = -self.log_joint(theta, rho)
            size = self.calculate_size(energy)
            
            print(f"For index {int(index, 2)} the energy is {energy} and the size is {size}")
            print(f"The coordinates are ({theta[0]}, {theta[1]})")
            #size = 16
            #print(f"Theta Coordinates are: ({theta[0]}, {theta[1]})")
            #print(f"Rho Coordinates are: ({rho[0]}, {rho[1]})")
            
            if -self.x_lim <= theta[0] <= self.x_lim and -self.y_lim <= theta[1] <= self.y_lim:
                self.ax.scatter(theta[0], theta[1], s=size, color='black')
            self._leaf_counter += 1
            return

        if orbit_root._left_child is not None:
            self.plot_coarse(orbit_root._left_child, index = index + "0")

        if orbit_root._right_child is not None:
            self.plot_coarse(orbit_root._right_child, index = index + "1")

    def plot_fine(self, orbit_root):
        if orbit_root._height == 0:
            if orbit_root._intermediate_grid_points is not None:
                #print(f"Intermediate points: {orbit_root._intermediate_grid_points}")
                if len(orbit_root._intermediate_grid_points) > 0:
                    for index, (theta, rho) in enumerate(orbit_root._intermediate_grid_points):
                        self._intermediate_counter += 1
                        #print(f"Intermediate coordinates are: ({theta[0]}, {theta[1]})")
                        if len(theta) > 0 and len(rho) > 0:
                            if -self.x_lim <= theta[0] <= self.x_lim and -self.y_lim <= theta[1] <= self.y_lim:
                                #print(f"Theta is {theta}")
                                self.ax.scatter(theta[0], theta[1], s=8, color='black')
                return
            else:
                #print("Intermediate points is None")
                return

        if orbit_root._left_child is not None:
            self.plot_fine(orbit_root._left_child)

        if orbit_root._right_child is not None:
            self.plot_fine(orbit_root._right_child)

    def add_arrows(self, orbit_root):
        if orbit_root._height == 0:
            theta = orbit_root._left_theta
            rho = orbit_root._left_rho
            rho = .25*rho/np.sqrt(np.dot(rho, rho))
            self.ax.arrow(theta[0],
                              theta[1],
                              rho[0],
                              rho[1],
                              head_width=0.20,
                              head_length=0.25)
            return 
        if orbit_root._left_child is not None:
            self.add_arrows(orbit_root._left_child)

        if orbit_root._right_child is not None:
            self.add_arrows(orbit_root._right_child)
        

    def plot_orbit(self, orbit_root):
        self.reset_fig_ax()
        #self.set_energy_max_min(orbit_root._energy_max, orbit_root._energy_min)
        self.plot_contour()
        self.plot_coarse(orbit_root)
        #self.add_arrows(orbit_root)
        #self.plot_fine(orbit_root)

    def show(self):
        plt.show()

    def save_plot(self, filename):
        self.fig.savefig(filename)


In [26]:
Plotter = OrbitPlotter(contour_data = Plotter._contour_data)

In [28]:
def obtain_orbit_with_step_size(theta_log_sigma_initial, 
                                velocity_log_sigma_component, 
                                velocity_x_component, 
                                max_step_size, 
                                halvings
                                ):
    

    seed = 12909067
    model = fv.create_model_stan_and_json("funnel", "funnel")
    rng = np.random.default_rng(seed)
    
    theta_0 = np.empty(model.param_unc_num())
    theta_0[0] = theta_log_sigma_initial
    theta_0[1:] = 0
    
    rho_0  = np.empty(model.param_unc_num())
    rho_0[0] = velocity_log_sigma_component
    rho_0[1:] = velocity_x_component
    
    rho_0 = (rho_0 /np.linalg.norm(rho_0))*np.sqrt(11)
    
    bernoulli_sequence = (1,)*10

    orbit = nod.NUTSOrbitDiagnostic(sampler,
                 rng,
                 theta_0,
                 rho_0,
                 max_step_size*2**(-halvings),
                 2**(halvings),
                 bernoulli_sequence,
                 tree_node_class=nod.NUTSTreeNodeDiagnostic)
    filename = f"Larger_Dots_LSig_{theta_log_sigma_initial}_vls_{velocity_log_sigma_component}_vx_{velocity_x_component}_maxh_{max_step_size}_halvings_{halvings}.png"
    return orbit, filename
    

In [29]:
orbits = [obtain_orbit_with_step_size(ls_init, -2.0, 1.0, 1/4, halvings) for ls_init in [-2.0, 2.0] for halvings in [0, 1, 2]]

In [30]:
unrefined_orbit, filename = orbits[0]
unrefined_orbit._orbit_root = unrefined_orbit._orbit_root._left_child._left_child
orbits[0] = (unrefined_orbit, filename)

In [31]:
energy_max = max(orbit._orbit_root._energy_max for orbit, filename in orbits)
energy_min = min(orbit._orbit_root._energy_min for orbit, filename in orbits)

In [32]:
Plotter.set_energy_max_min(energy_max, energy_min)

In [33]:
for orbit, filename in orbits:
    Plotter.reset_fig_ax()
    Plotter.plot_orbit(orbit._orbit_root)
    Plotter.save_plot(f"/Users/milostevenmarsden/Documents/Simulations/July2024/Fixed_Orbit_Evolution/Labeled/{filename}")