In [1]:
import pandas as pd

import polyclonal
import polyclonal.fit as fit

We're going to start with a data set and some pretty well-fitting parameters.
These are the parameters that we get from fitting the "exact" version of the data set.

In [2]:
data_path = "RBD_variants_escape_noisy.csv"
data = (
    pd.read_csv(data_path, na_filter=None)
    .query('library == "avg2muts"')
    .query("concentration in [0.25, 1, 4]")
    .reset_index(drop=True)
)
mut_escape_df = pd.read_csv("exact_mut_escape_df.csv")
activity_wt_df = pd.read_csv("exact_activity_wt_df.csv")
poly_abs = polyclonal.Polyclonal(
    data_to_fit=data,
    activity_wt_df=activity_wt_df,
    mut_escape_df=mut_escape_df,
)

orig_params = poly_abs._params

We start by making an object we can use for fitting.
This takes some time, and could be made a lot faster by not converting to and from a dense matrix.

In [3]:
%%time

prox_grad = fit.prox_grad_of_polyclonal(poly_abs)



CPU times: user 16 s, sys: 8.44 s, total: 24.4 s
Wall time: 18.8 s


Now we run the JAX version of the fitting code.

In [4]:
%%time

jax_params = prox_grad.run(poly_abs._params, tol=1e-7, max_iter=5000)

initial objective 7.293724e+02
iteration 2670, objective 6.650e+02, relative change 9.976e-08                                                                                       
relative change in objective function 1e-07 is within tolerance 1e-07 after 2670 iterations
CPU times: user 2min 26s, sys: 3.16 s, total: 2min 29s
Wall time: 2min 16s


Let's see what the original Polyclonal fitting code thinks of these parameters:

In [5]:
poly_abs._params = jax_params
poly_abs.fit(logfreq=100, fit_site_level_first=False)

# Starting optimization of 5799 parameters at Sat Jan  1 04:55:26 2022.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0     3.8362     665.01     577.56     53.165     34.287
          3     4.2407        665     577.56     53.166     34.279
# Successfully finished at Sat Jan  1 04:55:30 2022.


      fun: 665.0017719527469
 hess_inv: <5799x5799 LbfgsInvHessProduct with dtype=float64>
      jac: array([-7.37749212e-02, -3.12951089e-01, -1.08871757e-01, ...,
       -5.09604685e-04,  1.63759945e-03, -3.19200632e-05])
  message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 5
      nit: 2
     njev: 5
   status: 0
  success: True
        x: array([1.07566616, 3.22843841, 1.94317282, ..., 0.33487041, 0.69202319,
       0.33282094])

It thinks they are good!

Now let's see how fast the original code is:

In [6]:
%%time

poly_abs._params = orig_params
poly_abs.fit(logfreq=100, fit_site_level_first=False)

# Starting optimization of 5799 parameters at Sat Jan  1 04:55:30 2022.
       step   time_sec       loss   fit_loss reg_escape  regspread
          0   0.097575     729.37     643.32     53.454     32.601
        100     8.4999     669.78      582.8     53.237      33.75
        200     16.668     666.22     579.05       53.2     33.963
        300     24.949     665.51     577.98     53.234     34.302
        400     33.036     665.19     577.68     53.228     34.281
        500     41.564     665.07     577.63       53.2     34.248
        554     46.148     665.05      577.6     53.198      34.25
# Successfully finished at Sat Jan  1 04:56:16 2022.
CPU times: user 46.2 s, sys: 0 ns, total: 46.2 s
Wall time: 46.2 s


      fun: 665.0514594266459
 hess_inv: <5799x5799 LbfgsInvHessProduct with dtype=float64>
      jac: array([-1.39084482e-01, -6.24127208e-01, -2.24530265e-01, ...,
        2.38510511e-04, -3.84253065e-04, -8.32234644e-04])
  message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
     nfev: 612
      nit: 553
     njev: 612
   status: 0
  success: True
        x: array([1.07311803, 3.22757754, 1.94642557, ..., 0.32831496, 0.69195178,
       0.33004297])

The JAX code takes about 3.5 times as long if you include the setup time.