In [1]:
import os
import pickle
from itertools import combinations
import math
import pandas as pd
import re
import json
import binarymap as bmap
import matplotlib.pyplot as plt
import seaborn as sns
import plotnine
import jax
import jax.numpy as jnp
from jax.experimental import sparse
from jaxopt import ProximalGradient
import jaxopt
import numpy as onp
from scipy.stats import pearsonr
from tqdm.notebook import tqdm

import sys
sys.path.append("..")
from multidms.utils import *
from timeit import default_timer as timer
from multidms.model import ϕ, g, prox, cost_smooth

### Globals

In [2]:
func_score_data = pd.DataFrame()
sites = {}
wt_seqs = {}

for homolog in ["Delta", "Omicron_BA.1"]:
    
    # functional scores
    func_sel = pd.read_csv(f"../results/{homolog}/functional_selections.csv")
    func_sel = func_sel.assign(
        filename = f"../results/{homolog}/" + 
        func_sel.library + "_" + 
        func_sel.preselection_sample + 
        "_vs_" + func_sel.postselection_sample + 
        "_func_scores.csv"
    )
    func_sel = func_sel.assign(
        func_sel_scores_df = func_sel.filename.apply(lambda f: pd.read_csv(f))
    )
    func_sel = func_sel.assign(
        len_func_sel_scores_df = func_sel.func_sel_scores_df.apply(lambda x: len(x))
    )
    fun_sel = func_sel.assign(homolog = homolog)
    func_score_data = pd.concat([func_score_data, fun_sel]).reset_index(drop=True)
    
    # WT Protein sequence
    with open(f"../results/{homolog}/protein.fasta", "r") as seq_file:
        header = seq_file.readline()
        wt_seqs[homolog] = seq_file.readline().strip()

    # Sites
    sites[homolog] = (
        pd.read_csv(f"../results/{homolog}/site_numbering_map.csv")
        .rename({"sequential_site":f"{homolog}_site", "sequential_wt":f"{homolog}_wt"})
        .set_index(["reference_site"])
    )

# Add a column that gives a unique ID to each homolog/DMS experiment
func_score_data['homolog_exp'] = func_score_data.apply(
    lambda row: f"{row['homolog']}-{row['library']}-{row['replicate']}".replace('-Lib',''),
    axis=1
)

In [9]:
fit_params = {
    "fs_scaling_group_column" : "homolog_exp",
    "min_pre_counts" : 100,
    "pseudocount" : 0.1,
    "agg_variants" : True,
    "sample" : None,
    "min_pre_counts" : 100,
    "clip_target" : None,
    "func_score_target" : 'log2e',
    "experiment_ref" :'Delta-3-1',
    "experiment_2" : 'Omicron_BA.1-3-1',
    "shift_func_score_target_nonref" : -1,
    "warmup_to_ref" : False,
    "maxiter" : 5,
    "λ_lasso" : 5e-5,
    "λ_ridge" : 0
}

In [10]:
! rm results.pkl

In [11]:
if not os.path.exists("results.pkl"):
    cols = list(fit_params.keys()) + ["tuned_model_params", "all_subs", "dataset_preds"]
    results = pd.DataFrame(columns = cols)
else:
    results = pickle.load(open("results.pkl", "rb"))


delta_exps = [exp_row.homolog_exp for exp, exp_row in func_score_data.iterrows() if "Delta" in exp_row.homolog_exp]
omicron_exps = [exp_row.homolog_exp for exp, exp_row in func_score_data.iterrows() if "Omicron" in exp_row.homolog_exp]

for delta_exp in delta_exps:
    for omicron_exp in omicron_exps:
        print(f"running {delta_exp} Vs. {omicron_exp}")
        fit_param_i = fit_params.copy()
        fit_param_i["experiment_ref"] = delta_exp
        fit_param_i["experiment_2"] = omicron_exp
        row = run_fit(func_score_data.copy(), fit_params = fit_param_i)
        results.loc[len(results)] = row
        break
    break

pickle.dump(results, open("results.pkl", "wb"))

running Delta-1-1 Vs. Omicron_BA.1-1-1


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/139478 [00:00<?, ?it/s]

  0%|          | 0/139478 [00:00<?, ?it/s]

Delta-1-1 15433309 0.01 100000.0
Omicron_BA.1-1-1 7340049 0.01 100000.0
Found 11 site(s) lacking data in at least one homolog.
5801 of the 90855 variants were removed because they had mutations at the above sites, leaving 85054 variants.


  0%|          | 0/51278 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [8]:
results

Unnamed: 0,fs_scaling_group_column,min_pre_counts,pseudocount,agg_variants,sample,clip_target,func_score_target,experiment_ref,experiment_2,shift_func_score_target_nonref,warmup_to_ref,maxiter,λ_lasso,λ_ridge,tuned_model_params,all_subs,dataset_preds
0,homolog_exp,100,0.1,True,1000,,log2e,Delta-1-1,Omicron_BA.1-1-1,-1,False,5,5e-05,0,,,
1,homolog_exp,100,0.1,True,1000,,log2e,Delta-1-1,Omicron_BA.1-1-2,-1,False,5,5e-05,0,,,
2,homolog_exp,100,0.1,True,1000,,log2e,Delta-1-1,Omicron_BA.1-2-1,-1,False,5,5e-05,0,,,
3,homolog_exp,100,0.1,True,1000,,log2e,Delta-1-1,Omicron_BA.1-3-1,-1,False,5,5e-05,0,,,
