In [1]:
import pandas as pd

import torch
from epam.dnsm import TransformerBinarySelectionModel, DNSMBurrito
from epam.sequences import translate_sequences

In [2]:
pcp_df = pd.read_csv("~/data/wyatt-10x-1p5m_pcp_2023-10-07.csv")

# filter out rows of pcp_df where the parent and child sequences are identical
pcp_df = pcp_df[pcp_df["parent"] != pcp_df["child"]]

# NOTE downsampling here
pcp_df = pcp_df.sample(20000, random_state=42)

print(f"We have {len(pcp_df)} PCPs.")

We have 2000 PCPs.


In [3]:
shmple_weights_directory = "/Users/matsen/re/epam/data/shmple_weights/my_shmoof"
nhead = 4
dim_feedforward = 2048
layer_count = 3

dnsm = TransformerBinarySelectionModel(
    nhead=nhead, dim_feedforward=dim_feedforward, layer_count=layer_count
)

burrito = DNSMBurrito(pcp_df, shmple_weights_directory, dnsm, batch_size=1024, learning_rate=0.001, checkpoint_dir="./_checkpoints", log_dir="./_logs")

burrito.train(3)
burrito.optimize_branch_lengths()
burrito.train(15)
burrito.optimize_branch_lengths()
burrito.train(15)

Using Metal Performance Shaders
preparing data...




predicting mutabilities and substitutions...
consolidating this into substitution probabilities...
predicting mutabilities and substitutions...
consolidating this into substitution probabilities...
Epoch [0/3], Training Loss: 0.14439820498228073, Validation Loss: 0.14166443049907684
training model...
Epoch [1/3], Training Loss: 0.14281409978866577, Validation Loss: 0.1405491679906845
Epoch [2/3], Training Loss: 0.13496620953083038, Validation Loss: 0.13836506009101868
Epoch [3/3], Training Loss: 0.13516797125339508, Validation Loss: 0.13909541070461273


Finding optimal branch lengths:  29%|██▉       | 471/1600 [00:55<02:13,  8.46it/s]


KeyboardInterrupt: 

In [None]:
[aa_str] = translate_sequences([pcp_df.reset_index(drop=True).loc[0, "parent"]])
burrito.dnsm.selection_factors_of_aa_str(aa_str)

array([0.16746567, 0.18125452, 0.17475654, 0.14934239, 0.19283801,
       0.17483354, 0.22327097, 0.15719528, 0.2670195 , 0.15506317,
       0.15051098, 0.3086634 , 0.30574813, 0.16796537, 0.15364574,
       0.27505472, 0.2421864 , 0.19693483, 0.30720553, 0.20181261,
       0.2733704 , 0.15934032, 0.32970226, 0.30831188, 0.28278762,
       0.16281629, 0.25050843, 0.3119859 , 0.16555636, 0.2933701 ,
       0.25556317, 0.21109025, 0.15875149, 0.15081373, 0.2629757 ,
       0.15331246, 0.22006264, 0.17811066, 0.21454369, 0.3095778 ,
       0.23798394, 0.1833422 , 0.23784602, 0.17310007, 0.15132064,
       0.1633514 , 0.15469873, 0.31436643, 0.17545184, 0.15432854,
       0.30028984, 0.3328726 , 0.34694675, 0.34256014, 0.34860313,
       0.23313344, 0.3446737 , 0.3190169 , 0.35119778, 0.32210308,
       0.34833214, 0.30879936, 0.23360679, 0.28930157, 0.3455299 ,
       0.24755357, 0.30245617, 0.3314089 , 0.32135543, 0.33970687,
       0.3460881 , 0.20939623, 0.3135103 , 0.34284648, 0.33584

In [None]:
nhead = 4
dim_feedforward = 2048
layer_count = 3

model = TransformerBinarySelectionModel(
    nhead=nhead, dim_feedforward=dim_feedforward, layer_count=layer_count
)

model.load_state_dict(torch.load("/Users/matsen/re/epam/trained_dnsms/dnsm-2023-11-01-09-32.pth")["model_state_dict"])
model.eval()
model.selection_factors_of_aa_str(aa_str)

Using Metal Performance Shaders


array([0.16746567, 0.18125452, 0.17475654, 0.14934239, 0.19283801,
       0.17483354, 0.22327097, 0.15719528, 0.2670195 , 0.15506317,
       0.15051098, 0.3086634 , 0.30574813, 0.16796537, 0.15364574,
       0.27505472, 0.2421864 , 0.19693483, 0.30720553, 0.20181261,
       0.2733704 , 0.15934032, 0.32970226, 0.30831188, 0.28278762,
       0.16281629, 0.25050843, 0.3119859 , 0.16555636, 0.2933701 ,
       0.25556317, 0.21109025, 0.15875149, 0.15081373, 0.2629757 ,
       0.15331246, 0.22006264, 0.17811066, 0.21454369, 0.3095778 ,
       0.23798394, 0.1833422 , 0.23784602, 0.17310007, 0.15132064,
       0.1633514 , 0.15469873, 0.31436643, 0.17545184, 0.15432854,
       0.30028984, 0.3328726 , 0.34694675, 0.34256014, 0.34860313,
       0.23313344, 0.3446737 , 0.3190169 , 0.35119778, 0.32210308,
       0.34833214, 0.30879936, 0.23360679, 0.28930157, 0.3455299 ,
       0.24755357, 0.30245617, 0.3314089 , 0.32135543, 0.33970687,
       0.3460881 , 0.20939623, 0.3135103 , 0.34284648, 0.33584