# Model definition

In this notebook I define the model object for the Diffusion LM model.

In [8]:
import math
import numpy as np
from einops import repeat

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from diffusion_lm.data import E2EDataset
from diffusion_lm.model import DiffusionLM

from diffusion_lm.utils import timestep_embedding
from transformers import BertTokenizer, BertConfig
from transformers.models.bert.modeling_bert import BertEncoder

In [9]:
class DiffusionLM(nn.Module):
    def __init__(
        self,
        base_model="bert-base-uncased",
        T=2000,  # diffusion steps
        d=16,  # embedding dimensions
        lr=1e-4,
        dropout=0.1,
    ):
        super().__init__()
        self.tokenizer = BertTokenizer.from_pretrained(base_model)
        self.embedding = nn.Embedding(self.tokenizer.vocab_size, d)
        self.bert_config = BertConfig()
        self.encoder = BertEncoder(self.bert_config)
        self.hidden_dim = d
        self.diffusion_steps = T
        self.time_embed_dim = 4 * d
        self.dropout = nn.Dropout(dropout)
        self.hidden_size = self.bert_config.hidden_size
        self.LayerNorm = nn.LayerNorm(
            self.hidden_size, eps=self.bert_config.layer_norm_eps
        )

        # Add time embedding
        self.time_embedding = nn.Sequential(
            nn.Linear(d, self.time_embed_dim),
            nn.SiLU(),
            nn.Linear(self.time_embed_dim, self.hidden_size),
        )

        # Calculate timestep embeddings
        self.timestep_embeddings = self.get_timestep_embeddings()

        # Add position embeddings
        self.register_buffer(
            "position_ids",
            torch.arange(self.bert_config.max_position_embeddings).expand((1, -1)),
        )
        self.position_embeddings = nn.Embedding(
            self.bert_config.max_position_embeddings, self.hidden_size
        )

        # Downsample input vector
        self.input_projection = nn.Sequential(
            nn.Linear(d, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, self.hidden_size),
        )

        # Downsample output vector
        self.output_projection = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Tanh(),
            nn.Linear(self.hidden_size, 2 * d),
        )

    def get_timestep_embeddings(self):
        timesteps = torch.arange(self.diffusion_steps)
        timesteps = timestep_embedding(timesteps, self.hidden_dim)
        timesteps = self.time_embedding(timesteps)
        return timesteps

    def q_sample(self, x, T):
        """
        Otherwise known as q
        """
        n_batches, seq_length, embed_dim = x.shape

        # Repeat x along time dimension T times
        x_t = einops.repeat(x, "b s x -> b t s x", t=2000)

        # Calculate and propagate noise schedule
        beta_t = torch.Tensor(diffusion_noise_schedule(np.arange(T)))
        beta_t = einops.repeat(
            beta_t, "t -> b t w x", b=n_batches, w=seq_length, x=embed_dim
        )

        # Generate noised samples
        q_t = torch.normal(
            (1 - torch.sqrt(beta_t)) * x_t, std=1 - torch.sqrt(1 - beta_t)
        )

        return q_t

    def forward(self, embeddings):
        """
        Otherwise known as p
        """

        # Convert text to tokens
        n_batches, n_timesteps, seq_length, embed_dim = embeddings.shape

        # Upsample to `hidden_size` dimensional embeddings
        upsampled = self.input_projection(embeddings)
        logging.debug(f"upsampled.shape: {upsampled.shape}")

        # Add timestep embedding + unroll across each sequence
        timesteps = einops.repeat(
            timesteps, "t e -> b t s e", b=n_batches, s=seq_length
        )
        logging.debug(f"timestep.shape: {timesteps.shape}")

        # Calculate positional embedding
        position_embeddings = self.position_embeddings(
            self.position_ids[:, :seq_length]
        )
        logging.debug(f"position_embeddings.shape: {position_embeddings.shape}")

        # Apply dropout + layernorm
        encoder_inputs = self.dropout(
            self.LayerNorm(upsampled + timesteps + position_embeddings)
        )

        # Get `hidden_size`-dimensional bert representation
        encoder_inputs = einops.rearrange(encoder_inputs, "b t s x -> (b t) s x")

        encoded = self.encoder(encoder_inputs).last_hidden_state
        logging.debug(f"encoded.shape: {encoded.shape}")

        # Downsample to d-representation
        downsampled = self.output_projection(encoded)

        return downsampled

In [10]:
model = DiffusionLM()
e2e_dataset = E2EDataset("train")
e2e_dataloader = DataLoader(e2e_dataset, batch_size=2, shuffle=True)

Found cached dataset e2e_nlg (/home/kakapo/.cache/huggingface/datasets/e2e_nlg/default/0.0.0/bfeceb720929c2705bd227d1cfe5eaaab102a0bdac10dad618dac1e00c737430)


  0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
# diffusion time step

embedding = model.get_embedding(example_text)
noised_embeddings = model.q_sample(embedding, 2000)

AttributeError: 'DiffusionLM' object has no attribute 'get_embedding'

In [24]:
# Convert text to tokens
embeddings = noised_embeddings
seq_length = embeddings.size(2)

In [30]:
# Upsample to `hidden_size` dimensional embeddings
upsampled = model.input_projection(embeddings)
upsampled.shape

torch.Size([1, 2000, 14, 768])

In [40]:
n_batches, seq_length, embed_dim = embedding.shape

In [72]:
T = 2000
timestep = torch.arange(T).unsqueeze(1).repeat(n_batches, 1).reshape(n_batches, T)

In [73]:
# Add timestep embedding + unroll across each sequence
timesteps = model.time_embedding(timestep_embedding(timestep, model.hidden_dim))
timesteps = timesteps.unsqueeze(2).expand(-1, seq_length, -1)
timesteps.shape

RuntimeError: The size of tensor a (2000) must match the size of tensor b (8) at non-singleton dimension 2

In [None]:
# Calculate positional embedding
position_embeddings = self.position_embeddings(
    self.position_ids[:, :seq_length]
)
logging.debug(f"position_embeddings.shape: {position_embeddings.shape}")

# Apply dropout + layernorm
encoder_inputs = self.dropout(
    self.LayerNorm(upsampled + timesteps + position_embeddings)
)

# Get `hidden_size`-dimensional bert representation
encoded = self.encoder(encoder_inputs).last_hidden_state
logging.debug(f"encoded.shape: {encoded.shape}")

# Downsample to d-representation
downsampled = self.output_projection(encoded)

In [23]:
output = model.forward(example_text, timestep=t)
print(output.shape)
output

TypeError: forward() got an unexpected keyword argument 'timestep'