<a href="https://colab.research.google.com/github/hongqin/Generative_AI_Fa25/blob/main/2d_langevin_dynamics_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Langevin Dynamics Toy Example: Sampling 2D Gaussian Mixture
# Run in Google Colab

!pip install -q matplotlib seaborn

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Target distribution: Mixture of 2 Gaussians
def target_density(x):
    # Two 2D Gaussians
    mean1 = np.array([2.0, 2.0])
    mean2 = np.array([-2.0, -2.0])
    cov = 0.5 * np.eye(2)
    p1 = multivariate_gaussian(x, mean1, cov)
    p2 = multivariate_gaussian(x, mean2, cov)
    return 0.5 * p1 + 0.5 * p2

def multivariate_gaussian(x, mean, cov):
    diff = x - mean
    inv = np.linalg.inv(cov)
    exponent = -0.5 * np.sum(diff @ inv * diff, axis=1)
    norm = 1.0 / (2 * np.pi * np.sqrt(np.linalg.det(cov)))
    return norm * np.exp(exponent)

# Score function approximation: gradient of log p(x)
def score_function(x):
    mean1 = np.array([2.0, 2.0])
    mean2 = np.array([-2.0, -2.0])
    cov = 0.5 * np.eye(2)
    inv = np.linalg.inv(cov)

    diff1 = x - mean1
    diff2 = x - mean2
    score1 = -diff1 @ inv.T
    score2 = -diff2 @ inv.T

    p1 = multivariate_gaussian(x, mean1, cov)
    p2 = multivariate_gaussian(x, mean2, cov)
    denom = p1 + p2 + 1e-8

    weighted_score = (score1 * p1[:, None] + score2 * p2[:, None]) / denom[:, None]
    return weighted_score

# Langevin Dynamics Sampling
def langevin_sampling(score_fn, steps=100, step_size=0.1, noise_scale=0.1, n_samples=500):
    x = np.random.randn(n_samples, 2) * 4.0  # Start from noise
    trajectory = [x.copy()]
    for _ in range(steps):
        grad = score_fn(x)
        noise = np.random.randn(*x.shape) * noise_scale
        x = x + step_size * grad + noise
        trajectory.append(x.copy())
    return trajectory

# Visualize
from matplotlib import animation
from IPython.display import HTML

def plot_trajectory(traj):
    fig, ax = plt.subplots(figsize=(6, 6))

    def animate(i):
        ax.clear()
        ax.set_xlim(-5, 5)
        ax.set_ylim(-5, 5)
        sns.kdeplot(x=traj[i][:,0], y=traj[i][:,1], ax=ax, fill=True, cmap="Blues")
        ax.set_title(f"Step {i}")

    anim = animation.FuncAnimation(fig, animate, frames=len(traj), interval=200)
    plt.close()
    return HTML(anim.to_jshtml())

# Run the demo
trajectory = langevin_sampling(score_function, steps=50, step_size=0.1, noise_scale=0.15, n_samples=1000)
plot_trajectory(trajectory)
