In [None]:
from IPython import display
from functools import partial
import imageio
import matplotlib
import matplotlib.pytplot as plt
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from tqdm import tqdm

def logpdf(x, rx=2.5, ry=2.5, cx=0.0, cy=0.0):
  shifted_x = x - torch.tensor([cx, cy])
  scaled_x = shifted_x / torch.tensor([rx, ry])
  r = torch.linalg.norm(scaled_x, axis = -1)
  return -(r - 1)**2 / 0.033

def create_grad_func(logpdf, **kwargs):
  def grad_logpdf(x):
    x.requires_grad_(True)
    log_prob = logpdf(x, **kwargs)
    return torch.autograd.grad(log_prob.sum(), x)[0]
  return grad_logpdf

def create_shape(rx=2.5, ry=2.5, cx=0.0, cy=0.0):
  return {
      "class": "Ellipse",
      "kwargs": {
          "width": 2 * rx,
          "height": 2 * ry,
          "xy": (cx,cy)
      }
  }

def plot_farme(particles, step, shape, figsize=(4,4), lim=(-3, 3)):
  particles_np = particles.detach().cpu().numpy()
  fig, ax = plt.subplots(figsize=figsize)
  ax.scatter(particles_np[step, :, 0], particles_np[step, :, 1], alpha=0.1, s=1, color='blue')
  ax.set_xlim(*lim)
  ax.set_ylim(*lim)
  ax.set_xlabel("x coord")
  ax.set_ylabel("y coord")
  ax.set_aspect('equal')
  ax.set_title(f'Langevin Sampler at t={step}')

  shape_clas = shape["class"]
  shape_patch = getattr(matplotlib.patches, shape_clas)(
      edgecolor = 'red',
      facecolor = 'none',
      linewidth = 2,
      **shape["kwargs"]
  )
  ax.add_patch(shape_patch)

  fig.canvas.draw()
  buf = fig.canvas.buffer_rgba()
  image = np.asarray(buf)

  plt.close()
  return image

def plot_trajectory(particles, particle_idx, axis_names=["x", "y"], figsize=(10,3), lim=(-3, 3)):
  particles_np = particles.detach().cpu().numpy()
  fig, ax = plt.subplots(1, len(axis_names), figsize=figsize)
  for axis, axis_name in enumerate(axis_names):
    trajectory = particles_np[:, particle_idx, axis]
    ax[axis].plot(trajectory)
    ax[axis].set_ylim(*lim)
    ax[axis].set_title(f"Trajectory of particle{particle_idx} along {axis_name}-axis")
    ax[axis].set_xlabel("timestep")
    ax[axis].set_ylabel(f"{axis_name[axis]} coord")

def frames_to_image(frames):
  x, h = frames[0].shape[1], frames[0].shape[0]
  collated = Image.new('RGB', (w * len(frames), h))
  for i, frame in enumerate(frames):
    collated.paste(Image.fromarray(frame), (i * w, 0))
  return collated

def frames_to_gif(frames, filename="temp.gif"):
  imageio.mimsave(filename, frames, fps=5, loop=0)
  return filename

def langevin_update(grad_func, current_particles, noise, eta):
  next_particles = None
  return next_particles

def sample_langevin(grad_func, particles, num_steps, eta):
  particles_over_time = [particles]
  #...
  particles_over_time = torch.stack(particles_over_time)
  return particles_over_time

def sample_and_viz_langevin(device, langevin_kwargs, ellipse_kwargs, init_particles = None):
  if init_particles is None:
    init_particles = torch.randn(
        langevin_kwargs["num_particles"],
        langevin_kwargs["num_dims"],
        device = device
    )
  data = sample_langevin(
      create_grad_func(logpdf, **ellipse_kwargs),
      init_particles,
      langevin_kwargs["num_steps"],
      langevin_kwargs["eta"]
  )
  frames = []
  for t in tqdm(range(data.shape[0])):
    frames.append(plot_frame(data, t, create_shape(**ellipse_kwargs)))
  return frames


device = "cpu"
langevin_kwargs = {
    "num_particles": 10000,
    "num_dims": 2,
    "num_steps": 10,
    "eta": torch.tensor([1e-2, 1e-2])
}
ellipse_kwargs = {
    "rx": 1.5,
    "ry": 0.0,
    "cx": 0.0,
    "cy": 0.0
}

frames = sample_and_viz_langevin(device, langevin_kwargs, ellipse_kwargs)
frames_to_image(frames)

device = "cpu"
langevin_kwargs = {
    "num_particles": 10000,
    "num_dims": 2,
    "num_steps": 10,
    "eta": torch.tensor([1e-2, 1e-2])
}
ellipse_kwargs = {
    "rx": 1.0,
    "ry": 2.5,
    "cx": 0.0,
    "cy": 0.0
}

frames = sample_and_viz_langevin(device, langevin_kwargs, ellipse_kwargs)
frames_to_image(frames)

eta = None
#...

device = "cpu"
langevin_kwargs = {
    "num_particles": 10000,
    "num_dims": 2,
    "num_steps": 10,
    "eta": eta
}
ellipse_kwargs = {
    "rx": 1.0,
    "ry": 2.5,
    "cx": 0.0,
    "cy": 0.0
}

frames = sample_and_viz_langevin(device, langevin_kwargs, ellipse_kwargs)
frames_to_image(frames)

init_particles = None
#...

device = "cpu"
langevin_kwargs = {
    "num_particles": 10000,
    "num_dims": 2,
    "num_steps": 10,
    "eta": torch.tensor([1e-2, 1e-2])
}
ellipse_kwargs = {
    "rx": 1.5,
    "ry": 1.5,
    "cx": 1.0,
    "cy": 1.0
}

frames = sample_and_viz_langevin(device, langevin_kwargs, ellipse_kwargs, init_particles=init_particles)
frames_to_image(frames)