In [1]:
import os
import numpy as np
import pandas as pd
import torch 
import matplotlib.pyplot as plt
from datasets import Dataset
from transformers import AutoConfig, AutoTokenizer, DefaultDataCollator
from LAMAR.modeling_nucESM2 import EsmForMaskedLM
from safetensors.torch import load_file

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_max_length = 1026
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("tokenizer/single_nucleotide/", model_max_length=model_max_length)
config = AutoConfig.from_pretrained(
    "config/config_150M.json", 
    vocab_size=len(tokenizer), 
    pad_token_id=tokenizer.pad_token_id,
    mask_token_id=tokenizer.mask_token_id,
    token_dropout=False,
    positional_embedding_type="rotary",
    hidden_size=768,
    intermediate_size=3072,
    num_attention_heads=12,
    num_hidden_layers=12
    )
model = EsmForMaskedLM(config)
weights = load_file("/home/fr/fr_fr/fr_ml642/Thesis/LAMAR/weights")
weight_dict = {}

for k, v in weights.items():
    if k.startswith("esm.lm_head"):
        new_k = k.replace("esm", '', 1)
    elif k.startswith("lm_head"):
        new_k = k
    elif k.startswith("esm."):
        new_k = k
    else:
        if k.startswith("contact_head"):
            new_k = "esm." + k
        else:
            new_k = "esm." + k
    weight_dict[new_k] = v
model.load_state_dict(weight_dict)
model.to(device)
result = model.load_state_dict  (weight_dict, strict=False)
print("Missing keys:", result.missing_keys)
print("Unexpected keys:", result.unexpected_keys)

Missing keys: []
Unexpected keys: []


In [None]:
torch.cuda.is_available()

In [8]:
model.eval()

EsmForMaskedLM(
  (esm): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(10, 768, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 768, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0-11): 12 x EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise

In [10]:
## Visualization functions

import seaborn as sns
import matplotlib.pyplot as plt

def plot_map_with_seq(matrix, dna_sequence,  plot_size=10, vmax=5, tick_label_fontsize=8):

    fig, ax = plt.subplots(figsize=(plot_size, plot_size))
    
   
    sns.heatmap(matrix, cmap='coolwarm', vmax=vmax, ax=ax, 
                xticklabels=False, yticklabels=False)  
    ax.set_aspect('equal')

    tick_positions = np.arange(len(dna_sequence)) + 0.5 # Center the ticks

    ax.set_xticks(tick_positions)
    ax.set_yticks(tick_positions)
    ax.set_xticklabels(list(dna_sequence), fontsize=tick_label_fontsize, rotation=0)
    ax.set_yticklabels(list(dna_sequence), fontsize=tick_label_fontsize)

    plt.show()
    
def plot_map(matrix, vmax=None, display_values=False, annot_size=8, fig_size=10):
   
    plt.figure(figsize=(fig_size, fig_size))

    ax = sns.heatmap(matrix, cmap="coolwarm", vmax=vmax, annot=display_values, 
                     fmt=".2f", annot_kws={"size": annot_size})

    ax.set_aspect('equal')

    plt.show()

In [24]:
#dependency map generation functions

nuc_table = {"A" : 0, "C":1, "G":2, "T":3}

def mutate_sequence(seq):
    
    seq = seq.upper()
    mutated_sequences = {'seq':[], 'mutation_pos':[], 'nuc':[], 'var_nt_idx':[]}
    mutated_sequences['seq'].append(seq)
    mutated_sequences['mutation_pos'].append(-1)
    mutated_sequences['nuc'].append('real sequence')
    mutated_sequences['var_nt_idx'].append(-1)


    mutate_until_position = len(seq)

    for i in range(mutate_until_position):
        for nuc in ['A', 'C', 'G', 'T']:
            if nuc != seq[i]:
                mutated_sequences['seq'].append(seq[:i] + nuc + seq[i+1:])
                mutated_sequences['mutation_pos'].append(i)
                mutated_sequences['nuc'].append(nuc)
                mutated_sequences['var_nt_idx'].append(nuc_table[nuc])

    mutations_df = pd.DataFrame(mutated_sequences)

    return mutations_df

def tok_func_species(seq):
    
    return tokenizer.encode_plus(
        seq, 
        return_tensors='pt', 
        return_attention_mask=True,
        return_tokentype_ids=False,
        add_special_tokens=False,
    )


def create_dataloader(dataset, batch_size=64):
    ds = Dataset.from_pandas(dataset[['seq']])
    # batched tokenization returns lists (no tensor objects) and is safer for multiprocessing
    def tok_batch(examples):
        out = tokenizer(
            examples['seq'],
            padding=False,
            add_special_tokens=False,
            return_attention_mask=True,
        )
        return {'input_ids': out['input_ids'], 'attention_mask': out['attention_mask']}
    # Use batched=True and start with num_proc=1; increase only if stable
    tok_ds = ds.map(tok_batch, batched=True, batch_size=64, num_proc=1)
    rem_tok_ds = tok_ds.remove_columns('seq')
    data_collator = DefaultDataCollator()
    data_loader = torch.utils.data.DataLoader(rem_tok_ds, batch_size=batch_size, num_workers=4, shuffle=False, collate_fn=data_collator)
    return data_loader

acgt_idxs = [tokenizer.get_vocab()[nuc] for nuc in ['A', 'C', 'G', 'T']]

def model_inference(model, data_loader):

    output_arrays = []
    for i, batch in enumerate(data_loader):
        # get some tokenized sequences (B, L_in)
        input_ids = batch['input_ids'].to(device).squeeze(1)  # (B, L_in)
        attention_mask = batch['attention_mask'].to(device).squeeze(1)  #
        # predict
        with torch.autocast(device):
            with torch.no_grad():
                outputs = model(input_ids=input_ids, attention_mask=attention_mask).logits.cpu().to(torch.float32)  # (B, L_in, Vocab_size)
        output_probs = torch.nn.functional.softmax(outputs, dim=-1)[:,:,acgt_idxs] # B, L_seq, 4
        output_arrays.append(output_probs) 

    # rebuild to B, L_seq, 4
    snp_reconstruct = torch.concat(output_arrays, axis=0)

    return snp_reconstruct.to(torch.float32).numpy()

def compute_dependency_map(seq, epsilon=1e-10):

    dataset = mutate_sequence(seq) 
    data_loader = create_dataloader(dataset)
    snp_reconstruct = model_inference(model, data_loader)
    
     # for the logit add a small value epsilon and renormalize such that every prob in one position sums to 1
    snp_reconstruct = snp_reconstruct + epsilon
    snp_reconstruct = snp_reconstruct/snp_reconstruct.sum(axis=-1)[:,:, np.newaxis]

    seq_len = snp_reconstruct.shape[1]
    snp_effect = np.zeros((seq_len, seq_len,4, 4))
    reference_probs = snp_reconstruct[dataset[dataset['nuc'] == 'real sequence'].index[0]]

    
    snp_effect[dataset.iloc[1:]['mutation_pos'].values, : ,  dataset.iloc[1:]['var_nt_idx'].values,:] = np.log2(snp_reconstruct[1:]) - np.log2(1 - snp_reconstruct[1:]) \
        - np.log2(reference_probs) + np.log2(1-reference_probs)

    dep_map = np.max(np.abs(snp_effect), axis=(2,3))
    #zero main diagonal values
    dep_map[np.arange(dep_map.shape[0]), np.arange(dep_map.shape[0])] = 0

    return dep_map

In [None]:
chr1_14262_25035_weighted_1762_2274 = "AGCCTGCTGGGAGGGAAGTCACCTCCCCTCAAACGAGGAGCCCTGCGCTGGGGAGGCCGGACCTTTGGAGACTGTGTGTGGGGGCCTGGGCACTGACTTCTGCAACCACCTGAGCGCGGGCATCCTGTGTGCAGATACTCCCTGCTTCCTCTCTAGCCCCCACCCTGCAGAGCTGGACCCCTGAGCTAGCCATGCTCTGACAGTCTCAGTTGCACACACGAGCCAGCAGAGGGGTTTTGTGCCACTTCTGGATGCTAGGGTTACACTGGGAGACACAGCAGTGAAGCTGAAATGAAAAATGTGTTGCTGTAGTTTGTTATTAGACCCCTTCTTTCCATTGGTTTAATTAGGAATGGGGAACCCAGAGCCTCACTTGTTCAGGCTCCCTCTGCCCTAGAAGTGAGAAGTCCAGAGCTCTACAGTTTGAAAACCACTATTTTATGAACCAAGTAGAACAAGATATTTGAAATGGAAACTATTCAAAAAATTGAGAATTTCTGACCACTTAACAA"
len(chr1_14262_25035_weighted_1762_2274)

In [None]:
dep_map = compute_dependency_map(chr1_14262_25035_weighted_1762_2274)

Map:   0%|          | 0/1537 [00:00<?, ? examples/s]

Map: 100%|██████████| 1537/1537 [00:01<00:00, 1023.23 examples/s]

Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/fr/fr_fr/fr_ml642/.conda/envs/torch2.1.2/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/home/fr/fr_fr/fr_ml642/.conda/envs/torch2.1.2/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/home/fr/fr_fr/fr_ml642/.conda/envs/torch2.1.2/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/home/fr/fr_fr/fr_ml642/.conda/envs/torch2.1.2/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/home/fr/fr_fr/fr_ml642/.conda/envs/torch2.1.2/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/home/fr/fr_fr/fr_ml64

In [None]:
plot_map(dep_map, vmax=5)    