In [13]:
import torch
import pandas as pd
import numpy as np

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Available device {device}")

Available device cpu


In [3]:
df = pd.read_csv("cola_public/raw/in_domain_train.tsv", delimiter='\t', header=None, names=['sentence_source', 'label', 'label_notes', 'sentence'])

In [4]:
df.sample(10)

Unnamed: 0,sentence_source,label,label_notes,sentence
7691,sks13,0,*,I sent Bill money to Mary to Sam.
446,bc01,0,*,John thinks what Mary bought.
8005,ad03,1,,That Jason arrived infuriated Medea.
7400,sks13,1,,They are special.
153,cj99,0,?*,"I can well imagine if he eats more, him gettin..."
5635,c_13,1,,I gave my brother a birthday present.
6447,d_98,1,,Mary didn't pick any of the flowers.
8046,ad03,1,,Who ate the cake?
4443,ks08,1,,He will have been seeing his children.
5191,kl93,1,,I don't have potatoes.


In [5]:
sentences = df.sentence.values
labels = df.label.values
print(sentences[0])

Our friends won't buy this analysis, let alone the next one we propose.


In [6]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
tokens = tokenizer.tokenize(sentences[0])
print(f"tokens {tokens}")

tokens ['our', 'friends', 'won', "'", 't', 'buy', 'this', 'analysis', ',', 'let', 'alone', 'the', 'next', 'one', 'we', 'propose', '.']




In [7]:
token_ids = tokenizer.convert_tokens_to_ids(tokens)
print(token_ids)

[2256, 2814, 2180, 1005, 1056, 4965, 2023, 4106, 1010, 2292, 2894, 1996, 2279, 2028, 2057, 16599, 1012]


In [8]:
max_len = 0
for text in sentences:
    input_ids = tokenizer.encode(text, add_special_tokens=True)
    max_len = max(max_len, len(input_ids))
print(f"Max length of input ids {max_len}")

Max length of input ids 47


In [9]:
input_ids = []
attention_mask = []

for text in sentences:
    encoded_dict = tokenizer.encode_plus(
        text,
        add_special_tokens = True,
        max_length = 64,
        truncation = True,
        padding = 'max_length',
        return_attention_mask = True,
        return_tensors = 'pt'
    )
    input_ids.append(encoded_dict['input_ids'])
    attention_mask.append(encoded_dict['attention_mask'])

In [10]:
print(type(input_ids))
print(type(attention_mask))

<class 'list'>
<class 'list'>


In [11]:
input_ids = torch.cat(input_ids, dim=0)
attention_mask = torch.cat(attention_mask, dim=0)
labels = torch.tensor(labels)

In [12]:
print(f"Original {sentences[0]}")
print(f"Input_ids {input_ids[0]}")

Original Our friends won't buy this analysis, let alone the next one we propose.
Input_ids tensor([  101,  2256,  2814,  2180,  1005,  1056,  4965,  2023,  4106,  1010,
         2292,  2894,  1996,  2279,  2028,  2057, 16599,  1012,   102,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0])


In [13]:
from torch.utils.data.dataset import TensorDataset, random_split
dataset = TensorDataset(input_ids, attention_mask, labels)

In [14]:
train_size = int(len(dataset)*(0.9))
test_size = len(dataset)-train_size
train_data, test_data = random_split(dataset, lengths=[train_size, test_size])

In [15]:
from torch.utils.data import DataLoader, RandomSampler

batch_size = 32
train_dataloader = DataLoader(train_data,
                              sampler=RandomSampler(train_data),
                              batch_size=batch_size)

test_dataloader = DataLoader(test_data,
                             sampler=RandomSampler(test_data),
                             batch_size=batch_size)

In [16]:
from transformers import BertForSequenceClassification, AdamW, BertConfig

model = BertForSequenceClassification.from_pretrained('bert-base-uncased',
                                                      num_labels = 2,
                                                      output_attentions = False,
                                                      output_hidden_states = False)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
optimizer = torch.optim.AdamW(params=model.parameters(), lr=0.001, eps = 1e-8)


In [18]:
import numpy as np

# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [19]:
from transformers import get_linear_schedule_with_warmup

epochs = 1

total_steps = len(train_dataloader) * epochs
print(f"Total steps: {total_steps}")
# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)


Total steps: 241


In [20]:
import time
import datetime

def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))

    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))


In [25]:
import random
import numpy as np

seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
training_stats = []

# Measure the total training time for the whole run.
total_t0 = time.time()

for epoch_i in range(0, epochs):

    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    # Measure how long the training epoch takes.
    t0 = time.time()
    total_train_loss = 0

    model.train()

    # For each batch of training data...
    for step, batch in enumerate(train_dataloader):

        # Progress update every 40 batches.
        if step % 40 == 0 and not step == 0:
            # Calculate elapsed time in minutes.
            elapsed = format_time(time.time() - t0)

            # Report progress.
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        # `batch` contains three pytorch tensors:
        #   [0]: input ids
        #   [1]: attention masks
        #   [2]: labels
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()

        # outputs prior to activation.
        train_output = model(b_input_ids,
                             token_type_ids=None,
                             attention_mask=b_input_mask,
                             labels=b_labels)
        # print(train_output.loss)
        loss = train_output.loss
        total_train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
    avg_train_loss = total_train_loss / len(train_dataloader)

    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epcoh took: {:}".format(training_time))


    print("")
    print("Running Validation...")

    t0 = time.time()
    model.eval()

    # Tracking variables
    total_eval_accuracy = 0
    total_eval_loss = 0
    nb_eval_steps = 0

    # Evaluate data for one epoch
    for batch in test_dataloader:
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        with torch.no_grad():
            test_output = model(b_input_ids,
                                   token_type_ids=None,
                                   attention_mask=b_input_mask,
                                   labels=b_labels)
        loss = test_output.loss
        logits = test_output.logits
        total_eval_loss += loss.item()

        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        total_eval_accuracy += flat_accuracy(logits, label_ids)


    # Report the final accuracy for this validation run.
    avg_val_accuracy = total_eval_accuracy / len(test_dataloader)
    print("  Accuracy: {0:.2f}".format(avg_val_accuracy))

    # Calculate the average loss over all of the batches.
    avg_val_loss = total_eval_loss / len(test_dataloader)

    # Measure how long the validation run took.
    validation_time = format_time(time.time() - t0)

    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation took: {:}".format(validation_time))

    # Record all statistics from this epoch.
    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Valid. Accur.': avg_val_accuracy,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))


Training...
  Batch    40  of    241.    Elapsed: 0:11:24.
  Batch    80  of    241.    Elapsed: 0:20:35.
  Batch   120  of    241.    Elapsed: 0:30:29.
  Batch   160  of    241.    Elapsed: 0:40:02.
  Batch   200  of    241.    Elapsed: 0:50:03.
  Batch   240  of    241.    Elapsed: 1:01:28.

  Average training loss: 0.63
  Training epcoh took: 1:01:36

Running Validation...
  Accuracy: 0.70
  Validation Loss: 0.61
  Validation took: 0:02:21

Training complete!
Total training took 1:03:57 (h:mm:ss)


In [27]:
save_model_name = "bert_qkv_e1.pth"
torch.save(model, save_model_name)

In [23]:
model_saved = torch.load("bert_imdb5.pth", map_location=torch.device('cpu'))
encoder_layers = model_saved.bert.encoder.layer
print(len(encoder_layers))

  model_saved = torch.load("bert_imdb5.pth", map_location=torch.device('cpu'))


12


In [24]:
import pickle
import os

def min_max_normalize(tensor):
    min_val = tensor.min()
    max_val = tensor.max()
    normalized_tensor = (tensor-min_val)/(max_val-min_val)
    return normalized_tensor

def extract_bert_tensor_weights(output_dir):
    triplets = []
    encoder_layers = model_saved.bert.encoder.layer
    for layer_idx, layer in enumerate(encoder_layers):
        print(f"Processing layer: {layer_idx}")

        attention = layer.attention.self
        #extracting query,key, and value weights
        query_weights = attention.query.weight.detach().cpu().numpy()
        key_weights = attention.key.weight.detach().cpu().numpy()
        value_weights = attention.value.weight.detach().cpu().numpy()

        # print(f"query weights shape: {query_weights.shape}")

        batch_size = 53
        sequence_length = 256
        hidden_size = query_weights.shape[0]

        q_result = np.zeros((batch_size, sequence_length, hidden_size), dtype=np.float32)
        k_result = np.zeros((batch_size, sequence_length, hidden_size), dtype=np.float32)
        v_result = np.zeros((batch_size, sequence_length, hidden_size), dtype=np.float32)

        for i in range(batch_size):
            for j in range(sequence_length):
                q_result[i, j, :] = query_weights[:, j % hidden_size]
                k_result[i, j, :] = key_weights[:, j % hidden_size]
                v_result[i, j, :] = value_weights[:, j % hidden_size]
        # print(f"q_result shape: {q_result.shape}")

        q_normalized = min_max_normalize(q_result)
        k_normalized = min_max_normalize(k_result)
        v_normalized = min_max_normalize(v_result)

        q_flat_vector = q_normalized.reshape(-1)
        # print(f"Checking flat vector size of q: {type(q_flat_vector)}")
        k_flat_vector = k_normalized.reshape(-1)
        # print(f"Checking flat vector size of k: {type(k_flat_vector)}")
        v_flat_vector = v_normalized.reshape(-1)
        # print(f"Checking flat vector size of v: {type(v_flat_vector)}")

        triplets.append((q_flat_vector, k_flat_vector, v_flat_vector))
    
    output_file = os.path.join(output_dir,'bert_imdb5.pkl')
    with open(output_file, 'wb') as f:
        pickle.dump(triplets, f)
    print(f"Created triplets and created a pickle file")

if __name__=="__main__":
    output_dir = ''
    num_hidden_layers = 12
    extract_bert_tensor_weights(output_dir=output_dir)

Processing layer: 0
Processing layer: 1
Processing layer: 2
Processing layer: 3
Processing layer: 4
Processing layer: 5
Processing layer: 6
Processing layer: 7
Processing layer: 8
Processing layer: 9
Processing layer: 10
Processing layer: 11
Created triplets and created a pickle file
