In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import epam.models
import epam.sequences

from epam.sequences import translate_sequences, aa_index_of_codon, CODON_AA_INDICATOR_MATRIX, CODONS

In [2]:
pcp_df = pd.read_csv("/Users/matsen/re/epam/_ignore/wyatt-10x-1p5m_pcp_2023-10-07.csv")

# drop first column (index)
pcp_df = pcp_df.drop(pcp_df.columns[0], axis=1)

# just take first 1000 rows
pcp_df = pcp_df.iloc[:1000]

pcp_df["aa_parent"] = translate_sequences(pcp_df["parent"])
pcp_df["aa_child"] = translate_sequences(pcp_df["child"])

def mutation_vector_of(parent, child):
    return np.array(list(parent)) != np.array(list(child))

def mutation_column_of(col1, col2):
    return [mutation_vector_of(p, c) for p, c in zip(col1, col2)]

pcp_df["nt_mutations"] = mutation_column_of(pcp_df["parent"], pcp_df["child"])
pcp_df["aa_mutations"] = mutation_column_of(pcp_df["aa_parent"], pcp_df["aa_child"])

pcp_df["nt_mutation_count"] = [np.sum(x) for x in pcp_df["nt_mutations"]]
pcp_df["aa_mutation_count"] = [np.sum(x) for x in pcp_df["aa_mutations"]]

pcp_df["nt_mutation_frequency"] = pcp_df["nt_mutation_count"] / len(pcp_df["parent"])
pcp_df["aa_mutation_frequency"] = pcp_df["aa_mutation_count"] / len(pcp_df["aa_parent"])

pcp_df

Unnamed: 0,sample_id,family,parent,child,v_gene,child_is_leaf,aa_parent,aa_child,nt_mutations,aa_mutations,nt_mutation_count,aa_mutation_count,nt_mutation_frequency,aa_mutation_frequency
0,0,149198,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,IGHV3-33*01,False,QVQLVESGGGVVQPGRSLRLSCAASGFTFSSSGMHWVRQAPGKGLE...,QVQLVESGGGVVQPGRSLRLSCAASGFTFNSSGMHWVRQAPGKGLE...,"[False, False, False, False, False, False, Fal...","[False, False, False, False, False, False, Fal...",3,2,0.003,0.002
1,0,149198,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,IGHV3-33*01,False,QVQLVESGGGVVQPGRSLRLSCAASGFTFNSSGMHWVRQAPGKGLE...,QVQLVESGGGVVQPGRSLRLSCAASGFTFDSSGMHWVRQAPGKGLE...,"[False, False, False, False, False, False, Fal...","[False, False, False, False, False, False, Fal...",4,3,0.004,0.003
2,0,149198,CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,CAAGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,IGHV3-33*01,False,QVQLVESGGGVVQPGRSLRLSCAASGFTFDSSGMHWVRQAPGKGLE...,QVQLVESGGGVVQPGRSLRLSCATSGFNFDTSGMHWVRQAPGKGLE...,"[False, False, True, False, False, False, Fals...","[False, False, False, False, False, False, Fal...",16,7,0.016,0.007
3,0,149198,CAAGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,CAAATGCAGATGGTGGAGTCGGGGGGAGGCGTGGTCCAGCCAGGGA...,IGHV3-33*01,True,QVQLVESGGGVVQPGRSLRLSCATSGFNFDTSGMHWVRQAPGKGLE...,QMQMVESGGGVVQPGRSLTLSCATSGFNFETSALHWVRQAPGKGLE...,"[False, False, False, True, False, False, Fals...","[False, True, False, True, False, False, False...",19,9,0.019,0.009
4,0,149198,CAAGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,CAAGTGCAACTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGA...,IGHV3-33*01,True,QVQLVESGGGVVQPGRSLRLSCATSGFNFDTSGMHWVRQAPGKGLE...,QVQLVESGGGVVQPGRSLRLSCATSGINFDTSGMHWVRQAPGKGLE...,"[False, False, False, False, False, False, Fal...","[False, False, False, False, False, False, Fal...",7,4,0.007,0.004
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,0,366797,CAGCTGCAGCTGCAGGAGTCGGGCCCAGGACTGGTGAAGCCCTCGG...,CAGCTGCAGCTGCAGGAGTCGGGCCCAGGACTGGTGAAGCCCTCGG...,IGHV4-39*01,True,QLQLQESGPGLVKPSETLSLTCSVSGGSITSGTYYWGWIRQPPGKG...,QLQLQESGPGLVKPSETLSLTCSVSGGSITSGTYYWGWIRQPPGKG...,"[False, False, False, False, False, False, Fal...","[False, False, False, False, False, False, Fal...",0,0,0.000,0.000
996,0,367593,GAGGTGCAGCTGGGGGAGTCTGGGGGAAACTTGGTCCAGCCTGGGG...,GAGGTGCAGCTGGGGGAGTCTGGGGGAACCTTGGTCCAGCCTGGGG...,IGHV3-7*03,False,EVQLGESGGNLVQPGGSLRLSCAASGFTFSSYWMSWVRQAPGKGLE...,EVQLGESGGTLVQPGGSLRLSCAASGFTFSSYWMSWVRQAPGKGLE...,"[False, False, False, False, False, False, Fal...","[False, False, False, False, False, False, Fal...",2,1,0.002,0.001
997,0,367593,GAGGTGCAGCTGGGGGAGTCTGGGGGAACCTTGGTCCAGCCTGGGG...,GAGGTGCAGCTGGGGGAGTCTGGGGGAACCTTGGTCCAGCCTGGGG...,IGHV3-7*03,True,EVQLGESGGTLVQPGGSLRLSCAASGFTFSSYWMSWVRQAPGKGLE...,EVQLGESGGTLVQPGGSLRLSCAASGFTFSSYWMSWVRQAPGKGLE...,"[False, False, False, False, False, False, Fal...","[False, False, False, False, False, False, Fal...",0,0,0.000,0.000
998,0,367593,GAGGTGCAGCTGGGGGAGTCTGGGGGAACCTTGGTCCAGCCTGGGG...,GAGGTGCAGCTGGGGGAGTCTGGGGGAACCTTGGTCCAGCCTGGGG...,IGHV3-7*03,True,EVQLGESGGTLVQPGGSLRLSCAASGFTFSSYWMSWVRQAPGKGLE...,EVQLGESGGTLVQPGGSLRLSCAASGFTFSSYWMSWVRQAPGKGLE...,"[False, False, False, False, False, False, Fal...","[False, False, False, False, False, False, Fal...",1,0,0.001,0.000


In [3]:
# make the first row as a bunch of globals
for key, value in pcp_df.iloc[0].to_dict().items():
    globals()[key] = value

print(parent)  # Output will be 1


CAGGTGCAGCTGGTGGAGTCTGGGGGAGGCGTGGTCCAGCCTGGGAGGTCCCTGAGACTCTCCTGTGCAGCGTCTGGATTCACCTTCAGTAGCTCTGGCATGCACTGGGTCCGCCAGGCTCCAGGCAAGGGGCTGGAGTGGGTGGCAGTTATATGGTATGATGGAAGTAATAAATATTATGCAGACTCCGTGAAGGGCCGATTCACCATCTCCAGAGACAATTCCAAGAACACGGTGTATCTTCAAATGAACAGCCTAAGAGCCGAGGACACGGCTGTGTATTACTGTGCGAGAGAGGGGCACAGTAACTACCCCTACTACTACTACTACATGGACGTCTGGGGCAAAGGGACCACGGTCACCGTCTCCTCA


In [4]:
import shmple
import logging

weights_directory = "/Users/matsen/re/epam/data/shmple_weights/my_shmoof"

model = shmple.AttentionModel(
    weights_dir=weights_directory, log_level=logging.WARNING
)

[mut_rates], [subs_probs] = model.predict_mutabilities_and_substitutions([parent], [nt_mutation_frequency])

mut_rates = mut_rates.squeeze()



In [28]:
subs_probs.shape

(372, 4)

In [21]:
#for i in range(0, len(parent), 3):

i=0
parent_codon = parent[i : i + 3]
codon_mut_rates = mut_rates[i : i + 3]
codon_subs = subs_probs[i : i + 3]

naive_codon_mut_rates = np.sum(mut_rates.reshape(-1, 3), axis=1)

naive_codon_mut_rates.shape

# Sum across the rows of the codon_subs matrix
codon_subs.sum(axis=1)

parent_aa_index = aa_index_of_codon(parent_codon)

aa_trans_probs = epam.models.SHMple._codon_probs_of_mutation_matrix(codon_subs).reshape(-1) @ CODON_AA_INDICATOR_MATRIX

nonstop_prob = np.sum(aa_trans_probs)
nonself_prob = aa_trans_probs[parent_aa_index]

In [26]:
parent_codon, nonstop_prob

('CAG', 0.7106073687276928)

In [23]:
tensor_probs = epam.models.SHMple._codon_probs_of_mutation_matrix(codon_subs)
flat_probs = tensor_probs.reshape(-1)

for codon, prob in zip(CODONS, flat_probs):
    print(codon, prob)
    
tensor_probs

AAA 9.821687e-06
AAC 2.9093333e-06
AAG 1.9263482e-09
AAT 2.9950825e-06
ACA 0.008858625
ACC 0.0026240596
ACG 1.7374607e-06
ACT 0.0027014008
AGA 0.08571688
AGC 0.025390644
AGG 1.6811831e-05
AGT 0.026139004
ATA 0.011979606
ATC 0.0035485416
ATG 2.3495852e-06
ATT 0.0036531307
CAA 3.3231597e-09
CAC 9.843705e-10
CAG 6.5177833e-13
CAT 1.0133837e-09
CCA 2.9973085e-06
CCC 8.878484e-07
CCG 5.878684e-10
CCT 9.140167e-07
CGA 2.9002233e-05
CGC 8.590902e-06
CGG 5.6882685e-09
CGT 8.844109e-06
CTA 4.0532896e-06
CTC 1.200646e-06
CTG 7.949802e-10
CTT 1.2360337e-06
GAA 1.4574411e-05
GAC 4.317162e-06
GAG 2.8585097e-09
GAT 4.4444055e-06
GCA 0.013145321
GCC 0.003893844
GCG 2.5782194e-06
GCT 0.0040086107
GGA 0.12719534
GGC 0.037677195
GGG 2.4947089e-05
GGT 0.03878769
GTA 0.017776545
GTC 0.0052656834
GTG 3.4865513e-06
GTT 0.0054208837
TAA 3.315566e-05
TAC 9.82121e-06
TAG 6.502889e-09
TAT 1.0110679e-05
TCA 0.02990459
TCC 0.0088581955
TCG 5.8652504e-06
TCT 0.009119281
TGA 0.28935957
TGC 0.08571271
TGG 5.6752702e

array([[[9.8216869e-06, 2.9093333e-06, 1.9263482e-09, 2.9950825e-06],
        [8.8586248e-03, 2.6240596e-03, 1.7374607e-06, 2.7014008e-03],
        [8.5716881e-02, 2.5390644e-02, 1.6811831e-05, 2.6139004e-02],
        [1.1979606e-02, 3.5485416e-03, 2.3495852e-06, 3.6531307e-03]],

       [[3.3231597e-09, 9.8437047e-10, 6.5177833e-13, 1.0133837e-09],
        [2.9973085e-06, 8.8784839e-07, 5.8786842e-10, 9.1401671e-07],
        [2.9002233e-05, 8.5909023e-06, 5.6882685e-09, 8.8441093e-06],
        [4.0532896e-06, 1.2006460e-06, 7.9498019e-10, 1.2360337e-06]],

       [[1.4574411e-05, 4.3171622e-06, 2.8585097e-09, 4.4444055e-06],
        [1.3145321e-02, 3.8938441e-03, 2.5782194e-06, 4.0086107e-03],
        [1.2719534e-01, 3.7677195e-02, 2.4947089e-05, 3.8787689e-02],
        [1.7776545e-02, 5.2656834e-03, 3.4865513e-06, 5.4208837e-03]],

       [[3.3155658e-05, 9.8212104e-06, 6.5028889e-09, 1.0110679e-05],
        [2.9904591e-02, 8.8581955e-03, 5.8652504e-06, 9.1192806e-03],
        [2.893

In [25]:
# dot product of flat_probs and CODON_AA_INDICATOR_MATRIX.sum(axis=1)
flat_probs @ CODON_AA_INDICATOR_MATRIX.sum(axis=1)

0.7106073687276928

In [27]:
total = 0

for stop in epam.sequences.STOP_CODONS:
    print(stop, CODONS.index(stop))
    total = total + flat_probs[CODONS.index(stop)]

total

TAA 48
TAG 50
TGA 56


0.2893927317109064