This code is modified from the tutorial code in 
Yang Song, Jascha Sohl-Dickstein, Diederik P. Kingma, Abhishek Kumar, Stefano Ermon, and Ben Poole. "[Score-Based Generative Modeling through Stochastic Differential Equations.](https://arxiv.org/pdf/2011.13456.pdf)" Internation Conference on Learning Representations, 2021

In [102]:
#@title Define a time-dependent score-based model 

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed 
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)[..., None, None]


class ScoreNet(nn.Module):
  """A time-dependent score-based model built upon U-Net architecture."""

  def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
    """Initialize a time-dependent score-based network.

    Args:
      marginal_prob_std: A function that takes time t and gives the standard
        deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
      channels: The number of channels for feature maps of each resolution.
      embed_dim: The dimensionality of Gaussian random feature embeddings.
    """
    super().__init__()
    # Gaussian random feature embedding layer for time
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
         nn.Linear(embed_dim, embed_dim))
    # Encoding layers where the resolution decreases
    self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
    self.dense1 = Dense(embed_dim, channels[0])
    self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])
    self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
    self.dense2 = Dense(embed_dim, channels[1])
    self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])
    self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
    self.dense3 = Dense(embed_dim, channels[2])
    self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])
    self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
    self.dense4 = Dense(embed_dim, channels[3])
    self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])    

    # Decoding layers where the resolution increases
    self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
    self.dense5 = Dense(embed_dim, channels[2])
    self.tgnorm4 = nn.GroupNorm(32, num_channels=channels[2])
    self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)    
    self.dense6 = Dense(embed_dim, channels[1])
    self.tgnorm3 = nn.GroupNorm(32, num_channels=channels[1])
    self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)    
    self.dense7 = Dense(embed_dim, channels[0])
    self.tgnorm2 = nn.GroupNorm(32, num_channels=channels[0])
    self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)
    
    # The swish activation function
    self.act = lambda x: x * torch.sigmoid(x)
    self.marginal_prob_std = marginal_prob_std
  
  def forward(self, x, t): 
    # Obtain the Gaussian random feature embedding for t   
    embed = self.act(self.embed(t))    
    # Encoding path
    h1 = self.conv1(x)    
    # Incorporate information from t
    h1 += self.dense1(embed)
    # Group normalization
    h1 = self.gnorm1(h1)
    h1 = self.act(h1)
    h2 = self.conv2(h1)
    h2 += self.dense2(embed)
    h2 = self.gnorm2(h2)
    h2 = self.act(h2)
    h3 = self.conv3(h2)
    h3 += self.dense3(embed)
    h3 = self.gnorm3(h3)
    h3 = self.act(h3)
    h4 = self.conv4(h3)
    h4 += self.dense4(embed)
    h4 = self.gnorm4(h4)
    h4 = self.act(h4)

    # Decoding path
    h = self.tconv4(h4)
    # Skip connection from the encoding path
    h += self.dense5(embed)
    h = self.tgnorm4(h)
    h = self.act(h)
    h = self.tconv3(torch.cat([h, h3], dim=1))
    h += self.dense6(embed)
    h = self.tgnorm3(h)
    h = self.act(h)
    h = self.tconv2(torch.cat([h, h2], dim=1))
    h += self.dense7(embed)
    h = self.tgnorm2(h)
    h = self.act(h)
    h = self.tconv1(torch.cat([h, h1], dim=1))

    # Normalize output
    h = h / self.marginal_prob_std(t)[:, None, None, None]
    return h

In [103]:
#@title Define OUR score-based models 

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class GaussianFourierProjection(nn.Module):
  """Gaussian random features for encoding time steps."""  
  def __init__(self, embed_dim, scale=30.):
    super().__init__()
    # Randomly sample weights during initialization. These weights are fixed 
    # during optimization and are not trainable.
    self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
  def forward(self, x):
    x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
    return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class Dense(nn.Module):
  """A fully connected layer that reshapes outputs to feature maps."""
  def __init__(self, input_dim, output_dim):
    super().__init__()
    self.dense = nn.Linear(input_dim, output_dim)
  def forward(self, x):
    return self.dense(x)


class ScoreNet_1HiddenLayerFC(nn.Module):
  """A time-dependent score-based model built upon 1-hidden-layer fully-connected NN architecture."""

  def __init__(self, marginal_prob_std, hidden_dim= 16, embed_dim=4):
    """Initialize a time-dependent score-based network.

    Args:
      marginal_prob_std: A function that takes time t and gives the standard
        deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
      embed_dim: The dimensionality of Gaussian random feature embeddings.
    """
    super().__init__()
    # Gaussian random feature embedding layer for time
    self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
         nn.Linear(embed_dim, hidden_dim))
    self.dense1 = Dense(1, hidden_dim)
    self.dense2 = Dense(hidden_dim, 1)

    
    # The swish activation function
    self.act = lambda x: x * torch.sigmoid(x)
    self.marginal_prob_std = marginal_prob_std
  
  def forward(self, x, t): 
    # Obtain the Gaussian random feature embedding for t   
    embed = self.embed(t) 
    # Feature map
    h = self.dense1(x)   
    # Incorporate information from t
    h += embed
    h = self.dense2(self.act(h))
    h = h / 16 # Normalize by the power of hidden_dim (=2048)

    # Normalize output
    h = h / self.marginal_prob_std(t)[:, None]
    return h

In [None]:
#@title Auxilary codes to test substituted score models, can ignore

import torch
# from torchsummary import summary

x = torch.randn(32, 1)
t = torch.rand(32)
# GFP = GaussianFourierProjection(embed_dim=8)
# res = GFP.forward(t)
# DS = Dense(input_dim=4, output_dim=6)
# res = DS.forward(x=torch.randn(16, 2, 4))
SN = ScoreNet_1HiddenLayerFC(marginal_prob_std=1.2)
print(SN)
res = SN.forward(x, t)
res.size()
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# summary(ScoreNet(marginal_prob_std=1.).to(device), [(1, 28, 28), (None, 1, 1)])

In [104]:
#@title Set up the SDE

import functools

device = 'cuda' #@param ['cuda', 'cpu'] {'type':'string'}

def marginal_prob_std(t, sigma):
  """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.

  Args:    
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.  
  
  Returns:
    The standard deviation.
  """    
  t = torch.tensor(t, device=device)
  return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
  """Compute the diffusion coefficient of our SDE.

  Args:
    t: A vector of time steps.
    sigma: The $\sigma$ in our SDE.
  
  Returns:
    The vector of diffusion coefficients.
  """
  return torch.tensor(sigma**t, device=device)
  
sigma = 25.0#@param {'type':'number'}
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

In [105]:
#@title Define the loss function

def loss_fn(model, x, marginal_prob_std, eps=1e-5):
  """The loss function for training score-based generative models.

  Args:
    model: A PyTorch model instance that represents a 
      time-dependent score-based model.
    x: A mini-batch of training data.    
    marginal_prob_std: A function that gives the standard deviation of 
      the perturbation kernel.
    eps: A tolerance value for numerical stability.
  """
  random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps  
  z = torch.randn_like(x)
  std = marginal_prob_std(random_t)
  perturbed_x = x + z * std[:, None]
  score = model(perturbed_x, random_t)
  loss = torch.mean(torch.sum((score * std[:, None] + z)**2, dim=(1)))
  return loss

In [115]:
#@title Sample data from the Gaussian mixture

import torch
import numpy as np
import torch.distributions as D

class GetLoader(torch.utils.data.Dataset):
    # def __init__(self, data_root, data_label):
    #     self.data = data_root
    #     self.label = data_label
    def __init__(self, data_root):
        self.data = data_root
    def __getitem__(self, index):
        data = self.data[index]
        # labels = self.label[index]
        # return data, labels
        return data
    def __len__(self):
        return len(self.data)

n_data = 1000 # number of data
loc = 3. # location of modes
# data size: n_data * 1
# data sampled from 2-mode GMM with equal weights
mix = D.Categorical(torch.ones(2,)) 
comp = D.Normal(torch.tensor([-loc, loc]), torch.ones(2,))
gmm = D.MixtureSameFamily(mix, comp)
source_data = gmm.sample((n_data,))
dataset = GetLoader(source_data)

from torch.utils.data import DataLoader

datas = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=False)

for i, data in enumerate(datas):
    print("Batch {}\n{}".format(i, data))

# data.size()

第 0 个Batch 
tensor([ 3.4545,  4.2307,  3.4537, -5.1684,  2.5124, -0.9774,  2.9295, -2.2165,
         4.1086, -2.3678,  1.8971, -1.9056,  3.8853,  2.8749, -2.8030,  3.8626,
        -2.3626, -3.7237,  3.4676, -2.9863,  1.0915, -4.7500,  2.5474,  3.6417,
         1.8369, -2.7469,  2.9156, -2.9955, -2.3876, -4.6952,  4.2452, -2.4893])
第 1 个Batch 
tensor([-3.4227,  2.2794, -4.7184, -4.1803,  3.1782,  2.4326,  4.1751, -3.3122,
        -2.5830,  3.2893, -2.9894,  2.6886,  3.1617,  1.7049,  3.0762, -4.0633,
        -4.3795, -2.8614,  3.3156, -3.6159, -3.6630,  0.5622,  1.0092,  2.2229,
         3.8862,  3.1689, -5.2860,  4.1451,  2.1583,  2.0636,  4.4552,  1.6712])
第 2 个Batch 
tensor([-2.3817, -1.6967, -3.6747, -3.0554,  3.0285, -2.4978, -3.6327, -2.7464,
         2.4673,  1.5665,  4.0536, -2.7231, -3.5562,  1.4634, -3.2955,  2.4066,
        -3.5431,  3.1838, -2.1958, -2.4230,  2.7110,  3.2057,  3.0468,  2.0105,
        -3.0200, -2.8328,  1.8880, -4.7602,  3.6504, -3.4159, -3.6731, -2.7953])
第

In [116]:
# Gaussian Mixture Density
def GMM_density(x):
    return torch.tensor(0.5) / torch.sqrt(2 * torch.tensor(torch.pi)) * torch.exp(-(x - torch.tensor(loc)) ** 2 / 2) + torch.tensor(0.5) / torch.sqrt(2 * torch.tensor(torch.pi)) * torch.exp(-(x + torch.tensor(loc)) ** 2 / 2)

In [117]:
import torch
import functools
from torch.optim import Adam
from torch.optim import SGD
# from torch.utils.data import DataLoader
# import torchvision.transforms as transforms
# from torchvision.datasets import MNIST
# import tqdm
# import tqdm.notebook
from scipy import integrate
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn
torch.manual_seed(42)

score_model = torch.nn.DataParallel(ScoreNet_1HiddenLayerFC(marginal_prob_std=marginal_prob_std_fn))
score_model = score_model.to(device)

n_epochs = 2001#@param {'type':'integer'}
# size of a mini-batch
# batch_size =  32 #@param {'type':'integer'}
# learning rate
lr=5e-1 #@param {'type':'number'}

# dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
# data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

eps = 0.065
grid_size = 50
left = 10.
point = [-left + (2 * left / grid_size) * i for i in range(grid_size + 1)]
target_dens = torch.zeros(grid_size + 1)
for i in range(grid_size + 1):
    target_dens[i] = GMM_density(torch.tensor(point[i]))

kl = torch.zeros(31)
kl_iter = 0
def score(x):
  return score_model(torch.tensor([[x]]), torch.tensor([eps]))[0][0]

In [118]:
optimizer = SGD(score_model.parameters(), lr=lr)

In [None]:
#@title Training 

# optimizer = Adam(score_model.parameters(), lr=lr)

# tqdm_epoch = tqdm(range(n_epochs))
for epoch in tqdm(range(n_epochs)):
  # if (epoch > 1000):
  #   lr = 1e-4
  #   optimizer = Adam(score_model.parameters(), lr=lr)
  batch_total_loss = 0.
  num_items = 0
  # for x, y in data_loader:
  for x in datas:
    x = x[:,None].to(device)    
    loss = loss_fn(score_model, x, marginal_prob_std_fn)
    optimizer.zero_grad()
    loss.backward()    
    optimizer.step()
    batch_total_loss += loss.item() * x.shape[0]
    num_items += x.shape[0]
  # Print the averaged training loss so far
  # tqdm_epoch.set_description('Average Loss: {:.6f}'.format(batch_total_loss / num_items))

  ##
  # For density
  ##

  if(epoch % 100 == 0):
    print(('Average Loss: {:.6f}'.format(batch_total_loss / num_items)))
    density = torch.zeros(grid_size + 1)
    for i in tqdm(range(grid_size + 1)):
      density[i], err = torch.tensor(integrate.quad(score, -100.0, point[i]))
    max = density.max()
    dd = density - max
    dd = torch. exp(dd)
    dd /= (dd.sum() / (1 / (2 * left / grid_size)))
    seaborn.kdeplot(source_data, label = 'Target')
    plt.plot(point, dd, label = 'Trained')
    # plt.title('eps = {:5f}, epoch = {}'.format(eps, epoch))
    plt.title('epoch = {}'.format(epoch))
    plt.legend()
    # plt.savefig('loc_{}epoch_{}.png'.format(loc, epoch), dpi=300)
    plt.show()
    ddd = dd * (2 * left / grid_size)
    target_dens_ = target_dens * (2 * left / grid_size)
    print(ddd)
    kl[kl_iter] = F.kl_div(ddd.log(), target_dens_, reduction='sum')
    print(kl[kl_iter])
    # kl[kl_iter] = F.kl_div(dd.softmax(dim=-1).log(), target_dens.softmax(dim=-1), reduction='sum')
    kl_iter += 1
  # Update the checkpoint after each epoch of training
  # torch.save(score_model.state_dict(), 'ckpt_1000.pth')