In [None]:
%run scripts/strain_facts.py \
                --cvrg-thresh 0.05 \
                --nstrain 100 \
                --npos 6000 \
                --gamma-hyper 0.001 \
                --rho-hyper 1e-05 \
                --pi-hyper 0.1 \
                --epsilon-hyper 0.01 \
                --alpha-hyper 500 \
                --collapse 0.05 \
                --device cuda \
                --learning-rate 0.50 \
                --stop-at 100 \
                --max-iter 10000 \
                --outpath test.nc \
                data/ucfmt.sp-100022.gtpro-pileup.nc

In [None]:
model_fit = conditioned_model(
    model,
    data=dict(
#         alpha_hyper=100.,
        alpha=np.ones(n) * 100.,
#         epsilon_hyper=1e-2,
        epsilon=np.ones(n) * 1e-2,
        pi_hyper=1e-0 / s,
        y=y_obs_ss.values,
    ),
    s=s,
    m=m_ss.values,
    gamma_hyper=1e-2,
    dtype=torch.float32,
    device=args.device,
)

info("Fitting model.")
mapest, history = find_map(
    model_fit,
    lag=args.lag,
    stop_at=args.stop_at,
    learning_rate=5e-1,
    max_iter=int(1e4),
    clip_norm=args.clip_norm,
)

if args.device.startswith('cuda'):
    torch.cuda.empty_cache()

info("Finished fitting model.")
result = xr.Dataset(
    {
        "gamma": (["strain", "position"], mapest["gamma"]),
        "rho": (["strain"], mapest["rho"]),
        "alpha_hyper": ([], mapest["alpha_hyper"]),
        "pi": (["library_id", "strain"], mapest["pi"]),
        "epsilon": (["library_id"], mapest["epsilon"]),
        "rho_hyper": ([], mapest["rho_hyper"]),
        "epsilon_hyper": ([], mapest["epsilon_hyper"]),
        "pi_hyper": ([], mapest["pi_hyper"]),
        "alpha": (["library_id"], mapest["alpha"]),
        "p_noerr": (["library_id", "position"], mapest["p_noerr"]),
        "p": (["library_id", "position"], mapest["p"]),
        "y": (["library_id", "position"], y_obs_ss),
        "m": (["library_id", "position"], m_ss),
        "elbo_trace": (["iteration"], history),
    },
    coords=dict(
        strain=np.arange(s),
        position=data_fit.position,
        library_id=data_fit.library_id,
    ),
)

## Check Fit

In [None]:
plot_loss_history(result.elbo_trace.values)

In [None]:
plt.scatter(result.pi.mean('library_id'), result.rho)

In [None]:
np.abs(result.y - (result.p * result.m)).sum() / result.m.sum()

In [None]:
import seaborn as sns

sns.heatmap(result.gamma.to_pandas())

In [None]:
sns.clustermap(result.pi.to_pandas())

In [None]:
plt.hist(result.pi.max('strain'), bins=20)
#plt.yscale('log')
None

In [None]:
info("Building genotyping model.")
data_geno = data.sel(library_id=suff_cvrg_samples).isel(position=slice(0, 1000))
m = data_geno.sum("allele")
n, g = m.shape
y_obs = data_geno.sel(allele="alt")
s = args.nstrains
info(f"Model shape: n={n}, g={g}, s={s}")
model_geno = conditioned_model(
    model,
    data=dict(
        alpha=np.ones(n) * args.alpha,
        epsilon_hyper=args.epsilon_hyper,
        y=y_obs.values,
        pi=mapest['pi']
    ),
    s=s,
    m=m.values,
    gamma_hyper=args.gamma_hyper,
    dtype=torch.float32,
    device='cuda',
)

info("resulting model.")
mapest_geno, history_geno = find_map(
    model_geno,
    lag=args.lag,
    stop_at=args.stop_at,
    learning_rate=args.learning_rate,
    max_iter=args.max_iter,
    clip_norm=args.clip_norm,
)

In [None]:
plot_loss_history(result.elbo_trace.values)

In [None]:
result2 = xr.load_dataset('data/core/100022/gtpro.sfacts.nc')