# Effect of acceleration on EM algorithm (1/n)

plain EM vs. parabolic EM (original) vs. parabolic EM (modified)

Collect the trajectories of the EM algorithms when starting from the same initial values.  
See [here](../inspect-real-data/240718-1-fitting-bacteria.ipynb) for the initial values. They have been chosen as the best results of multiple short EMs (i.e., the emEM initialization).

In [1]:
import json
import gzip
from collections import namedtuple
from multiprocessing import Pool

from tqdm.auto import tqdm

from colaml.__main__ import model_from_json, phytbl_from_json, CoLaMLEncoder
from myconfig import DATA_DIR
REF_RUN_DIR = DATA_DIR/'inspect-real-data'/'bacteria'

In [2]:
from threadpoolctl import threadpool_limits
threadpool_limits(1, user_api='blas')

<threadpoolctl.threadpool_limits at 0x7fd9bec79690>

In [3]:
phytbl, columns = phytbl_from_json(REF_RUN_DIR/'bacteria-lmax2-filt05.json.gz', 2)

In [4]:
# plain EM vs. parabolic EM (original) vs. parabolic EM (modified)
configs = dict(
    plainEM         = dict(method='EM'),  
    parabolicEMorig = dict(method='parabolic_EM', heuristics=False), 
    parabolicEMmod  = dict(method='parabolic_EM', heuristics=True ), 
)

In [5]:
Job = namedtuple('Job', ['ref_fitting', 'method_key'])
jobs = [
    Job(ref_fitting, method)
    for ref_fitting in sorted(REF_RUN_DIR.glob('bacteria-lmax2-filt05.fit*.json.gz'))
    for method in configs.keys()
]

In [6]:
class NaiveLogger:
    def __init__(self):
        self.storage = []
    def info(self, fmt, msg):
        self.storage.append(msg)

In [7]:
def collect_EM_traj(job):
    logger = NaiveLogger()
    with gzip.open(job.ref_fitting, 'rt') as file:
        init = json.load(file)['result']['init']
    mmm = model_from_json(job.ref_fitting)
    mmm.update(**init)    
    mmm.fit(
        phytbl, logger=logger, **configs[job.method_key], 
        max_rounds=5000, stop_criteria=(1e-6, 1e-6), show_progress=False
    )
    log_path = DATA_DIR/'misc-EM-acceleration'/job.method_key/f'{job.method_key}-{job.ref_fitting.name}'
    with gzip.open(log_path, 'wt') as file:
        json.dump(logger.storage, file, cls=CoLaMLEncoder, indent=2)

with Pool(10) as pool:
    for _ in tqdm(pool.imap_unordered(collect_EM_traj, jobs)):
        pass

0it [00:00, ?it/s]