In [89]:
import torch
from transformers import DistilBertTokenizer, DistilBertModel

model = DistilBertModel.from_pretrained('distilbert-base-uncased')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

distilbert_weights = model.state_dict()

embedding_dimension = 768
num_heads = 6
head_size = embedding_dimension // num_heads 

def layer_norm(x, weight, bias, eps=1e-6):
    
    mean = x.mean(dim=-1, keepdim=True)
    std_dev = x.std(dim=-1, keepdim=True)
    x_normalized = (x - mean) / (std_dev + eps)
    output = weight * x_normalized + bias
    
    return output

def get_head_tensor(X_expanded, layer, Q_K_or_V):
    
    #Weight matrix W_Q, W_K, or W_V
    weight_matrix = distilbert_weights['transformer.layer.' + str(layer) + '.attention.' + Q_K_or_V.lower() + '_lin.weight']
    head_divided_weight_matrix = weight_matrix.view(num_heads, head_size, embedding_dimension)

    #Bias matrix b_Q, b_K, or b_V
    bias_matrix = distilbert_weights['transformer.layer.' + str(layer) + '.attention.' + Q_K_or_V.lower() + '_lin.bias']
    head_divided_bias_matrix = bias_matrix.view(num_heads, head_size)

    # Multiply X with W_Q, W_K, or W_V
    head_matrices = torch.matmul(X_expanded, head_divided_weight_matrix.transpose(1, 2)) + head_divided_bias_matrix.unsqueeze(1)

    # Reshape to get the head tensor
    head_matrices = head_matrices.squeeze(1)
    
    return head_matrices

def embed(sentence):
    
    distilbert_weights = model.state_dict()
    
    # Tokenize the sentence
    inputs = tokenizer(sentence, return_tensors="pt")
    inputs = inputs["input_ids"][0]
    tokens_length = len(inputs)
    
    # Full token embeddings
    W = distilbert_weights['embeddings.word_embeddings.weight']

    # Sentence token embeddings
    X = W[inputs]
    
    # Positional embeddings
    P_full = distilbert_weights['embeddings.position_embeddings.weight']
    P = P_full[:tokens_length, :]

    # Add position embeddings to token embeddings
    X = X + P

    # Normalize
    X = layer_norm(X, distilbert_weights['embeddings.LayerNorm.weight'], distilbert_weights['embeddings.LayerNorm.bias'])
    
    return X

tokens_len = X.shape[0] #TODO: redundant, initialize when creating a class

def attention(X, layer):
    
    # For pytorch broadcasting to work, we need to expand the tensor to (1, 9, 768)
    X_expanded = X.unsqueeze(0)  # Shape: (1, 9, 768)
    
    # Query, Key, and Value heads
    Q = get_head_tensor(X_expanded, layer, 'Q')
    K = get_head_tensor(X_expanded, layer, 'K')
    V = get_head_tensor(X_expanded, layer, 'V')

    # Attention Weights
    A = torch.softmax(torch.matmul(Q, K.transpose(1, 2) / torch.sqrt(torch.tensor(head_size).float())),dim=-1)

    # Update V
    V = torch.matmul(A, V)

    # Concatenating the heads
    V = V.view(tokens_len,embedding_dimension)

    #Linear layer
    W_out_lin = distilbert_weights['transformer.layer.' + str(layer) + '.attention.out_lin.weight']
    b_out_lin = distilbert_weights['transformer.layer.' + str(layer) + '.attention.out_lin.bias']
    b_out_lin_matrix = b_out_lin.repeat(tokens_len, 1)

    residual = torch.matmul(V, W_out_lin) + b_out_lin_matrix  #TODO: Need to transpose W_out_lin as per copilot suggestion?

    # Residual Connections
    X = X + residual

    # Normalize
    W_sa = distilbert_weights['transformer.layer.' + str(layer) + '.sa_layer_norm.weight']
    b_sa = distilbert_weights['transformer.layer.' + str(layer) + '.sa_layer_norm.bias']
    X = layer_norm(X, W_sa, b_sa)
    
    return X

def feed_forward(X, layer):
    
    # ff Linear 1
    W_ff1 = distilbert_weights['transformer.layer.' + str(layer) + '.ffn.lin1.weight']
    b_ff1 = distilbert_weights['transformer.layer.' + str(layer) + '.ffn.lin1.bias']
    b_ff1_matrix = b_ff1.repeat(9, 1)

    FF_data = torch.matmul(X, W_ff1.transpose(0,1) ) + b_ff1_matrix

    # FF ReLU
    FF_data = torch.relu(FF_data)

    # FF Linear 2
    W_ff2 = distilbert_weights['transformer.layer.' + str(layer) + '.ffn.lin2.weight']
    b_ff2 = distilbert_weights['transformer.layer.' + str(layer) + '.ffn.lin2.bias']
    b_ff2_matrix = b_ff2.repeat(9, 1)

    X = torch.matmul(FF_data, W_ff2.transpose(0,1) ) + b_ff2_matrix

    # Normalize
    W_ff = distilbert_weights['transformer.layer.' + str(layer) + '.output_layer_norm.weight']
    b_ff = distilbert_weights['transformer.layer.' + str(layer) + '.output_layer_norm.bias']
    X = layer_norm(X, W_ff, b_ff)
    
    return X


In [90]:
X = embed("The cat sat on the mat.")
for layer in range(6):
    X = attention(X, layer)
    X = feed_forward(X, layer)
print(X)

tensor([[-0.9836, -0.1272, -0.6853,  ...,  1.0959,  0.5143,  0.6237],
        [ 0.2990, -0.2386,  0.7141,  ..., -0.2033,  0.0946,  0.0958],
        [ 1.0037, -0.1444,  0.3477,  ..., -0.0689, -0.6273,  0.4990],
        ...,
        [ 1.1924, -0.2387,  1.2874,  ...,  1.0004, -0.8292, -0.3271],
        [-0.2179,  0.1382,  1.5083,  ...,  0.2565,  1.0773, -0.4967],
        [ 0.4736,  0.0803,  0.5177,  ...,  1.0466, -0.1284, -0.4433]])
