In [8]:
import pathlib as pl
import functools as fnt

import pandas as pd

_PROJECT_BASE = pl.Path("/home/ebertp/work/code/cubi/project-run-hgsvc-hybrid-assemblies").resolve(strict=True)

VERKKO_ASSM_STATS_TABLE_FILE = _PROJECT_BASE.joinpath(
        "annotations", "autogen",
        "verkko_assemblies.hgsvc3.tsv"
)

HIFIASM_ASSM_STATS_TABLE_FILE = _PROJECT_BASE.joinpath(
        "annotations", "autogen",
        "hifiasm_assemblies.hgsvc3.tsv"
)


class TableManager:
    
    def __init__(self, table_file_path):
        self.table = pd.read_csv(
            table_file_path, sep="\t", header=[0,1],
            comment="#", index_col=[0,1,2]
        )
        
        self.rows = self._build_row_cache()
        self.columns = self._build_column_cache()
        self.dtypes = {
            "cov": float,
            "length": int,
            "num": int,
            "aun": int,
            "n50": int,
            "pct_dip": float
        }
        return
        
    def _build_row_cache(self):
        
        sample_rows = dict()
        for row_num, sample in enumerate(self.table.index.get_level_values("sample"), start=0):
            assert sample not in sample_rows
            sample_rows[sample] = row_num
        assert len(sample_rows) == self.table.shape[0]
        return sample_rows

    def _simplify_statistic_name(self, stat_name):
        
        if "at" in stat_name:
            stat_name = stat_name.split("_at_")[0]
        # last part now must be threshold
        threshold_value = stat_name.split("_grt_")[-1]
        threshold_value = threshold_value.strip("bp").lower()
        if stat_name.startswith("length") or stat_name.startswith("total"):
            statistic = stat_name.split("_")[1].lower()
        elif stat_name.startswith("cov"):
            statistic = stat_name.split("_")[0]
        elif stat_name.startswith("pct"):
            statistic = stat_name.rsplit("_", 3)[0]
        else:
            raise ValueError(stat_name)
        return statistic, threshold_value
    
    def _build_column_cache(self):
        
        asm_columns = dict()
        for col_num, (asm_unit, statistic) in enumerate(self.table.columns, start=0):
            unit_suffix = asm_unit.split("-", 1)[-1]
            stat, threshold = self._simplify_statistic_name(statistic)
            asm_columns[(unit_suffix, stat, threshold)] = col_num
        assert len(asm_columns) == self.table.shape[1]
        return asm_columns
    
    @fnt.lru_cache(10000)
    def get_stat(self, sample, unit_stat, divide=None):
        dtype = self.dtypes[unit_stat[1]]
        if len(unit_stat) == 2:
            unit_stat = unit_stat[0], unit_stat[1], "0"
        value = dtype(self.table.iat[self.rows[sample], self.columns[unit_stat]])
        if divide is not None:
            value = round(value/divide, 1)
        return value
    
    def get_outliers(self, unit_stat, threshold):

        if len(unit_stat) == 2:
            unit_stat = unit_stat[0], unit_stat[1], "0"
        column_idx = self.columns[unit_stat]
        subset = self.table.iloc[:, column_idx]
        outliers = subset.loc[subset > threshold, :]
        return outliers


ASSM_STATS = None
if VERKKO_ASSM_STATS_TABLE_FILE.is_file():
    ASSM_STATS = TableManager(VERKKO_ASSM_STATS_TABLE_FILE)
    VRK_ASSM_STATS = ASSM_STATS
    
if HIFIASM_ASSM_STATS_TABLE_FILE.is_file():
    HSM_ASSM_STATS = TableManager(HIFIASM_ASSM_STATS_TABLE_FILE)