# Profile `Polyclonal` fit

First initialize object:

In [1]:
import pandas as pd

import polyclonal

data_to_fit = (
    pd.read_csv("../notebooks/RBD_variants_escape_exact.csv", na_filter=False)
    .query('library == "avg2muts"')
    .query("concentration in [0.25, 1, 4]")
)

poly_abs = polyclonal.Polyclonal(
    data_to_fit=data_to_fit,
    activity_wt_df=pd.DataFrame.from_records(
        [
            ("epitope 1", 3.0),
            ("epitope 2", 2.0),
            ("epitope 3", 1.0),
        ],
        columns=["epitope", "activity"],
    ),
    site_escape_df=pd.DataFrame.from_records(
        [
            ("epitope 1", 484, 10.0),
            ("epitope 2", 446, 10.0),
            ("epitope 3", 417, 10.0),
        ],
        columns=["epitope", "site", "escape"],
    ),
    data_mut_escape_overlap="fill_to_data",
)

Now fit the model while profiling the results to `pstats`:

In [2]:
import cProfile

cProfile.run("poly_abs.fit(verbosity=1)", "pstats")

First fitting site-level model.
Starting optimization of 522 parameters at Wed Nov 17 08:05:53 2021.
 Initial loss: 8988.466
Step 1: loss=4635.964 at Wed Nov 17 08:05:53 2021
Step 11: loss=1166.957 at Wed Nov 17 08:05:53 2021
Step 21: loss=1036.383 at Wed Nov 17 08:05:54 2021
Step 31: loss=1003.888 at Wed Nov 17 08:05:54 2021
Step 41: loss=984.702 at Wed Nov 17 08:05:54 2021
Step 51: loss=975.8869 at Wed Nov 17 08:05:55 2021
Step 61: loss=971.6521 at Wed Nov 17 08:05:55 2021
Step 71: loss=968.4942 at Wed Nov 17 08:05:55 2021
Step 81: loss=966.3823 at Wed Nov 17 08:05:56 2021
Step 91: loss=964.7508 at Wed Nov 17 08:05:56 2021
Step 101: loss=963.6208 at Wed Nov 17 08:05:57 2021
Step 111: loss=962.6531 at Wed Nov 17 08:05:57 2021
Step 121: loss=961.6781 at Wed Nov 17 08:05:57 2021
Step 131: loss=960.427 at Wed Nov 17 08:05:58 2021
Step 141: loss=958.6037 at Wed Nov 17 08:05:58 2021
Step 151: loss=956.3534 at Wed Nov 17 08:05:59 2021
Step 161: loss=954.4476 at Wed Nov 17 08:05:59 2021
Step

Analyze profiling results:

In [3]:
import pstats

stats = pstats.Stats("pstats")

stats.sort_stats("tottime").print_stats(25)

Wed Nov 17 08:08:25 2021    pstats

         10110272 function calls (10100509 primitive calls) in 153.629 seconds

   Ordered by: internal time
   List reduced from 1310 to 25 due to restriction <25>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     4289   26.745    0.006   50.052    0.012 /fh/fast/bloom_j/software/miniconda3/envs/BloomLab/lib/python3.8/site-packages/scipy/sparse/compressed.py:365(multiply)
     8581   16.187    0.002   16.187    0.002 {built-in method scipy.sparse._sparsetools.coo_tocsr}
458345/458321   12.586    0.000   12.601    0.000 {built-in method numpy.array}
   102167   12.295    0.000   12.295    0.000 {method 'reduce' of 'numpy.ufunc' objects}
     8578   10.706    0.001   18.211    0.002 /fh/fast/bloom_j/software/miniconda3/envs/BloomLab/lib/python3.8/site-packages/scipy/sparse/construct.py:403(_compressed_sparse_stack)
     4307   10.374    0.002   10.374    0.002 {method 'nonzero' of 'numpy.ndarray' objects}
     4289    8.280

<pstats.Stats at 0x7fd1066d51c0>