In [1]:
import torch
from transformers import AutoTokenizer
from src.transformers.models.bert import BertModel
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tokenizer = AutoTokenizer.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True)
model = BertModel.from_pretrained("zhihan1996/DNA_bert_6", trust_remote_code=True).to(device)

  from .autonotebook import tqdm as notebook_tqdm
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
  return self.fget.__get__(instance, owner)()


In [2]:
num_layers = 12
num_heads = 12

In [3]:
tokens_dict = {"tokens": list()}
vocab_reverse = dict((value, key) for key, value in tokenizer.vocab.items())


attention_dict = dict()
point_position_dict = dict()
agg_attn_dict = dict()
for layer in range(num_layers):
    attention_dict[layer] = dict()
    point_position_dict[layer] = dict()
    for head in range(num_heads):
        agg_attn_dict[f"{layer}_{head}"] = list()
        attention_dict[layer][head] = dict()
        attention_dict[layer][head]["layer"] = layer 
        attention_dict[layer][head]["head"] = head
        attention_dict[layer][head]["tokens"] = list() 
        point_position_dict[layer][head] = dict()
        point_position_dict[layer][head]["layer"] = layer 
        point_position_dict[layer][head]["head"] = head
        point_position_dict[layer][head]["tokens"] = list()
        point_position_dict[layer][head]["query"] = list()  
        point_position_dict[layer][head]["key"] = list()         

In [4]:
dna_1 = "GGGGTAATCAGAGCAGAACCAGGCACCTGCCCTGCCTGATGTCCTCTGCTCAGGGCTGGCAGCTGTGTCCTGTGTCCTCCCCACCCCCTGGGACCACAAAGCTCCACCCCTGCCACACCCTGACATACTCAAGCCCAGGAGCCTGACCCAGGGCTCAGGGTGGGGTCAAAAACCGGGGGGATCTGATTTGCATGGATGGACTCTCCCCCTCTCAGAGTATGAAGAGAGGGAGAGATCTGGGGGAAGCTCAGCTTCAGCTGTGGTAGAGAAGACAGGATTCAGGACAATCTCCAGCATGGC"
dna_2 = "AAAGAGACCCGGGGAGCATCTGGGCTTCCAAGGTCCTCGGTACGGCCCAAGGCAGCGAAGGACGCGCGGCTCCAGGCTGCGGGAGCCAGGACGACCGGGGGCTCCCAGAGCGCGAAGTCGCGATCCTCGGCGGTGGAGAGCTCGTGCCAAAACGTCCTCCCCTGCGCCAGTCAGGCCTTCGCGGGGCTGGCAGGCGGGCGGGGGCGGGGCCGCCGCACTTTAAGAGGCTGTGCAGGCAGACAGACCTCCAGGCCCGCTAGGGGATCCGCGCCATGGAGGCCGCCCGGGACTATGCAGGAG"
dna_3 = "AGACCCCGGAGCCACAAGGAGAGGGCTGGATCCCCGGCTCAGAGGGAAGAGGTCGGATCCCCAGCTGAGAGGGAGGAGGGTCCCGGACCCTAGGAGTGGGAAGGAAAGGCTCGGATCCCCTGATCCCCAGGAGGAGGGGACCCGGCTGCCTCCCGGTTGGGGCCGCGCGAGGGCGGGGCGCGGAAGGATCCGGGAGGGCCGTGCTCCGCCACCCAGTATATATCTGTCCCCAGTCCCCGGGGCCGCCTCATTCCCTGTCCTCGGATCACAGTCTCTTCTCACTACAGTGTCGCCGCCTCT"
dna_4 = "GTCTTTCCTTGGAGGAGGCATTGGCACGAGTTACTATAAACTCCCTCTGAATCTCAAGACTTCTGGGACGCCGATTCCGCTCCTGGCCTGGGGCAAGGCGTGGGAGCTTGGAAGCCAGCGCTGCGCTCCCCGTGGGAAGCGATCGTCTCCTCTGTCAACTCGCGCCTGGGCACTTAGCCCCTCCCGTTTCAGGGCGCCGCCTCCCCGGATGGCAAACACTATAAAGTGGCGGCGAATAAGGTTCCTCCTGCTGCTCTCGGTTTAGTCCAAGATCAGCGATATCACGCGTCCCCCGGAGCA"

dataset = [dna_1, dna_2, dna_3, dna_4]

sentence_stops = list()
sentence_starts = list()
pos = 0
for sequence in dataset:
    sentence_starts.append(pos)
    inputs = tokenizer(" ".join([sequence[i:i+6] for i in range(0, len(sequence)-6, 1)]), return_tensors = 'pt')
    out = model(**inputs.to(device), output_attentions=True, return_dict=True)
    tokens = [vocab_reverse[x] for x in inputs['input_ids'].tolist()[0]]   
    pos = pos+len(tokens)
    sentence_stops.append(pos) 
    for i,value in enumerate(tokens):
        single_token = {}
        single_token['value'] = value
        single_token['type'] = "query"
        single_token["length"] = len(tokens)
        single_token['pos_int'] = i
        single_token['position'] = i/(len(tokens)- 1)
        single_token['sentence'] = " ".join(tokens) 
   
        tokens_dict['tokens'].append(single_token)
        for layer in range(num_layers): 
            for head in range(num_heads):
                point_position_dict[layer][head]['query'].append(out.query[layer][head][i].detach().cpu().numpy())
                point_position_dict[layer][head]['key'].append(out.key[layer][head][i].detach().cpu().numpy())
                attention_dict[layer][head]['tokens'].append({'attention' : out.attentions[layer][0][head][i].detach().cpu().numpy()})
        

In [5]:
# Getting point positions

#Some code from chatGPT!!!!!!!
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap
import numpy as np
import tqdm

def get_pca_embeddings(vectors, n_components=2):
    pca = PCA(n_components=n_components)
    return pca.fit_transform(vectors)
def get_tsne_embeddings(vectors, n_components=2, perplexity=30.0):
    tsne = TSNE(n_components=n_components, perplexity=perplexity, n_iter=250)
    return tsne.fit_transform(vectors)
def get_umap_embeddings(vectors, n_components=2, n_neighbors=15, min_dist=0.1):
    umap_model = umap.UMAP(n_components=n_components, n_neighbors=n_neighbors, min_dist=min_dist)
    return umap_model.fit_transform(vectors)
def calculate_centroid(vectors):
    """Calculate the centroid of a list of vectors."""
    return np.mean(vectors, axis=0)
def translate_vectors(source_vectors, target_vectors):
    """Translate source_vectors so their centroid matches that of target_vectors."""
    source_centroid = calculate_centroid(source_vectors)
    target_centroid = calculate_centroid(target_vectors)
    translation = target_centroid - source_centroid
    translated_vectors = source_vectors + translation
    return translated_vectors

def calculate_norm(vector):
    return np.linalg.norm(vector)

for layer in tqdm.tqdm(range(num_layers)): 
    for head in tqdm.tqdm(range(num_heads)):
        translated_key = translate_vectors(point_position_dict[layer][head]['key'], point_position_dict[layer][head]['query'])
        vectors = np.stack(point_position_dict[layer][head]['query'] + [np.array(row) for row in translated_key])
        
        pca_2d = get_pca_embeddings(vectors, n_components=2)
        pca_3d = get_pca_embeddings(vectors, n_components=3)

        tsne_2d = get_tsne_embeddings(vectors, n_components=2)
        tsne_3d = get_tsne_embeddings(vectors, n_components=3)

        umap_2d = get_umap_embeddings(vectors, n_components=2)
        umap_3d = get_umap_embeddings(vectors, n_components=3)
        
        for token in range(umap_2d.shape[0]):
            point_position_dict[layer][head]['tokens'].append({
                "tsne_x" : tsne_2d[token][0],
                "tsne_y" : tsne_2d[token][1],
                
                "tsne_x_3d" : tsne_3d[token][0],
                "tsne_y_3d" : tsne_3d[token][1],                
                "tsne_z_3d" : tsne_3d[token][2],                
                
                "umap_x" : umap_2d[token][0],
                "umap_y" : umap_2d[token][1],          

                "umap_x_3d" : umap_3d[token][0],
                "umap_y_3d" : umap_3d[token][1],                
                "umap_z_3d" : umap_3d[token][2],                         
                
                "pca_x" : pca_2d[token][0],
                "pca_y" : pca_2d[token][1],
                
                "pca_x_3d" : pca_3d[token][0],
                "pca_y_3d" : pca_3d[token][1],                                    
                "pca_z_3d" : pca_3d[token][2],    
                
                "norm" : calculate_norm(vectors[token])                              
            })

100%|██████████| 12/12 [01:39<00:00,  8.29s/it]
100%|██████████| 12/12 [01:42<00:00,  8.53s/it]
100%|██████████| 12/12 [01:41<00:00,  8.48s/it]
100%|██████████| 12/12 [01:41<00:00,  8.44s/it]
100%|██████████| 12/12 [01:43<00:00,  8.63s/it]
100%|██████████| 12/12 [01:37<00:00,  8.11s/it]
100%|██████████| 12/12 [01:35<00:00,  7.96s/it]
100%|██████████| 12/12 [01:34<00:00,  7.86s/it]
100%|██████████| 12/12 [01:34<00:00,  7.87s/it]
100%|██████████| 12/12 [01:35<00:00,  7.99s/it]
100%|██████████| 12/12 [01:38<00:00,  8.21s/it]
100%|██████████| 12/12 [01:41<00:00,  8.48s/it]
100%|██████████| 12/12 [19:46<00:00, 98.85s/it]


In [6]:
for layer in tqdm.tqdm(range(num_layers)): 
    for head in tqdm.tqdm(range(num_heads)):
        avg =  np.average(np.stack([np.stack([token['attention'] for token in attention_dict[layer][head]['tokens'][sentence_starts[i]:sentence_stops[i]]]) for i in range(len(dataset))]),axis=0)
        agg_attn_dict[f"{layer}_{head}"] = [{"attention" : avg[i].tolist()} for i in range(avg.shape[0])]


100%|██████████| 12/12 [00:00<00:00, 260.69it/s]
100%|██████████| 12/12 [00:00<00:00, 174.76it/s]
100%|██████████| 12/12 [00:00<00:00, 146.11it/s]
100%|██████████| 12/12 [00:00<00:00, 173.26it/s]
100%|██████████| 12/12 [00:00<00:00, 174.34it/s]
100%|██████████| 12/12 [00:00<00:00, 157.38it/s]
100%|██████████| 12/12 [00:00<00:00, 163.82it/s]
100%|██████████| 12/12 [00:00<00:00, 191.62it/s]
100%|██████████| 12/12 [00:00<00:00, 170.55it/s]
100%|██████████| 12/12 [00:00<00:00, 171.89it/s]
100%|██████████| 12/12 [00:00<00:00, 157.59it/s]
100%|██████████| 12/12 [00:00<00:00, 158.60it/s]
100%|██████████| 12/12 [00:00<00:00, 13.87it/s]


In [7]:
for layer in tqdm.tqdm(range(num_layers)): 
    for head in tqdm.tqdm(range(num_heads)):
        del point_position_dict[layer][head]["query"]
        del point_position_dict[layer][head]["key"]

100%|██████████| 12/12 [00:00<00:00, 539.48it/s]
100%|██████████| 12/12 [00:00<00:00, 629.93it/s]
100%|██████████| 12/12 [00:00<00:00, 609.91it/s]
100%|██████████| 12/12 [00:00<00:00, 623.66it/s]
100%|██████████| 12/12 [00:00<00:00, 604.93it/s]
100%|██████████| 12/12 [00:00<00:00, 442.57it/s]
100%|██████████| 12/12 [00:00<00:00, 662.04it/s]
100%|██████████| 12/12 [00:00<00:00, 530.47it/s]
100%|██████████| 12/12 [00:00<00:00, 423.45it/s]
100%|██████████| 12/12 [00:00<00:00, 425.98it/s]
100%|██████████| 12/12 [00:00<00:00, 544.43it/s]
100%|██████████| 12/12 [00:00<00:00, 611.00it/s]
100%|██████████| 12/12 [00:00<00:00, 32.05it/s]


In [8]:
for layer in tqdm.tqdm(range(num_layers)): 
    for head in tqdm.tqdm(range(num_heads)):
        for i in range(len(point_position_dict[layer][head]['tokens'])):
            for data_feature in ["tsne_x", "tsne_y", "umap_x", "umap_y", "norm", "tsne_x_3d", "tsne_y_3d", "tsne_z_3d", "umap_x_3d", "umap_y_3d", "umap_z_3d", "pca_x", "pca_y", "pca_x_3d", "pca_y_3d", "pca_z_3d"]:
                point_position_dict[layer][head]['tokens'][i][data_feature] = float(point_position_dict[layer][head]['tokens'][i][data_feature])

100%|██████████| 12/12 [00:00<00:00, 131.74it/s]
100%|██████████| 12/12 [00:00<00:00, 126.65it/s]
100%|██████████| 12/12 [00:00<00:00, 125.97it/s]
100%|██████████| 12/12 [00:00<00:00, 130.81it/s]
100%|██████████| 12/12 [00:00<00:00, 135.13it/s]
100%|██████████| 12/12 [00:00<00:00, 138.67it/s]
100%|██████████| 12/12 [00:00<00:00, 146.87it/s]
100%|██████████| 12/12 [00:00<00:00, 138.59it/s]
100%|██████████| 12/12 [00:00<00:00, 149.73it/s]
100%|██████████| 12/12 [00:00<00:00, 148.83it/s]
100%|██████████| 12/12 [00:00<00:00, 143.30it/s]
100%|██████████| 12/12 [00:00<00:00, 141.16it/s]
100%|██████████| 12/12 [00:01<00:00, 11.18it/s]


In [9]:
import copy
tokens = copy.deepcopy(tokens_dict["tokens"])
pls = list()
for token in tqdm.tqdm(tokens):
    token["type"] = "key"
    tokens_dict["tokens"].append(token)
    

100%|██████████| 1184/1184 [00:00<00:00, 4414271.94it/s]


In [10]:
for layer in tqdm.tqdm(range(num_layers)): 
    for head in tqdm.tqdm(range(num_heads)):
        for att in attention_dict[layer][head]['tokens']:
            att['attention'] = att['attention'].tolist() 
        

100%|██████████| 12/12 [00:00<00:00, 57.77it/s]
100%|██████████| 12/12 [00:00<00:00, 55.87it/s]
100%|██████████| 12/12 [00:00<00:00, 88.04it/s]
100%|██████████| 12/12 [00:00<00:00, 135.33it/s]
100%|██████████| 12/12 [00:00<00:00, 142.98it/s]
100%|██████████| 12/12 [00:00<00:00, 144.19it/s]
100%|██████████| 12/12 [00:00<00:00, 150.71it/s]
100%|██████████| 12/12 [00:00<00:00, 145.44it/s]
100%|██████████| 12/12 [00:00<00:00, 148.05it/s]
100%|██████████| 12/12 [00:00<00:00, 152.99it/s]
100%|██████████| 12/12 [00:00<00:00, 150.12it/s]
100%|██████████| 12/12 [00:00<00:00, 163.29it/s]
100%|██████████| 12/12 [00:01<00:00,  9.10it/s]


## Writing to files

In [11]:
import json

In [12]:
with open("/home/cameron/repos/attention-viz-bio/web/data/DNABERT/agg_attn.json", "w") as fp:
    json.dump(agg_attn_dict , fp) 

In [13]:
with open("/home/cameron/repos/attention-viz-bio/web/data/DNABERT/tokens.json", "w") as fp:
    json.dump(tokens_dict , fp) 

In [14]:
for layer in tqdm.tqdm(range(num_layers)): 
    for head in tqdm.tqdm(range(num_heads)):
        with open(f"/home/cameron/repos/attention-viz-bio/web/data/DNABERT/attention/layer{layer}_head{head}.json", "w") as fp:
            json.dump(attention_dict[layer][head] , fp) 
        with open(f"/home/cameron/repos/attention-viz-bio/web/data/DNABERT/byLayerHead/layer{layer}_head{head}.json", "w") as fp:
            json.dump(point_position_dict[layer][head] , fp)             
        

100%|██████████| 12/12 [00:03<00:00,  3.94it/s]
100%|██████████| 12/12 [00:03<00:00,  3.96it/s]
100%|██████████| 12/12 [00:03<00:00,  3.92it/s]
100%|██████████| 12/12 [00:03<00:00,  3.93it/s]
100%|██████████| 12/12 [00:02<00:00,  4.33it/s]
100%|██████████| 12/12 [00:02<00:00,  4.37it/s]
100%|██████████| 12/12 [00:02<00:00,  4.29it/s]
100%|██████████| 12/12 [00:03<00:00,  3.62it/s]
100%|██████████| 12/12 [00:03<00:00,  3.13it/s]
100%|██████████| 12/12 [00:03<00:00,  3.78it/s]
100%|██████████| 12/12 [00:03<00:00,  3.84it/s]
100%|██████████| 12/12 [00:02<00:00,  4.01it/s]
100%|██████████| 12/12 [00:36<00:00,  3.08s/it]


: 