In [None]:
import torch
from util import Gaussian, GaussianMixture, get_device, set_all_seeds
from util import hist1d_sampleable, plot_path, plot_path_animation
import matplotlib.pyplot as plt
device = get_device()
set_all_seeds(42)

In [None]:
# Constants for the duration of our use of Gaussian conditional probability paths, to avoid polluting the namespace...
PARAMS = {
    "scale": 15.0,
    "target_scale": 8.0,
    "target_std": 1,
    "sample_num": 8,
    "mode_num": 2,
    "aggregate": False,
    "path_timesteps": 200,
    "sde_sigma": 2,
    "no_border": True,
}

p_data = GaussianMixture.symmetric_1D(nmodes=PARAMS["mode_num"], std=PARAMS["target_std"], scale=PARAMS["target_scale"]).to(device)

In [None]:
hist1d_sampleable(p_data, 1000)

In [None]:
from flow import GaussianConditionalProbabilityPath, LinearAlpha, SquareRootBeta

In [None]:
p_simple = Gaussian.isotropic(p_data.dim, 4.0)
path = GaussianConditionalProbabilityPath(
    p_data = p_data,
    alpha = LinearAlpha(),
    beta = SquareRootBeta(),
    p_simple = p_simple
).to(device)

In [None]:
if PARAMS["aggregate"]:
    z = path.sample_conditioning_variable(PARAMS["mode_num"])
    z = z.repeat(PARAMS["sample_num"] // PARAMS["mode_num"], 1)
else:
    z = path.sample_conditioning_variable(PARAMS["sample_num"])
ts = torch.arange(PARAMS["path_timesteps"]).to(device) / (float(PARAMS["path_timesteps"])-1.)

In [None]:
xts = []
for t in ts:
    t = t.repeat(PARAMS["sample_num"]).unsqueeze(1)
    xts.append(path.sample_conditional_path(z, t))
xts = torch.stack(xts) # (timestep, path_num, dim)

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

plot_path(xts, ts, ax)
ax.set_xlim(0, 1)
ax.set_ylim(-PARAMS["target_scale"]*1.5, PARAMS["target_scale"]*1.5)
ax.set_xlabel('Time')
ax.set_title('Probability Paths Over Timesteps')
ax.legend()
ax.grid(True)

plt.show()

In [None]:
xts[0].shape

In [None]:
from flow import ConditionalVectorFieldODE, ConditionalVectorFieldSDE
from flow import EulerSimulator, EulerMaruyamaSimulator
from einops import rearrange
ode = ConditionalVectorFieldODE(path, z)
simulator = EulerSimulator(ode)

In [None]:
x0 = path.p_simple.sample(PARAMS["sample_num"]) # (num_samples, dim)
ts = torch.linspace(0.0, 1.0, PARAMS["path_timesteps"]).view(1,-1,1).expand(PARAMS["sample_num"],-1,1).to(device) # (num_samples, nts, 1)
xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)
xts = rearrange(xts, 'b t d -> t b d')

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

plot_path(xts, ts[0], ax)
ax.set_xlim(0, 1)
ax.set_ylim(-PARAMS["target_scale"]*1.5, PARAMS["target_scale"]*1.5)
ax.set_xlabel('Time')
ax.set_title('ODE Vector Field Trajectory')
#ax.legend()
ax.grid(True)

plt.show()

In [None]:
sde = ConditionalVectorFieldSDE(path, z, PARAMS["sde_sigma"])
simulator = EulerMaruyamaSimulator(sde)
x0 = path.p_simple.sample(PARAMS["sample_num"]) # (num_samples, dim)
ts = torch.linspace(0.0, 1.0, PARAMS["path_timesteps"]).view(1,-1,1).expand(PARAMS["sample_num"],-1,1).to(device) # (num_samples, nts, 1)
xts = simulator.simulate_with_trajectory(x0, ts) # (bs, nts, dim)
xts = rearrange(xts, 'b t d -> t b d')

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

plot_path(xts, ts[0], ax)
ax.set_xlim(0, 1)
ax.set_ylim(-PARAMS["target_scale"]*1.5, PARAMS["target_scale"]*1.5)
ax.set_xlabel('Time')
ax.set_title('SDE Vector Field Trajectory')
#ax.legend()
ax.grid(True)

plt.show()

In [None]:
from IPython.display import HTML #
fig, ax = plt.subplots(figsize=(10, 5))

anim = plot_path_animation(xts, ts[0], fig, ax)

ax.set_xlim(0, 1)
ax.set_ylim(-PARAMS["target_scale"]*1.5, PARAMS["target_scale"]*1.5)
ax.set_xlabel('Time')
ax.set_title('SDE Vector Field Trajectory')
#ax.legend()
ax.grid(True)

plt.close() # Close the plot to prevent it from showing as a static image
HTML(anim.to_jshtml())

In [None]:
heat_samples_num = 100000
heat_x0 = path.p_simple.sample(heat_samples_num) # (num_samples, dim)
heat_z = path.sample_conditioning_variable(heat_samples_num)
heat_sde = ConditionalVectorFieldSDE(path, heat_z, PARAMS["sde_sigma"])
heat_simulator = EulerMaruyamaSimulator(heat_sde)
heat_ts = torch.linspace(0.0, 1.0, PARAMS["path_timesteps"]).view(1,-1,1).expand(heat_samples_num,-1,1).to(device)
heat_xts = heat_simulator.simulate_with_trajectory(heat_x0, heat_ts) # (bs, nts, dim)
heat_xts = rearrange(heat_xts, 'b t d -> t b d')

In [None]:
from util import plot_heatmap
fig, ax = plt.subplots(figsize=(10, 5))


heat_map, H_normalized = plot_heatmap(heat_xts, heat_ts, ax)
ax.set_xlim(0, 1)
ax.set_ylim(-PARAMS["target_scale"]*1.5, PARAMS["target_scale"]*1.5)
ax.set_xlabel('Time')
ax.set_title('SDE Vector Field Trajectory')
#ax.legend()
ax.grid(False)

plt.show()

In [None]:
anim = plot_path_animation(xts, ts[0], fig, ax)

ax.set_xlim(0, 1)
ax.set_ylim(-PARAMS["target_scale"]*1.5, PARAMS["target_scale"]*1.5)
if PARAMS["no_border"]:
    fig.subplots_adjust(
        left=0.01,    # Small buffer on the left
        right=0.99,   # Small buffer on the right
        bottom=0.01,  # Small buffer on the bottom
        top=0.99,     # Small buffer on the top
        hspace=0.0,
        wspace=0.0
    )
    # 1. Remove the axis ticks and labels (the numbers and tick marks)
    ax.set_xticks([])
    ax.set_yticks([])
    
    # 2. Remove the surrounding box/spines (the border)
    # Spines are the lines connecting the x and y axes.
    for spine in ax.spines.values():
        spine.set_visible(False)
    
    # 3. Optional: Remove the axis labels if you previously set them
    # You may have already removed the visual ticks, but this ensures labels are gone.
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_title('')
else:
    ax.set_xlabel('Time')
    ax.set_title('SDE Vector Field Trajectory')
    ax.legend()

ax.grid(False) 
plt.close() # Close the plot to prevent it from showing as a static image
HTML(anim.to_jshtml())

In [None]:
from matplotlib.animation import FFMpegWriter
import os

writer = FFMpegWriter(fps=40, bitrate=1800) 
# Note: fps should match 1000/interval (1000/50 = 20)

# The save method will automatically use the writer
output_filename = os.path.join("..", "output", "probability_paths_animation.mp4")
print(f"Saving animation to {output_filename}...")

# Use 'anim.save'
anim.save(output_filename, writer=writer) 

print("Save complete!")