# Fit the neutralization curves for a plate
The fitting is done using [neutcurve](https://jbloomlab.github.io/neutcurve).

In [None]:
import neutcurve

import pandas as pd

Get variables from `snakemake`:

In [None]:
process_counts_qc_failures = snakemake.input.qc_failures
frac_infectivity_csv = snakemake.input.frac_infectivity_csv
output_csv = snakemake.output.csv
output_pdf = snakemake.output.pdf
curvefit_params = snakemake.params.curvefit_params
plate = snakemake.wildcards.plate

Make sure the process counts ran OK:

In [None]:
with open(process_counts_qc_failures) as f:
    failures = [line.strip() for line in f if line.strip()]
if failures:
    raise ValueError(
        "First fix `process_counts` QC failures:\n\t" + "\n\t".join(failures)
    )
else:
    print("No failures for the `process_counts` QC, so proceeding with fitting.")

Read the fraction infectivity:

In [None]:
frac_infectivity_ceiling = curvefit_params["frac_infectivity_ceiling"]

if not (frac_infectivity_ceiling > 0) or (frac_infectivity_ceiling is None):
    raise ValueError(f"invalid {frac_infectivity_ceiling=}")
print(f"Clipping with {frac_infectivity_ceiling=}")

frac_infectivity = pd.read_csv(frac_infectivity_csv).assign(
    serum_replicate=lambda x: x["serum"].where(
        x["plate_replicate"] == plate,
        x["serum"] + " (" + x["plate_replicate"].replace(plate, "") + ")",
    ),
    serum_concentration=lambda x: 1 / x["dilution_factor"],
    frac_infectivity=lambda x: x["frac_infectivity"].clip(
        upper=frac_infectivity_ceiling,
    ),
)

Fit all the neutralization curves:

In [None]:
print(f"Fitting with {curvefit_params['fixtop']=} and {curvefit_params['fixbottom']=}")

fits = neutcurve.CurveFits(
    frac_infectivity.rename(
        columns={
            "frac_infectivity": "fraction infectivity",
            "serum_concentration": "serum concentration",
        }
    ),
    conc_col="serum concentration",
    fracinf_col="fraction infectivity",
    serum_col="serum_replicate",
    virus_col="strain",
    replicate_col="barcode",
    fixtop=curvefit_params["fixtop"],
    fixbottom=curvefit_params["fixbottom"],
)

Plot all the neutralization curves:

In [None]:
fig, _ = fits.plotReplicates(
    attempt_shared_legend=False,
    legendfontsize=9,
    titlesize=10,
    ticksize=10,
    ncol=6,
)

Save the curves to a file:

In [None]:
print(f"Saving to {output_pdf}")
fig.savefig(output_pdf)

Get the fit parameters:

In [None]:
fit_params = (
    fits.fitParams(average_only=False, no_average=True)
    .rename(columns={"serum": "serum_replicate", "replicate": "barcode"})
    .assign(nt50=lambda x: 1 / x["ic50"])
    .merge(
        frac_infectivity[
            ["serum", "serum_replicate", "plate_replicate"]
        ].drop_duplicates(),
        validate="many_to_one",
    )
    .drop(columns=["ic50_str", "nreplicates", "serum_replicate"])
    .sort_values(["serum", "virus", "plate_replicate", "barcode"])
)

assert len(fit_params) == len(frac_infectivity.groupby(["barcode", "serum_replicate"]))

print(f"Saving to {output_csv}")

fit_params.to_csv(output_csv, index=False, float_format="%.4g")