In [1]:
import sys
sys.path.append("../source")
import torch
from torch import nn

In [54]:
import torch
import math

R = 1.0    # main radius
r = 0.2    # tube radius

def make_init_dist(n_samples: int = 1000, sigma: float = 0.05):
    """
    Initial: single Gaussian on torus at theta=0, phi=0 (x=1, y=0, z=0)
    Returns (theta, phi)
    """
    theta = torch.randn(n_samples) * sigma + 0.0
    phi = torch.randn(n_samples) * sigma + 0.0
    theta %= 2*math.pi
    phi %= 2*math.pi
    return torch.stack([theta, phi], dim=1)


def make_final_dist(n_samples: int = 1000, sigma: float = 0.05):
    n_half = n_samples // 2

    # Top Gaussian
    theta_top = torch.randn(n_half) * sigma + 3*math.pi/4
    phi_top = torch.randn(n_half) * sigma + math.pi/2

    # Bottom Gaussian
    theta_bottom = torch.randn(n_half) * sigma + 5*math.pi/4
    phi_bottom = torch.randn(n_half) * sigma - math.pi/2
    """
    theta_bottom2 = torch.randn(n_half) * sigma + math.pi/2
    phi_bottom2 = torch.randn(n_half) * sigma + math.pi/2

    theta_bottom3 = torch.randn(n_half) * sigma
    phi_bottom3 = torch.randn(n_half) * sigma + math.pi
    """
    # Wrap angles
    theta_top %= 2*math.pi
    phi_top %= 2*math.pi
    theta_bottom %= 2*math.pi
    phi_bottom %= 2*math.pi
    """
    theta_bottom2 %= 2*math.pi
    phi_bottom2 %= 2*math.pi
    theta_bottom3 %= 2*math.pi
    phi_bottom3 %= 2*math.pi
    """
    theta = torch.cat([theta_top, theta_bottom])#, theta_bottom2, theta_bottom3])
    phi = torch.cat([phi_top, phi_bottom])# phi_bottom2, phi_bottom3])

    return torch.stack([theta, phi], dim=1)


In [55]:
import torch
import plotly.graph_objects as go
import math

# Torus parameters
R = 1.0
r = 0.2
n_theta = 50   # number of circles along the main axis
n_phi = 20     # number of circles along the tube
n_points = 100 # points per circle

# --- Generate torus wireframe ---
theta_vals = torch.linspace(0, 2*math.pi, n_theta)
phi_vals = torch.linspace(0, 2*math.pi, n_phi)
theta_circle = torch.linspace(0, 2*math.pi, n_points)
phi_circle = torch.linspace(0, 2*math.pi, n_points)

torus_lines = []

# Circles along the tube (phi direction) for each theta
for theta in theta_vals:
    x = (R + r * torch.cos(phi_circle)) * torch.cos(theta)
    y = (R + r * torch.cos(phi_circle)) * torch.sin(theta)
    z = r * torch.sin(phi_circle)
    torus_lines.append(go.Scatter3d(x=x.numpy(), y=y.numpy(), z=z.numpy(),
                                    mode='lines', line=dict(color='lightgrey', width=2),
                                    showlegend=False))

# Circles along the main axis (theta direction) for each phi
for phi in phi_vals:
    x = (R + r * torch.cos(phi)) * torch.cos(theta_circle)
    y = (R + r * torch.cos(phi)) * torch.sin(theta_circle)
    z = r * torch.sin(phi) * torch.ones_like(theta_circle)
    torus_lines.append(go.Scatter3d(x=x.numpy(), y=y.numpy(), z=z.numpy(),
                                    mode='lines', line=dict(color='lightgrey', width=2),
                                    showlegend=False))



def torus_to_xyz(theta_phi, R=1.0, r=0.2):
    theta = theta_phi[:,0]
    phi = theta_phi[:,1]
    x = (R + r * torch.cos(phi)) * torch.cos(theta)
    y = (R + r * torch.cos(phi)) * torch.sin(theta)
    z = r * torch.sin(phi)
    return torch.stack([x,y,z], dim=1)

# --- Generate scatter points ---
init_points = torus_to_xyz(make_init_dist())
final_points = torus_to_xyz(make_final_dist())

init_scatter = go.Scatter3d(
    x=init_points[:,0].numpy(),
    y=init_points[:,1].numpy(),
    z=init_points[:,2].numpy(),
    mode='markers',
    marker=dict(size=4, color='green'),
    name='Initial'
)

final_scatter = go.Scatter3d(
    x=final_points[:,0].numpy(),
    y=final_points[:,1].numpy(),
    z=final_points[:,2].numpy(),
    mode='markers',
    marker=dict(size=4, color='red'),
    name='Final'
)

# --- Combine everything ---
fig = go.Figure(data=torus_lines + [init_scatter, final_scatter])

fig.update_layout(
    scene=dict(
        xaxis=dict(range=[-1.3,1.3]),
        yaxis=dict(range=[-1.3,1.3]),
        zaxis=dict(range=[-0.3,0.3]),
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=0.3)
    ),
    title="Initial (green) and final (red) distributions on torus wireframe"
)

fig.show()


In [56]:
from source.models import SimpleFlow, EquivariantFlow

In [57]:
flow = EquivariantFlow(2)
optimizer = torch.optim.Adam(flow.parameters(), 1e-3)
loss_fn = nn.MSELoss()

In [58]:
def log_map_torus(x, y):
    return (y - x + torch.pi)%(2*torch.pi)-torch.pi

def exp_map_torus(x, v):
    return (x+v)%(2*torch.pi)

In [60]:
B = 512
epochs = 100
for epoch in range(epochs):
    for i in range(100):

        x0 = make_init_dist(B)
        y0 = make_final_dist(B)

        t = torch.rand(size=(B, 1))

        x_t = exp_map_torus(x0, (1-t)*log_map_torus(x0, y0))
        u = log_map_torus(x0, y0)

        optimizer.zero_grad()
        loss = loss_fn(flow(x_t, t), u)
        loss.backward()
        optimizer.step()

    if (epoch%10==0):
        print("Epoch: ", epoch, "Loss: ", loss.item())


Epoch:  90 Loss:  0.11482641845941544


In [61]:
from source.sampling import integrate_torus
import math
x0 = make_init_dist(B)  # B points

trajectories = integrate_torus(x0, flow, t_end=1.0, steps=100)  # (101, B, 2)


In [62]:
def trajectories_xyz(traj):
    # traj: (steps+1, B, 2)
    steps, B, _ = traj.shape
    xyz = torch.stack([torus_to_xyz(traj[k]) for k in range(steps)], dim=0)
    return xyz  # shape: (steps+1, B, 3)

traj_xyz = trajectories_xyz(trajectories)


In [71]:
flow_lines = []
steps, B, _ = traj_xyz.shape
for i in range(B):
    x = traj_xyz[:,i,0].detach().numpy()
    y = traj_xyz[:,i,1].detach().numpy()
    z = traj_xyz[:,i,2].detach().numpy()
    flow_lines.append(go.Scatter3d(
        x=x, y=y, z=z,
        mode='lines',
        line=dict(color='blue', width=2),
        showlegend=False
    ))

init_points = torus_to_xyz(x0)
init_scatter = go.Scatter3d(
    x=init_points[:,0].numpy(),
    y=init_points[:,1].numpy(),
    z=init_points[:,2].numpy(),
    mode='markers',
    marker=dict(size=4, color='green'),
    name='Initial'
)


final_scatter = go.Scatter3d(
    x=traj_xyz[-1, :, 0].detach().cpu().numpy(),
    y=traj_xyz[-1, :, 1].detach().cpu().numpy(),
    z=traj_xyz[-1, :, 2].detach().cpu().numpy(),
    mode='markers',
    marker=dict(size=4, color='red'),
    name='Final'
)


# Combine with previous torus wireframe + scatter
fig = go.Figure(data=torus_lines + [init_scatter, final_scatter] + flow_lines)

fig.update_layout(
    scene=dict(
        xaxis=dict(range=[-1.3,1.3]),
        yaxis=dict(range=[-1.3,1.3]),
        zaxis=dict(range=[-0.3,0.3]),
        aspectmode='manual',
        aspectratio=dict(x=1, y=1, z=0.3)
    ),
    title="Flow trajectories on the torus"
)

fig.show()
