# Post-processing (1/n)

- Log-likelihood
- Parameter estimation accuracy (see also snippets.py)
    - rates: Mean Absolute Percentage Error (MAPE)
    - root probs: Earth Mover's Distance (EMD)
        - copy: $d(i, j) = |i - j|$
        - category: $d(x, y) = 1 - \delta_{x,y}$ (i.e., a discrete metric space)
- Reconstruction

Category identification was performed by finding the optimal assignment that minimizes the total MAPE of gain loss rates $r$.


In [1]:
import gzip, json
from multiprocessing import Pool
from operator import itemgetter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

from colaml.__main__ import model_from_json, phytbl_from_json
from myconfig import ROOT_DIR, DATA_DIR, DATASET_DIR
from snippets import assign_mmm_category, mmm_param_dist

In [2]:
from threadpoolctl import threadpool_limits
threadpool_limits(1)

<threadpoolctl.threadpool_limits at 0x7fa8cbaec0a0>

In [3]:
import colaml
colaml.__version__

'0.1.dev14+g6c01617'

In [4]:
jobs = pd.read_csv('240628-batch-job-array.txt', sep='\t')
conditions = pd.read_csv(DATASET_DIR/'01-simulation01'/'conditions.tsv', sep='\t')

## Log-likelihood

In [5]:
def batch_loglik(job):
    try:
        phytbl, _ = phytbl_from_json(ROOT_DIR/job.infile, job.lmax)
        mmm = model_from_json(ROOT_DIR/job.outfile)
        stats = mmm.sufficient_stats(phytbl)
        loglik = stats.col_loglik.sum()
    
    except Exception:
        loglik = np.nan
        
    jobinfo = job.loc[['conditionID', 'data_rep', 'fit_rep']].to_dict()
    
    return jobinfo | dict(loglik = loglik)

In [6]:
with Pool(16) as pool:
    loglik = pd.DataFrame(tqdm(pool.imap(
        batch_loglik, 
        map(itemgetter(1), jobs.iterrows())
    )))

0it [00:00, ?it/s]

In [7]:
loglik.to_pickle(DATA_DIR/'post-batch'/'01-simulation01'/'loglik.pkl.bz2')

## Parameter estimation accuracy

In [8]:
def load_ans_par(infile):
    with gzip.open(infile, 'rt') as file:
        truth = json.load(file)
    return {k: np.asarray(v) for k, v in truth['params'].items()}

def load_est_par(outfile):
    mmm = model_from_json(outfile)
    return mmm._decompress_flat_params(mmm.flat_params)

In [9]:
def batch_dist(job):
    try:
        ans = load_ans_par(ROOT_DIR/job.infile )
        est = load_est_par(ROOT_DIR/job.outfile)
        dist = mmm_param_dist(ans, est)
    
    except Exception:
        dist = {}
        
    jobinfo = job.loc[['conditionID', 'data_rep', 'fit_rep']].to_dict()
    
    return jobinfo | dist

In [10]:
with Pool(16) as pool:
    param_dist = pd.DataFrame(tqdm(pool.imap(
        batch_dist, 
        map(itemgetter(1), jobs.iterrows())
    )))

0it [00:00, ?it/s]

In [11]:
param_dist.to_pickle(DATA_DIR/'post-batch'/'01-simulation01'/'param-accuracy.pkl.bz2')

## Reconstruction

In [12]:
def batch_recon_pcorrect(job):
    with gzip.open(ROOT_DIR/job.infile, 'rt') as file:
        truth = json.load(file)
    cpy_ans = pd.DataFrame(**truth['recon']).sort_index()
    cat_ans = pd.DataFrame(**truth['otherstates'][0]['states']).sort_index()

    correct_rate = {}
    for method in ('joint', 'marginal'):
        correct_rate[method, 'pcorrect_cpy'] = None 
        correct_rate[method, 'pcorrect_cat'] = None 

    try:
        phytbl, columns = phytbl_from_json(ROOT_DIR/job.infile, job.lmax)
        tree = phytbl.tree.to_ete3()
        depth = {node.name: tree.get_distance(node) for node in tree.traverse()}
        
        mmm = model_from_json(ROOT_DIR/job.outfile)
        ans = load_ans_par(ROOT_DIR/job.infile )
        est = load_est_par(ROOT_DIR/job.outfile)
        ans_idx, est_idx = assign_mmm_category(ans, est)
        inv = np.empty_like(est_idx)
        inv[est_idx] = ans_idx
    
        for method in ('joint', 'marginal'):
            try:
                recon = mmm.reconstruct(phytbl, method=method)
                cpy_recon = pd.DataFrame.from_dict(
                    recon.to_dict(), 
                    orient='index', columns=columns
                ).sort_index()
                cat_recon = pd.DataFrame.from_dict(
                    recon.otherstates['categories'].to_dict(), 
                    orient='index', columns=columns
                ).sort_index().apply(inv.__getitem__)

                correct_rate[method, 'pcorrect_cpy'] = cpy_recon.eq(cpy_ans).mean(axis=1).to_dict()
                correct_rate[method, 'pcorrect_cat'] = cat_recon.eq(cat_ans).mean(axis=1).to_dict()

            except Exception:
                continue
    
    except Exception:
        pass
    
    jobinfo = job.loc[['conditionID', 'data_rep', 'fit_rep']].to_dict()
    
    return jobinfo | dict(depth=depth) | correct_rate

In [13]:
with Pool(16) as pool:
    recon_pcorrect = (
        pd.concat(tqdm(map(pd.DataFrame, pool.imap(
            batch_recon_pcorrect, 
            map(itemgetter(1), jobs.iterrows())
        ))))
        .reset_index(names='node')
        .set_index(['conditionID','data_rep','fit_rep','node'])
        .reset_index()
    )

0it [00:00, ?it/s]

  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP = np.log(self.root_probs)
  log_rootP 

In [14]:
recon_pcorrect.to_pickle(DATA_DIR/'post-batch'/'01-simulation01'/'recon-accuracy.pkl.bz2')