# Model definition

In this notebook I define the model object for the Diffusion LM model.

In [203]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BertTokenizer, BertConfig, BertModel
from transformers.models.bert.modeling_bert import BertEncoder

In [204]:
def diffusion_noise_schedule(t, T=2000, s=1e-4):
    alpha = 1 - np.sqrt(t / T + s)
    return np.sqrt(1 - alpha)

In [213]:
def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period)
        * torch.arange(start=0, end=half, dtype=torch.float32)
        / half
    )
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

In [278]:
class DiffusionLM(nn.Module):
    def __init__(
        self,
        base_model="bert-base-uncased",
        T=2000,  # diffusion steps
        d=16,  # embedding dimensions
        lr=1e-4,
        dropout=0.1,
    ):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(base_model)
        self.embedding = nn.Embedding(self.tokenizer.vocab_size, d)
        self.bert_config = BertConfig()
        self.bert_model = BertModel(self.bert_config)
        self.hidden_dim = d
        self.time_embed_dim = 4 * d
        self.dropout = nn.Dropout(dropout)
        self.hidden_size = self.bert_config.hidden_size
        self.LayerNorm = nn.LayerNorm(
            self.hidden_size, eps=self.bert_config.layer_norm_eps
        )

        # Add position embeddings
        self.register_buffer(
            "position_ids",
            torch.arange(self.bert_config.max_position_embeddings).expand((1, -1)),
        )
        self.position_embeddings = self.bert_model.embeddings.position_embeddings

        # Add time embedding
        self.time_embedding = nn.Sequential(
            nn.Linear(d, self.time_embed_dim),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim, self.hidden_size),
        )

        # Downsample input vector
        self.input_projection = nn.Sequential(
            nn.Linear(d, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, self.hidden_size),
        )

        # Downsample output vector
        self.output_projection = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, d),
        )

    def forward(self, w, t):

        # Convert text to tokens
        tokens = model.tokenizer(w, return_tensors="pt")["input_ids"]
        seq_length = tokens.size(1)

        # Get d-dimensional token embeddings
        embeddings = self.embedding(encoded_input)

        # Upsample to `hidden_size` dimensional embeddings
        upsampled = self.input_projection(embeddings)
        print(f"upsampled.shape: {upsampled.shape}")

        # Add timestep embedding + unroll across each sequence
        timesteps = self.time_embedding(timestep_embedding(t, self.hidden_dim))
        timesteps = timesteps.unsqueeze(1).expand(-1, seq_length, -1)
        print(f"timestep.shape: {timesteps.shape}")

        # Calculate positional embedding
        position_embeddings = self.position_embeddings(
            self.position_ids[:, :seq_length]
        )
        print(f"position_embeddings.shape: {position_embeddings.shape}")

        # Apply dropout + layernorm
        encoder_inputs = self.dropout(
            self.LayerNorm(upsampled + timesteps + position_embeddings)
        )

        # Get `hidden_size`-dimensional bert representation
        representations = model.bert_model.encoder(encoder_inputs).last_hidden_state
        print(f"representations.shape: {representations.shape}")

        # Downsample to d-representation
        downsampled = self.output_projection(representations)

        return downsampled

In [279]:
model = DiffusionLM()

In [280]:
# diffusion time step
t = torch.tensor([0])
example_text = "In a hole in the ground there lived a hobbit"
output = model.forward(example_text, t=t)
print(output.shape)
output

upsampled.shape: torch.Size([1, 14, 768])
timestep.shape: torch.Size([1, 14, 768])
position_embeddings.shape: torch.Size([1, 14, 768])
representations.shape: torch.Size([1, 14, 768])
torch.Size([1, 14, 16])


tensor([[[-0.1186,  0.1950,  0.0790,  0.0560, -0.3271, -0.1720,  0.4681,
           0.0050,  0.1108, -0.0352, -0.3271, -0.1223,  0.2669,  0.5697,
          -0.0176,  0.1616],
         [-0.1581,  0.0814,  0.1467,  0.0228, -0.2938,  0.3322, -0.0022,
           0.1411, -0.1313,  0.0730, -0.0896,  0.0159,  0.5176,  0.2286,
           0.1918,  0.0072],
         [ 0.0645,  0.4671,  0.1659, -0.1761, -0.1825, -0.0747,  0.1974,
           0.1158, -0.0988, -0.0140,  0.1065,  0.0419,  0.6395,  0.3416,
           0.3220,  0.1937],
         [-0.1090,  0.1854,  0.2750, -0.0114, -0.2811, -0.0364,  0.3118,
           0.0676,  0.0936,  0.0554, -0.3419,  0.0867,  0.4015,  0.3010,
           0.1553,  0.0403],
         [-0.2477, -0.0603,  0.1905, -0.2386, -0.3107,  0.0805,  0.1016,
           0.0788,  0.0560,  0.0754, -0.0119, -0.1821,  0.3974,  0.0819,
           0.1427, -0.1553],
         [-0.0889,  0.2032,  0.0881, -0.4264, -0.1793, -0.2520, -0.1048,
           0.0356,  0.0330, -0.2745,  0.0748, -0.123