In [None]:
import pathlib
import pandas as pd
import pickle as pkl
import collections

%cd -q "/home/ebertp/work/code/cubi/project-run-hgsvc-hybrid-assemblies/notebooks"

_PROJECT_CONFIG_NB = str(pathlib.Path("00_project_config.ipynb").resolve(strict=True))
_PLOT_CONFIG_NB = str(pathlib.Path("05_plot_config.ipynb").resolve(strict=True))

%run $_PROJECT_CONFIG_NB
%run $_PLOT_CONFIG_NB

_MYNAME="cache-seq-lengths"
_MYSTAMP=get_nb_stamp(_MYNAME)

ASSEMBLER="verkko"

_SEQLENS_CACHE_FILES = {
    "seqlens": PROJECT_NB_CACHE.joinpath(f"cache.seqlens.{ASSEMBLER}.pck"),
    "gaplens": PROJECT_NB_CACHE.joinpath(f"cache.gaplens.{ASSEMBLER}.pck"),
    "adjlens": PROJECT_NB_CACHE.joinpath(f"cache.adjlens.{ASSEMBLER}.pck"),
    "gapseq": PROJECT_NB_CACHE.joinpath(f"cache.gapseq.{ASSEMBLER}.pck"),
    "gapfree": PROJECT_NB_CACHE.joinpath(f"cache.gapfree.{ASSEMBLER}.pck"),
}


def get_sample_asm_unit(file_name):

    if "-hifiasm-" in file_name:
        sample, _, asm_unit = file_name.split(".")[0].split("-")
        sample = f"{sample}.hsm-ps-sseq"
    else:
        parts = file_name.split(".")
        try:
            asm_unit = parts[2].split("-")[1]
        except IndexError:
            # file name w/o asm unit
            asm_unit = None
        sample = parts[0] + "." + parts[1]
    return sample, asm_unit


def read_ngap_files():

    source_folder = PROJECT_DATA_ROOT.joinpath(
        f"2024_ngaps/hgsvc/{ASSEMBLER}"
    ).resolve(strict=True)

    print("Reading cache input data from: ", source_folder)

    gaplens = collections.Counter()
    for bed_file in source_folder.glob("*.bed"):
        sample, _ = get_sample_asm_unit(bed_file.name)
        df = pd.read_csv(
            bed_file, sep="\t", header=None, skiprows=1,
            usecols=[0,3,4], names=["seq", "sample", "length"]
        )
        assert df["sample"].iloc[0] == sample
        for row in df.itertuples():
            gaplens[(sample, row.seq)] += row.length
            gaplens[sample] += row.length
    if not gaplens:
        raise RuntimeError
    return gaplens


def read_fasta_index_files():

    source_folder = PROJECT_DATA_ROOT.joinpath(
        f"2024_fasta_index/hgsvc/{ASSEMBLER}"
    ).resolve(strict=True)

    print("Reading cache input data from: ", source_folder)

    seqlens = collections.Counter()
    for fai_file in source_folder.glob("*.fai"):
        if "contaminants" in fai_file.name:
            continue
        sample, asm_unit = get_sample_asm_unit(fai_file.name)
        df = pd.read_csv(fai_file, sep="\t", header=None, usecols=[0,1], names=["seq", "length"])
        for row in df.itertuples():
            assert (sample, row.seq) not in seqlens
            seqlens[(sample, None, row.seq)] = row.length
            seqlens[(sample, asm_unit, row.seq)] = row.length
            seqlens[(sample, asm_unit, None)] += row.length
            seqlens[sample] += row.length
    return seqlens


def compute_adjusted_sequence_length():

    seqlens = load_seqlen_cache("seqlens")
    gaplens = load_seqlen_cache("gaplens")
    assert isinstance(gaplens, collections.Counter)

    check_keys = [k for k in seqlens.keys() if len(k) > 1 and k[1] is None]
    seq_w_gap = collections.defaultdict(set)
    seq_wo_gap = collections.defaultdict(set)
    adj_seqlen = collections.Counter()
    for (smp, _, seq) in check_keys:
        gaplen = gaplens[(smp, seq)]
        if gaplen > 0:
            seq_w_gap[smp].add(seq)
        else:
            seq_wo_gap[smp].add(seq)
        seqlen = seqlens[(smp, None, seq)]
        adj_len = seqlen - gaplen
        assert adj_len > 0
        adj_seqlen[(smp, seq)] = adj_len
        adj_seqlen[smp] += adj_len
    return adj_seqlen, seq_w_gap, seq_wo_gap


def _build_any_seqlen_cache(cache_file, data):
    with open(cache_file, "wb") as cache:
        cache_obj = {
            "data": data,
            "source": _MYSTAMP
        }
        pkl.dump(cache_obj, cache)
    return None


def build_seqlen_cache(which, force_rebuild=False):

    cache_file = _SEQLENS_CACHE_FILES[which]
    if force_rebuild or not cache_file.is_file():
        if which == "seqlens":
            seqlens = read_fasta_index_files()
            _ = _build_any_seqlen_cache(cache_file, seqlens)
        elif which == "gaplens":
            gaplens = read_ngap_files()
            _ = _build_any_seqlen_cache(cache_file, gaplens)
        elif which in ["adjlens", "gapfree", "gapseq"]:
            _ = build_seqlen_cache("seqlens", force_rebuild)
            _ = build_seqlen_cache("gaplens", force_rebuild)
            adjlens, seq_w_gap, gapfree_seq = compute_adjusted_sequence_length()
            cache_file = _SEQLENS_CACHE_FILES["adjlens"]
            _ = _build_any_seqlen_cache(cache_file, adjlens)
            cache_file = _SEQLENS_CACHE_FILES["gapfree"]
            _ = _build_any_seqlen_cache(cache_file, gapfree_seq)
            cache_file = _SEQLENS_CACHE_FILES["gapseq"]
            _ = _build_any_seqlen_cache(cache_file, seq_w_gap)
        else:
            raise ValueError(which)
    else:
        return


def load_seqlen_cache(which, force_rebuild=False, report_source=False):

    cache_file = _SEQLENS_CACHE_FILES[which]
    _ = build_seqlen_cache(which, force_rebuild)
    with open(cache_file, "rb") as cache:
        cache_obj = pkl.load(cache)
        if report_source:
            src = cache_obj["source"]
            print(f"Loading cache file {CACHE_FILE_SEQLEN}\nSource: {src}\n")
        data = cache_obj["data"]
    return data
