In [2]:
from neuralplayground.arenas import Sphere
from neuralplayground.agents import Stachenfeld2018, RatInASphere
from neuralplayground.backend import episode_based_training_loop
from neuralplayground.backend import SingleSim
from neuralplayground.comparison import GridScorer
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import multiprocessing

## Grid cells observed in spherical space

In [3]:
simulation_id = "SR_rat_in_a_sphere"
agent_class = RatInASphere
env_class = Sphere
training_loop = episode_based_training_loop

N_SLICES = N_STACKS = 36

agent_params = {"discount":  0.9,
                "threshold": 1e-6,
                "lr_td":  1e-2,
                "n_slices": N_SLICES,
                "n_stacks": N_STACKS}

env_params = {"n_stacks": N_SLICES,
             "n_slices": N_STACKS}

training_loop_params = {"t_episode": 1000, "n_episode": 1000}

In [4]:
sim_spherical = SingleSim(simulation_id = simulation_id,
                agent_class = agent_class,
                agent_params = agent_params,
                env_class = env_class,
                env_params = env_params,
                training_loop = training_loop,
                training_loop_params = training_loop_params)

In [5]:
# trained_agent, env, training_hist = sim.load_results('spherical_results')
# gridness_scores = []
# for eig in range(trained_agent.srmat.shape[0]):
#     r_out_im, _= trained_agent.get_rate_map_matrix(trained_agent.srmat, eigen_vector=eig)
#     GridScorer_Stachenfeld2018 = GridScorer(trained_agent.n_stacks + 1)
#     # GridScorer_Stachenfeld2018.plot_grid_score(r_out_im=r_out_im, plot= True)
#     sac, grid_field_props = GridScorer_Stachenfeld2018.get_scores(r_out_im)
#     gridness_scores.append(grid_field_props['gridscore'])
# fig, ax = plt.subplots(1,1)
# sns.histplot(gridness_scores, ax=ax, bins=20, kde=True, color='green'), ax.set_title("Histogram of gridness scores")
# plt.show()

## Grid cells explored on tangent space (orthogonal projection)

In [6]:
from numpy import ndarray
import matplotlib.pyplot as plt 
import matplotlib as mpl
from neuralplayground.plotting.plot_utils import make_plot_rate_map
import random
from typing import Union

class RatOnTangent(Stachenfeld2018):
    def __init__(self, agent_name: str = "SR", discount: float = 0.9, threshold: float = 0.000001, lr_td: float = 0.01, room_width: float = 12, room_depth: float = 12, state_density: float = 1, twoD: bool = True, **mod_kwargs):
        super().__init__(agent_name, discount, threshold, lr_td, room_width, room_depth, state_density, twoD, **mod_kwargs)
        self.freq_map = np.zeros(self.n_state)
    
    def act(self, obs):
        self.obs_history.append(obs)
        if len(self.obs_history) >= 1000:
            self.obs_history = [
                obs,
            ]

        if len(obs) == 0:
            action = None
        else:
            # Random policy
            action = np.random.uniform(-1,1,3)
            self.next_state = self.obs_to_state(obs[0][:2])
            self.freq_map[self.next_state] += 1
            action = np.array(action)
        return action
    
    
    def get_rate_map_matrix(
        self,
        sr_matrix=None,
        eigen_vector: Union[int, list, tuple] = None,
    ):
        if sr_matrix is None:
            sr_matrix = self.successor_rep_solution()
        if eigen_vector is None:
            eigen_vector = np.arange(self.n_slices * self.n_stacks)
        evals, evecs = np.linalg.eig(sr_matrix)
        if isinstance(eigen_vector, int):
            return evecs[:, eigen_vector].reshape((self.n_stacks, self.n_slices)).real
        r_out_im = [evecs[:, evec_idx].reshape((self.n_stacks, self.n_slices)).real for evec_idx in eigen_vector]
        return r_out_im
    
    
    def plot_rate_map(
        self,
        sr_matrix=None,
        eigen_vectors: Union[int, list, tuple] = None,
        ax: mpl.axes.Axes = None,
        save_path: str = None,
    ):
        if eigen_vectors is None:
            eigen_vectors = random.randint(5, 19)

        if isinstance(eigen_vectors, int):
            rate_map_mat= self.get_rate_map_matrix(sr_matrix, eigen_vector=eigen_vectors)

            if ax is None:
                f, ax = plt.subplots(1, 1, figsize=(4, 5))
            make_plot_rate_map(rate_map_mat, ax, "Rate map: Eig" + str(eigen_vectors), "azimuthal", "polar", "Firing rate")
        else:
            if ax is None:
                f, ax = plt.subplots(1, len(eigen_vectors), figsize=(4 * len(eigen_vectors), 5))
            if isinstance(ax, mpl.axes.Axes):
                ax = [ax,]

            rate_map_mats = self.get_rate_map_matrix(sr_matrix, eigen_vector=eigen_vectors)
            for i, rate_map_mat in enumerate(rate_map_mats):
                make_plot_rate_map(rate_map_mat, ax[i], "Rate map: " + "Eig" + str(eigen_vectors[i]), "azimuthal", "polar", "Firing rate")
        
        
        if save_path is None:
            pass
        else:
            plt.savefig(save_path, bbox_inches="tight")
            return ax
    
           
    def plot_freq_map(
        self,
        ax: mpl.axes.Axes = None,
        save_path: str = None
    ):
        if ax is None:
            fig, ax = plt.subplots(1,1,figsize=(4,5))

        make_plot_rate_map(self.freq_map.reshape(self.resolution_width, self.resolution_depth), ax, "Frequency Map", "x", "y", "Frequency")

        if save_path is not None:
            plt.savefig(save_path, bbox_inches="tight")
           
        return ax

In [7]:
agent_params = {"room_width" : 2,
                "room_depth" : 2,
                "state_density" : 25.5}

agent_class = RatOnTangent
env_class = Sphere

sim_tangent = SingleSim(simulation_id = simulation_id,
                agent_class = agent_class,
                agent_params = agent_params,
                env_class = env_class,
                env_params = env_params,
                training_loop = training_loop,
                training_loop_params = training_loop_params)


In [8]:
# # trained_agent, env, training_hist = sim.load_results('tangential_results')
# gridness_scores = []
# for eig in range(trained_agent.srmat.shape[0]):
#     r_out_im, _= trained_agent.get_rate_map_matrix(trained_agent.srmat, eigen_vector=eig)
#     GridScorer_Stachenfeld2018 = GridScorer(trained_agent.room_width + 1)
#     # GridScorer_Stachenfeld2018.plot_grid_score(r_out_im=r_out_im, plot= True)
#     sac, grid_field_props = GridScorer_Stachenfeld2018.get_scores(r_out_im)
#     gridness_scores.append(grid_field_props['gridscore'])
# fig, ax = plt.subplots(1,1)
# sns.histplot(gridness_scores, ax=ax, bins=20, kde=True, color='green'), ax.set_title("Histogram of gridness scores")
# plt.show()

In [9]:
# import os
# import cv2

# current_dir = os.getcwd()
# frames_dir = 'frames'
# abs_frames_dir = os.path.join(current_dir, frames_dir)
# if not os.path.exists(abs_frames_dir):
#     os.makedirs(abs_frames_dir)
# for i in range(1000):
#     env.render(history_length=i+1, save_dir=os.path.join(abs_frames_dir, f'frame_{i:03d}.png'))

# image_folder = 'frames'
# video_name = 'video_gravity_2.avi'

# images = [img for img in os.listdir(image_folder) if img.endswith(".png")]
# frame = cv2.imread(os.path.join(image_folder, images[0]))
# height, width, layers = frame.shape

# video = cv2.VideoWriter(video_name, 0, 20, (width,height))

# for image in images:
#     video.write(cv2.imread(os.path.join(image_folder, image)))

# cv2.destroyAllWindows()
# video.release()

## Grid cells on tangent space (logarithmic map)

In [10]:
from numpy import ndarray


class RatOnGeometricTangent(RatOnTangent):
    def __init__(self, agent_name: str = "SR", discount: float = 0.9, threshold: float = 0.000001, lr_td: float = 0.01, room_width: float = 12, room_depth: float = 12, state_density: float = 1, twoD: bool = True, **mod_kwargs):
        super().__init__(agent_name, discount, threshold, lr_td, room_width, room_depth, state_density, twoD, **mod_kwargs)
        self.freq_map = np.zeros(self.n_state)
    
    def act(self, obs):
        self.obs_history.append(obs)
        if len(self.obs_history) >= 1000:
            self.obs_history = [
                obs,
            ]

        if len(obs) == 0:
            action = None
        else:
            # Random policy
            action = np.random.uniform(-1,1,3)
            self.next_state = self.obs_to_state(obs[0])
            self.freq_map[self.next_state] += 1
            action = np.array(action)
        return action
    
    def obs_to_state(self, pos: ndarray):
        logarithmic_pos = Sphere.logarithmic_map(np.zeros(3), pos)[:2]
        return super().obs_to_state(logarithmic_pos)

In [11]:
agent_params["room_width"] = 3
agent_params["room_depth"] = 3
agent_params["state_density"] = 10

sim_logarithmic = SingleSim(simulation_id = simulation_id,
                agent_class = RatOnGeometricTangent,
                agent_params = agent_params,
                env_class = env_class,
                env_params = env_params,
                training_loop = training_loop,
                training_loop_params = training_loop_params)

In [14]:
simulations = [sim_spherical, sim_tangent, sim_logarithmic]
sim_paths = ['spherical_results', 'tangential_results', 'logarithmic_results']

sim_args = [(sim, sim_path) for sim, sim_path in zip(simulations, sim_paths)]

for sim, sim_path in sim_args:
    sim.run_sim(sim_path)

In [13]:
# from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
# from mpl_toolkits.mplot3d import Axes3D

# rate_map_15, _ = agent.get_rate_map_matrix(agent.srmat, eigen_vector=10)


# f = plt.figure()

# # Top down projection for 3d plot
# ax = plt.axes(projection = "3d")  
# ax.view_init(90, -90)

# # Plot 3d-hemisphere
# phi = np.linspace(np.pi/2, np.pi, N_STACKS)
# theta = np.linspace(0, 2*np.pi, N_SLICES)
# phi, theta = np.meshgrid(phi, theta)
# x = np.sin(phi)*np.cos(theta)
# y = np.sin(phi)*np.sin(theta)
# z = np.cos(phi)
# norm = plt.Normalize(rate_map_15.min(), rate_map_15.max())
# ax.plot_surface(x,y,z, facecolors = plt.cm.jet(norm(rate_map_15.T)), linewidth=0, antialiased=False, alpha=0.8)

# save_dir = None

# if save_dir is not None:
#     plt.savefig(save_dir)
#     plt.close(f)
# else:
#     plt.show()