In [2]:
import torch
import numpy as np

In [23]:
def loadModel(weight_path: str, mapping_location: str='cpu'):
    model = torch.load(weight_path, map_location=mapping_location)
    encoder_layers = model.bert.encoder.layer
    print(f"Number of encoder layers: {len(encoder_layers)}")
    return model

if __name__=="__main__":
    path = 'bert_imdb_weights_store/bert_imdb45.pth' # give the weight path
    model_saved = loadModel(path)

  model = torch.load(weight_path, map_location=mapping_location)


Number of encoder layers: 12


In [24]:
import pickle
import os

def min_max_normalize(tensor):
    min_val = tensor.min()
    max_val = tensor.max()
    normalized_tensor = (tensor-min_val)/(max_val-min_val)
    return normalized_tensor

def extract_bert_tensor_weights(output_dir):
    triplets = []
    encoder_layers = model_saved.bert.encoder.layer
    for layer_idx, layer in enumerate(encoder_layers):
        print(f"Processing layer: {layer_idx}")

        attention = layer.attention.self
        #extracting query,key, and value weights
        query_weights = attention.query.weight.detach().cpu().numpy()
        key_weights = attention.key.weight.detach().cpu().numpy()
        value_weights = attention.value.weight.detach().cpu().numpy()

        # print(f"query weights shape: {query_weights.shape}")

        batch_size = 53
        sequence_length = 256
        hidden_size = query_weights.shape[0]

        q_result = np.zeros((batch_size, sequence_length, hidden_size), dtype=np.float32)
        k_result = np.zeros((batch_size, sequence_length, hidden_size), dtype=np.float32)
        v_result = np.zeros((batch_size, sequence_length, hidden_size), dtype=np.float32)

        for i in range(batch_size):
            for j in range(sequence_length):
                q_result[i, j, :] = query_weights[:, j % hidden_size]
                k_result[i, j, :] = key_weights[:, j % hidden_size]
                v_result[i, j, :] = value_weights[:, j % hidden_size]
        # print(f"q_result shape: {q_result.shape}")

        q_normalized = min_max_normalize(q_result)
        k_normalized = min_max_normalize(k_result)
        v_normalized = min_max_normalize(v_result)

        q_flat_vector = q_normalized.reshape(-1)
        # print(f"Checking flat vector size of q: {type(q_flat_vector)}")
        k_flat_vector = k_normalized.reshape(-1)
        # print(f"Checking flat vector size of k: {type(k_flat_vector)}")
        v_flat_vector = v_normalized.reshape(-1)
        # print(f"Checking flat vector size of v: {type(v_flat_vector)}")

        triplets.append((q_flat_vector, k_flat_vector, v_flat_vector))
    
    output_file = os.path.join(output_dir,'bert_imdb_pickle_store/bert_imdb45.pkl')
    with open(output_file, 'wb') as f:
        pickle.dump(triplets, f)
    print(f"Created triplets and created a pickle file")

if __name__=="__main__":
    output_dir = ''
    num_hidden_layers = 12
    extract_bert_tensor_weights(output_dir=output_dir)

Processing layer: 0
Processing layer: 1
Processing layer: 2
Processing layer: 3
Processing layer: 4
Processing layer: 5
Processing layer: 6
Processing layer: 7
Processing layer: 8
Processing layer: 9
Processing layer: 10
Processing layer: 11
Created triplets and created a pickle file
