## Full Test

### Data Processing

In [1]:
from pathlib import Path

from scripts.data_processing.align_and_parse import construct_master_alignment_df

reads_df = construct_master_alignment_df(
    amplicon_file="configs/pikh1_amplicon.gb",
    alignparse_config="configs/alignparse_config.yaml",
    locus_name="pikh1_amplicon",
    merged_reads_dir=Path("/home/oscar/rethinking_de_sequence_analysis/enrichment_analysis/merged_all")
)

reads_df.head()


Processing A2...
    Aligning reads from /home/oscar/rethinking_de_sequence_analysis/enrichment_analysis/merged_all/A2_merged.fastq to configs/pikh1_amplicon.gb...
    Parsing alignment from /tmp/tmptsza0cqi.sam...
    Adding mutation info columns for locus pikh1_amplicon...
    Expanding mutations to full sequences for A2...

Processing lib...
    Aligning reads from /home/oscar/rethinking_de_sequence_analysis/enrichment_analysis/merged_all/lib_merged.fastq to configs/pikh1_amplicon.gb...
    Parsing alignment from /tmp/tmpuwoais2w.sam...
    Adding mutation info columns for locus pikh1_amplicon...
    Expanding mutations to full sequences for lib...

Processing A4...
    Aligning reads from /home/oscar/rethinking_de_sequence_analysis/enrichment_analysis/merged_all/A4_merged.fastq to configs/pikh1_amplicon.gb...
    Parsing alignment from /tmp/tmpjvib8j5x.sam...
    Adding mutation info columns for locus pikh1_amplicon...
    Expanding mutations to full sequences for A4...

Processin

Unnamed: 0,library,sample,sequence,gene_mutations
0,library,A2,GTTCCAGACTACGCGCTGCAGGCTAGTGGTGGAGGAGGCTCCGGTG...,A15G T42C A47G T78C T107C A108G T209C
1,library,A2,GTTCCAGACTACGCGCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,A15G A34G A47G C48T T57C T70C T90C C224T
2,library,A2,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,T195C T209C C224T
3,library,A2,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,T117C A163G C200T T209C
4,library,A2,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,A54G T81C T205C T209C T216C A233G


In [2]:
from scripts.data_processing.construct_variant_table import construct_variant_table

variant_table = construct_variant_table(
    aligned_reads_df=reads_df,
    wt_dna_file='configs/pikh1_dna_seq.fasta'
)

variant_table.variant_count_df.head()

Constructing variant table...


Unnamed: 0,library,sample,barcode,count,variant_call_support,codon_substitutions,aa_substitutions,n_codon_substitutions,n_aa_substitutions
0,library,A1,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,12924,1,,,0,0
1,library,A1,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,3732,1,GAA13GGA GAT32GAC GAT42GGT,E13G D42G,3,2
2,library,A1,GTTCCAGACTGCGTTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1667,1,CAA4CGA AAA5ATA ATG12AGG GCT26GCC GCT35GTT TAT...,Q4R K5I M12R A35V,6,4
3,library,A1,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1272,1,GAA68GGA,E68G,1,1
4,library,A1,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1076,1,GAA45GGA,E45G,1,1


In [12]:
variant_table.variant_count_df.to_parquet('outputs/variant_count_df.parquet')

In [1]:
import pandas as pd

variant_count_df = pd.read_parquet('outputs/variant_count_df.parquet')

In [2]:
from scripts.data_processing.preprocess_variant_data import get_simple_counts_df

simple_count_df = get_simple_counts_df(
    variant_count_df,
    'aa_substitutions',
)

simple_count_df.head()

Unnamed: 0,aa_substitutions,sample,count
0,,A1,25576.0
1,A11D,A1,4.0
2,A11D A23S A75V K77R D78G,A1,0.0
3,A11D A26V I54T D66G Q74R A75V D78G,A1,0.0
4,A11D A35T G65S D66N Q71P,A1,1.0


In [3]:
from scripts.data_processing.preprocess_variant_data import clean_variant_data

cleaned_df = clean_variant_data(
    simple_count_df,
    count_columns=[c for c in simple_count_df.columns if "count" in c],
    read_count_threshold=20,
    variants_to_remove={'aa_substitutions': ''},
    pivot_on='sample',
    return_format="long"
)

cleaned_df.head()

Cleaning complete. 108900 variants remaining (from 5047560).


Unnamed: 0,aa_substitutions,sample,count
0,A11D,A1,4.0
1,A11G D42G A75V,A1,0.0
2,A11P,A1,0.0
3,A11P A75V,A1,0.0
4,A11P A75V D78G,A1,0.0


In [21]:
from Bio import SeqIO

from scripts.data_processing.preprocess_variant_data import calculate_functional_scores

wt_amplicon = 'configs/pikh1_amplicon.gb'
wt_seq = str(SeqIO.read(wt_amplicon, "genbank").seq)

func_df = calculate_functional_scores(
    cleaned_df,
    'lib',
    wt_sequence=wt_seq,
)

func_df.head()

Calculating scores grouping by 'aa_substitutions' (type: aa_sub)...


Unnamed: 0,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,var_A3,...,count_B1,score_B2,var_B2,count_B2,score_B3,var_B3,count_B3,score_B4,var_B4,count_B4
0,A11D,11.0,-1.353637,8.968991,4.0,-2.938599,9.894044,1.0,-1.353637,8.968991,...,3.0,-1.716207,9.101141,3.0,0.0,8.687453,11.0,-0.616671,8.78398,7.0
1,A11G D42G A75V,0.0,0.0,16.650952,0.0,0.0,16.650952,0.0,0.0,16.650952,...,0.0,6.409391,12.537187,42.0,0.0,16.650952,0.0,0.0,16.650952,0.0
2,A11P,5.0,-3.459432,12.866645,0.0,2.423211,8.774462,29.0,-0.652077,9.298583,...,8.0,-0.652077,9.298583,3.0,0.0,9.082337,5.0,-3.459432,12.866645,0.0
3,A11P A75V,0.0,0.0,16.650952,0.0,7.60733,12.509561,97.0,3.459432,12.866645,...,1.0,1.584963,13.875793,1.0,3.169925,12.95074,4.0,0.0,16.650952,0.0
4,A11P A75V D78G,0.0,0.0,16.650952,0.0,5.209453,12.60072,18.0,1.584963,13.875793,...,1.0,2.321928,13.320761,2.0,3.169925,12.95074,4.0,0.0,16.650952,0.0


In [22]:
from scripts.data_processing.prepare_for_modeling import expand_mut_to_seq

func_df = expand_mut_to_seq(
    func_df,
    mut_col="aa_substitutions",
    wt_seq_path="configs/pikh1_aa_seq.fasta"
)

func_df.head()

Unnamed: 0,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,var_A3,...,score_B2,var_B2,count_B2,score_B3,var_B3,count_B3,score_B4,var_B4,count_B4,full_sequence
0,A11D,11.0,-1.353637,8.968991,4.0,-2.938599,9.894044,1.0,-1.353637,8.968991,...,-1.716207,9.101141,3.0,0.0,8.687453,11.0,-0.616671,8.78398,7.0,GLKQKIVIKVDMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...
1,A11G D42G A75V,0.0,0.0,16.650952,0.0,0.0,16.650952,0.0,0.0,16.650952,...,6.409391,12.537187,42.0,0.0,16.650952,0.0,0.0,16.650952,0.0,GLKQKIVIKVGMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...
2,A11P,5.0,-3.459432,12.866645,0.0,2.423211,8.774462,29.0,-0.652077,9.298583,...,-0.652077,9.298583,3.0,0.0,9.082337,5.0,-3.459432,12.866645,0.0,GLKQKIVIKVPMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...
3,A11P A75V,0.0,0.0,16.650952,0.0,7.60733,12.509561,97.0,3.459432,12.866645,...,1.584963,13.875793,1.0,3.169925,12.95074,4.0,0.0,16.650952,0.0,GLKQKIVIKVPMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...
4,A11P A75V D78G,0.0,0.0,16.650952,0.0,5.209453,12.60072,18.0,1.584963,13.875793,...,2.321928,13.320761,2.0,3.169925,12.95074,4.0,0.0,16.650952,0.0,GLKQKIVIKVPMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...


In [13]:
import numpy as np

from scripts.data_processing.calculate_physics_scores import apply_biophysical_model

model_type = 'linear_factor'
variables = ['fold', 'C', 'F']
sort_names = ['A1', 'A2', 'A3', 'A4', 'B1', 'B2', 'B3', 'B4']
design_matrix = np.array([
    [1, 1, 0], # Sort A1
    [1, 1, 1], # Sort A2
    [1, 4, 1], # Sort A3
    [1, 4, 4], # Sort A4
    [1, 0, 1], # Sort B1
    [1, 1, 1], # Sort B2
    [1, 1, 4], # Sort B3
    [1, 4, 4], # Sort B4
])

func_df = apply_biophysical_model(
    func_df,
    model_type,
    variables,
    sort_names,
    design_matrix
)
func_df.head()

Fitting linear_factor model to 12100 variants...


Unnamed: 0,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,var_A3,...,var_B4,count_B4,full_sequence,fold,C,F,var_fold,var_C,var_F,R2
0,A11D,11.0,-1.353637,8.968991,4.0,-2.938599,9.894044,1.0,-1.353637,8.968991,...,8.78398,7.0,GLKQKIVIKVDMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,-2.177867,0.062655,0.436059,0.197153,0.032637,0.032419,0.644273
1,A11G D42G A75V,0.0,0.0,16.650952,0.0,0.0,16.650952,0.0,0.0,16.650952,...,16.650952,0.0,GLKQKIVIKVGMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,2.068939,-0.26696,-0.26696,3.079313,0.550393,0.550393,0.095203
2,A11P,5.0,-3.459432,12.866645,0.0,2.423211,8.774462,29.0,-0.652077,9.298583,...,12.866645,0.0,GLKQKIVIKVPMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,0.764324,-0.663017,-0.153657,1.681117,0.269605,0.282314,0.3297
3,A11P A75V,0.0,0.0,16.650952,0.0,7.60733,12.509561,97.0,3.459432,12.866645,...,16.650952,0.0,GLKQKIVIKVPMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,3.536214,-0.292797,-0.284824,3.541822,0.561113,0.576439,0.093943
4,A11P A75V D78G,0.0,0.0,16.650952,0.0,5.209453,12.60072,18.0,1.584963,13.875793,...,16.650952,0.0,GLKQKIVIKVPMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,2.90233,-0.565707,0.043465,1.436654,0.236202,0.239204,0.241167


In [23]:
from scripts.data_processing.prepare_for_modeling import add_weights

#variant_cols = ['var_fold','var_C','var_F']

func_df = add_weights(
    func_df,
    #value_cols=variant_cols
)

func_df.head()

Adding weights using scores approach (Pure Inverse Variance, no overdispersion)...


Unnamed: 0,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,var_A3,...,count_B4,full_sequence,weight_A1,weight_A2,weight_A3,weight_A4,weight_B1,weight_B2,weight_B3,weight_B4
0,A11D,11.0,-1.353637,8.968991,4.0,-2.938599,9.894044,1.0,-1.353637,8.968991,...,7.0,GLKQKIVIKVDMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,0.111495,0.101071,0.111495,0.115301,0.109876,0.109876,0.115109,0.113844
1,A11G D42G A75V,0.0,0.0,16.650952,0.0,0.0,16.650952,0.0,0.0,16.650952,...,0.0,GLKQKIVIKVGMEGNNCRSKAMALVASTGGVDSVALVGDLRGKIEV...,0.060057,0.060057,0.060057,0.060057,0.060057,0.079763,0.060057,0.060057
2,A11P,5.0,-3.459432,12.866645,0.0,2.423211,8.774462,29.0,-0.652077,9.298583,...,0.0,GLKQKIVIKVPMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,0.07772,0.113967,0.107543,0.07772,0.111747,0.107543,0.110104,0.07772
3,A11P A75V,0.0,0.0,16.650952,0.0,7.60733,12.509561,97.0,3.459432,12.866645,...,0.0,GLKQKIVIKVPMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,0.060057,0.079939,0.07772,0.060057,0.072068,0.072068,0.077216,0.060057
4,A11P A75V D78G,0.0,0.0,16.650952,0.0,5.209453,12.60072,18.0,1.584963,13.875793,...,0.0,GLKQKIVIKVPMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,0.060057,0.079361,0.072068,0.060057,0.072068,0.075071,0.077216,0.060057


In [24]:
from pathlib import Path
from scripts.data_processing.prepare_for_modeling import train_val_test_split

data_split_dir = Path('outputs/')

train_df, val_df, test_df = train_val_test_split(
    func_df,
    data_split_dir
)
train_df.head()

Unnamed: 0,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,var_A3,...,count_B4,full_sequence,weight_A1,weight_A2,weight_A3,weight_A4,weight_B1,weight_B2,weight_B3,weight_B4
5224,K5R E45D A75V,0.0,0.0,16.650952,0.0,3.169925,12.95074,4.0,8.194757,12.502421,...,0.0,GLKQRIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIDV...,0.060057,0.077216,0.079985,0.060057,0.060057,0.077216,0.072068,0.060057
5220,K5R E13K G29D,6.0,0.547488,8.864778,9.0,0.691878,8.843912,10.0,-2.115477,10.033266,...,0.0,GLKQRIVIKVAMKGNNCRSKAMALVASTDGVDSVALVGDLRDKIEV...,0.112806,0.113072,0.099668,0.108221,0.099668,0.112806,0.078074,0.078074
11833,Y49C I57T Q74R A75V,0.0,0.0,16.650952,0.0,5.357552,12.589744,20.0,7.467606,12.511732,...,0.0,GLKQKIVIKVAMEGNNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,0.060057,0.07943,0.079925,0.060057,0.060057,0.060057,0.060057,0.060057
8391,N15S K62R A75V,0.0,0.0,16.650952,0.0,2.321928,13.320761,2.0,2.807355,13.082891,...,0.0,GLKQKIVIKVAMEGSNCRSKAMALVASTGGVDSVALVGDLRDKIEV...,0.060057,0.075071,0.076436,0.060057,0.060057,0.075071,0.07964,0.060057
9431,R18G A75V D78G,0.0,1.584963,13.875793,1.0,8.280771,12.501599,155.0,10.072803,12.492079,...,454.0,GLKQKIVIKVAMEGNNCGSKAMALVASTGGVDSVALVGDLRDKIEV...,0.072068,0.07999,0.080051,0.079166,0.072068,0.079986,0.080071,0.080046


In [25]:
from scripts.data_processing.prepare_for_modeling import create_split_datasets

model_checkpoint = "facebook/esm2_t6_8M_UR50D"

wt_aa_file = 'configs/pikh1_aa_seq.fasta'
wt_aa_seq = str(SeqIO.read(wt_aa_file, "fasta").seq)

target_config = [
    {"value_col": "score_A1", "weight_col": "weight_A1"},
    {"value_col": "score_A2", "weight_col": "weight_A2"},
    {"value_col": "score_A3", "weight_col": "weight_A3"},
    {"value_col": "score_A4", "weight_col": "weight_A4"},
    {"value_col": "score_B1", "weight_col": "weight_B1"},
    {"value_col": "score_B2", "weight_col": "weight_B2"},
    {"value_col": "score_B3", "weight_col": "weight_B3"},
    {"value_col": "score_B4", "weight_col": "weight_B4"},
]

train_ds, val_ds, test_ds = create_split_datasets(
    model_checkpoint,
    target_config,
    wt_aa_seq,
    seq_col='full_sequence',
    data_split_dir=data_split_dir,
    mut_col='aa_substitutions'
)

Map: 100%|██████████| 9680/9680 [00:01<00:00, 6254.72 examples/s]
Map: 100%|██████████| 1210/1210 [00:00<00:00, 6148.91 examples/s]
Map: 100%|██████████| 1210/1210 [00:00<00:00, 6303.03 examples/s]


### Modeling

In [None]:
from pathlib import Path

from scripts.modeling.models import EsmMutNHeads
from scripts.modeling.train import initialize_trainer

model_output_dir = Path('outputs/models')

model_hyperparameters = {
    "learning_rate": 3e-5,
    "max_grad_norm": 1.0,
    "head_learning_rate": 5e-4,
    "use_differential_lr": True,
    "per_device_train_batch_size": 32,
    "per_device_eval_batch_size": 32,
    "epochs": 50,
    "weight_decay": 0.01,
    "metric_for_best_model": "eval_loss",
    "greater_is_better": False,
    "es_patience": 5,
}

model = EsmMutNHeads(model_checkpoint, len(target_config), mlp_hidden_dim=160)
trainer = initialize_trainer(model, model_output_dir, target_config, train_ds, val_ds, model_hyperparameters)
trainer.train()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using Differential Learning Rates: Backbone=3e-05, Head=0.0005


Epoch,Training Loss,Validation Loss,Pearson Score A1,Spearman Score A1,Pearson Score A2,Spearman Score A2,Pearson Score A3,Spearman Score A3,Pearson Score A4,Spearman Score A4,Pearson Score B1,Spearman Score B1,Pearson Score B2,Spearman Score B2,Pearson Score B3,Spearman Score B3,Pearson Score B4,Spearman Score B4,Avg Spearman
1,No log,0.271469,0.509891,0.464888,0.791466,0.807053,0.822965,0.832348,0.698427,0.75261,0.658966,0.668114,0.685638,0.692217,0.789668,0.812005,0.651466,0.746068,0.721913
2,0.310400,0.256785,0.516729,0.469889,0.812966,0.82423,0.837107,0.845202,0.730294,0.762721,0.675945,0.679824,0.706496,0.711591,0.806107,0.824239,0.68368,0.761166,0.734858
3,0.310400,0.243119,0.544759,0.489772,0.821586,0.836013,0.849693,0.859244,0.745752,0.778028,0.695079,0.7013,0.722576,0.734379,0.816384,0.832529,0.69487,0.76896,0.750028
4,0.251000,0.237694,0.564334,0.503207,0.829759,0.843407,0.84805,0.859069,0.760182,0.784411,0.701348,0.698878,0.741007,0.748483,0.818778,0.833733,0.717628,0.771438,0.755328
5,0.228100,0.228289,0.550904,0.487626,0.82406,0.837811,0.857925,0.865816,0.762069,0.784357,0.695659,0.703451,0.750042,0.761165,0.829938,0.84316,0.72311,0.78983,0.759152
6,0.228100,0.226683,0.562867,0.513766,0.827937,0.841888,0.863343,0.87042,0.759722,0.774545,0.709508,0.706728,0.752449,0.763646,0.831366,0.846056,0.716455,0.765459,0.760314
7,0.208700,0.228997,0.56169,0.506938,0.829057,0.844568,0.860748,0.869068,0.766429,0.787367,0.697576,0.695728,0.736724,0.748371,0.827815,0.843018,0.720061,0.77636,0.758927
8,0.208700,0.226033,0.561508,0.517658,0.832506,0.84592,0.862079,0.868801,0.776991,0.791093,0.70992,0.700966,0.762973,0.768486,0.835784,0.844715,0.728746,0.784009,0.765206
9,0.196500,0.226498,0.559709,0.504743,0.828743,0.845652,0.865298,0.870818,0.777611,0.790406,0.70185,0.696991,0.760725,0.763067,0.832912,0.844138,0.727327,0.777892,0.761714
10,0.177400,0.235308,0.565579,0.519544,0.821722,0.836386,0.860988,0.867433,0.772129,0.781501,0.695592,0.689289,0.756549,0.760022,0.826786,0.83523,0.718494,0.760191,0.7562


TrainOutput(global_step=3939, training_loss=0.21050683142842905, metrics={'train_runtime': 120.0741, 'train_samples_per_second': 4030.845, 'train_steps_per_second': 126.172, 'total_flos': 0.0, 'train_loss': 0.21050683142842905, 'epoch': 13.0})

In [None]:
from pathlib import Path

from scripts.modeling.models import PhysicsESM
from scripts.modeling.train import initialize_trainer

model_output_dir = Path('outputs/models')

model_hyperparameters = {
    "learning_rate": 3e-5,
    "max_grad_norm": 1.0,
    "head_learning_rate": 5e-4,
    "use_differential_lr": True,
    "per_device_train_batch_size": 32,
    "per_device_eval_batch_size": 32,
    "epochs": 50,
    "weight_decay": 0.01,
    "metric_for_best_model": "eval_loss",
    "greater_is_better": False,
    "es_patience": 5,
}

uM, nM_100 = 1e-6, 100e-9
sort_config = [
        # --- Trajectory A ---
        # A1 (Index 0): Independent
        {"ligands": {"C": uM, "F": 0}, "masks": {"C": 1, "F": 0}, 
         "history_indices": [0]},
        
        # A2 (Index 1): Depends on A1 + A2
        {"ligands": {"C": uM, "F": uM}, "masks": {"C": 1, "F": 1}, 
         "history_indices": [0, 1]},
        
        # A3 (Index 2): Depends on A1 + A2 + A3
        {"ligands": {"C": nM_100, "F": uM}, "masks": {"C": 1, "F": 1}, 
         "history_indices": [0, 1, 2]},
        
        # A4 (Index 3): Depends on A1 + A2 + A3 + A4
        {"ligands": {"C": nM_100, "F": nM_100}, "masks": {"C": 1, "F": 1}, 
         "history_indices": [0, 1, 2, 3]},

        # --- Trajectory B ---
        # B1 (Index 4): Independent
        {"ligands": {"C": 0, "F": uM}, "masks": {"C": 0, "F": 1}, 
         "history_indices": [4]},
        
        # B2 (Index 5): Depends on B1 + B2
        {"ligands": {"C": uM, "F": uM}, "masks": {"C": 1, "F": 1}, 
         "history_indices": [4, 5]},
        
        # B3 (Index 6): Depends on B1 + B2 + B3
        {"ligands": {"C": uM, "F": nM_100}, "masks": {"C": 1, "F": 1}, 
         "history_indices": [4, 5, 6]},
        
        # B4 (Index 7): Depends on B1 + B2 + B3 + B4
        {"ligands": {"C": nM_100, "F": nM_100}, "masks": {"C": 1, "F": 1}, 
         "history_indices": [4, 5, 6, 7]},
    ]

ligand_names = ['C','F']
decoder_type = 'hill_langmuir'
embedding_mode = 'mutant_mean'

RT = 0.593
wt_dG_fold = -4.83
wt_dG_CF = {
    'C': RT * np.log(200e-9), 
    'F': RT * np.log(30e-6)
}

learn_anchor = True
learn_physics = True

model = PhysicsESM(
    model_checkpoint=model_checkpoint,
    sort_configs=sort_config,
    ligand_names=ligand_names,
    decoder_type=decoder_type,
    embedding_mode=embedding_mode,
    wt_dGs=wt_dG_CF,
    wt_dG_fold=wt_dG_fold,
    learn_anchor=learn_anchor,
    learn_physics=learn_physics,
    mlp_hidden_dim=160
)
trainer = initialize_trainer(model, model_output_dir, target_config, train_ds, val_ds, model_hyperparameters)
trainer.train()

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using Differential Learning Rates: Backbone=3e-05, Head=0.0005


Epoch,Training Loss,Validation Loss,Pearson Score A1,Spearman Score A1,Pearson Score A2,Spearman Score A2,Pearson Score A3,Spearman Score A3,Pearson Score A4,Spearman Score A4,Pearson Score B1,Spearman Score B1,Pearson Score B2,Spearman Score B2,Pearson Score B3,Spearman Score B3,Pearson Score B4,Spearman Score B4,Avg Spearman
1,No log,0.536687,0.240748,0.145169,0.706354,0.741805,0.746213,0.791329,0.610946,0.71118,0.56379,0.511432,0.568788,0.555481,0.697507,0.780915,0.586586,0.716624,0.619242
2,0.604700,0.496801,0.315932,0.156783,0.694239,0.746084,0.762458,0.797179,0.638925,0.723404,0.557737,0.5032,0.576716,0.539339,0.699738,0.760396,0.586427,0.70105,0.615929
3,0.604700,0.486756,0.342008,0.177632,0.628144,0.736687,0.682661,0.794047,0.577699,0.712286,0.551642,0.529801,0.567077,0.553476,0.643671,0.748244,0.551245,0.694467,0.61833
4,0.509000,0.448993,0.409338,0.205078,0.697003,0.747667,0.740505,0.817735,0.642374,0.748452,0.580838,0.516124,0.614518,0.573181,0.692727,0.767349,0.602775,0.72732,0.637863
5,0.467900,0.433877,0.408899,0.19012,0.733524,0.736645,0.775727,0.816942,0.663942,0.742605,0.580489,0.505608,0.615663,0.57022,0.726686,0.774645,0.614267,0.725802,0.632823
6,0.467900,0.42103,0.406034,0.20102,0.731313,0.736199,0.773712,0.81111,0.667112,0.742287,0.581859,0.512542,0.619451,0.568138,0.728115,0.777863,0.615638,0.724176,0.634167
7,0.435400,0.407131,0.41917,0.20263,0.760134,0.745265,0.802384,0.822495,0.692779,0.746883,0.583333,0.513712,0.631569,0.581935,0.743551,0.787459,0.63151,0.734675,0.641882


KeyboardInterrupt: 

In [34]:
test_results = trainer.predict(test_ds)
test_results.metrics

{'test_loss': 0.22626453638076782,
 'test_pearson_score_A1': 0.5616424083709717,
 'test_spearman_score_A1': 0.5221513721147265,
 'test_pearson_score_A2': 0.8394550681114197,
 'test_spearman_score_A2': 0.8545005884913137,
 'test_pearson_score_A3': 0.866632342338562,
 'test_spearman_score_A3': 0.8735465468784568,
 'test_pearson_score_A4': 0.777446448802948,
 'test_spearman_score_A4': 0.8051719168422578,
 'test_pearson_score_B1': 0.7359922528266907,
 'test_spearman_score_B1': 0.7194807334622257,
 'test_pearson_score_B2': 0.7752083539962769,
 'test_spearman_score_B2': 0.7726563020965223,
 'test_pearson_score_B3': 0.8199391961097717,
 'test_spearman_score_B3': 0.8382110243710251,
 'test_pearson_score_B4': 0.7004178166389465,
 'test_spearman_score_B4': 0.7612591452415339,
 'test_avg_spearman': 0.7683722036872578,
 'test_runtime': 0.6352,
 'test_samples_per_second': 1904.892,
 'test_steps_per_second': 59.823}

In [1]:
import torch
print(f"Is CUDA available? {torch.cuda.is_available()}")

Is CUDA available? True


## Scratch Pad

In [None]:
from pathlib import Path

from scripts.data_processing.align_and_parse import align_and_parse

aligned_df = align_and_parse(
    amplicon_file="configs/pikh1_amplicon.gb",
    alignparse_config="configs/alignparse_config.yaml",
    locus_name="pikh1_amplicon",
    merged_reads_fastq="/home/oscar/rethinking_de_sequence_analysis/enrichment_analysis/merged_all/A1_merged.fastq",
    alignment_path="A1_aligned.sam"
)

aligned_df

Unnamed: 0,query_name,query_clip5,query_clip3,gene_mutations,gene_accuracy,n_subs,n_indels
0,A01940:336:GW231120000:1:2101:12183:1016,0,0,A52G C83T T97C T174C A203G,0.999470,5,0
1,A01940:336:GW231120000:1:2101:23285:1016,0,0,T145C,0.999870,1,0
2,A01940:336:GW231120000:1:2101:31548:1016,0,0,T4C G66A T217C,0.998985,3,0
3,A01940:336:GW231120000:1:2101:3351:1031,0,0,A64G A95G T96C,0.999834,3,0
4,A01940:336:GW231120000:1:2101:13024:1031,0,0,T33C T74C T102C C104T T110C T147C T161C,0.999870,7,0
...,...,...,...,...,...,...,...
343189,A01940:336:GW231120000:1:2278:17662:36996,0,0,C80T A203G T208C,0.999867,3,0
343190,A01940:336:GW231120000:1:2278:5719:37012,0,0,A54G T119C A134G A160G A203G T209G,0.999518,6,0
343191,A01940:336:GW231120000:1:2278:9010:37012,0,0,T49C T111C T150C T217C,0.998826,4,0
343192,A01940:336:GW231120000:1:2278:24849:37043,0,0,A25G A125G T217C T234C,0.999529,4,0


In [1]:
from pathlib import Path

from scripts.data_processing.align_and_parse import construct_master_alignment_df

reads_df = construct_master_alignment_df(
    amplicon_file="configs/pikh1_amplicon.gb",
    alignparse_config="configs/alignparse_config.yaml",
    locus_name="pikh1_amplicon",
    merged_reads_dir=Path("/home/oscar/rethinking_de_sequence_analysis/enrichment_analysis/merged_all")
)


Processing A2...
    Aligning reads from /home/oscar/rethinking_de_sequence_analysis/enrichment_analysis/merged_all/A2_merged.fastq to configs/pikh1_amplicon.gb...
    Parsing alignment from /tmp/tmpx7hcehrk.sam...
    Adding mutation info columns for locus pikh1_amplicon...
    Expanding mutations to full sequences for A2...

Processing lib...
    Aligning reads from /home/oscar/rethinking_de_sequence_analysis/enrichment_analysis/merged_all/lib_merged.fastq to configs/pikh1_amplicon.gb...
    Parsing alignment from /tmp/tmpdxb694d0.sam...
    Adding mutation info columns for locus pikh1_amplicon...
    Expanding mutations to full sequences for lib...

Processing A4...
    Aligning reads from /home/oscar/rethinking_de_sequence_analysis/enrichment_analysis/merged_all/A4_merged.fastq to configs/pikh1_amplicon.gb...
    Parsing alignment from /tmp/tmppn6pkqtx.sam...
    Adding mutation info columns for locus pikh1_amplicon...
    Expanding mutations to full sequences for A4...

Processin

In [2]:
reads_df

Unnamed: 0,library,sample,sequence,gene_mutations
0,library,A2,GTTCCAGACTACGCGCTGCAGGCTAGTGGTGGAGGAGGCTCCGGTG...,A15G T42C A47G T78C T107C A108G T209C
1,library,A2,GTTCCAGACTACGCGCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,A15G A34G A47G C48T T57C T70C T90C C224T
2,library,A2,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,T195C T209C C224T
3,library,A2,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,T117C A163G C200T T209C
4,library,A2,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,A54G T81C T205C T209C T216C A233G
...,...,...,...,...
3248911,library,B3,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,T178C T209C A221G C224T
3248912,library,B3,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,A82G C176T
3248913,library,B3,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,A95G T106G T192C
3248914,library,B3,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,A7G T102C A125G T137C C224T


In [5]:
reads_df.to_parquet("outputs/master_reads_df.parquet")

In [1]:
import pandas as pd
reads_df = pd.read_parquet("outputs/master_reads_df.parquet")

In [3]:
from scripts.data_processing.construct_variant_table import construct_variant_table

variant_table = construct_variant_table(
    aligned_reads_df=reads_df,
    wt_dna_file='configs/pikh1_dna_seq.fasta'
)

variant_table

Constructing variant table...


<dms_variants.codonvarianttable.CodonVariantTable at 0x743b72fff3e0>

In [4]:
variant_table.barcode_variant_df

Unnamed: 0,library,barcode,variant_call_support,codon_substitutions,aa_substitutions,n_codon_substitutions,n_aa_substitutions
0,library,AATCCAGACTACGCTCCGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1,GGT1AAT ATC6ACC TCT73TCC,G1N I6T,3,2
1,library,AATCCAGACTACGCTCTGCAGGCTAGGGGTGGAGCAGGCTCTGGTG...,1,GGT1AAT AAA9AAG ATG12ACG TTG24TTA TAT49TGT,G1N M12T Y49C,5,3
2,library,AATCCAGACTACGCTCTGCAGGCTAGTGCTGGAGGAGGCTCTGGTG...,1,GGT1AAT GTT10GCT TCT19CCT GAT78GGT,G1N V10A S19P D78G,4,4
3,library,AATCCAGACTACGCTCTGCAGGCTAGTGGTGGAGAAGGCTCTGGTG...,1,GGT1AAT ATG12AAG TCT27TCC ACT28GCT GAT42GGT AT...,G1N M12K T28A D42G Q74R D78G,9,6
4,library,AATCCAGACTACGCTCTGCAGGCTAGTGGTGGAGAAGGCTCTGGTG...,1,GGT1AAT ATG12AAG TTG56CTG AAG62GAG GCT75GTT AA...,G1N M12K K62E A75V K77R,6,5
...,...,...,...,...,...,...,...
941057,library,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1,GGT1TGT AAG3AGG CAA4CGA GAA13GGA AGA18AAA ATG2...,G1C K3R Q4R E13G R18K M22T Y49H,7,7
941058,library,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1,GGT1TGT AAG3AGG CAA4CGA AAC16AGC GGT50GGC GAC6...,G1C K3R Q4R N16S D66G L70S,8,6
941059,library,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1,GGT1TGT AAG3AGG CAA4CGA AAC16AGC ATT57GTT AAA6...,G1C K3R Q4R N16S I57V K63R E68G,7,7
941060,library,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1,GGT1TGT AAG3AGG CAA4CGA AAC16AGC GAC52GGC GAT7...,G1C K3R Q4R N16S D52G D78N,6,6


In [6]:
variant_table.variant_count_df.to_parquet("outputs/variant_count_df.parquet")

In [7]:
variant_table.variant_count_df

Unnamed: 0,library,sample,barcode,count,variant_call_support,codon_substitutions,aa_substitutions,n_codon_substitutions,n_aa_substitutions
0,library,A1,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,12924,1,,,0,0
1,library,A1,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,3732,1,GAA13GGA GAT32GAC GAT42GGT,E13G D42G,3,2
2,library,A1,GTTCCAGACTGCGTTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1667,1,CAA4CGA AAA5ATA ATG12AGG GCT26GCC GCT35GTT TAT...,Q4R K5I M12R A35V,6,4
3,library,A1,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1272,1,GAA68GGA,E68G,1,1
4,library,A1,GTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1076,1,GAA45GGA,E45G,1,1
...,...,...,...,...,...,...,...,...,...
1129083,library,lib,TTTCCAGACTACGGTCTGCAGGCTAGGGGTGGAGGAGGCTCTGGTG...,1,1,GGT1TGT AAA5AGA AAA9AAG GCT21GCC GAT39AGT ATT5...,G1C K5R D39S I57V K77E,7,5
1129084,library,lib,TTTCCAGACTACGGTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1,1,GGT1TGT AAG3GAG AAA5AGA ATC51ACC GAC52GGC,G1C K3E K5R I51T D52G,5,5
1129085,library,lib,TTTCCAGACTACGGTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1,1,GGT1TGT AAA5AGA,G1C K5R,2,2
1129086,library,lib,TTTCCAGACTAGGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,1,1,GGT1TGT CAA4CAG GTT25ATT GAA68GGA TCT73CCT,G1C V25I E68G S73P,5,4


In [1]:
import pandas as pd

variant_count_df = pd.read_parquet("outputs/variant_count_df.parquet")

In [2]:
from scripts.data_processing.process_variant_data import get_simple_counts_df

simple_count_df = get_simple_counts_df(
    variant_count_df,
    'barcode',
    cols_to_keep=['aa_substitutions']
)

simple_count_df

Unnamed: 0,barcode,aa_substitutions,sample,count
0,AATCCAGACTACGCTCCGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1N I6T,A1,0.0
1,AATCCAGACTACGCTCTGCAGGCTAGGGGTGGAGCAGGCTCTGGTG...,G1N M12T Y49C,A1,1.0
2,AATCCAGACTACGCTCTGCAGGCTAGTGCTGGAGGAGGCTCTGGTG...,G1N V10A S19P D78G,A1,0.0
3,AATCCAGACTACGCTCTGCAGGCTAGTGGTGGAGAAGGCTCTGGTG...,G1N M12K T28A D42G Q74R D78G,A1,0.0
4,AATCCAGACTACGCTCTGCAGGCTAGTGGTGGAGAAGGCTCTGGTG...,G1N M12K K62E A75V K77R,A1,0.0
...,...,...,...,...
8469553,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C K3R Q4R E13G R18K M22T Y49H,lib,0.0
8469554,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C K3R Q4R N16S D66G L70S,lib,0.0
8469555,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C K3R Q4R N16S I57V K63R E68G,lib,0.0
8469556,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C K3R Q4R N16S D52G D78N,lib,0.0


In [3]:
simple_count_df.rename(columns={"barcode": "dna_sequence"}, inplace=True)

In [9]:
from Bio import SeqIO

from scripts.data_processing.process_variant_data import calculate_functional_scores

wt_amplicon = 'configs/pikh1_amplicon.gb'
wt_seq = str(SeqIO.read(wt_amplicon, "genbank").seq)

func_df = calculate_functional_scores(
    simple_count_df,
    'lib',
    group_col='dna_sequence',
    sequence_type='dna',
    wt_sequence=wt_seq,
    cols_to_keep=['aa_substitutions'],
)

func_df

Calculating scores grouping by 'dna_sequence' (type: dna)...


Unnamed: 0,dna_sequence,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,...,count_B1,score_B2,var_B2,count_B2,score_B3,var_B3,count_B3,score_B4,var_B4,count_B4
0,AATCCAGACTACGCTCCGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1N I6T,0.0,1.391070,8.325698,0.0,1.722111,8.325740,0.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
1,AATCCAGACTACGCTCTGCAGGCTAGGGGTGGAGCAGGCTCTGGTG...,G1N M12T Y49C,0.0,2.976033,5.550540,1.0,1.722111,8.325740,0.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
2,AATCCAGACTACGCTCTGCAGGCTAGTGCTGGAGGAGGCTCTGGTG...,G1N V10A S19P D78G,1.0,-0.193892,5.550540,0.0,0.137148,5.550581,0.0,0.137710,...,0.0,0.084618,5.550574,0.0,-0.436230,5.550515,0.0,-0.181223,5.550541,0.0
3,AATCCAGACTACGCTCTGCAGGCTAGTGGTGGAGAAGGCTCTGGTG...,G1N M12K T28A D42G Q74R D78G,0.0,1.391070,8.325698,0.0,1.722111,8.325740,0.0,1.722672,...,0.0,3.254543,5.550574,1.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
4,AATCCAGACTACGCTCTGCAGGCTAGTGGTGGAGAAGGCTCTGGTG...,G1N M12K K62E A75V K77R,0.0,1.391070,8.325698,0.0,1.722111,8.325740,0.0,1.722672,...,0.0,3.254543,5.550574,1.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
941057,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C K3R Q4R E13G R18K M22T Y49H,0.0,2.976033,5.550540,1.0,1.722111,8.325740,0.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
941058,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C K3R Q4R N16S D66G L70S,0.0,2.976033,5.550540,1.0,1.722111,8.325740,0.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
941059,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C K3R Q4R N16S I57V K63R E68G,0.0,2.976033,5.550540,1.0,1.722111,8.325740,0.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
941060,TTTCCAGGCTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C K3R Q4R N16S D52G D78N,0.0,2.976033,5.550540,1.0,1.722111,8.325740,0.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0


In [10]:
from scripts.data_processing.process_variant_data import clean_variant_data

cleaned_df = clean_variant_data(
    func_df,
    count_columns=[c for c in func_df.columns if "count" in c],
    read_count_threshold=3,
    variants_to_remove={'aa_substitutions': ''},
)

cleaned_df

Cleaning complete. 111943 variants remaining (from 941062).


Unnamed: 0,dna_sequence,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,...,count_B1,score_B2,var_B2,count_B2,score_B3,var_B3,count_B3,score_B4,var_B4,count_B4
10,AATCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1N M22V,0.0,5.091510,4.483171,6.0,1.722111,8.325740,0.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
55,ATCCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1S N16S E68G Q71R,3.0,-1.416285,4.757637,0.0,-1.085244,4.757679,0.0,-1.084682,...,0.0,-1.137774,4.757672,0.0,-1.658623,4.757612,0.0,-1.403615,4.757639,0.0
68,ATCCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1S,2.0,0.654105,2.220349,1.0,-0.599817,4.995550,0.0,1.722672,...,0.0,-0.652348,4.995542,0.0,0.411766,2.220324,1.0,-0.918189,4.995509,0.0
75,ATCCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1S D42G,2.0,-0.930858,4.995508,0.0,0.985145,2.220391,1.0,-0.599256,...,0.0,-0.652348,4.995542,0.0,-1.173196,4.995483,0.0,-0.918189,4.995509,0.0
109,ATCCCAGACTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1S Q4R V37A Y49H,0.0,1.391070,8.325698,0.0,1.722111,8.325740,0.0,1.722672,...,0.0,4.839506,4.625521,4.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
940947,TTTCCAGACTACGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C D39G V47I K77R,0.0,4.198425,4.757637,3.0,1.722111,8.325740,0.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
941018,TTTCCAGACTACGCTGTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C I6V A75V,0.0,1.391070,8.325698,0.0,4.892036,4.625528,4.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0
941036,TTTCCAGACTAGGCTCTGCAGGCTAGTGGTGGCGGAGGCTCTGGTG...,G1C E68D,3.0,-1.416285,4.757637,0.0,-1.085244,4.757679,0.0,-1.084682,...,0.0,-1.137774,4.757672,0.0,-1.658623,4.757612,0.0,-1.403615,4.757639,0.0
941042,TTTCCAGACTGCGCTCTGCAGGCTAGTGGTGGAGGAGGCTCTGGTG...,G1C Q4R A75V,0.0,4.560995,4.625487,4.0,1.722111,8.325740,0.0,1.722672,...,0.0,1.669581,8.325733,0.0,1.148732,8.325673,0.0,1.403739,8.325700,0.0


In [14]:
import pandas as pd
import numpy as np

def calculate_pooled_overdispersion(
    df: pd.DataFrame,
    group_by: list | str,
    data_type: str = 'scores',
    value_col: str | None = None,
    min_replicates: int = 2
) -> pd.DataFrame:
    """
    Calculates global overdispersion (experimental noise) by pooling variance across replicates.
    
    Args:
        df (pd.DataFrame): Input dataframe.
        group_by (list | str): Columns defining the groups. 
                               - Wide: ['aa_substitutions']
                               - Long: ['aa_substitutions', 'sample'] (Sample must be last!)
        data_type (str): 'scores' (Excess Variance) or 'counts' (Dispersion Index).
        value_col (str, optional): The column with data (Long format). If None, assumes Wide format.
        min_replicates (int): Minimum size of a group to include.

    Returns:
        pd.DataFrame: Summary of noise metrics per sample.
    """
    # 1. Setup
    if isinstance(group_by, str): group_by = [group_by]
    
    # Determine Format Mode
    is_long = value_col is not None
    
    # Select columns to aggregate
    if not is_long:
        prefix = 'score_' if data_type == 'scores' else 'count_'
        target_cols = [c for c in df.columns if c.startswith(prefix)]
    else:
        target_cols = [value_col]

    print(f"Calculating overdispersion for {data_type}...")

    # 2. Aggregation (The Heavy Lifting)
    # Calculate Variance, Mean, and N for every group in one vectorized pass
    # observed=True is critical for memory safety with Categoricals
    stats = df.groupby(group_by, observed=True)[target_cols].agg(['var', 'mean', 'count'])

    # 3. Process Results
    results = []

    # --- WIDE FORMAT LOGIC ---
    # The sample name is encoded in the column name (e.g. 'score_Library')
    if not is_long:
        for col in target_cols:
            sub_stats = stats[col]
            valid = sub_stats[sub_stats['count'] >= min_replicates]
            
            # Calculate and append
            sample_name = col.replace(prefix, '')
            _calc_metrics_and_append(results, sample_name, valid, data_type)

    # --- LONG FORMAT LOGIC ---
    # The sample name is encoded in the LAST level of the group_by index
    else:
        # Extract the relevant stats for the value column
        sub_stats = stats[value_col]
        valid = sub_stats[sub_stats['count'] >= min_replicates]
        
        # Calculate the raw metric for every variant group
        if data_type == 'scores':
            raw_metrics = valid['var']
        else:
            # Dispersion Index = Var / Mean
            with np.errstate(divide='ignore', invalid='ignore'):
                raw_metrics = valid['var'] / valid['mean']
            raw_metrics = raw_metrics[np.isfinite(raw_metrics)]

        # Now Aggregate BY SAMPLE (The last level of the index)
        # We assume group_by = [Variant, Sample], so level=-1 is Sample
        sample_level = group_by[-1]
        
        # Group the metrics by sample and take the average
        summary = raw_metrics.groupby(level=sample_level).agg(['mean', 'median', 'count'])
        
        for sample_name, row in summary.iterrows():
            res = {'sample': sample_name, 'n_variants_used': int(row['count'])}
            
            if data_type == 'scores':
                res['pooled_variance_mean'] = row['mean']
                res['pooled_variance_median'] = row['median']
            else:
                res['dispersion_index_mean'] = row['mean']
                res['dispersion_index_median'] = row['median']
            results.append(res)

    return pd.DataFrame(results)

def _calc_metrics_and_append(results, sample, df, data_type):
    """Helper to calculate final summary stats from valid groups."""
    res = {'sample': sample, 'n_variants_used': len(df)}
    
    if data_type == 'scores':
        res['pooled_variance_mean'] = df['var'].mean()
        res['pooled_variance_median'] = df['var'].median()
    else:
        # Dispersion Index logic
        with np.errstate(divide='ignore', invalid='ignore'):
            d_index = df['var'] / df['mean']
        d_index = d_index[np.isfinite(d_index)]
        
        res['dispersion_index_mean'] = d_index.mean()
        res['dispersion_index_median'] = d_index.median()
        
    results.append(res)

In [16]:
overdis_df = calculate_pooled_overdispersion(
    cleaned_df,
    group_by='aa_substitutions',
    min_replicates=3
)

overdis_df

Calculating overdispersion for scores...


Unnamed: 0,sample,n_variants_used,pooled_variance_mean,pooled_variance_median
0,A1,4499,1.552069,0.837369
1,A2,4499,2.062414,1.576248
2,A3,4499,2.402124,1.82089
3,A4,4499,1.427247,0.837369
4,B1,4499,1.420389,0.94092
5,B2,4499,1.981428,1.477327
6,B3,4499,2.139607,1.603905
7,B4,4499,1.606905,0.837369


In [2]:
import pandas as pd

from scripts.data_processing.preprocess_variant_data import get_simple_counts_df

df = pd.read_parquet("outputs/variant_count_df.parquet")

df = get_simple_counts_df(
    df,
    'aa_substitutions',
)
df

Unnamed: 0,aa_substitutions,sample,count
0,,A1,25576.0
1,A11D,A1,4.0
2,A11D A23S A75V K77R D78G,A1,0.0
3,A11D A26V I54T D66G Q74R A75V D78G,A1,0.0
4,A11D A35T G65S D66N Q71P,A1,1.0
...,...,...,...
5047555,Y49W I57T A75V D78G,lib,0.0
5047556,Y49W I57T D66G,lib,0.0
5047557,Y49W K63E D78E,lib,0.0
5047558,Y49W L60S S73P Q74R,lib,0.0


In [3]:
from scripts.data_processing.preprocess_variant_data import clean_variant_data

df = clean_variant_data(
    df,
    ['count'],
    9,
    pivot_on='sample',
    return_format='long'
)

df

Cleaning complete. 223389 variants remaining (from 5047560).


Unnamed: 0,aa_substitutions,sample,count
0,,A1,25576.0
1,A11D,A1,4.0
2,A11D I54T D66G Q74R A75V D78G,A1,0.0
3,A11D I54T E68K Q74R A75V K77E D78G,A1,0.0
4,A11D L70S A75V,A1,0.0
...,...,...,...
223384,Y49S L70S K77R,lib,0.0
223385,Y49S L70S Q74L A75V,lib,0.0
223386,Y49S L70S Q74R,lib,0.0
223387,Y49S L70S Q74R A75V,lib,0.0


In [4]:
from scripts.data_processing.preprocess_variant_data import calculate_functional_scores

func_df = calculate_functional_scores(
    df,
    'lib',
)

func_df

Calculating scores grouping by 'aa_substitutions' (type: aa_sub)...


Unnamed: 0,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,var_A3,...,count_B1,score_B2,var_B2,count_B2,score_B3,var_B3,count_B3,score_B4,var_B4,count_B4
0,,62613.0,0.000000,0.000229,25576.0,0.000000,0.000300,17793.0,0.000000,0.000308,...,25034.0,0.000000,0.000284,19149.0,0.000000,0.000228,25811.0,0.000000,0.000281,19381.0
1,A11D,11.0,-0.061982,0.643630,4.0,-1.123476,1.568718,1.0,0.506199,0.643669,...,3.0,-0.007040,0.775807,3.0,1.278460,0.362091,11.0,1.075122,0.458645,7.0
2,A11D I54T D66G Q74R A75V D78G,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,1.859836,8.325630,...,0.0,1.709167,8.325618,0.0,1.278460,8.325590,0.0,6.736188,4.289022,16.0
3,A11D I54T E68K Q74R A75V K77E D78G,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,1.859836,8.325630,...,0.0,1.709167,8.325618,0.0,1.278460,8.325590,0.0,6.901247,4.275385,18.0
4,A11D L70S A75V,0.0,1.291655,8.325591,0.0,3.400086,5.550467,1.0,3.444798,5.550471,...,0.0,1.709167,8.325618,0.0,4.085815,4.757529,3.0,4.499148,4.757555,3.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24816,Y49S L70S K77R,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,3.444798,5.550471,...,3.0,4.516522,4.757557,3.0,5.922316,4.329361,12.0,1.691794,8.325617,0.0
24817,Y49S L70S Q74L A75V,0.0,1.291655,8.325591,0.0,4.622478,4.757565,3.0,5.319267,4.541323,...,0.0,3.294130,5.550459,1.0,9.798096,4.174194,183.0,12.156339,4.165825,706.0
24818,Y49S L70S Q74R,0.0,1.291655,8.325591,0.0,8.891939,4.193723,67.0,7.837115,4.228967,...,4.0,4.031095,4.995427,2.0,4.085815,4.757529,3.0,4.013722,4.995426,2.0
24819,Y49S L70S Q74R A75V,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,5.319267,4.541323,...,0.0,1.709167,8.325618,0.0,2.863422,5.550431,1.0,1.691794,8.325617,0.0


In [5]:
import numpy as np
import pandas as pd

def solve_linear_factor_model(
        variables: list,
        targets: list[float],
        weights: list[float],
        coefficients: np.ndarray,
) -> pd.Series:
    # Define output schema
    # [variable_1, variable_2, ...], [variance_1, variance 2, ...], [R2]
    index_names = variables + [f'var_{v}' for v in variables] + ['R2']

    # Rename variables to match linear algebra
    X = coefficients
    y = np.array(targets)
    w = np.array(weights)

    # Handle missing data
    mask = np.isfinite(y) & np.isfinite(w)
    if mask.sum() < len(variables):
        raise RuntimeError("Not enough data points to solve. Ensure missing data is filtered out.")
    
    # Weighted normal equation
    W = np.diag(w)
    XtWX = X.T @ W @ X
    XtWy = X.T @ W @ y

    # Solve for beta
    XtWX_inv = np.linalg.pinv(XtWX)
    beta = XtWX_inv @ XtWy

    # Calculate goodness of fit
    y_pred = X @ beta
    y_mean = np.average(y, weights=w)
    ss_tot = np.sum(w * (y - y_mean)**2)
    ss_res = np.sum(w * (y - y_pred)**2)
    r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0

    # Calculate variance beta
    dof = len(y) - len(variables)
    sigma_sq = ss_res / dof
    beta_vars = np.diag(XtWX_inv * sigma_sq)

    # Concatenate for reporting
    results = np.concatenate([beta, beta_vars, [r2]])

    return pd.Series(results, index=index_names)

In [17]:
highest_mutant = func_df.sort_values(by='count_lib', ascending=False).iloc[0]
highest_mutant

aa_substitutions            
count_lib            62613.0
score_A1                 0.0
var_A1              0.000229
count_A1             25576.0
score_A2                 0.0
var_A2                0.0003
count_A2             17793.0
score_A3                 0.0
var_A3              0.000308
count_A3             17250.0
score_A4                 0.0
var_A4              0.000391
count_A4             12831.0
score_B1                 0.0
var_B1              0.000233
count_B1             25034.0
score_B2                 0.0
var_B2              0.000284
count_B2             19149.0
score_B3                 0.0
var_B3              0.000228
count_B3             25811.0
score_B4                 0.0
var_B4              0.000281
count_B4             19381.0
Name: 0, dtype: object

In [5]:
[c for c in func_df.columns if 'score' in c]

['score_A1',
 'score_A2',
 'score_A3',
 'score_A4',
 'score_B1',
 'score_B2',
 'score_B3',
 'score_B4']

In [102]:
highest_mutant = func_df.sort_values(by='count_lib', ascending=False).iloc[10]
variables = ['stab', 'bindC', 'bindF']
targets = [highest_mutant[f'{c}'] for c in func_df.columns if 'score' in c]
weights = [1/highest_mutant[f'{c}'] for c in func_df.columns if 'var' in c]
coefficients = np.array([
    [1, 1, 0],
    [1, 1, 1],
    [1, 4, 1],
    [1, 4, 4],
    [1, 0, 1],
    [1, 1, 1],
    [1, 1, 4],
    [1, 4, 4],
])

results = solve_linear_factor_model(
    variables,
    targets,
    weights,
    coefficients
)

results

stab         1.228868
bindC       -0.107665
bindF       -0.242344
var_stab     0.153701
var_bindC    0.042803
var_bindF    0.035217
R2           0.432085
dtype: float64

In [41]:
from scipy.optimize import least_squares
from scipy.special import expit
from typing import Callable
import numpy as np

def hill_langmuir(
        features: np.ndarray,
        ddGs: np.ndarray,
        wt_dGs: np.ndarray,
        RT: float = 0.593
) -> np.ndarray:
    '''
    Hill-Langmuir model.

    Args:
        features: Shape (N_samples, N_ligands) in Molar.
        ddGs: Latent variables [ddG_fold, ddG_bind_1...] (Deltas).
        dG_WTs: Fixed baselines [dG_fold_WT, dG_bind_1_WT...] (Absolute).
        RT: Thermal energy factor (default 0.593 kcal/mol).
    '''
    # Separate params
    dG_fold = wt_dGs[0] + ddGs[0]
    binding_dGs = wt_dGs[1:] + ddGs[1:]

    # Stability
    # dG_fold sign is negative as K_fold and dG_fold represent the same direction
    #K_fold = np.exp(-dG_fold / RT)
    #p_active = K_fold / (1.0 + K_fold)
    p_active = expit(-dG_fold / RT)

    # Binding
    Kds = np.exp(binding_dGs / RT)
    p_bound = features / (features + Kds + 1e-20)


    # Final pred
    p_final = p_active * p_bound.sum(axis=1)

    return p_final


def predict_log2_fold_change(
        features: np.ndarray,
        ddGs: np.ndarray,
        wt_dGs: np.ndarray,
        RT: float = 0.593
) -> np.ndarray:
    """
    Wraps hill_langmuir to predict Log2 Fold Change relative to WT.
    """
    # Calculate mutant probability
    # Uses the ddGs provided by the solver
    p_mut = hill_langmuir(features, ddGs, wt_dGs, RT)
    
    # Calculate WT probability
    # Uses 0.0 for all ddGs (definition of Wild Type)
    ddGs_wt = np.zeros_like(ddGs)
    p_wt = hill_langmuir(features, ddGs_wt, wt_dGs, RT)
    
    # Calculate Log2 Ratio
    # Add epsilon to prevent log(0)
    epsilon = 1e-20
    return np.log2((p_mut + epsilon) / (p_wt + epsilon))


def solve_global_epistasis_model(
        variables: list,
        targets: list[float],
        weights: list[float],
        features: np.ndarray,
        wt_energies: list[float],
        model_func: Callable = predict_log2_fold_change,
        bounds: tuple | None = None,
        alpha: float = 0.1       
) -> pd.Series:
    '''
    Fits the Hill-Langmuir thermodynamic model to weighted data. Specifically, fits ddG values for each input
    variable with a weight and a bias to explain the log-fold change in fraction bound relative to wild-type.
    Uses L2 regularization to constrain the ddG values to be more realistic.

    Args:
        variables (list): Names of the latent variables (e.g., ['ddG_fold', 'ddG_bind_C', 'ddG_bind_F']),
                   Note, ddG_fold must be first.
        targets (list[float]): Observed enrichment scores.
        weights (list[float]): Sample weights (inverse variance)
        features: Matrix of ligand concentrations for each sample. Shape: (N_samples, N_conc).
        wt_dGs (list[float]): Fixed baselines [dG_fold_WT, dG_bind_1_WT...] (Absolute).
        model_func (callable): The physics function (defaults to hill_langmuir)
        bounds (tuple): Bounds for predicted ddGs. (lower_bound, upper_bound). 
        alpha (float): Regularization parameter to prevent exploding predictions.

    Returns:
        pd.Series: [beta for each ddG, variance for each ddG, R2]
    '''
    # Schema for output
    param_names = variables + ['scale_factor', 'intercept']
    index_names = param_names + [f'var_{v}' for v in param_names] + ['R2']

    y = np.array(targets)
    w = np.array(weights)
    X = features
    wt_dGs = np.array(wt_energies)

    # Mask missing data
    mask = np.isfinite(y) & np.isfinite(w)
    if mask.sum() <= len(variables):
        print("Not enough data points to solve. Ensure missing data is filtered out. Returning NaNs.")
        return pd.Series([np.nan] * len(index_names), index=index_names)
    
    X = X[mask]
    y = y[mask]
    w = w[mask]

    # If user didn't provide bounds, enforce +/- 20 kcal/mol
    if bounds is None:
        limit = 20.0
        # Lower bound, Upper bound for each variable
        # +-20 for ddG, +-10 factor signal amplification, +-5 score offset
        lb = [-limit] * len(variables) + [0.1, -5.0]
        ub = [limit] * len(variables) + [10.0, 5.0]
        bounds = (lb, ub)

    # Define weighted residuals callback function for minimization
    # sum(weight * (observed - predicted)^2)
    def residuals(params):
        # Unpack: [Physics_Params, Scale, Intercept]
        beta_physics = params[:-2]
        A = params[-2]
        B = params[-1]
        
        # Pure Physics Prediction
        raw_pred = model_func(X, beta_physics, wt_dGs)
        
        # Scaled Prediction
        y_pred = A * raw_pred + B
        
        # Error + Regularization (only penalize physics betas, not scale/bias)
        err = np.sqrt(w) * (y - y_pred)
        penalty = np.sqrt(alpha) * beta_physics
        return np.concatenate([err, penalty])
    
    # Solve
    initial_guess = np.concatenate([np.zeros(len(variables)), [1.0, 0.0]])

    try:
        # Fit model
        res = least_squares(residuals, initial_guess, bounds=bounds, method='trf')
        beta_all = res.x

        # Unpack
        beta_phys = beta_all[:-2]
        A, B = beta_all[-2:]

        # Get prediction using parameters
        pred = A * model_func(X, beta_phys, wt_dGs) + B

        # Calculate pure residuals for R2
        pure_res = (y - pred) * np.sqrt(w)

        # Calculate Statistics
        ss_res = np.sum(pure_res**2)
        y_mean = np.average(y, weights=w)
        ss_tot = np.sum(w * (y - y_mean)**2)
        r2 = 1 - (ss_res / ss_tot) if ss_tot > 0 else 0

        # Estimate Variance (Covariance Matrix)
        # Cov = (J.T * J)^-1 * MSE
        dof = len(y) - len(variables)
        mse = ss_res / dof if dof > 0 else ss_res
        
        J = res.jac
        try:
            cov_matrix = np.linalg.pinv(J.T @ J) * mse
            beta_vars = np.diag(cov_matrix)
        except:
            beta_vars = np.zeros(len(beta_all))

        results = np.concatenate([beta_all, beta_vars, [r2]])
        return pd.Series(results, index=index_names)

    except Exception as e:
        return pd.Series([np.nan] * len(index_names), index=index_names)


In [7]:
RT = 0.593
wt_dG_fold_C = RT * np.log(200e-9)
wt_dG_fold_F = RT * np.log(30e-6)

wt_dG_fold_C, wt_dG_fold_F

(np.float64(-9.146994442946236), np.float64(-6.175687713547156))

In [None]:
highest_mutant = func_df.sort_values(by='count_lib', ascending=False).iloc[20]

variables = ['ddG_fold', 'ddG_bindC', 'ddG_bindF']
targets = [highest_mutant[f'{c}'] for c in func_df.columns if 'score' in c]
weights = [1/highest_mutant[f'{c}'] for c in func_df.columns if 'var' in c]
wt_dGs = [-4.83, wt_dG_fold_C, wt_dG_fold_F]
concentrations = np.array([
    [1e-6, 0],
    [1e-6, 1e-6],
    [100e-9, 1e-6],
    [100e-9, 100e-9],
    [0, 1e-6],
    [1e-6, 1e-6],
    [1e-6, 100e-9],
    [100e-9, 100e-9],
])

results = solve_global_epistasis_model(
    variables,
    targets,
    weights,
    concentrations,
    wt_dGs,
)

results

ddG_fold         0.715265
ddG_bindC        0.039557
ddG_bindF       -0.122970
var_ddG_fold     0.105313
var_ddG_bindC    0.009965
var_ddG_bindF    0.005017
R2               0.507144
dtype: float64

In [None]:
def apply_biophysical_model(
    df: pd.DataFrame,
    model_type: str,
    variables: list,
    sample_cols: list,
    design_info: np.ndarray,
    wt_energies: list[float] | None = None,
    alpha: float | None = None
) -> pd.DataFrame:
    """
    Applies a biophysical solver (Linear or Global Epistasis) variant-by-variant.

    Args:
        df: Dataframe containing variant data. Must have columns f'score_{sample}' 
            and f'var_{sample}' for each sample in sample_cols.
        model_type: 'linear_factor' or 'global_epistasis'.
        variables: Names of the latent variables to solve for.
        sample_cols: List of sample names (e.g., ['A1', 'A2', ...]) matching dataframe columns.
        design_info: 
            - If linear: Design matrix (X) of shape (n_samples, n_variables).
            - If epistasis: Concentration matrix of shape (n_samples, n_ligands).
        wt_energies: (Required for global_epistasis) List of [dG_fold_WT, dG_bind_WT...].
        alpha: (Required for global_epistasis) Regularization strength.

    Returns:
        pd.DataFrame: Original df with new columns for betas, variances, and R2.
    """
    print(f"Fitting {model_type} model to {len(df)} variants...")

    # Validate Inputs
    score_cols = [f'score_{s}' for s in sample_cols]
    var_cols = [f'var_{s}' for s in sample_cols]
    
    missing_cols = [c for c in score_cols + var_cols if c not in df.columns]
    if missing_cols:
        raise ValueError(f"Missing required columns in dataframe: {missing_cols}")

    if model_type == 'global_epistasis' and wt_energies is None:
        raise ValueError("Must provide `wt_energies` for global epistasis model.")

    # Define Row Processor
    def _fit_row(row):
        # Extract Targets
        targets = row[score_cols].values.astype(float)
        
        # Extract Variances and convert to Weights (1/sigma^2)
        variances = row[var_cols].values.astype(float)
        
        # Handle zero or infinite variance safely
        # If variance is 0 (or very close), weight becomes infinite. 
        # We clip variance to a tiny epsilon to keep weights finite but large.
        variances = np.maximum(variances, 1e-12)
        weights = 1.0 / variances
        
        # Dispatch to Solver
        if model_type == 'linear_factor':
            return solve_linear_factor_model(
                variables=variables,
                targets=targets,
                weights=weights,
                coefficients=design_info
            )
            
        elif model_type == 'global_epistasis':
            return solve_global_epistasis_model(
                variables=variables,
                targets=targets,
                weights=weights,
                features=design_info,
                wt_energies=wt_energies,
                model_func=predict_log2_fold_change,
                alpha=alpha
            )
        else:
            raise ValueError(f"Unknown model_type: {model_type}")

    # Apply to DataFrame
    # axis=1 applies the function to each row
    results_df = df.apply(_fit_row, axis=1)

    # Merge Results
    # This concatenates the new columns (beta_*, var_*, R2) horizontally
    return pd.concat([df, results_df], axis=1)

In [33]:
wt_dGs = [1.5, wt_dG_fold_C, wt_dG_fold_F]
samples = ['A1', 'A2', 'A3', 'A4', 'B1', 'B2', 'B3', 'B4']

concentrations = np.array([
    [1e-6, 0], [1e-6, 1e-6], [100e-9, 1e-6], [100e-9, 100e-9], # A samples
    [0, 1e-6], [1e-6, 1e-6], [1e-6, 100e-9], [100e-9, 100e-9]  # B samples
])

vars_to_solve = ['ddG_fold', 'ddG_bindC', 'ddG_bindF']

df_results = apply_biophysical_model(
    df=func_df,
    model_type='global_epistasis',
    variables=vars_to_solve,
    sample_cols=samples,
    design_info=concentrations,
    wt_energies=wt_dGs,
    alpha=0.1
)

df_results

Fitting global_epistasis model to 24821 variants...


Unnamed: 0,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,var_A3,...,ddG_bindC,ddG_bindF,scale_factor,intercept,var_ddG_fold,var_ddG_bindC,var_ddG_bindF,var_scale_factor,var_intercept,R2
0,,62613.0,0.000000,0.000229,25576.0,0.000000,0.000300,17793.0,0.000000,0.000308,...,0.000000,0.000000,1.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
1,A11D,11.0,-0.061982,0.643630,4.0,-1.123476,1.568718,1.0,0.506199,0.643669,...,-0.094923,0.025369,10.000000,0.093426,15.823993,13.099336,0.973411,141505.300308,8035.334536,0.453804
2,A11D I54T D66G Q74R A75V D78G,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,1.859836,8.325630,...,-0.254974,-0.022192,10.000000,0.665227,4.958901,4.902858,0.030589,6777.888475,2520.519594,0.485701
3,A11D I54T E68K Q74R A75V K77E D78G,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,1.859836,8.325630,...,-0.263010,-0.022949,10.000000,0.643052,5.322516,5.261183,0.032595,6801.674411,2705.367245,0.482740
4,A11D L70S A75V,0.0,1.291655,8.325591,0.0,3.400086,5.550467,1.0,3.444798,5.550471,...,-0.115746,0.052889,10.000000,2.504640,1.839603,1.470235,0.314563,10648.153560,934.648722,0.548331
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24816,Y49S L70S K77R,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,3.444798,5.550471,...,0.116541,0.012225,10.000000,4.460501,5.129167,5.034966,0.078282,37841.670558,2606.372940,0.185707
24817,Y49S L70S Q74L A75V,0.0,1.291655,8.325591,0.0,4.622478,4.757565,3.0,5.319267,4.541323,...,-0.185535,0.161077,10.000000,4.847519,31.439200,15.367853,11.343580,42429.047271,15970.352508,0.232397
24818,Y49S L70S Q74R,0.0,1.291655,8.325591,0.0,8.891939,4.193723,67.0,7.837115,4.228967,...,-0.271839,-0.016953,10.000000,4.196167,24.741964,24.312513,0.069193,29288.137867,12572.750245,0.246034
24819,Y49S L70S Q74R A75V,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,5.319267,4.541323,...,-0.262696,-0.008990,10.000000,1.156766,4.040947,3.975834,0.009128,5169.440712,2053.794648,0.588459


In [34]:
df_results.R2.mean()

np.float64(0.28540691740349544)

In [20]:
variables = ['ddG_fold', 'ddG_bindC', 'ddG_bindF']
samples = ['A1', 'A2', 'A3', 'A4', 'B1', 'B2', 'B3', 'B4']
coefficients = np.array([
    [1, 1, 0],
    [1, 1, 1],
    [1, 4, 1],
    [1, 4, 4],
    [1, 0, 1],
    [1, 1, 1],
    [1, 1, 4],
    [1, 4, 4],
])

df_linear_results = apply_biophysical_model(
    df=func_df,                     
    model_type='linear_factor',     
    variables=variables,
    sample_cols=samples,
    design_info=coefficients,      
    wt_energies=None,
    alpha=None
)

Fitting linear_factor model to 24821 variants...


In [21]:
df_linear_results

Unnamed: 0,aa_substitutions,count_lib,score_A1,var_A1,count_A1,score_A2,var_A2,count_A2,score_A3,var_A3,...,score_B4,var_B4,count_B4,ddG_fold,ddG_bindC,ddG_bindF,var_ddG_fold,var_ddG_bindC,var_ddG_bindF,R2
0,,62613.0,0.000000,0.000229,25576.0,0.000000,0.000300,17793.0,0.000000,0.000308,...,0.000000,0.000281,19381.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
1,A11D,11.0,-0.061982,0.643630,4.0,-1.123476,1.568718,1.0,0.506199,0.643669,...,1.075122,0.458645,7.0,-0.685765,0.205784,0.422490,0.187851,0.022300,0.021402,0.776413
2,A11D I54T D66G Q74R A75V D78G,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,1.859836,8.325630,...,6.736188,4.289022,16.0,0.389555,0.620874,0.449549,1.431713,0.241017,0.241016,0.528010
3,A11D I54T E68K Q74R A75V K77E D78G,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,1.859836,8.325630,...,6.901247,4.275385,18.0,0.355021,0.637125,0.465800,1.535357,0.258450,0.258450,0.525623
4,A11D L70S A75V,0.0,1.291655,8.325591,0.0,3.400086,5.550467,1.0,3.444798,5.550471,...,4.499148,4.757555,3.0,1.511834,0.295142,0.508810,0.215475,0.028351,0.028003,0.811722
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
24816,Y49S L70S K77R,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,3.444798,5.550471,...,1.691794,8.325617,0.0,3.665717,-0.580684,0.468173,0.996325,0.147573,0.150736,0.357826
24817,Y49S L70S Q74L A75V,0.0,1.291655,8.325591,0.0,4.622478,4.757565,3.0,5.319267,4.541323,...,12.156339,4.165825,706.0,1.751141,0.226235,1.667445,4.315474,0.558917,0.553948,0.578810
24818,Y49S L70S Q74R,0.0,1.291655,8.325591,0.0,8.891939,4.193723,67.0,7.837115,4.228967,...,4.013722,4.995426,2.0,4.250810,0.963527,0.000869,5.070314,0.739142,0.808355,0.244977
24819,Y49S L70S Q74R A75V,0.0,1.291655,8.325591,0.0,1.815123,8.325626,0.0,5.319267,4.541323,...,1.691794,8.325617,0.0,0.975866,0.914035,0.114511,1.204706,0.144365,0.147259,0.603719


In [42]:
df_linear_results.R2.mean()

np.float64(0.3693993034470977)