<a target="_blank" href="https://colab.research.google.com/github/felixp8/text-to-nn/blob/main/experiments/mlp/diffusion/diffusion.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

In [1]:
!git clone https://github.com/felixp8/text-to-nn.git

Cloning into 'text-to-nn'...
remote: Enumerating objects: 184, done.[K
remote: Counting objects: 100% (78/78), done.[K
remote: Compressing objects: 100% (64/64), done.[K
remote: Total 184 (delta 29), reused 28 (delta 9), pack-reused 106[K
Receiving objects: 100% (184/184), 71.23 MiB | 21.25 MiB/s, done.
Resolving deltas: 100% (67/67), done.


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import functools

In [3]:
class GaussianFourierProjection(nn.Module):
    """Gaussian random features for encoding time steps."""
    def __init__(self, embed_dim, scale=30.):
        super().__init__()
        # Randomly sample weights during initialization. These weights are fixed
        # during optimization and are not trainable.
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)
    def forward(self, x):
        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

In [4]:
class ScoreNet(nn.Module):
    """A time-dependent score-based model built upon U-Net architecture."""

    def __init__(self, marginal_prob_std, input_dim, hidden_dims=[], embed_dim=256, context_dim=768):
        """Initialize a time-dependent score-based network.

        Args:
          marginal_prob_std: A function that takes time t and gives the standard
            deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
          channels: The number of channels for feature maps of each resolution.
          embed_dim: The dimensionality of Gaussian random feature embeddings.
        """
        super().__init__()
        # Gaussian random feature embedding layer for time
        self.t_embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim))
        self.y_embed = nn.Linear(context_dim, embed_dim)
        # Encoding layers where the resolution decreases
        hidden_dims = [input_dim,] + hidden_dims + [input_dim,]
        x_layers = []
        t_layers = []
        norm_layers = []
        for i in range(len(hidden_dims) - 2):
            x_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            t_layers.append(nn.Linear(embed_dim, hidden_dims[i+1]))
            norm_layers.append(nn.LayerNorm(hidden_dims[i+1]))
        self.x_layers = nn.ModuleList(x_layers)
        self.t_layers = nn.ModuleList(t_layers)
        self.norm_layers = nn.ModuleList(norm_layers)
        self.final = nn.Linear(hidden_dims[-2], hidden_dims[-1])

        # The swish activation function
        self.act = nn.SiLU()
        self.marginal_prob_std = marginal_prob_std

    def forward(self, x, t):
        # Obtain the Gaussian random feature embedding for t
        t_embed = self.act(self.t_embed(t))

        h = x
        for i in range(len(self.x_layers)):
            h = self.x_layers[i](h)
            h += self.t_layers[i](t_embed)
            h = self.norm_layers[i](h)
            h = self.act(h)

        h = self.act(self.final(h))

        # Normalize output
        h = h / self.marginal_prob_std(t)[:, None]
        return h


class ScoreNetConditional(nn.Module):
    """A time-dependent score-based model built upon U-Net architecture."""

    def __init__(self, marginal_prob_std, input_dim, hidden_dims=[], embed_dim=256, context_dim=768):
        """Initialize a time-dependent score-based network.

        Args:
          marginal_prob_std: A function that takes time t and gives the standard
            deviation of the perturbation kernel p_{0t}(x(t) | x(0)).
          channels: The number of channels for feature maps of each resolution.
          embed_dim: The dimensionality of Gaussian random feature embeddings.
        """
        super().__init__()
        # Gaussian random feature embedding layer for time
        self.t_embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim))
        self.y_embed = nn.Linear(context_dim, embed_dim)
        # Encoding layers where the resolution decreases
        hidden_dims = [input_dim,] + hidden_dims + [input_dim,]
        x_layers = []
        t_layers = []
        y_layers = []
        norm_layers = []
        for i in range(len(hidden_dims) - 2):
            x_layers.append(nn.Linear(hidden_dims[i], hidden_dims[i+1]))
            t_layers.append(nn.Linear(embed_dim, hidden_dims[i+1]))
            y_layers.append(nn.Linear(embed_dim, hidden_dims[i+1]))
            norm_layers.append(nn.LayerNorm(hidden_dims[i+1]))
        self.x_layers = nn.ModuleList(x_layers)
        self.t_layers = nn.ModuleList(t_layers)
        self.y_layers = nn.ModuleList(y_layers)
        self.norm_layers = nn.ModuleList(norm_layers)
        self.final = nn.Linear(hidden_dims[-2], hidden_dims[-1])

        # The swish activation function
        self.act = nn.SiLU()
        self.marginal_prob_std = marginal_prob_std

    def forward(self, x, t, y):
        # Obtain the Gaussian random feature embedding for t
        t_embed = self.act(self.t_embed(t))
        y_embed = self.act(self.y_embed(y))

        h = x
        for i in range(len(self.x_layers)):
            h = self.x_layers[i](h)
            h += self.t_layers[i](t_embed)
            h += self.y_layers[i](y_embed)
            h = self.norm_layers[i](h)
            h = self.act(h)

        h = self.act(self.final(h))

        # Normalize output
        h = h / self.marginal_prob_std(t)[:, None]
        return h

In [5]:
device = 'cpu' # ['cuda', 'cpu']

def marginal_prob_std(t, sigma):
    """Compute the mean and standard deviation of $p_{0t}(x(t) | x(0))$.

    Args:
      t: A vector of time steps.
      sigma: The $\sigma$ in our SDE.

    Returns:
      The standard deviation.
    """
    t = torch.tensor(t, device=device)
    return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
    """Compute the diffusion coefficient of our SDE.

    Args:
      t: A vector of time steps.
      sigma: The $\sigma$ in our SDE.

    Returns:
      The vector of diffusion coefficients.
    """
    return torch.tensor(sigma**t, device=device)

sigma =  50.0
marginal_prob_std_fn = functools.partial(marginal_prob_std, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

In [6]:
def loss_fn(model, x, marginal_prob_std, y=None, eps=1e-5):
    """The loss function for training score-based generative models.

    Args:
      model: A PyTorch model instance that represents a
        time-dependent score-based model.
      x: A mini-batch of training data.
      marginal_prob_std: A function that gives the standard deviation of
        the perturbation kernel.
      eps: A tolerance value for numerical stability.
    """
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
    z = torch.randn_like(x)
    std = marginal_prob_std(random_t)
    perturbed_x = x + z * std[:, None]
    if y is not None:
        score = model(perturbed_x, random_t, y)
    else:
        score = model(perturbed_x, random_t)
    loss = torch.mean(torch.sum((score * std[:, None] + z)**2, dim=(1,)))
    return loss

In [21]:
import pandas as pd
import h5py

expressions_file = "./text-to-nn/experiments/mlp/data_generation/data/tiny/expressions.csv"
parameters_file = "./text-to-nn/experiments/mlp/data_generation/data/tiny/parameters.h5"
embeddings_file = "./text-to-nn/experiments/mlp/data_generation/data/tiny/instructor_embeddings.h5"

expr_csv = pd.read_csv(expressions_file)
with h5py.File(parameters_file, 'r') as h5f:
    parameters = h5f['nn_parameters'][:h5f['counter'][()].item()]
with h5py.File(embeddings_file, 'r') as h5f:
    embeddings = h5f['embeddings'][()]

assert expr_csv.shape[0] == parameters.shape[0]
assert expr_csv.shape[0] == embeddings.shape[0]

if True:
    from sklearn.preprocessing import OneHotEncoder
    manual_embeddings = OneHotEncoder(sparse_output=False).fit_transform(np.array(expr_csv.expr)[:, None])
    embeddings = manual_embeddings

In [22]:
good_mask = (expr_csv['best_mse_loss'] < 1.)

expr_csv = expr_csv[good_mask]
parameters = parameters[good_mask]
embeddings = embeddings[good_mask]

In [35]:
train_mask = np.random.choice(parameters.shape[0], size=int(parameters.shape[0]*0.6), replace=False)
train_mask = np.isin(np.arange(parameters.shape[0]), train_mask)
valid_mask = ~train_mask

train_expr = expr_csv[train_mask]
train_parameters = parameters[train_mask]
train_embeddings = embeddings[train_mask]

valid_expr = expr_csv[valid_mask]
valid_parameters = parameters[valid_mask]
valid_embeddings = embeddings[valid_mask]

In [36]:
train_parameters = torch.from_numpy(train_parameters).to(torch.float)
train_embeddings = torch.from_numpy(train_embeddings).to(torch.float)

In [37]:
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

score_model = ScoreNetConditional(
    input_dim=65,
    hidden_dims=[128, 64, 32, 64, 128],
    embed_dim=128,
    context_dim=8,
    marginal_prob_std=marginal_prob_std_fn,
)
score_model = score_model.to(device)

n_epochs = 100
## size of a mini-batch
batch_size =  256
## learning rate
lr=1e-3
## log freq
log_freq = 10

dataset = TensorDataset(train_parameters, train_embeddings)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

optimizer = optim.Adam(score_model.parameters(), lr=lr)
for epoch in range(n_epochs):
    avg_loss = 0.
    num_items = 0
    for x, y in data_loader:
        x = x.to(device)
        loss = loss_fn(score_model, x, marginal_prob_std_fn, y)
        # loss = loss_fn(score_model, x, marginal_prob_std_fn)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
    # Print the averaged training loss so far.
    if epoch % log_freq == 0:
        print('Epoch {:04d} Average Loss: {:5f}'.format(epoch, avg_loss / num_items))
        # Update the checkpoint after each epoch of training.
        torch.save(score_model.state_dict(), 'ckpt.pth')

  t = torch.tensor(t, device=device)


Epoch 0000 Average Loss: 64.284260
Epoch 0010 Average Loss: 49.406865
Epoch 0020 Average Loss: 47.394924
Epoch 0030 Average Loss: 47.193865
Epoch 0040 Average Loss: 46.446875
Epoch 0050 Average Loss: 46.043321
Epoch 0060 Average Loss: 45.320599
Epoch 0070 Average Loss: 45.278607
Epoch 0080 Average Loss: 44.911371
Epoch 0090 Average Loss: 44.357430


In [38]:
num_steps =  500
def Euler_Maruyama_sampler(score_model,
                           input_dim,
                           y,
                           marginal_prob_std,
                           diffusion_coeff,
                           batch_size=64,
                           num_steps=num_steps,
                           device='cuda',
                           eps=1e-3):
    """Generate samples from score-based models with the Euler-Maruyama solver.

    Args:
      score_model: A PyTorch model that represents the time-dependent score-based model.
      marginal_prob_std: A function that gives the standard deviation of
        the perturbation kernel.
      diffusion_coeff: A function that gives the diffusion coefficient of the SDE.
      batch_size: The number of samplers to generate by calling this function once.
      num_steps: The number of sampling steps.
        Equivalent to the number of discretized time steps.
      device: 'cuda' for running on GPUs, and 'cpu' for running on CPUs.
      eps: The smallest time step for numerical stability.

    Returns:
      Samples.
    """
    t = torch.ones(batch_size, device=device)
    init_x = torch.randn(batch_size, input_dim, device=device) \
        * marginal_prob_std(t)[:, None]
    time_steps = torch.linspace(1., eps, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]
    x = init_x
    with torch.no_grad():
        for time_step in time_steps:
            batch_time_step = torch.ones(batch_size, device=device) * time_step
            g = diffusion_coeff(batch_time_step)
            mean_x = x + (g**2)[:, None] * score_model(x, batch_time_step, y) * step_size
            x = mean_x + torch.sqrt(step_size) * g[:, None] * torch.randn_like(x)
    # Do not include any noise in the last sampling step.
    return mean_x

In [40]:
## Load the pre-trained checkpoint from disk.
device = 'cpu' # ['cuda', 'cpu']
ckpt = torch.load('ckpt.pth', map_location=device)
score_model.load_state_dict(ckpt)

sample_batch_size = 64
sampler = Euler_Maruyama_sampler
y = torch.from_numpy(valid_embeddings[:sample_batch_size]).to(torch.float)

## Generate samples using the specified sampler.
samples = sampler(score_model,
                  65,
                  y,
                  marginal_prob_std_fn,
                  diffusion_coeff_fn,
                  sample_batch_size,
                  device=device)

  t = torch.tensor(t, device=device)
  return torch.tensor(sigma**t, device=device)


In [41]:
samples.shape

torch.Size([64, 65])

In [42]:
class MLP(nn.Module):
    def __init__(self, input_dim: int, hidden_dims: list, output_dim: int, activation: str = "relu", bias=True):
        super().__init__()
        dims = [input_dim] + hidden_dims + [output_dim]

        if activation == "relu":
            activation = nn.ReLU
        elif activation == "sigmoid":
            activation = nn.Sigmoid
        elif activation == "tanh":
            activation = nn.Tanh
        elif activation == "gelu":
            activation = nn.GELU
        else:
            raise ValueError()

        layerlist = []
        for i in range(len(dims) - 2):
            layerlist.append(nn.Linear(dims[i], dims[i+1], bias=bias))
            layerlist.append(activation())
        layerlist.append(nn.Linear(dims[-2], dims[-1], bias=True))

        self.layers = nn.Sequential(*layerlist)

    def forward(self, x):
        return self.layers(x)

In [43]:
sampled_model = MLP(input_dim=2, hidden_dims=[16,], output_dim=1)
nn.utils.vector_to_parameters(samples[0, :], sampled_model.parameters())

In [44]:
valid_expr.iloc[0]

Unnamed: 0                      1
expr                    (i1 - i0)
index                           1
best_mse_loss            0.000099
best_scaled_mse_loss     0.000005
Name: 1, dtype: object

In [46]:
sampled_model(torch.tensor([[1., 1.]]))

tensor([[2574.7905]], grad_fn=<AddmmBackward0>)