<a href="https://colab.research.google.com/github/cshangRL/iap-diffusion-labs/blob/main/mytests/lab3-1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from abc import ABC, abstractmethod
from typing import Optional, List, Type, Tuple, Dict
import math

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes._axes import Axes
import torch
import torch.nn as nn
import torch.distributions as D
from torch.func import vmap, jacrev
from tqdm import tqdm
import seaborn as sns
from sklearn.datasets import make_moons, make_circles
from torchvision import datasets, transforms
from torchvision.utils import make_grid

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Part 0: Recycling components from previous labs

In [3]:
class OldSampleable(ABC):
  @abstractmethod
  def sample(self, num_samples: int) -> torch.Tensor:
    pass

In [4]:
class Sampleable(ABC):
  @abstractmethod
  def sample(self, num_samples: int) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    pass

In [5]:
class IsotropicGaussian(nn.Module, Sampleable):
  def __init__(self, shape: List[int], std: float=1.0):
    super().__init__()
    self.shape = shape
    self.std = std
    self.dummy = nn.Buffer(torch.zeros(1))

  def sample(self, num_samples) -> Tuple[torch.Tensor, torch.Tensor]:
    return self.std * torch.randn(num_samples, *self.shape).to(self.dummy.device), None

In [7]:
class ConditionalProbabilityPath(nn.Module, ABC):
  def __init__(self, p_simple: Sampleable, p_data: Sampleable):
    super().__init__()
    self.p_simple = p_simple
    self.p_data = p_data

  def sample_marginal_path(self, t: torch.Tensor) -> torch.Tensor:
    num_samples = t.shape[0]
    z, _ = self.samle_conditioning_variable(num_samples)
    x = self.sample_conditional_path(z, t)
    return x

  @abstractmethod
  def sample_conditioning_variable(self, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
    pass

  @abstractmethod
  def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    pass

  @abstractmethod
  def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    pass

  @abstractmethod
  def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    pass

In [8]:
class Alpha(ABC):
  def __init__(self):
    assert torch.allclose( self(torch.zeros(1, 1, 1, 1)), torch.zeros(1,1,1,1))
    assert torch.allclose( self(torch.ones(1,1,1,1)), torch.ones(1,1,1,1))

  @abstractmethod
  def __call__(self, t: torch.Tensor) -> torch.Tensor:
    pass

  def dt(self, t: torch.Tensor) -> torch.Tensor:
    t = t.unsqueeze(1)
    dt = vmap(jacrev(self))(t)
    return dt.view(-1, 1, 1, 1)

class Beta(ABC):
  def __init__(self):
    assert torch.allclose( self(torch.zeros(1, 1, 1, 1)), torch.ones(1,1,1,1))
    assert torch.allclose( self(torch.ones(1,1,1,1)), torch.zeros(1,1,1,1))

  @abstractmethod
  def __call__(self, t: torch.Tensor) -> torch.Tensor:
    pass

  def dt(self, t: torch.Tensor) -> torch.Tensor:
    t = t.unsqueeze(1)
    dt = vmap(jacrev(self))(t)
    return dt.view(-1, 1, 1, 1)

class LinearAlpha(Alpha):
  def __call__(self, t: torch.Tensor) -> torch.Tensor:
    return t

  def dt(self, t: torch.Tensor) -> torch.Tensor:
    return torch.ones_like(t)

class LinearBeta(Beta):
  def __call__(self, t: torch.Tensor) -> torch.Tensor:
    return 1 - t

  def dt(self, t: torch.Tensor) -> torch.Tensor:
    return - torch.ones_like(t)

class GaussianConditionalProbabilityPath(ConditionalProbabilityPath):
  def __init__(self, p_data: Sampleable, p_simple_shape: List[int], alpha: Alpha, beta: Beta):
    p_simple = IsotropicGaussian(shape = p_simple_shape, std=1.0)
    super().__init__(p_simple, p_data)
    self.alpha = alpha
    self.beta = beta

  def sample_conditioning_variable(self, num_samples: int) -> torch.Tensor:
    return self.p_data.sample(num_samples)

  def sample_conditional_path(self, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    return self.alpha(t) * z + self.beta(t)* torch.randn_like(z)

  def conditional_vector_field(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    alpha_t = self.alpha(t)
    beta_t = self.beta(t)
    dt_alpha_t = self.alpha.dt(t)
    dt_beta_t = self.beta.dt(t)

    return (dt_alpha_t - dt_beta_t/beta_t * alpha_t)*z + dt_beta_t/beta_t * x

  def conditional_score(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    alpha_t = self.alpha(t)
    beta_t = self.beta(t)
    return (z * alpha_t - x)/ beta_t **2

In [9]:
class ODE(ABC):
  @abstractmethod
  def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
    pass

class SDE(ABC):
  @abstractmethod
  def drift_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
    pass

  @abstractmethod
  def diffusion_coefficient(self, xt: torch.Tensor, t: torch.Tensor, **kwargs) -> torch.Tensor:
    pass

In [None]:
class Simulator(ABC):
  @abstractmethod
  def step(self, xt: torch.Tensor, t: torch.Tensor, dt: torch.Tensor, **kwargs):
    pass

  @torch.no_grad()
  def simulate(self, x: torch.Tensor, ts: torch.Tensor, **kwargs) -> torch.Tensor:
    nts = ts.shape[1]
    for t_idx in tqdm(range(nts - 1)):
      t = ts[:, t_idx]
      h = ts[:, t_idx + 1] - t
      x = self.step(x, t, h, **kwargs)
    return x