In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


from sklearn.datasets import make_swiss_roll
from sklearn.datasets import make_moons


from pytagi import NetProp

from diffuser_v0 import Diffuser

import matplotlib.animation as animation
from ipywidgets import interact, IntSlider
import mplcursors

In [None]:
x, _ = make_swiss_roll(n_samples=100000, noise=0.5)
# Make two-dimensional to easen visualization
x = x[:, [0, 2]]
x = (x - x.mean()) / x.std()

#make_moons, labels = make_moons(n_samples=100000, noise=0.01)
#moons = make_moons * [5., 10.]
#x = make_moons


diffusion_steps = 40  # Number of steps in the diffusion process

# Set noising variances betas as in Nichol and Dariwal paper (https://arxiv.org/pdf/2102.09672.pdf)
s = 0.008
timesteps = np.arange(0, diffusion_steps)
schedule = np.cos((timesteps / diffusion_steps + s) / (1 + s) * np.pi / 2)**2

baralphas = schedule / schedule[0]
betas = 1 - baralphas / np.concatenate([baralphas[0:1], baralphas[0:-1]])
alphas = 1 - betas

X = np.array(x, dtype=np.float32)

#from .diffuser_heteros_v0 import Diffuser

diffuser = Diffuser(
    num_epochs=100,
    #batch_size=2048,
    batch_size=2500,
    X_data=X,
    diffusion_steps=diffusion_steps,
    sampling_dim=(2500, 2),
    alphas=alphas,
    betas=betas,
)

error_var = diffuser.train()

In [None]:
x, xt, var, var_temp = diffuser.sample()

x2 = x
xt2 = xt


# Select all points
num_points = len(xt2[0])

# Initialize trajectory list for each point
trajectories = [[] for _ in range(num_points)]

def draw_frame(i, point_index):
    plt.clf()
    Xvis = xt2[i]
    sc = plt.scatter(Xvis[:, 0], Xvis[:, 1], marker="1", c=var_temp[i][:,0], cmap='viridis', s=1, vmin=0.0, vmax=0.01, alpha=0.4)
    plt.xlim([-2.2, 2.2])
    plt.ylim([-2.2, 2.2])

    # Plot selected trajectory
    trajectory = np.array(trajectories[point_index])
    if len(trajectory) > 1:
        plt.plot(trajectory[:, 0], trajectory[:, 1], color='green', alpha=0.8)

        # Plot arrows indicating direction of movement
        k = len(trajectory) - 2

        plt.arrow(trajectory[k, 0], trajectory[k, 1],
                    trajectory[k+1, 0] - trajectory[k, 0], trajectory[k+1, 1] - trajectory[k, 1],
                    color='green', alpha=0.8, width=0.005, head_width=0.05, head_length=0.1)

    return sc,

# Save trajectories for all points
for i in range(len(xt2)):
    for j in range(num_points):
        trajectories[j].append(xt2[i][j])

# Define interactive visualization
@interact(point_index=IntSlider(min=0, max=num_points-1, step=1, value=0))
def visualize_trajectory(point_index=0):
    fig = plt.figure()
    anim = animation.FuncAnimation(fig, draw_frame, frames=len(xt2), interval=1, blit=True, fargs=(point_index,))
    plt.close(fig)  # Close the figure to avoid duplicate plots
    return anim.to_jshtml()