In [1]:
from model import SEDD
from omegaconf import OmegaConf
import torch
import torch.nn as nn
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg_path='configs//config.yaml'
cfg = OmegaConf.load(cfg_path)
cfg.model=OmegaConf.load('configs//model//small.yaml')
score_model = SEDD(cfg).to('cpu')

In [3]:
param_names=[param[0] for param in score_model.named_parameters()] #see if they are naming all their parameters
for k, v in score_model.state_dict().items():
    if k not in param_names:
        print(k)


rotary_emb.inv_freq


In [4]:
score_model.state_dict()['rotary_emb.inv_freq'].shape #this is the only model parameter that we can't access via looking at named_parmeters()

torch.Size([32])

In [5]:
for param in score_model.named_parameters():
    print(param[0])
    print(param[1].shape)

vocab_embed.embedding
torch.Size([50258, 768])
sigma_map.mlp.0.weight
torch.Size([128, 256])
sigma_map.mlp.0.bias
torch.Size([128])
sigma_map.mlp.2.weight
torch.Size([128, 128])
sigma_map.mlp.2.bias
torch.Size([128])
blocks.0.norm1.weight
torch.Size([768])
blocks.0.attn_qkv.weight
torch.Size([2304, 768])
blocks.0.attn_out.weight
torch.Size([768, 768])
blocks.0.norm2.weight
torch.Size([768])
blocks.0.mlp.0.weight
torch.Size([3072, 768])
blocks.0.mlp.0.bias
torch.Size([3072])
blocks.0.mlp.2.weight
torch.Size([768, 3072])
blocks.0.mlp.2.bias
torch.Size([768])
blocks.0.adaLN_modulation.weight
torch.Size([4608, 128])
blocks.0.adaLN_modulation.bias
torch.Size([4608])
blocks.1.norm1.weight
torch.Size([768])
blocks.1.attn_qkv.weight
torch.Size([2304, 768])
blocks.1.attn_out.weight
torch.Size([768, 768])
blocks.1.norm2.weight
torch.Size([768])
blocks.1.mlp.0.weight
torch.Size([3072, 768])
blocks.1.mlp.0.bias
torch.Size([3072])
blocks.1.mlp.2.weight
torch.Size([768, 3072])
blocks.1.mlp.2.bias
to

In [6]:
print(score_model)

SEDD(
  (vocab_embed): EmbeddingLayer()
  (sigma_map): TimestepEmbedder(
    (mlp): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): SiLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
    )
  )
  (rotary_emb): Rotary()
  (blocks): ModuleList(
    (0-11): 12 x DDiTBlock(
      (norm1): LayerNorm()
      (attn_qkv): Linear(in_features=768, out_features=2304, bias=False)
      (attn_out): Linear(in_features=768, out_features=768, bias=False)
      (dropout1): Dropout(p=0.1, inplace=False)
      (norm2): LayerNorm()
      (mlp): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='tanh')
        (2): Linear(in_features=3072, out_features=768, bias=True)
      )
      (dropout2): Dropout(p=0.1, inplace=False)
      (adaLN_modulation): Linear(in_features=128, out_features=4608, bias=True)
    )
  )
  (output_layer): DDitFinalLayer(
    (norm_final): LayerNorm()
    (linear): Lin

The 12 DDit blocks (these are slightly modified transformer blocks from https://github.com/facebookresearch/DiT) should all be the same in terms of the form of the functions. Let's just make sure that is indeed the convention used for the named parameters.

In [7]:
block_param_names=[]
for param in score_model.named_parameters():
    if param[0].startswith('blocks.0'):
        block_param_names.append(param[0].partition('blocks.0.')[2])

In [8]:
block_param_names

['norm1.weight',
 'attn_qkv.weight',
 'attn_out.weight',
 'norm2.weight',
 'mlp.0.weight',
 'mlp.0.bias',
 'mlp.2.weight',
 'mlp.2.bias',
 'adaLN_modulation.weight',
 'adaLN_modulation.bias']

In [9]:
block_param_size_dict={key:[] for key in block_param_names}
for param in score_model.named_parameters():
    if param[0].startswith('blocks'):
        matches_name=False
        for ending in block_param_names:
            if param[0].endswith(ending):
                block_param_size_dict[ending].append(param[1].shape)
                matches_name=True
        if matches_name==False:
            print(f'Warning: paramater {param[0]} has a mismatched naming convention')


In [10]:
for value in block_param_size_dict.values():
    print(len(value))
    print(set(value))
    if len(set(value))!=1:
        print('Warning: mismatched size!')

12
{torch.Size([768])}
12
{torch.Size([2304, 768])}
12
{torch.Size([768, 768])}
12
{torch.Size([768])}
12
{torch.Size([3072, 768])}
12
{torch.Size([3072])}
12
{torch.Size([768, 3072])}
12
{torch.Size([768])}
12
{torch.Size([4608, 128])}
12
{torch.Size([4608])}


Okay good, so indeed we only need to understand the form of a given block to understand the form of the entire model. To summarize, we need to understand the form of the vocab embedding, the sigma map for the timestep embedding, the rotary embedding, a given DDit block, and the final layer.

In class SEDD, we see self.vocab_embed = EmbeddingLayer(config.model.hidden_size, vocab_size). 

We also see the first line in the method forward for SEDD is x = self.vocab_embed(indices), so this is the first operation performed when passing a sequence into the network.

The inputs into SEDD are indices and sigma, where I believe indices are the tokenized words (just represented by what position they take in the dictionary) and sigma is a time-step

In config.yaml, we have vocab_size = 50257 and in small.yaml we have hidden_size = 768. 

For some reason they define their own embedding layer in transformer.py rather than using the standard nn.embedding. We see there that self.vocab_embed(indices) takes in a tensor of indices (necessarily natural numbers between 0 and vocab_size) and outputs for each index a vector of length 768.

The matrix holding these 50257 vectors of length 768 is vocab_embed.embedding

In [11]:
#example on a batch of 2 sequences with length 4
vocab_size=50257
hidden_size=768
embedding = nn.Parameter(torch.empty((vocab_size,hidden_size)))
torch.nn.init.kaiming_uniform_(embedding, a=math.sqrt(5))
indices=torch.tensor([[1,2,3,199],[5,6,7,40]])
embedding[indices]

tensor([[[ 0.0110, -0.0273,  0.0203,  ...,  0.0276,  0.0107, -0.0220],
         [-0.0201,  0.0239,  0.0255,  ...,  0.0160, -0.0254,  0.0331],
         [ 0.0352, -0.0040,  0.0310,  ...,  0.0106,  0.0093,  0.0100],
         [-0.0292, -0.0103,  0.0148,  ..., -0.0120, -0.0008, -0.0164]],

        [[ 0.0097, -0.0285, -0.0270,  ...,  0.0166, -0.0031,  0.0328],
         [ 0.0329,  0.0052,  0.0140,  ..., -0.0037,  0.0115,  0.0331],
         [ 0.0172, -0.0130,  0.0213,  ...,  0.0102,  0.0156, -0.0292],
         [-0.0351, -0.0051, -0.0329,  ..., -0.0021,  0.0172, -0.0065]]],
       grad_fn=<IndexBackward0>)

Next we have c = F.silu(self.sigma_map(sigma)). c is passed as an argument into each block and into the output layer. self.sigma_map = TimestepEmbedder(config.model.cond_dim).

In small.yaml we see cond_dim=128. 

Looking at the forward method for TimestepEmbedder, we have 
def forward(self, t):
    t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
    t_emb = self.mlp(t_freq)
    return t_emb

frequency_embedding_size has default 256, and timestep_embedding is some fixed fourier transform thing which is not learned during training. It takes in a tensor of times and returns for each a 256 dimensional vector.

mlp is given by 
self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, cond_dim, bias=True),
            nn.SiLU(),
            nn.Linear(cond_dim, cond_dim, bias=True),
        )
Looking at the torch documentation, https://pytorch.org/docs/stable/nn.html, we see that the 256 dimensional vector v from each time is transformed via W_2*silu(W_1v+b_1)+b_2, where W_1 is a cond_dim by 256 matrix, b_1 and b_2 cond_dim dimensional vectors, W_2 a cond_dim by cond_dim matrix, and silu is the element-wise silu function.

These are sigma_map.mlp.0.weight, sigma_map.mlp.0.bias, sigma_map.mlp.2.weight, and sigma_map.mlp.2.bias

Next we have rotary_cos_sin = self.rotary_emb(x). Supposing the input was a single sequence of L words, x will be L by hidden_size, since each word gets embedded into a hidden_size-dimensional vector.

self.rotary_emb = rotary.Rotary(config.model.hidden_size // config.model.n_heads)

In small.yaml, we have hidden_size = 768, n_heads=12.

Looking at the Rotary class, we see inv_freq's values are not actually changed during training. This is also doing some kind of fourier thing, and returns two tensors of size...


In [12]:
from model import rotary
Rotary=rotary.Rotary(768/12)
x=embedding[torch.tensor([1,2,3,4,5])]
print(Rotary(x)[0].shape)
print(Rotary(x)[1].shape)
#okay I'm not sure what the role of this is tbh

torch.Size([1, 768, 3, 1, 64])
torch.Size([1, 768, 3, 1, 64])


Next, we have 
for i in range(len(self.blocks)):
    x = self.blocks[i](x, rotary_cos_sin, c, seqlens=None)
where 
self.blocks = nn.ModuleList([
            DDiTBlock(config.model.hidden_size, config.model.n_heads, config.model.cond_dim, dropout=config.model.dropout) for _ in range(config.model.n_blocks)
        ])

In small.yaml, we have hidden_size = 768, n_heads = 12, cond_dim = 128, dropout= 0.1

In forward for DDiTBlock:
        batch_size, seq_len = x.shape[0], x.shape[1]

        bias_dropout_scale_fn = self._get_bias_dropout_scale()

        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c)[:, None].chunk(6, dim=2)

        # attention operation
        x_skip = x
        x = modulate_fused(self.norm1(x), shift_msa, scale_msa)
        # dtype0 = x.dtype

        qkv = self.attn_qkv(x)
        qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, h=self.n_heads)
        with torch.cuda.amp.autocast(enabled=False):
            cos, sin = rotary_cos_sin
            qkv = rotary.apply_rotary_pos_emb(
                qkv, cos.to(qkv.dtype), sin.to(qkv.dtype)
            )
        
        #qkv = rearrange(qkv, 'b s ... -> (b s) ...')
        #if seqlens is None:
        #    cu_seqlens = torch.arange(
        #        0, (batch_size + 1) * seq_len, step=seq_len,
        #        dtype=torch.int32, device=qkv.device
        #    )
        #else:
        #    cu_seqlens = seqlens.cumsum(-1)

        #x = flash_attn_varlen_qkvpacked_func(
        #    qkv, cu_seqlens, seq_len, 0., causal=False)
        
        #x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)

        # Z: I replaced the above with the below because I think flash attention was the thing which is picky about what GPU you use 
        
        # Separate Q, K, V (b, s, h, d_head)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]


        # Create Attention Mask for Variable-Length Sequences
        if seqlens is not None:
            attn_mask = torch.ones(batch_size, seq_len, seq_len, device=x.device) * float('-inf')
            for i in range(batch_size):
                valid_len = seqlens[i]
                attn_mask[i, :valid_len, :valid_len] = 0  # Allow valid positions only
        else:
            attn_mask = None #not sure about this

        # Apply Attention
        x = F.scaled_dot_product_attention(
                q, k, v,attn_mask=attn_mask
            )
        
        x = rearrange(x, '(b s) h d -> b s (h d)', b=batch_size)

        x = bias_dropout_scale_fn(self.attn_out(x), None, gate_msa, x_skip, self.dropout)

        # mlp operation
        x = bias_dropout_scale_fn(self.mlp(modulate_fused(self.norm2(x), shift_mlp, scale_mlp)), None, gate_mlp, x, self.dropout)

adaLN_modulation is a linear transformation from cond_dim to 6*hidden_size, with matrix and vector given by blocks.0.adaLN_modulation.weight and blocks.0.adaLN_modulation.bias

norm1 = LayerNorm(hidden_size). This normalizes the each vector representing a token of length hidden_size to mean 0 and variance 1 then multiplies it by a diagonal matrix with diagonal given by blocks.0.norm1.weight

modulate_fused is just shifting and scaling x

attn_qkv is a linear layer with no bias from hidden_size to 3*hidden_size with matrix given by blocks.0.attn_qkv.weight 

apply_rotary_pos_emb is the same operation as before, doesn't require any learning

The bit with attention mask is to set the values past the length of the sequence to -infty when doing F.scaled_dot_product_attention

F.scaled_dot_product_attention is actually applying the attention mechanism with qkv. I'm not sure how the masking works so IDK for sure if it is set up correctly. This is just a function of q k and v, doesn't introduce any new paramaters

attn_out is a linear layer with no bias from hidden_size to hidden_size

bias_dropout_scale_fn is using F.dropout to randomly zero some values in the tensor 

norm2 is another LayerNorm(hidden_size)

self.mlp is nn.Sequential(
            nn.Linear(dim, mlp_ratio * hidden_size, bias=True),
            nn.GELU(approximate="tanh"),
            nn.Linear(mlp_ratio * hidden_size, hidden_size, bias=True)
        )

mlp_ratio defaults to 4


It is plain to see then where each of these paramaters comes in:

blocks.0.norm1.weight

torch.Size([768])

blocks.0.attn_qkv.weight

torch.Size([2304, 768])

blocks.0.attn_out.weight

torch.Size([768, 768])

blocks.0.norm2.weight

torch.Size([768])

blocks.0.mlp.0.weight

torch.Size([3072, 768])

blocks.0.mlp.0.bias

torch.Size([3072])

blocks.0.mlp.2.weight

torch.Size([768, 3072])

blocks.0.mlp.2.bias

torch.Size([768])

blocks.0.adaLN_modulation.weight

torch.Size([4608, 128])

blocks.0.adaLN_modulation.bias

torch.Size([4608])






The last thing used in foward for SEDD is 

x = self.output_layer(x, c)

(other than some scatter thing that I'm not sure what its doing yet)

self.output_layer=DDitFinalLayer(config.model.hidden_size, vocab_size, config.model.cond_dim)

these are hidden_size: 768, tokens: 50257, cond_dim: 128

forward for DDitFinalLater is:
shift, scale = self.adaLN_modulation(c)[:, None].chunk(2, dim=2)
x = modulate_fused(self.norm_final(x), shift, scale)
x = self.linear(x)

adaLN_modulation is nn.Linear(cond_dim, 2 * hidden_size, bias=True)
norm final is LayerNorm(hidden_size)
self.linear = nn.Linear(hidden_size, vocab_size). By default bias is true

So this final later converts the hidden_size vector for each token to a vocab_size vector.

This final output vector is used for our transition probabilities. In particular, for every position we now have a vector of length the number of tokens. So sample what token to update that position to, we will exponentiate each element and normalize to a probability vector. The x'th element is then interpreted as approximating p_t(x) in the score function. Unfortunately in machine learning the output vectors before applying this "softmax" function are also called the "scores" which is a bit confusing.