In [64]:
import setup

setup.main()
%load_ext autoreload
%autoreload 2
%load_ext jupyter_black

import neurometry.datasets.synthetic as synthetic
import numpy as np

import matplotlib.pyplot as plt


import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"
import geomstats.backend as gs

Working directory:  /home/facosta/neurometry/neurometry
Directory added to path:  /home/facosta/neurometry
Directory added to path:  /home/facosta/neurometry/neurometry
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
The jupyter_black extension is already loaded. To reload it, use:
  %reload_ext jupyter_black


In [65]:
task_points, intrinsic_coords = synthetic.hypersphere(1, 1000)
noisy_points, manifold_points = synthetic.synthetic_neural_manifold(
    points=task_points,
    encoding_dim=3,
    nonlinearity="sigmoid",
    scales=gs.array([5, 3, 1]),
)

X = manifold_points
labels = intrinsic_coords

In [66]:
from neurometry.estimators.geometry.immersion_estimator import ImmersionEstimator

In [67]:
extrinsic_dim = 3
topology = "circle"
device = "cuda"

immersion_estimator = ImmersionEstimator(extrinsic_dim, topology, device=device)

In [68]:
immersion_estimator.fit(X, labels)

====> Train Epoch: 1 Average loss: 188.1801
====> Test Epoch: 1 Average loss: 165.0614
====> Train Epoch: 2 Average loss: 74.1690
====> Test Epoch: 2 Average loss: 70.4762
====> Train Epoch: 3 Average loss: 47.7297
====> Test Epoch: 3 Average loss: 57.2686
====> Train Epoch: 4 Average loss: 44.2286
====> Test Epoch: 4 Average loss: 54.5528
====> Train Epoch: 5 Average loss: 44.1926
====> Test Epoch: 5 Average loss: 53.0346
====> Train Epoch: 6 Average loss: 43.8985
====> Test Epoch: 6 Average loss: 53.7503
====> Train Epoch: 7 Average loss: 43.8415
====> Test Epoch: 7 Average loss: 55.3904
====> Train Epoch: 8 Average loss: 43.9621
====> Test Epoch: 8 Average loss: 57.2641
====> Train Epoch: 9 Average loss: 43.0592
====> Test Epoch: 9 Average loss: 56.2048
====> Train Epoch: 10 Average loss: 41.9819
====> Test Epoch: 10 Average loss: 47.0864
====> Train Epoch: 11 Average loss: 32.4589
====> Test Epoch: 11 Average loss: 24.6105
====> Train Epoch: 12 Average loss: 11.0092
====> Test Epoc

In [69]:
immersion = immersion_estimator.estimate_

In [70]:
import torch

intrinsic_coords = torch.tensor(intrinsic_coords, device=device)

In [71]:
recon_points = immersion(intrinsic_coords).detach().cpu().numpy().squeeze()

In [72]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# plot manifold points in 3d and recon points in 3d in separate subplots

fig = make_subplots(
    rows=1, cols=2, specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}]]
)

fig.add_trace(
    go.Scatter3d(
        x=manifold_points[:, 0],
        y=manifold_points[:, 1],
        z=manifold_points[:, 2],
        mode="markers",
        marker=dict(size=3),
    ),
    row=1,
    col=1,
)

fig.add_trace(
    go.Scatter3d(
        x=recon_points[:, 0],
        y=recon_points[:, 1],
        z=recon_points[:, 2],
        mode="markers",
        marker=dict(size=3),
    ),
    row=1,
    col=2,
)

fig.show()