In [None]:
import numpy as np
import time
import datetime
from pathlib import Path

In [None]:
population_size = int(5e8)
rate = 2
entropy = 42
number_samples = 1000
number_simulations = 500
simulation_number = 0

In [None]:
assert simulation_number <= number_simulations
seed_sequence = np.random.SeedSequence(entropy)
seed = seed_sequence.spawn(number_simulations)[simulation_number-1]

In [None]:
base_relative_abundances = [1e-4, 1e-3, 1e-2]

relative_abundances = [relative_abundance * number
                       for relative_abundance 
                       in base_relative_abundances
                       for number in (1,2,5) 
                       for repeat in range(10)]

relative_abundances += [1-sum(relative_abundances)]
frequencies = np.array(relative_abundances)

## CTPMHg Simulation - Iterate over droplets, then over marginals

In [None]:
def CTPMHg_simulation_droplets_strains(population_size, rate, seed, number_samples, frequencies):
    # probably doing a little bit too much implicit rounding here in general but... too lazy to change
    sub_population_sizes = (population_size * frequencies).astype(int)
    
    rng = np.random.default_rng(seed)
    total_sample_sizes = rng.poisson(lam=rate, size=number_samples)

    # seems like this variable is also only used for unit testing in this function
    # although this unit test is more important b/c if it fails then sample wasn't
    # actually from the truncated Poisson distribution so...
    cumulative_sample_sizes = np.cumsum(total_sample_sizes)
    try:
        assert cumulative_sample_sizes[-1] <= population_size
    except AssertionError as e:
        raise NotImplementedError(e)

    # seems like in this function I don't actually need this variable for algorithm
    # just for like unit testing at the end of the function, that is what it seems to me
    remaining_population_sizes = np.sum(sub_population_sizes) * np.ones(number_samples).astype(int)
    remaining_population_sizes[1:] -= cumulative_sample_sizes[:-1]

    remaining_sub_population_sizes = np.zeros((len(frequencies), number_samples)).astype(int)
    remaining_sub_population_sizes[:,0] = sub_population_sizes

    sample_sizes = np.zeros((len(frequencies), number_samples)).astype(int)

    for d in range(number_samples-1):
        droplet_d_sample = rng.multivariate_hypergeometric(
                                            colors=remaining_sub_population_sizes[:,d],
                                            nsample=total_sample_sizes[d],
                                            method='marginals'
                                            )
        remaining_sub_population_sizes[:,d+1] = remaining_sub_population_sizes[:,d] - droplet_d_sample
        sample_sizes[:,d] = droplet_d_sample
        
    droplet_d_sample = rng.multivariate_hypergeometric(
                                        colors=remaining_sub_population_sizes[:,number_samples-1],
                                        nsample=total_sample_sizes[number_samples-1],
                                        method='marginals'
                                        )
    sample_sizes[:,number_samples-1] = droplet_d_sample

    assert np.all(remaining_population_sizes == np.sum(remaining_sub_population_sizes, axis=0))
    assert np.all(total_sample_sizes == np.sum(sample_sizes,axis=0))

    return {"pop_sizes": remaining_sub_population_sizes, "sample_sizes": sample_sizes}

In [None]:
prettify = lambda integer: str(integer).zfill(len(str(number_simulations)))

In [None]:
results_filename = 'npzfiles/CTPMHg_results.{}.npz'.format(prettify(simulation_number))
results_file = Path(results_filename)

if results_file.is_file():
    # simulation already ran successfully on previous attempt
    pass
else:
    start_time = time.time()
    results = CTPMHg_simulation_droplets_strains(population_size=population_size, 
                                                rate=rate, seed=seed, 
                                                number_samples=number_samples, 
                                                frequencies=frequencies)
    runtime = time.time() - start_time
    
    with open('notebook_logs/runtime.{}.log'.format(prettify(simulation_number)), 'a') as file_pointer:
        # https://stackoverflow.com/a/775095/10634604
        runtime_string = str(datetime.timedelta(seconds=runtime))
        file_pointer.write('Runtime was {} in Hours:Minutes:Seconds.'.format(runtime_string))

    np.savez_compressed(results_filename, **results)
            
    # Maybe this will help prevent memory leaks? 
    # Honestly not sure what happens when using papermill with multiprocessing.
    del(results)