In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

from netam import framework, models
from netam.common import nt_mask_tensor_of, BASES
from netam.framework import SHMoofDataset

from epam import molevol, sequences
from epam.molevol import reshape_for_codons, build_mutation_matrices, codon_probs_of_mutation_matrices

import sys
sys.path.append("/Users/matsen/re/netam-experiments-1")
from shmex.shm_data import load_shmoof_dataframes, pcp_df_of_non_shmoof_nickname, dataset_dict

In [2]:
data_nickname = "shmoof"
pcp_df = pcp_df_of_non_shmoof_nickname(data_nickname)
crepe_path = "../train/trained_models/cnn_joi_lrg-shmoof_small-fixed-0"
crepe = framework.load_crepe(crepe_path)
model = crepe.model
site_count = 500

Loading /Users/matsen/data/shmoof_pcp_2023-11-30_MASKED.csv.gz


In [3]:
train_df, val_df = load_shmoof_dataframes(dataset_dict["shmoof"], val_nickname="small")
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
train_df["branch_length"] = pd.read_csv(crepe_path+".train_branch_lengths.csv")["branch_length"]
pcp_df = train_df

In [4]:
def trim_seqs_to_codon_boundary(seqs):
    return [seq[:len(seq) - len(seq) % 3] for seq in seqs]

pcp_df["parent"] = trim_seqs_to_codon_boundary(pcp_df["parent"])
pcp_df["child"] = trim_seqs_to_codon_boundary(pcp_df["child"])

# take the first 1000 row
pcp_df = pcp_df[:1000].copy()

In [5]:
rates, csps = framework.trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
pcp_df["rates"] = rates
pcp_df["subs_probs"] = csps

In [6]:
# print the last entry of pcp_df["subs_probs"]
pcp_df["subs_probs"].iloc[-1]

tensor([[0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500],
        [0.2500, 0.2500, 0.2500, 0.2500],
        ...,
        [0.2058, 0.0000, 0.1938, 0.6004],
        [0.0952, 0.0000, 0.2724, 0.6324],
        [0.2352, 0.5070, 0.2578, 0.0000]])

In [7]:
# get the minimum value of the tensors in pcp_df["subs_probs"]
min([x.min() for x in pcp_df["subs_probs"]])

tensor(0.)

In [8]:
def codon_probs_of_parent_scaled_rates_and_sub_probs(parent_idxs, scaled_rates, sub_probs):
    # This is from `aaprobs_of_parent_scaled_rates_and_sub_probs`.
    mut_probs = 1.0 - torch.exp(-scaled_rates)
    parent_codon_idxs = reshape_for_codons(parent_idxs)
    codon_mut_probs = reshape_for_codons(mut_probs)
    codon_sub_probs = reshape_for_codons(sub_probs)
    
    # This is from `aaprob_of_mut_and_sub` 
    mut_matrices = build_mutation_matrices(parent_codon_idxs, codon_mut_probs, codon_sub_probs)
    codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

    return codon_probs

In [9]:
parent, rates, subs_probs, branch_length = pcp_df.loc[0, ["parent", "rates", "subs_probs", "branch_length"]]
# truncate each to be a multiple of 3 in length
parent = parent[:len(parent) - len(parent) % 3]
rates = rates[:len(rates) - len(rates) % 3]
subs_probs = subs_probs[:len(subs_probs) - len(subs_probs) % 3]

In [10]:
mask = nt_mask_tensor_of(parent)
parent_idxs = sequences.nt_idx_tensor_of_str(parent.replace("N", "A"))

In [16]:
parent_idxs = sequences.nt_idx_tensor_of_str(parent.replace("N", "A"))
parent_len = len(parent)

scaled_rates = branch_length * rates[:parent_len]

codon_probs = codon_probs_of_parent_scaled_rates_and_sub_probs(parent_idxs, scaled_rates, subs_probs)
last_codon_probs = codon_probs[-1]
last_codon_probs

tensor([[[1.2721e-10, 6.0225e-07, 1.3777e-10, 1.7612e-10],
         [1.5850e-10, 7.5039e-07, 1.7166e-10, 2.1944e-10],
         [9.2743e-11, 4.3907e-07, 1.0044e-10, 1.2840e-10],
         [2.3795e-07, 1.1265e-03, 2.5770e-07, 3.2942e-07]],

        [[1.1153e-07, 5.2802e-04, 1.2079e-07, 1.5441e-07],
         [1.3897e-07, 6.5790e-04, 1.5050e-07, 1.9239e-07],
         [8.1312e-08, 3.8495e-04, 8.8062e-08, 1.1257e-07],
         [2.0862e-04, 9.8765e-01, 2.2594e-04, 2.8882e-04]],

        [[3.7351e-10, 1.7683e-06, 4.0452e-10, 5.1711e-10],
         [4.6539e-10, 2.2033e-06, 5.0402e-10, 6.4430e-10],
         [2.7231e-10, 1.2892e-06, 2.9492e-10, 3.7700e-10],
         [6.9864e-07, 3.3076e-03, 7.5665e-07, 9.6724e-07]],

        [[6.3252e-10, 2.9945e-06, 6.8503e-10, 8.7569e-10],
         [7.8811e-10, 3.7311e-06, 8.5354e-10, 1.0911e-09],
         [4.6114e-10, 2.1832e-06, 4.9942e-10, 6.3842e-10],
         [1.1831e-06, 5.6012e-03, 1.2813e-06, 1.6380e-06]]])

In [14]:
# Function to calculate the number of differences between two codons
def num_differences(codon1, codon2):
    return sum(c1 != c2 for c1, c2 in zip(codon1, codon2))

# Initialize a dictionary to store num_diff_tensors indexed by codons
num_diff_tensors = {}

# Iterate over all possible codons and calculate the num_diff_tensors
for i, base1 in enumerate(BASES):
    for j, base2 in enumerate(BASES):
        for k, base3 in enumerate(BASES):
            codon = base1 + base2 + base3
            num_diff_tensor = torch.zeros(4, 4, 4, dtype=torch.int)
            for i2, base1_2 in enumerate(BASES):
                for j2, base2_2 in enumerate(BASES):
                    for k2, base3_2 in enumerate(BASES):
                        codon_2 = base1_2 + base2_2 + base3_2
                        num_diff = num_differences(codon, codon_2)
                        num_diff_tensor[i2, j2, k2] = num_diff
            num_diff_tensors[codon] = num_diff_tensor

# make a dict mapping from codon to triple integer index
codon_to_idxs = {base_1+base_2+base_3: (i, j, k) for i, base_1 in enumerate(BASES) for j, base_2 in enumerate(BASES) for k, base_3 in enumerate(BASES)}
codon_to_idxs

ctc_num_diff_tensor = num_diff_tensors["CTC"]
assert ctc_num_diff_tensor[codon_to_idxs["CTC"]] == 0
assert ctc_num_diff_tensor[codon_to_idxs["ATC"]] == 1
assert ctc_num_diff_tensor[codon_to_idxs["ACC"]] == 2
assert ctc_num_diff_tensor[codon_to_idxs["ACT"]] == 3

In [17]:
def probs_of_difference_count(num_diff_tensor, codon_probs):
    """
    Calculate total probabilities for each number of differences between codons.

    Args:
    - num_diff_tensor (torch.Tensor): A 4x4x4 integer tensor containing the number of differences
                                       between each codon and a reference codon.
    - codon_probs (torch.Tensor): A 4x4x4 tensor containing the probabilities of various codons.

    Returns:
    - total_probs (torch.Tensor): A 1D tensor containing the total probabilities for each number
                                   of differences (0 to 3).
    """
    total_probs = []

    for differences_count in range(4):
        # Create a mask of codons with the desired number of differences
        mask = num_diff_tensor == differences_count

        # Multiply componentwise with the codon_probs tensor and sum
        total_prob = (codon_probs * mask.float()).sum()

        # Append the total probability to the list
        total_probs.append(total_prob.item())

    return torch.tensor(total_probs)

probs_of_difference_count(ctc_num_diff_tensor, last_codon_probs)


tensor([9.8765e-01, 1.2329e-02, 2.4462e-05, 1.1690e-08])

In [22]:
import torch

def probs_of_difference_count_seq(parent_seq, codon_probs, num_diff_tensors):
    """
    Calculate probabilities of difference count between parent codons and all other codons for all the sites of a sequence.

    Args:
    - parent_seq (str): The parent nucleotide sequence.
    - codon_probs (torch.Tensor): A tensor containing the probabilities of various codons.
    - num_diff_tensors (dict): A dictionary containing num_diff_tensors indexed by codons.

    Returns:
    - probs (torch.Tensor): A tensor containing the probabilities of different
                                                 counts of differences between parent codons and
                                                 all other codons.
    """
    # Check if the size of the first dimension of codon_probs matches the length of parent_seq divided by 3
    if len(parent_seq) // 3 != codon_probs.size(0):
        raise ValueError("The size of the first dimension of codon_probs should match the length of parent_seq divided by 3.")

    # Initialize a list to store the probabilities of different counts of differences
    probs = []

    # Iterate through codons in parent_seq
    for i in range(0, len(parent_seq), 3):
        # Extract the codon from parent_seq
        codon = parent_seq[i:i+3]

        # if codon contains an N, append a tensor of 4 -1s to probs then continue
        if "N" in codon:
            probs.append(torch.tensor([-1.0] * 4))
            continue

        # Get the corresponding num_diff_tensor from num_diff_tensors
        num_diff_tensor = num_diff_tensors[codon]

        # Get the ith entry of codon_probs
        codon_probs_i = codon_probs[i // 3]

        # Calculate the probabilities of different counts of differences using the num_diff_tensor and codon_probs_i
        total_probs = probs_of_difference_count(num_diff_tensor, codon_probs_i)

        # Append the probabilities to the list
        probs.append(total_probs)

    # Concatenate all the probabilities into a tensor
    probs = torch.stack(probs)

    return probs

# Example usage:
probs_of_difference_count_seq(parent, codon_probs, num_diff_tensors)


tensor([[-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        [-1.0000e+00, -1.0000e+00, -1.0000e+00, -1.0000e+00],
        

In [27]:
mut_probs = 1.0 - torch.exp(-scaled_rates)
parent_codon_idxs = reshape_for_codons(parent_idxs)
codon_mut_probs = reshape_for_codons(mut_probs)
codon_sub_probs = reshape_for_codons(normed_subs_probs)

# This is from `aaprob_of_mut_and_sub` 
mut_matrices = build_mutation_matrices(parent_codon_idxs, codon_mut_probs, codon_sub_probs)
codon_probs = codon_probs_of_mutation_matrices(mut_matrices)

codon_probs[-1]
mut_matrices[-1]
mut_probs[-10:]
codon_sub_probs[-2:]
normed_subs_probs[-2:]
subs_probs[:parent_len, :]

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        ...,
        [-1.3564,  0.0000, -0.2794,  0.2474],
        [ 0.0166,  0.2365, -0.2994,  0.0000],
        [-0.3850,  0.0000, -0.3053, -0.0597]])

In [11]:
# get smallest entry of codon_probs[-1]
codon_probs[-1].min()

tensor(nan)