In [1]:
import torch
from transformers import AutoTokenizer
from src.transformers.models.bert import BertModel
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
model = BertModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)

  from .autonotebook import tqdm as notebook_tqdm
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
  return self.fget.__get__(instance, owner)()


In [45]:
tokens_dict = {"tokens": list()}
vocab_reverse = dict((value, key) for key, value in tokenizer.vocab.items())

attention_dict = dict()
point_position_dict = dict()
agg_attn_dict = dict()
for layer in range(12):
    attention_dict[layer] = dict()
    point_position_dict[layer] = dict()
    for head in range(12):
        agg_attn_dict[f"{layer}_{head}"] = list()
        attention_dict[layer][head] = dict()
        attention_dict[layer][head]["layer"] = layer 
        attention_dict[layer][head]["head"] = head
        attention_dict[layer][head]["tokens"] = list() 
        point_position_dict[layer][head] = dict()
        point_position_dict[layer][head]["layer"] = layer 
        point_position_dict[layer][head]["head"] = head
        point_position_dict[layer][head]["tokens"] = list()
        point_position_dict[layer][head]["query"] = list()  
        point_position_dict[layer][head]["key"] = list()         

In [46]:
dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
dataset = [dna]
for sequence in dataset:
    inputs = tokenizer(" ".join([sequence[i:i+6] for i in range(0, len(sequence), 1)]), return_tensors = 'pt')
    out = model(**inputs.to(device), output_attentions=True, return_dict=True)
    tokens = [vocab_reverse[x] for x in inputs['input_ids'].tolist()[0]]    
    for i,value in enumerate(tokens):
        single_token = {}
        single_token['value'] = value
        single_token['type'] = "query"
        single_token["length"] = len(tokens)
        single_token['pos_int'] = i
        single_token['position'] = i/(len(tokens)- 1)
        single_token['sentence'] = " ".join(tokens) 
   
        tokens_dict['tokens'].append(single_token)
        for layer in range(12): 
            for head in range(12):
                point_position_dict[layer][head]['query'].append(out.query[layer][head][i].detach().cpu().numpy())
                point_position_dict[layer][head]['key'].append(out.key[layer][head][i].detach().cpu().numpy())
                attention_dict[layer][head]['tokens'].append({'attention' : out.attentions[layer][0][head][i].detach().cpu().numpy()})
        

In [47]:
# Getting point positions

#Some code from chatGPT!!!!!!!
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap
import numpy as np
import tqdm

def get_pca_embeddings(vectors, n_components=2):
    pca = PCA(n_components=n_components)
    return pca.fit_transform(vectors)
def get_tsne_embeddings(vectors, n_components=2, perplexity=30.0):
    tsne = TSNE(n_components=n_components, perplexity=perplexity)
    return tsne.fit_transform(vectors)
def get_umap_embeddings(vectors, n_components=2, n_neighbors=15, min_dist=0.1):
    umap_model = umap.UMAP(n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist)
    return umap_model.fit_transform(vectors)


def calculate_norm(vector):
    return np.linalg.norm(vector)

for layer in tqdm.tqdm(range(12)): 
    for head in tqdm.tqdm(range(12)):
        for vector in ['query', 'key']:
            vectors = np.stack(point_position_dict[layer][head][vector])

            pca_2d = get_pca_embeddings(vectors, n_components=2)
            pca_3d = get_pca_embeddings(vectors, n_components=3)

            tsne_2d = get_tsne_embeddings(vectors, n_components=2)
            tsne_3d = get_tsne_embeddings(vectors, n_components=3)

            umap_2d = get_umap_embeddings(vectors, n_components=2)
            umap_3d = get_umap_embeddings(vectors, n_components=3)
            
            for token in range(umap_2d.shape[0]):
                point_position_dict[layer][head]['tokens'].append({
                    "tsne_x" : tsne_2d[token][0],
                    "tsne_y" : tsne_2d[token][1],
                    
                    "tsne_x_3d" : tsne_3d[token][0],
                    "tsne_y_3d" : tsne_3d[token][1],                
                    "tsne_z_3d" : tsne_3d[token][2],                
                    
                    "umap_x" : umap_2d[token][0],
                    "umap_y" : umap_2d[token][1],          

                    "umap_x_3d" : umap_3d[token][0],
                    "umap_y_3d" : umap_3d[token][1],                
                    "umap_z_3d" : umap_3d[token][2],                         
                    
                    "pca_x" : pca_2d[token][0],
                    "pca_y" : pca_2d[token][1],
                    
                    "pca_x_3d" : pca_3d[token][0],
                    "pca_y_3d" : pca_3d[token][1],                                    
                    "pca_z_3d" : pca_3d[token][2],    
                    
                    "norm" : calculate_norm(vectors[token])                              
                })

  0%|          | 0/12 [00:00<?, ?it/s]

100%|██████████| 12/12 [00:42<00:00,  3.57s/it]
100%|██████████| 12/12 [00:43<00:00,  3.59s/it]
100%|██████████| 12/12 [00:45<00:00,  3.80s/it]
100%|██████████| 12/12 [00:44<00:00,  3.70s/it]
100%|██████████| 12/12 [00:44<00:00,  3.68s/it]
100%|██████████| 12/12 [00:44<00:00,  3.70s/it]
100%|██████████| 12/12 [00:45<00:00,  3.80s/it]
100%|██████████| 12/12 [00:44<00:00,  3.75s/it]
100%|██████████| 12/12 [00:49<00:00,  4.10s/it]
100%|██████████| 12/12 [00:46<00:00,  3.90s/it]
100%|██████████| 12/12 [00:46<00:00,  3.88s/it]
100%|██████████| 12/12 [00:46<00:00,  3.83s/it]
100%|██████████| 12/12 [09:03<00:00, 45.30s/it]


In [43]:
for layer in tqdm.tqdm(range(12)): 
    for head in tqdm.tqdm(range(12)):
        attention_dict[]np.average(np.stack([token['attention'] for token in attention_dict[0][0]['tokens']]), axis=1)


(61, 2)

In [61]:
np.stack([token['attention'] for token in attention_dict[0][0]['tokens']]).shape

(61, 61)

In [55]:
np.average(np.stack([token['attention'] for token in attention_dict[0][0]['tokens']]), axis=1)

array([0.01639344, 0.01639344, 0.01639344, 0.01639345, 0.01639344,
       0.01639344, 0.01639345, 0.01639344, 0.01639344, 0.01639344,
       0.01639344, 0.01639344, 0.01639344, 0.01639345, 0.01639344,
       0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639345,
       0.01639344, 0.01639344, 0.01639344, 0.01639345, 0.01639344,
       0.01639344, 0.01639344, 0.01639344, 0.01639345, 0.01639344,
       0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344,
       0.01639344, 0.01639344, 0.01639344, 0.01639344, 0.01639344,
       0.01639345, 0.01639344, 0.01639344, 0.01639344, 0.01639344,
       0.01639345, 0.01639345, 0.01639345, 0.01639344, 0.01639344,
       0.01639344, 0.01639345, 0.01639344, 0.01639344, 0.01639344,
       0.01639345, 0.01639344, 0.01639345, 0.01639344, 0.01639344,
       0.01639344], dtype=float32)

In [22]:

inputs = tokenizer(" ".join([dna[i:i+6] for i in range(0, len(dna)-6, 1)]), return_tensors = 'pt')


In [10]:
len(out.attentions)

12

In [5]:
out.attentions[0].shape

torch.Size([1, 12, 61, 61])

In [11]:
out.attentions[0][0][1][1]

tensor([0.0209, 0.0144, 0.0059, 0.0198, 0.0140, 0.0352, 0.0210, 0.0125, 0.0037,
        0.0059, 0.0109, 0.0372, 0.0382, 0.0058, 0.0201, 0.0153, 0.0225, 0.0069,
        0.0243, 0.0233, 0.0055, 0.0011, 0.0012, 0.0033, 0.0259, 0.0102, 0.0288,
        0.0083, 0.0144, 0.0206, 0.0613, 0.0212, 0.0308, 0.0196, 0.0060, 0.0021,
        0.0089, 0.0165, 0.0643, 0.0132, 0.0056, 0.0143, 0.0109, 0.0035, 0.0046,
        0.0101, 0.0302, 0.0428, 0.0203, 0.0044, 0.0099, 0.0223, 0.0420, 0.0128,
        0.0215, 0.0047, 0.0051, 0.0047, 0.0036, 0.0041, 0.0014],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [28]:
" ".join([vocab_reverse[x] for x in inputs['input_ids'].tolist()[0]])

'[CLS] ACGTAG CGTAGC GTAGCA TAGCAT AGCATC GCATCG CATCGG ATCGGA TCGGAT CGGATC GGATCT GATCTA ATCTAT TCTATC CTATCT TATCTA ATCTAT TCTATC CTATCG TATCGA ATCGAC TCGACA CGACAC GACACT ACACTT CACTTG ACTTGG CTTGGT TTGGTT TGGTTA GGTTAT GTTATC TTATCG TATCGA ATCGAT TCGATC CGATCT GATCTA ATCTAC TCTACG CTACGA TACGAG ACGAGC CGAGCA GAGCAT AGCATC GCATCT CATCTC ATCTCG TCTCGT CTCGTT TCGTTA CGTTAG [SEP]'

In [10]:
out = model(**inputs.to(device), output_attentions=True, return_dict=True)

In [12]:
out.query.shape[]

torch.Size([12, 12, 61, 64])

In [68]:
model.embeddings.forward(inputs.to(device)["input_ids"], inputs.to(device)["token_type_ids"])

tensor([[[-0.0622, -0.0385,  0.1013,  ..., -0.0403, -0.0234, -0.0343],
         [ 0.1521,  0.0156,  0.0175,  ..., -0.1897,  0.0935,  0.0241],
         [ 0.0246,  0.0086,  0.0833,  ..., -0.0325,  0.0045,  0.0782],
         ...,
         [-0.0537,  0.1426,  0.0506,  ...,  0.0769,  0.0726,  0.1250],
         [ 0.1343, -0.1633, -0.0138,  ..., -0.4566, -0.1181, -0.0063],
         [-0.0328, -0.0395,  0.0512,  ..., -0.0384, -0.0244, -0.0343]]],
       device='cuda:0', grad_fn=<NativeLayerNormBackward0>)

In [70]:
hidden_states = model.encoder.forward(model.embeddings.forward(inputs.to(device)["input_ids"], inputs.to(device)["token_type_ids"]),inputs.to(device)["attention_mask"] ,output_all_encoded_layers =True) # [1, sequence_length, 768]

In [73]:
hidden_states[0].shape

torch.Size([17, 768])

In [38]:
hidden_states[1].shape

torch.Size([1, 768])

In [8]:


# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768

# embedding with max pooling
embedding_max = torch.max(hidden_states[0], dim=0)[0]
print(embedding_max.shape) # expect to be 768

torch.Size([768])
torch.Size([768])


In [5]:
torch.cuda.is_available()

True

In [6]:
torch.cuda.current_device()

0

In [10]:
hidden_states.shape

torch.Size([1, 17, 768])

In [14]:
len(tokenizer.vocab)

4096

In [13]:
model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(4101, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
   