In [1]:
from tqdm import tqdm
import random
import os
import argparse
import numpy as np
import pandas as pd
from Bio import SeqIO
import re
import esm
import torch
import json

fastafile = "/biodata/franco/datasets/disprot/disprot_OK_fullset_2023_12.fasta"

start = 0
end = 1
outdir = "ESM2_dev_output"
output_attentions = True


def sequence_masker(seq, i, j):
    masked_sequence_list = seq.split()
    if j<=i:
        print(f"index j={j} must be greater than i={i}")
        raise
    for x in range(i, j):
        if j > len(seq):
            break
        masked_sequence_list[x] = f"<mask>"
    return " ".join(masked_sequence_list)


#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

print("Parameters:")
print(f"\t- Output attentions: {output_attentions}")

if not os.path.exists(outdir):
    os.makedirs(outdir)
if not os.path.exists(outdir+"/logits"):
    os.makedirs(outdir+"/logits")
if output_attentions:
    if not os.path.exists(outdir+"/attentions"):
        os.makedirs(outdir+"/attentions")

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f"\t- Device: {device}")
# Load model and tokenizer
model, alphabet = esm.pretrained.esm2_t36_3B_UR50D()

if device.type == 'cuda':
    model = model.eval().cuda()  # disables dropout for deterministic results
    torch.cuda.empty_cache()
else:
    model = model.eval()

# Define the loss function obj
loss = torch.nn.CrossEntropyLoss()

# Read fasta sequences
counter = 0
batch_converter = alphabet.get_batch_converter()


Parameters:
	- Output attentions: True
	- Device: cuda:0


In [2]:
dir(esm.pretrained)

['ESM2',
 'Namespace',
 'Path',
 '__builtins__',
 '__cached__',
 '__doc__',
 '__file__',
 '__loader__',
 '__name__',
 '__package__',
 '__spec__',
 '_download_model_and_regression_data',
 '_has_regression_weights',
 '_load_model_and_alphabet_core_v1',
 '_load_model_and_alphabet_core_v2',
 'esm',
 'esm1_t12_85M_UR50S',
 'esm1_t34_670M_UR100',
 'esm1_t34_670M_UR50D',
 'esm1_t34_670M_UR50S',
 'esm1_t6_43M_UR50S',
 'esm1b_t33_650M_UR50S',
 'esm1v_t33_650M_UR90S',
 'esm1v_t33_650M_UR90S_1',
 'esm1v_t33_650M_UR90S_2',
 'esm1v_t33_650M_UR90S_3',
 'esm1v_t33_650M_UR90S_4',
 'esm1v_t33_650M_UR90S_5',
 'esm2_t12_35M_UR50D',
 'esm2_t30_150M_UR50D',
 'esm2_t33_650M_UR50D',
 'esm2_t36_3B_UR50D',
 'esm2_t48_15B_UR50D',
 'esm2_t6_8M_UR50D',
 'esm_if1_gvp4_t16_142M_UR50',
 'esm_msa1_t12_100M_UR50S',
 'esm_msa1b_t12_100M_UR50S',
 'esmfold_structure_module_only_150M',
 'esmfold_structure_module_only_150M_270K',
 'esmfold_structure_module_only_15B',
 'esmfold_structure_module_only_35M',
 'esmfold_structur

In [3]:
sequences = []
lengths = []
for record in SeqIO.parse(fastafile, "fasta"):
    sequences.append(str(record.seq))
    lengths.append(len(str(record.seq)))


In [14]:
ix = lengths.index(np.max(lengths))
ix = lengths.index(300)
longest_seq = sequences[ix]
start = ix
end = start+1
print(ix)
print(len(longest_seq))

1521
300


In [15]:
counter = 0
if output_attentions:
    need_head_weights = True
else:
    need_head_weights = False
for record in SeqIO.parse(fastafile, "fasta"):
    #sequences.append(record)
    if counter >= start and counter < end:
        uniprot_id = record.id
        aa_sequence = str(record.seq)

        pred_dict = dict()
        mask_sizes = [1]
        print(f" Processing {uniprot_id}, protnum {counter}, len: {len(aa_sequence)}")
        if uniprot_id not in pred_dict:
            pred_dict[uniprot_id] = dict()
        
        target_seq = aa_sequence
        input_seq = [" ".join(list(re.sub(r"[UZOB]", "X", target_seq)))]
        batch_labels, batch_strs, batch_tokens = batch_converter([(uniprot_id, input_seq[0])])
        if not os.path.exists(f"{outdir}/{uniprot_id}.json"):
            for mask_size in mask_sizes:
                print(f"#### Mask size: {mask_size} ####")
                    
                loss_sequence = list()
                match_sequence = list()
                logits_sequence = list()
                meanatt_sequence = list()
                maxatt_sequence = list()
                for i in tqdm(range(len(target_seq)-mask_size+1)):

                    masked_seq = sequence_masker(input_seq[0], i, i+mask_size)
                    mbatch_labels, mbatch_strs, mbatch_tokens = batch_converter([(uniprot_id, masked_seq)])
                    with torch.no_grad():
                        results = model(mbatch_tokens.to(device), repr_layers=[36], return_contacts=False, need_head_weights=need_head_weights)
                    cpulogits = results['logits'][0].cpu()
                    loss_val = float(loss(cpulogits[1:-1,], mbatch_tokens[0][1:-1]).numpy())  ## recently corrected to discard cls and eos tokens
                    loss_sequence.append(loss_val)
                    logits_sequence.append(cpulogits[1:-1,].numpy().tolist())
                    #fastpred = tokenizer.decode(torch.tensor(cpulogits[:,:-1,:].numpy().argmax(-1)[0]), skip_special_tokens=False).replace("<"," <").replace(">","> ")
                    fastpred = " ".join([alphabet.get_tok(t) for t in results['logits'][0].cpu().numpy().argmax(-1)][1:-1]) ## delete first and last tokens
                    if input_seq[0] == fastpred:
                        match_sequence.append(True)
                    else:
                        pred_arr = fastpred.split()
                        seq_arr  = input_seq[0].split()
                        if len(pred_arr) == len(seq_arr):
                            local_match_sequence = list()
                            for j in range(len(pred_arr)):
                                if pred_arr[j] != seq_arr[j]:
                                    local_match_sequence.append((j,pred_arr[j], seq_arr[j]))
                            match_sequence.append(local_match_sequence)
                        else:
                            print(f"{i} - Mismatch length error")
                            match_sequence.append(False)
                            loss_sequence
                    if output_attentions:
                        att_cpu = results['attentions'].squeeze().cpu()
                        fullmax = torch.amax(att_cpu, dim=(0,1))
                        fullmean = torch.mean(att_cpu, dim=(0,1))
                        meanatt_sequence.append(fullmean)
                        maxatt_sequence.append(fullmean)
                pred_dict[uniprot_id][f"aamask_{mask_size}"] = dict()
                pred_dict[uniprot_id][f"aamask_{mask_size}"]["match"] = match_sequence
                pred_dict[uniprot_id][f"aamask_{mask_size}"]["loss"] = loss_sequence
                ### This takes too much time and space, we will save the logits somewhere else
                #pred_dict[uniprot_id][f"aamask_{mask_size}"]["logits"] = logits_sequence
                np.save(f"{outdir}/logits/{uniprot_id}_logits_sequence.npy",  np.array(logits_sequence, dtype=object), allow_pickle=True)
                if output_attentions:
                    np.save(f"{outdir}/attentions/{uniprot_id}_max_attentions_sequence.npy",  np.array(meanatt_sequence, dtype=object), allow_pickle=True)
                    np.save(f"{outdir}/attentions/{uniprot_id}_mean_attentions_sequence.npy",  np.array(maxatt_sequence, dtype=object), allow_pickle=True)
            with open(f"{outdir}/{uniprot_id}.json", 'w') as outfmt:
                json.dump(pred_dict, outfmt)
        else:
            print(f"Skipping {uniprot_id} masks")
        if not os.path.exists(f"{outdir}/logits/{uniprot_id}_logits.pt"):
            ## Output the complete attention matrices with a full pass, no mask
            with torch.no_grad():
                #emb = fullmodel(input_ids=true_tok, labels=true_tok, output_attentions=output_attentions, attention_mask=attention_mask, decoder_attention_mask=attention_mask)
                if device.type == 'cuda':
                    # if len(aa_sequence) > 600:
                    #     results = model(batch_tokens.to(device), repr_layers=[36], return_contacts=False)
                    # else:
                    results = model(batch_tokens.to(device), repr_layers=[36], return_contacts=False, need_head_weights=need_head_weights)
                    # torch.save(results['contacts'][0].cpu(), f"{outdir}/logits/{uniprot_id}_contacts.pt")
                else:
                    results = model(batch_tokens.to(device), repr_layers=[36], return_contacts=True, need_head_weights=need_head_weights)
                    torch.save(results['contacts'][0].cpu(), f"{outdir}/logits/{uniprot_id}_contacts.pt")
            cpulogits = results['logits'][0].cpu()[1:-1,]
            torch.save(cpulogits, f"{outdir}/logits/{uniprot_id}_logits.pt")
            if output_attentions:
                att_cpu = results['attentions'].squeeze().cpu()
                fullmax = torch.amax(att_cpu, dim=(0,1))
                fullmean = torch.mean(att_cpu, dim=(0,1))
                torch.save(fullmax, f"{outdir}/attentions/{uniprot_id}_original_max_attentions.pt")
                torch.save(fullmean, f"{outdir}/attentions/{uniprot_id}_original_mean_attentions.pt")
        else:
            print(f"Skipping {uniprot_id} attentions matrices and logits")
    counter += 1
    if device == 'cuda:0':
        torch.cuda.empty_cache()


 Processing Q8GY88, protnum 1521, len: 300
#### Mask size: 1 ####


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [02:07<00:00,  2.35it/s]


In [12]:
np.sum([True for l in lengths if l<700])

2164

In [10]:
lengths

[529,
 170,
 107,
 105,
 318,
 256,
 122,
 295,
 165,
 76,
 164,
 316,
 198,
 827,
 691,
 328,
 98,
 1018,
 495,
 97,
 118,
 211,
 777,
 293,
 506,
 638,
 424,
 677,
 107,
 70,
 239,
 190,
 169,
 206,
 423,
 350,
 416,
 756,
 62,
 336,
 616,
 462,
 205,
 279,
 147,
 206,
 116,
 140,
 372,
 313,
 595,
 765,
 380,
 530,
 952,
 281,
 160,
 551,
 393,
 490,
 663,
 315,
 419,
 267,
 521,
 498,
 262,
 194,
 423,
 443,
 500,
 197,
 286,
 376,
 421,
 247,
 440,
 155,
 117,
 49,
 449,
 310,
 771,
 458,
 651,
 549,
 72,
 395,
 507,
 632,
 833,
 219,
 65,
 177,
 85,
 209,
 276,
 92,
 512,
 549,
 336,
 189,
 288,
 363,
 685,
 884,
 328,
 88,
 265,
 106,
 210,
 261,
 445,
 648,
 331,
 619,
 720,
 961,
 68,
 222,
 362,
 427,
 143,
 254,
 332,
 206,
 599,
 190,
 165,
 90,
 708,
 180,
 164,
 108,
 65,
 314,
 200,
 126,
 346,
 720,
 433,
 915,
 229,
 408,
 540,
 743,
 741,
 304,
 639,
 737,
 726,
 84,
 273,
 380,
 453,
 161,
 258,
 575,
 268,
 309,
 373,
 196,
 439,
 826,
 254,
 350,
 69,
 382,
 163,
 

In [8]:
att_cpu = results['attentions'].squeeze().cpu()
fullmax = torch.amax(att_cpu, dim=(0,1))
fullmean = torch.mean(att_cpu, dim=(0,1))
plt.imshow(np.array(fullmax))
plt.colorbar()
plt.show()

plt.imshow(np.array(fullmean))
plt.colorbar()
plt.show()

NameError: name 'plt' is not defined