In [3]:
import matplotlib.pyplot as plt



from shmex.shm_data import pcp_df_of_non_shmoof_nickname, dataset_dict
from netam.multihit import (
    MultihitBurrito,
    train_test_datasets_of_pcp_df,
    HitClassModel,
    prepare_pcp_df,
)
from netam.molevol import reshape_for_codons
from epam import evaluation
import netam.framework as framework
import torch
import pandas as pd
from netam.common import BASES_AND_N_TO_INDEX

## How to use a hit class crepe:
Load some data and an existing crepe whose predictions should be adjusted

In [4]:
burrito_params = {
    "batch_size": 1024,
    "learning_rate": 0.1,
    "min_learning_rate": 1e-4,
    # "l2_regularization_coeff": 1e-6
}
epochs = 200
site_count = 500

# This is the model whose predictions will be adjusted
# nt_crepe_path = "../train/trained_models/cnn_joi_lrg-shmoof_small-fixed-0"
nt_crepe_path = '../train/fixed_models/cnn_ind_med-shmoof_small-full-0'
nt_crepe = framework.load_crepe(nt_crepe_path)

# This is the hit class model
hc_crepe_path = "shmoof_small-hc"
hc_crepe = framework.load_crepe(hc_crepe_path)

tang_df = pcp_df_of_non_shmoof_nickname("tangshm")
subsampled_tang_df = tang_df.copy().reset_index(drop=True)
pcp_df = prepare_pcp_df(subsampled_tang_df, nt_crepe, site_count)

train_data, val_data = train_test_datasets_of_pcp_df(pcp_df)
starting_branch_lengths_estimates = train_data.branch_lengths

Loading /Users/wdumm/data/v1/tang-deepshm-oof_pcp_2024-04-09_MASKED_NI.csv.gz



The model adjustments are produced with joint_train (they're not just based on the branch lengths for the uncorrected model),
so you may want to re-fit branch lengths on your data:

In [5]:
burrito = MultihitBurrito(train_data, val_data, hc_crepe.model, **burrito_params)
print(burrito.model.values.exp())
new_branch_lengths = burrito.find_optimal_branch_lengths(train_data)
train_data.branch_lengths = new_branch_lengths

tensor([0.8997, 1.4292, 3.2327], grad_fn=<ExpBackward0>)


Optimizing branch lengths: 100%|██████████| 1997/1997 [19:49<00:00,  1.68it/s]    



But if you just want to use the model to adjust your neutral codon prob predictions, you can do that:

In [6]:
# produce some random codon probs
sample_codon_probs = torch.rand((100, 4, 4, 4))
sample_codon_probs /= sample_codon_probs.sum(dim=(1, 2, 3), keepdim=True)

# Make some random parent codon nt indices between 0 and 3:
parent_codon_nt_indices = torch.randint(0, 4, (100, 3))

# Here are your adjusted codon probs!
adjusted_codon_probs = hc_crepe.model(parent_codon_nt_indices, sample_codon_probs.log()).exp()

In [13]:
import netam.hit_class as hit_class
adjusted_aggregated = hit_class.hit_class_probs_tensor(parent_codon_nt_indices, adjusted_codon_probs)
original_aggregated = hit_class.hit_class_probs_tensor(parent_codon_nt_indices, sample_codon_probs)

original_log_probs = original_aggregated.log()
corrections = torch.cat([torch.tensor([0.0]), hc_crepe.model.values])
# we'll use the corrections to adjust the uncorrected hit class probs
corrections = corrections[
    torch.arange(4).unsqueeze(0).tile((original_log_probs.shape[0], 1))
]
original_log_probs += corrections
aggregate_first = torch.softmax(original_log_probs, dim=1)

In [14]:
print("adjusted", adjusted_aggregated[0])
print("original", original_aggregated[0])
print("aggregate_first", aggregate_first[0])

adjusted tensor([1.4226e-04, 7.2570e-02, 2.8584e-01, 6.4145e-01],
       grad_fn=<SelectBackward0>)
original tensor([2.9686e-04, 1.6832e-01, 4.1734e-01, 4.1404e-01])
aggregate_first tensor([1.4226e-04, 7.2570e-02, 2.8584e-01, 6.4145e-01],
       grad_fn=<SelectBackward0>)
