# Model definition

In this notebook I define the model object for the Diffusion LM model.

In [1]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BertTokenizer, BertConfig
from transformers.models.bert.modeling_bert import BertEncoder

In [2]:
def diffusion_noise_schedule(t, T=2000, s=1e-4):
    alpha = 1 - np.sqrt(t / T + s)
    return np.sqrt(1 - alpha)

In [3]:
def timestep_embedding(timesteps, dim, max_period=10000):
    """
    Create sinusoidal timestep embeddings.
    :param timesteps: a 1-D Tensor of N indices, one per batch element.
                      These may be fractional.
    :param dim: the dimension of the output.
    :param max_period: controls the minimum frequency of the embeddings.
    :return: an [N x dim] Tensor of positional embeddings.
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period)
        * torch.arange(start=0, end=half, dtype=torch.float32)
        / half
    )
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    return embedding

In [4]:
class DiffusionLM(nn.Module):
    def __init__(
        self,
        base_model="bert-base-uncased",
        T=2000,  # diffusion steps
        d=16,  # embedding dimensions
        lr=1e-4,
        dropout=0.1,
    ):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(base_model)
        self.embedding = nn.Embedding(self.tokenizer.vocab_size, d)
        self.bert_config = BertConfig()
        self.encoder = BertEncoder(self.bert_config)
        self.hidden_dim = d
        self.diffusion_steps = T
        self.time_embed_dim = 4 * d
        self.dropout = nn.Dropout(dropout)
        self.hidden_size = self.bert_config.hidden_size
        self.LayerNorm = nn.LayerNorm(
            self.hidden_size, eps=self.bert_config.layer_norm_eps
        )

        # Add position embeddings
        self.register_buffer(
            "position_ids",
            torch.arange(self.bert_config.max_position_embeddings).expand((1, -1)),
        )
        self.position_embeddings = nn.Embedding(
                self.bert_config.max_position_embeddings, self.hidden_size
        )

        # Add time embedding
        self.time_embedding = nn.Sequential(
            nn.Linear(d, self.time_embed_dim),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim, self.hidden_size),
        )

        # Downsample input vector
        self.input_projection = nn.Sequential(
            nn.Linear(d, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, self.hidden_size),
        )

        # Downsample output vector
        self.output_projection = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, d),
        )
        
    def get_embedding(self, text):
        '''
        Returns the downsampled token embedding sequence for the given `words`
        '''
        # Convert text to tokens
        tokens = model.tokenizer(text, return_tensors="pt")["input_ids"]
        seq_length = tokens.size(1)

        # Get d-dimensional token embeddings
        embeddings = self.embedding(tokens)
        return embeddings
    
    def forward_diffusion(self, x, T):
        pass

    def forward(self, text, timestep):

        # Convert text to tokens
        embeddings = self.get_embedding(text)
        seq_length = embeddings.size(1)

        # Upsample to `hidden_size` dimensional embeddings
        upsampled = self.input_projection(embeddings)
        print(f"upsampled.shape: {upsampled.shape}")

        # Add timestep embedding + unroll across each sequence
        timesteps = self.time_embedding(timestep_embedding(timestep, self.hidden_dim))
        timesteps = timesteps.unsqueeze(1).expand(-1, seq_length, -1)
        print(f"timestep.shape: {timesteps.shape}")

        # Calculate positional embedding
        position_embeddings = self.position_embeddings(
            self.position_ids[:, :seq_length]
        )
        print(f"position_embeddings.shape: {position_embeddings.shape}")

        # Apply dropout + layernorm
        encoder_inputs = self.dropout(
            self.LayerNorm(upsampled + timesteps + position_embeddings)
        )

        # Get `hidden_size`-dimensional bert representation
        representations = model.encoder(encoder_inputs).last_hidden_state
        print(f"representations.shape: {representations.shape}")

        # Downsample to d-representation
        downsampled = self.output_projection(representations)

        return downsampled
    
    def fit(self, train_dataset, epochs, batch_size):
        
        for epoch in epochs:
            
            for batch in train_dataset:
                
                for text in batch:
                
                    # Embed the provided text
                    embedding = self.get_embedding(text)

                    # Calculate the forward diffusion steps
                    diffused_embeddings = self.forward_diffusion(embedding, T=self.diffusion_steps)

In [5]:
model = DiffusionLM()

In [6]:
# diffusion time step
t = torch.tensor([0])
example_text = "In a hole in the ground there lived a hobbit"
output = model.forward(example_text, timestep=t)
print(output.shape)
output

upsampled.shape: torch.Size([1, 14, 768])
timestep.shape: torch.Size([1, 14, 768])
position_embeddings.shape: torch.Size([1, 14, 768])
representations.shape: torch.Size([1, 14, 768])
torch.Size([1, 14, 16])


tensor([[[-3.6522e-01,  2.3675e-01,  2.6880e-02, -1.2815e-01,  1.8887e-01,
          -8.4596e-02,  8.7145e-02,  2.6924e-01, -3.9145e-01,  3.8011e-01,
           6.2855e-02, -3.1256e-01,  1.9716e-01,  2.9512e-02,  1.0252e-01,
           1.5824e-01],
         [-4.0389e-01, -1.0364e-01, -2.9069e-01, -7.6002e-02, -1.6227e-01,
           1.8669e-01,  6.3557e-02,  3.6291e-01, -4.8197e-01,  1.7780e-01,
           2.1744e-01,  4.3620e-02,  1.2060e-01, -1.8396e-01, -4.1723e-02,
          -8.8776e-02],
         [-3.5117e-01,  5.6830e-02, -1.4911e-01, -4.2641e-01,  2.2400e-02,
           2.4115e-01,  5.4709e-01,  4.7731e-01, -5.8525e-03,  1.0133e-01,
          -3.9800e-01,  4.1449e-01,  4.5522e-01,  8.5924e-02,  2.0514e-02,
           2.2399e-01],
         [-4.1458e-01, -1.1675e-01,  1.6530e-01,  2.9977e-03,  7.2667e-02,
           2.5901e-01, -1.5625e-01,  4.8887e-01, -5.4611e-01,  1.1773e-01,
          -3.2767e-03,  5.0806e-02,  5.3547e-02,  4.4382e-03,  2.5977e-01,
           2.1639e-01],
    