# Synthetic Neural Manifolds

### Set Up + Imports

In [45]:
import setup
setup.main()
%load_ext autoreload
%autoreload 2

import neurometry.datasets.synthetic as synthetic
import numpy as np

import matplotlib.pyplot as plt


import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs

import plotly.graph_objects as go

Working directory:  /Users/facosta/Desktop/code/neurometry/neurometry
Directory added to path:  /Users/facosta/Desktop/code/neurometry
Directory added to path:  /Users/facosta/Desktop/code/neurometry/neurometry
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [83]:
num_points = 100
data = synthetic.hypersphere(1, num_points)

# Extracting x, y, and z coordinates for scatter plot
x = data[:, 0]
y = data[:, 1]
z = np.zeros(num_points)


N = 3
encoding_matrix = synthetic.random_encoding_matrix(data.shape[1], N)

# Preparing vectors
vectors = [encoding_matrix[:, i] for i in range(N)]
colors = ['red', 'green', 'blue']

scatter = go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=5), name='Data Points')

# Adding vectors as lines and cones
lines_and_cones = []
for idx, (vector, color) in enumerate(zip(vectors, colors)):
    # Line (shaft of the arrow)
    lines_and_cones.append(
        go.Scatter3d(
            x=[0, vector[0]], y=[0, vector[1]], z=[0, 0],
            mode='lines', line=dict(color=color, width=5),
            name=f'Encoding Vector {idx+1}'
        )
    )
    
    # Cone (head of the arrow)
    lines_and_cones.append(
        go.Cone(
            x=[vector[0]], y=[vector[1]], z=[0], 
            u=[vector[0]/10], v=[vector[1]/10], w=[0],
            showscale=False, colorscale=[[0, color], [1, color]],
            sizemode='absolute', sizeref=0.1
        )
    )

# Create figure and add traces
fig = go.Figure(data=[scatter] + lines_and_cones)

# Update layout
fig.update_layout(
    title={
        'text': "Feature Space",
        'y':0.5,
        'x':0.1,
        'xanchor': 'center',
        'yanchor': 'top',
        'font': dict(
            size=25
        )
    },
    scene=dict(
        aspectmode='cube',
        xaxis=dict(range=[-1.2, 1.2], title='Feature 1'),
        yaxis=dict(range=[-1.2, 1.2], title='Feature 2'),
        zaxis=dict(range=[-1.2, 1.2], title='')
    ),
    margin=dict(l=0, r=0, b=0, t=0)
)

# Show the plot
fig.show()

In [60]:
new_encoding_vectors = gs.einsum('ij,jk->ik', encoding_matrix.T, encoding_matrix)

new_vectors = [new_encoding_vectors[:, i] for i in range(N)]

encoded_data = np.einsum("ij,jk->ik",data,encoding_matrix)

x = encoded_data[:,0]
y = encoded_data[:,1]
z = encoded_data[:,2]

In [84]:
colors = ['red', 'green', 'blue']

scatter = go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=5))

# Adding vectors as lines and cones
lines_and_cones = []
for idx, (vector, color) in enumerate(zip(new_vectors, colors)):
    # Line (shaft of the arrow)
    lines_and_cones.append(
        go.Scatter3d(
            x=[0, vector[0]], y=[0, vector[1]], z=[0, vector[2]],
            mode='lines', line=dict(color=color, width=5),
            name=f'Encoding Vector {idx+1}'
        )
    )
    
    # Cone (head of the arrow)
    lines_and_cones.append(
        go.Cone(
            x=[vector[0]], y=[vector[1]], z=[vector[2]], 
            u=[vector[0]/10], v=[vector[1]/10], w=[vector[2]/10],
            showscale=False, colorscale=[[0, color], [1, color]],
            sizemode='absolute', sizeref=0.1
        )
    )

# Create figure and add traces
fig = go.Figure(data=[scatter] + lines_and_cones)

# Update layout
fig.update_layout(
    title={
        'text': "Neural space",
        'y':0.5,
        'x':0.1,
        'xanchor': 'center',
        'yanchor': 'top',
        'font': dict(
            size=25
        )
    },
    scene=dict(
        aspectmode='cube',
        xaxis=dict(range=[-1.2, 1.2], title='Neuron 1'),
        yaxis=dict(range=[-1.2, 1.2], title='Neuron 2'),
        zaxis=dict(range=[-1.2, 1.2], title='Neuron 3')
    ),
    margin=dict(l=0, r=0, b=0, t=0)
)

# Show the plot
fig.show()