In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from netam import framework, models
from netam.common import nt_mask_tensor_of, BASES
from netam.framework import SHMoofDataset

from epam import molevol, sequences
from epam.molevol import reshape_for_codons, build_mutation_matrices, codon_probs_of_mutation_matrices

import sys
sys.path.append("/Users/matsen/re/netam-experiments-1")
from shmex.shm_data import load_shmoof_dataframes, pcp_df_of_non_shmoof_nickname, dataset_dict

In [2]:
data_nickname = "shmoof"
pcp_df = pcp_df_of_non_shmoof_nickname(data_nickname)
crepe_path = "../train/trained_models/cnn_joi_lrg-shmoof_small-fixed-0"
crepe = framework.load_crepe(crepe_path)
model = crepe.model
site_count = 500

Loading /Users/matsen/data/shmoof_pcp_2023-11-30_MASKED.csv.gz


In [3]:
train_df, val_df = load_shmoof_dataframes(dataset_dict["shmoof"], val_nickname="small")
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
train_df["branch_length"] = pd.read_csv(crepe_path+".train_branch_lengths.csv")["branch_length"]
pcp_df = train_df

In [4]:
def trim_seqs_to_codon_boundary(seqs):
    return [seq[:len(seq) - len(seq) % 3] for seq in seqs]

pcp_df["parent"] = trim_seqs_to_codon_boundary(pcp_df["parent"])
pcp_df["child"] = trim_seqs_to_codon_boundary(pcp_df["child"])

# take the first 1000 row
pcp_df = pcp_df[:1000].copy()

In [5]:
def codon_probs_of_parent_scaled_rates_and_sub_probs(parent_idxs, scaled_rates, sub_probs):
    # This is from `aaprobs_of_parent_scaled_rates_and_sub_probs`.
    mut_probs = 1.0 - torch.exp(-scaled_rates)
    parent_codon_idxs = reshape_for_codons(parent_idxs)
    codon_mut_probs = reshape_for_codons(mut_probs)
    codon_sub_probs = reshape_for_codons(sub_probs)
    
    # This is from `aaprob_of_mut_and_sub` 
    mut_matrices = build_mutation_matrices(parent_codon_idxs, codon_mut_probs, codon_sub_probs)
    codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

    return codon_probs

In [6]:
# Function to calculate the number of differences between two codons
def num_differences(codon1, codon2):
    return sum(c1 != c2 for c1, c2 in zip(codon1, codon2))

# Initialize a dictionary to store num_diff_tensors indexed by codons
num_diff_tensors = {}

# Iterate over all possible codons and calculate the num_diff_tensors
for i, base1 in enumerate(BASES):
    for j, base2 in enumerate(BASES):
        for k, base3 in enumerate(BASES):
            codon = base1 + base2 + base3
            num_diff_tensor = torch.zeros(4, 4, 4, dtype=torch.int)
            for i2, base1_2 in enumerate(BASES):
                for j2, base2_2 in enumerate(BASES):
                    for k2, base3_2 in enumerate(BASES):
                        codon_2 = base1_2 + base2_2 + base3_2
                        num_diff = num_differences(codon, codon_2)
                        num_diff_tensor[i2, j2, k2] = num_diff
            num_diff_tensors[codon] = num_diff_tensor

# make a dict mapping from codon to triple integer index
codon_to_idxs = {base_1+base_2+base_3: (i, j, k) for i, base_1 in enumerate(BASES) for j, base_2 in enumerate(BASES) for k, base_3 in enumerate(BASES)}
codon_to_idxs

ctc_num_diff_tensor = num_diff_tensors["CTC"]
assert ctc_num_diff_tensor[codon_to_idxs["CTC"]] == 0
assert ctc_num_diff_tensor[codon_to_idxs["ATC"]] == 1
assert ctc_num_diff_tensor[codon_to_idxs["ACC"]] == 2
assert ctc_num_diff_tensor[codon_to_idxs["ACT"]] == 3

In [7]:
def probs_of_difference_count(num_diff_tensor, codon_probs):
    """
    Calculate total probabilities for each number of differences between codons.

    Args:
    - num_diff_tensor (torch.Tensor): A 4x4x4 integer tensor containing the number of differences
                                       between each codon and a reference codon.
    - codon_probs (torch.Tensor): A 4x4x4 tensor containing the probabilities of various codons.

    Returns:
    - total_probs (torch.Tensor): A 1D tensor containing the total probabilities for each number
                                   of differences (0 to 3).
    """
    total_probs = []

    for differences_count in range(4):
        # Create a mask of codons with the desired number of differences
        mask = num_diff_tensor == differences_count

        # Multiply componentwise with the codon_probs tensor and sum
        total_prob = (codon_probs * mask.float()).sum()

        # Append the total probability to the list
        total_probs.append(total_prob.item())

    return torch.tensor(total_probs)

In [22]:
def probs_of_difference_count_seq(parent_seq, codon_probs, num_diff_tensors):
    """
    Calculate probabilities of difference count between parent codons and all other codons for all the sites of a sequence.

    Args:
    - parent_seq (str): The parent nucleotide sequence.
    - codon_probs (torch.Tensor): A tensor containing the probabilities of various codons.
    - num_diff_tensors (dict): A dictionary containing num_diff_tensors indexed by codons.

    Returns:
    - probs (torch.Tensor): A tensor containing the probabilities of different
                                                 counts of differences between parent codons and
                                                 all other codons.
    """
    # Check if the size of the first dimension of codon_probs matches the length of parent_seq divided by 3
    if len(parent_seq) // 3 != codon_probs.size(0):
        raise ValueError("The size of the first dimension of codon_probs should match the length of parent_seq divided by 3.")

    # Initialize a list to store the probabilities of different counts of differences
    probs = []

    # Iterate through codons in parent_seq
    for i in range(0, len(parent_seq), 3):
        # Extract the codon from parent_seq
        codon = parent_seq[i:i+3]

        # if codon contains an N, append a tensor of 4 -1s to probs then continue
        if "N" in codon:
            probs.append(torch.tensor([-100.0] * 4))
            continue

        # Get the corresponding num_diff_tensor from num_diff_tensors
        num_diff_tensor = num_diff_tensors[codon]

        # Get the ith entry of codon_probs
        codon_probs_i = codon_probs[i // 3]

        # Calculate the probabilities of different counts of differences using the num_diff_tensor and codon_probs_i
        total_probs = probs_of_difference_count(num_diff_tensor, codon_probs_i)

        # Append the probabilities to the list
        probs.append(total_probs)

    # Concatenate all the probabilities into a tensor
    probs = torch.stack(probs)

    return probs


In [23]:
rates, csps = framework.trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
pcp_df["rates"] = rates
pcp_df["subs_probs"] = csps

In [24]:
parent, rates, subs_probs, branch_length = pcp_df.loc[0, ["parent", "rates", "subs_probs", "branch_length"]]
# truncate each to be a multiple of 3 in length
parent = parent[:len(parent) - len(parent) % 3]
rates = rates[:len(rates) - len(rates) % 3]
subs_probs = subs_probs[:len(subs_probs) - len(subs_probs) % 3]

mask = nt_mask_tensor_of(parent)
parent_idxs = sequences.nt_idx_tensor_of_str(parent.replace("N", "A"))
parent_len = len(parent)
scaled_rates = branch_length * rates[:parent_len]

codon_probs = codon_probs_of_parent_scaled_rates_and_sub_probs(parent_idxs, scaled_rates, subs_probs)
last_codon_probs = codon_probs[-1]
last_codon_probs

# Example usage:
probs_of_difference_count_seq(parent, codon_probs, num_diff_tensors)[-6:]

tensor([[9.8838e-01, 1.1576e-02, 3.9784e-05, 3.9472e-08],
        [9.8436e-01, 1.5556e-02, 8.0096e-05, 1.3358e-07],
        [9.8991e-01, 1.0065e-02, 2.6426e-05, 1.5379e-08],
        [9.9122e-01, 8.7560e-03, 2.5097e-05, 2.3505e-08],
        [9.9409e-01, 5.8967e-03, 1.1123e-05, 6.6755e-09],
        [9.8765e-01, 1.2329e-02, 2.4462e-05, 1.1690e-08]])

In [25]:
def count_codon_mutations(pcp_df, child_col="child", seed=42):
    # Initialize the random number generator
    rng = np.random.default_rng(seed)
    
    # Dictionary to store total counts of mutations across all PCPs
    mutation_counts = {0: 0, 1: 0, 2: 0, 3: 0}
    
    # Loop over each row in the DataFrame
    for index, row in pcp_df.iterrows():
        parent_seq = row['parent']
        child_seq = row[child_col]
        
        # Select a random frame
        frame = rng.integers(0, 3)
        
        # Process sequences starting from the selected frame
        for i in range(frame, len(parent_seq) - 2, 3):
            parent_codon = parent_seq[i:i+3]
            child_codon = child_seq[i:i+3]
            
            # Ensure we have complete codons
            if len(parent_codon) == 3 and len(child_codon) == 3:
                # Count mutations
                mutations = sum(1 for a, b in zip(parent_codon, child_codon) if a != b)
                mutation_counts[mutations] += 1

    return mutation_counts

child_codon_mut_count = count_codon_mutations(pcp_df, child_col="child")
child_codon_mut_count

{0: 117382, 1: 2451, 2: 197, 3: 22}

In [38]:
mutated_codons_count = sum(child_codon_mut_count[i] for i in range(1, 4))
mutated_codons_count

2670

In [36]:
total_diff_probs = torch.zeros(4)
prob_count = 0

# loop through every row of pcp_df
for _, row in pcp_df.iterrows():
    parent = row['parent']
    rates = row['rates']
    subs_probs = row['subs_probs']
    branch_length = row['branch_length']

    parent = parent[:len(parent) - len(parent) % 3]
    rates = rates[:len(rates) - len(rates) % 3]
    subs_probs = subs_probs[:len(subs_probs) - len(subs_probs) % 3]

    parent_idxs = sequences.nt_idx_tensor_of_str(parent.replace("N", "A"))
    scaled_rates = branch_length * rates

    codon_probs = codon_probs_of_parent_scaled_rates_and_sub_probs(parent_idxs, scaled_rates, subs_probs)

    diff_probs = probs_of_difference_count_seq(parent, codon_probs, num_diff_tensors)

    real_diff_probs = diff_probs[diff_probs.min(axis=1).values > torch.tensor([-99.]), :]
    total_diff_probs += real_diff_probs.sum(dim=0)
    prob_count += real_diff_probs.shape[0]
    
mean_diff_probs = total_diff_probs / prob_count
mean_diff_probs

tensor([9.7247e-01, 2.6584e-02, 9.1313e-04, 2.8865e-05])

In [43]:
relative_data = []

for i in range(2, 4):
    relative_mutation_rate = mean_diff_probs[i] / mean_diff_probs[1]
    relative_mutation_count = child_codon_mut_count[i] / child_codon_mut_count[1]
    relative_data.append((i, relative_mutation_rate.item(), relative_mutation_count))
    
relative_df = pd.DataFrame(relative_data, columns=["num_diff", "relative_mutation_rate", "relative_mutation_count"])
relative_df["ratio"] =  relative_df["relative_mutation_count"] / relative_df["relative_mutation_rate"]
relative_df

Unnamed: 0,num_diff,relative_mutation_rate,relative_mutation_count,ratio
0,2,0.034348,0.080375,2.340012
1,3,0.001086,0.008976,8.26682
