In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir(r"C:\Users\sapir\out_of_onedrive\dina\CalibVEP")

In [3]:
import utils
import definitions as d
import torch
from data_classes import METHODS_TO_ESM, MutationVariantSet
from mutation_record import MutationRecord
from utils import is_disordered
from load_cluster_lookup import get_cluster_sizes_of_sequence


  from .autonotebook import tqdm as notebook_tqdm


In [4]:



# def get_trunctad_logits_tiling(protein_seq, seq_name, model, alphabet, rep_layer):
    # """Get logits for long sequences using tiling approach."""
# 
    # return process_long_sequence_chunking_with_overlapping_regions(alphabet, seq_name, protein_seq, model, rep_layer)


def process_long_sequence_chunking_with_overlapping_regions(alphabet, seq_name, protein_seq, model,
                                                            rep_layer):
    """
    Process long protein sequences by chunking with overlapping regions.
    Ensures that <mask> tokens are never split across chunks.

    Args:
        alphabet: The ESM alphabet
        seq_name: Name of the sequence
        protein_seq: The full protein sequence
        model: The ESM model
        esm_model_name: Name of the ESM model
        rep_layer: Representation layer to use

    Returns:
        Combined logits for the entire sequence
    """

    seq_length = len(protein_seq)
    chunk_size = d.ESM_MAX_LENGTH

    # Find all mask positions in the sequence 
    mask_positions = []
    start_pos = 0
    while True:
        mask_pos = protein_seq.find(d.MASK_TOKEN, start_pos)
        if mask_pos == -1:
            break
        mask_positions.append((mask_pos, mask_pos + len(d.MASK_TOKEN)))
        start_pos = mask_pos + len(d.MASK_TOKEN)


    def would_split_mask(chunk_start, chunk_end, mask_positions):
        """Check if the chunk boundaries would split any mask token."""
        for mask_start, mask_end in mask_positions:
            # Check if mask is partially inside the chunk (i.e., would be split)
            if (chunk_start < mask_end and chunk_end > mask_start and
                    not (chunk_start <= mask_start and chunk_end >= mask_end)):
                return True, mask_start, mask_end
        return False, None, None

    chunk_positions = []
    start_pos = 0
    end_pos = 0

    while start_pos < seq_length and end_pos < seq_length:
        end_pos = min(start_pos + chunk_size, seq_length)

        # Check if this chunk would split a mask
        would_split, mask_start, mask_end = would_split_mask(start_pos, end_pos, mask_positions)

        if would_split:
            # Adjust chunk_end to not split the mask
            if mask_start >= start_pos:
                # Mask starts within or after chunk start - end chunk before mask
                end_pos = mask_start
            else:
                # Mask starts before chunk start but ends within chunk - this shouldn't happen with proper chunking
                # but handle it by including the entire mask
                end_pos = mask_end

        # If the adjusted chunk is too small (less than overlap size),
        # we need to include the mask in this chunk
        if end_pos - start_pos < d.OVERLAP_SIZE_LONG_PROTEIN and would_split:
            # Find the mask that's causing the issue and include it
            for mask_start, mask_end in mask_positions:
                if mask_start < start_pos + chunk_size and mask_end > end_pos:
                    end_pos = min(mask_end, seq_length)
                    break

        # Skip empty chunks
        if end_pos <= start_pos:
            start_pos += 1
            continue

        # Define the valid region (excluding padding)
        valid_start = d.OVERLAP_SIZE_LONG_PROTEIN if start_pos > 0 else 0
        valid_end = (end_pos - start_pos) - d.OVERLAP_SIZE_LONG_PROTEIN if end_pos < seq_length else (end_pos - start_pos)

        # Ensure valid_end doesn't exceed chunk size
        valid_end = min(valid_end, end_pos - start_pos)

        # Ensure we have a valid region
        if valid_end <= valid_start:
            valid_end = end_pos - start_pos

        chunk_positions.append({
            'chunk_start': start_pos,
            'chunk_end': end_pos,
            'valid_start': valid_start,
            'valid_end': valid_end
        })


        # Slide the window, considering the overlap
        if end_pos >= seq_length:
            break

        # Calculate next start position
        next_start = start_pos + chunk_size - 2 * d.OVERLAP_SIZE_LONG_PROTEIN

        # Make sure we don't start in the middle of a mask token
        for mask_start, mask_end in mask_positions:
            if mask_start < next_start < mask_end:
                # Adjust start to after the mask
                next_start = mask_end
                break

        start_pos = next_start

    # Process each chunk
    final_logits = None

    for i, pos in enumerate(chunk_positions):
        # Extract the chunk
        chunk_seq = protein_seq[pos['chunk_start']:pos['chunk_end']]

        # Process the chunk
        batch_tokens = get_batch_token(alphabet, seq_name, chunk_seq)
        chunk_logits = get_trunctad_logits(True, batch_tokens, model, rep_layer)
        chunk_logits = chunk_logits.squeeze(0)

        # Extract the valid region
        valid_logits = chunk_logits[pos['valid_start']:pos['valid_end']]

        # Append to the combined logits
        if final_logits is None:
            final_logits = valid_logits
        else:
            final_logits = torch.cat([final_logits, valid_logits], dim=0)

    return final_logits


def get_batch_token(alphabet, example_name, sequence):
    tokenizer = alphabet.get_batch_converter()
    input = [(example_name, sequence)]
    _, _, batch_tokens = tokenizer(input)
    batch_tokens = batch_tokens.to(d.DEVICE)
    return batch_tokens


def get_trunctad_logits(aa_only, batch_tokens, model, rep_layers):
    chunk_logits = model(batch_tokens, repr_layers=rep_layers, return_contacts=False)['logits']
    logit_parts = []
    # logit_parts.append(chunk_logits[0, 1:-1, 4:24] if aa_only else chunk_logits[0, 1:-1, :])
    logit_parts.append(chunk_logits[0, 1:, 4:24] if aa_only else chunk_logits[0, 1:-1, :])
    return torch.stack(logit_parts).to(d.DEVICE)


def get_mutant_dest_and_seq(method_mutant, sequence, aa_mut):
    if method_mutant == METHODS_TO_ESM.MUTANTE:
        mutant_seq = sequence[:aa_mut.mut_idx] + aa_mut.change_aa + sequence[aa_mut.mut_idx + 1:]
    elif method_mutant == METHODS_TO_ESM.MASKED:
        mutant_seq = sequence[:aa_mut.mut_idx] + "<mask>" + sequence[aa_mut.mut_idx + 1:]
    elif method_mutant == METHODS_TO_ESM.WT:
        mutant_seq = sequence
    else:
        raise ValueError(f"Unknown method_mutant: {method_mutant}")
    return mutant_seq


def run_esm_without_poem(model, alphabet, rep_layers, protein_seq, mutation_desc):
    mutation_variant_set = MutationVariantSet()

    model.eval()
    with torch.no_grad():
        aa_mut = utils.process_mutation_name(mutation_desc)
        for method_mutant in METHODS_TO_ESM.get_methods():
            mutant_seq = get_mutant_dest_and_seq(method_mutant, protein_seq, aa_mut)
            seq_name = f"{method_mutant.value}_{mutation_desc}"
            truncated_logits = process_long_sequence_chunking_with_overlapping_regions(alphabet, seq_name, mutant_seq, model, rep_layers)

            mutant_record = MutationRecord(
                protein_seq=protein_seq,
                aa_mut=aa_mut,
                truncated_logits=truncated_logits,
            )
            mutation_variant_set.add_mutation_record(mutant_record, method_mutant)

    return mutation_variant_set


In [None]:
import pickle

import numpy as np

        
def classify_mutation_to_tree_path(homolog_count, is_disordered, sequence_length):
    """
    Helper function to classify a mutation to its tree path based on its characteristics.
    
    Args:
        homolog_count: Number of same sequence in cluster
        is_disordered: Boolean indicating if mutation is in disordered region
        sequence_length: Length of the protein sequence
        
    Returns:
        Tree path string that can be used as a key
    """
    # Homolog classification
    if homolog_count <= 450:
        homolog_part = "homologs_0_to_450"
    else:
        homolog_part = "homologs_450_plus"
    
    # Disorder classification
    if is_disordered:
        disorder_part = "disordered"
    else:
        disorder_part = "ordered"
    
    # Length classification
    if sequence_length <= 1022:
        length_part = "shorter_than_1022"
    else:
        length_part = "longer_than_1022"
    
    return f"all/{homolog_part}/{disorder_part}/{length_part}"


def get_pathogenic_percentage(tree_path_key: str, leaf_best_scores: dict, bins_dict: dict) -> tuple[str, float]:
    """Get the pathogenic percentage for a mutation based on its tree path key.

    Args:
        tree_path_key (str): The key representing the tree path for the mutation.
        leaf_best_scores (dict): A dictionary containing the best scores for each tree path.
        bins_dict (dict): A dictionary containing bin information for each tree path.

    Returns:
        tuple[str, float]: A tuple containing the bin index as a string and the pathogenic percentage.
    """
    bin_key, score_llr = leaf_best_scores[tree_path_key]

    # Get the bin information
    bin_info = bins_dict[tree_path_key][bin_key]
    bin_edges = np.array(bin_info['bin_edges'])
    bin_stats = bin_info['bin_stats']

    # Find which bin the score falls into
    bin_index = np.digitize(score_llr, bin_edges) - 1

    # Handle edge cases
    if bin_index < 0:
        bin_index = 0
    elif bin_index >= len(bin_edges) - 1:
        bin_index = len(bin_edges) - 2

    # Convert bin_index to string (as that's how it's stored in bin_stats)
    bin_index_str = str(bin_index)
    pathogenic_percentage = bin_stats[int(bin_index_str)]['path_pct']

    return score_llr, pathogenic_percentage

In [6]:
# import pandas as pd
# df = pd.read_parquet(r"c:\Users\sapir\out_of_onedrive\dina\entropy-missense-prediction\resources\2_sort_runnings\preprocessed_P53_HUMAN.parquet")

In [7]:
seq = "MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD"
mutation = "M1A"
# esm_model_name = "esm1b_t33_650M_UR50S"
esm_model_name = "esm1_t6_43M_UR50S"

In [8]:
orig_model, orig_alphabet = utils.esm_setup(esm_model_name, device=d.DEVICE)

Using cache found in C:\Users\sapir/.cache\torch\hub\facebookresearch_esm_main


model loaded on cpu


In [9]:
mutation_variant_set = run_esm_without_poem(orig_model, orig_alphabet, [33], seq, mutation)

In [10]:
bins_dict_path = "data/bins_dict.pkl"

with open(bins_dict_path, "rb") as f:
    bins_dict = pickle.load(f)


leaf_best_scores = {
    'all/homologs_0_to_450/disordered/longer_than_1022': ['wt_marginals_base_wt_score_126', mutation_variant_set.wt.nadav_base_wt_score.item()],
    'all/homologs_0_to_450/disordered/shorter_than_1022': ['masked_marginals_entropy_weighted_llr_score_142', mutation_variant_set.masked.entropy_weighted_llr_score.item()],
    'all/homologs_0_to_450/ordered/longer_than_1022': ['wt_marginals_base_wt_score_200', mutation_variant_set.wt.nadav_base_wt_score.item()],
    'all/homologs_0_to_450/ordered/shorter_than_1022': ['mutant_marginals_entropy_weighted_llr_score_200', mutation_variant_set.mutante.entropy_weighted_llr_score.item()],
    'all/homologs_450_plus/disordered/longer_than_1022': ['wt_not_nadav_marginals_base_wt_score_203', mutation_variant_set.wt.llr_base_score.item()],
    'all/homologs_450_plus/disordered/shorter_than_1022': ['wt_not_nadav_marginals_base_wt_score_77', mutation_variant_set.wt.llr_base_score.item()],
    'all/homologs_450_plus/ordered/longer_than_1022': ['mutant_marginals_entropy_weighted_llr_score_200', mutation_variant_set.mutante.entropy_weighted_llr_score.item()],
    'all/homologs_450_plus/ordered/shorter_than_1022': ['masked_marginals_entropy_weighted_llr_score_200', mutation_variant_set.masked.entropy_weighted_llr_score.item()],
}

In [None]:
data_path = 'data/new_clusters_sequences.parquet'
homolog_count = get_cluster_sizes_of_sequence(seq, 'data', 'parquet')

In [None]:


aa_mut = utils.process_mutation_name(mutation)
is_disorder_region = is_disordered(seq, aa_mut.mut_idx)
tree_path_key = classify_mutation_to_tree_path(
    homolog_count=data_path,
    is_disordered=is_disorder_region,
    sequence_length=len(seq)
)

In [None]:
score_llr, pathogenic_percentage = get_pathogenic_percentage(tree_path_key, leaf_best_scores, bins_dict)
print(f"Tree Path Key: {tree_path_key}")
print(f"Score LLR: {score_llr}, Pathogenic Percentage: {pathogenic_percentage:.2f}%")

Tree Path Key: all/homologs_0_to_450/disordered/shorter_than_1022
Bin Index: 3, Pathogenic Percentage: 52.82%
