In [53]:
import pandas as pd
import pathlib as pl
import collections as col
import re
import itertools as itt

root_folder = pl.Path("/home/ebertp/work/projects/hgsvc/2024_debug_hsmdup")

def load_assembly_info(fai_file):

    df = pd.read_csv(fai_file, sep="\t", header=None, usecols=[0,1], names=["contig", "length"])

    def plain_name(seqname):
        if "#" in seqname:
            plain = seqname.split("#")[0]
        elif "_" in seqname:
            plain = seqname.split("_")[0]
        else:
            plain = seqname
        assert re.match("^[0-9hutglc]+$", plain), plain
        return plain

    
    df["plain"] = df["contig"].apply(plain_name)
    seq_names = set(df["plain"].values)
    total_length = df["length"].sum()
    length_lookup = dict((row.plain, row.length) for row in df.itertuples())
    return seq_names, total_length, length_lookup


def summarize_assembly(filename):

    sample = "NA19240"
    assert sample in filename
    if "hap1" in filename or "hap2" in filename:
        asm_unit = filename.split(".")[0].split("-")[-1]
        assm_key = (sample, "lin-phased", asm_unit)
    else:
        if "p_ctg" in filename:
            level = "contig"
        else:
            assert "p_utg" in filename
            level = "unitig"
        if ".bp." in filename:
            phasing = "gfa-partial"
        elif ".dip." in filename:
            phasing = "gfa-phased"
        else:
            raise
        assm_key = (sample, phasing, level)

    return assm_key


def compute_ab_similarity(contigs_a, contigs_b):


    total_a = sum(contigs_a.values())
    matched_a = 0
    missed_seq = 0
    missed_seq_names = set()
    missed_seq_len = 0
    for seq1, seqlen1 in contigs_a.items():
        try:
            seqlen2 = contigs_b[seq1]
            matched_a += seqlen1
        except KeyError:
            missed_seq += 1
            missed_seq_names.add(seq1)
            missed_seq_len += seqlen1

    missed_seq_pct = round(missed_seq / len(contigs_a) * 100, 2)
    matched_bp_pct = round(matched_a / total_a * 100, 2)
    missed_seq_len = round(missed_seq_len / int(1e6), 2)
    return missed_seq_pct, matched_bp_pct, missed_seq_len
    

def check_assembly_consistency(key1, contigs1, key2, contigs2):

    if len(set(contigs1.keys()).intersection(set(contigs2.keys()))) == 0:
        return (key1, key2, 0, 0, -1), (key2, key1, 0, 0, -1)
    else:
        sim_12 = compute_ab_similarity(contigs1, contigs2)
        sim_21 = compute_ab_similarity(contigs2, contigs1)
        return (key1, key2, *sim_12), (key2, key1, *sim_21)


def parse_rukki_paths(gaf_file, assembler="hifiasm"):

    haps = {
        "HAPLOTYPE1": "h1",
        "HAPLOTYPE2": "h2",
        "NA": "h0"
    }

    re_utigs = re.compile("utg[0-9]+[lc]")
    if assembler == "verkko":
        re_utigs = re.compile("utig4\-[0-9]+")

    tig_assignment = col.defaultdict(set)
    with open(gaf_file, "r") as table:
        _ = table.readline()
        for line in table:
            _, path, assigned_hap = line.strip().split()
            hap = haps[assigned_hap]
            utigs_in_path = re.findall(re_utigs, path)
            assert len(utigs_in_path) > 0, line.strip()
            tig_assignment[hap] = tig_assignment[hap].union(set(utigs_in_path))
            for utig in utigs_in_path:
                tig_assignment[utig].add(hap)

    #print(assembler)
    # for a,b in itt.combinations(sorted(tig_assignment.keys()), 2):
    #     ambig = tig_assignment[a].intersection(tig_assignment[b])
    #     if len(ambig) > 0:
    #         print(a, len(tig_assignment[a]))
    #         print(b, len(tig_assignment[b]))
    #         print(len(ambig))
    
    return tig_assignment


def read_alignments(norm_paf, only_primary=True, lower_cap_mapq=0):

    paf = pd.read_csv(norm_paf, sep="\t", header=0)
    if only_primary:
        paf = paf.loc[paf["tp_align_type"] != 2, :].copy()
        paf = paf.loc[paf["align_matching"] > 10000, :].copy()
    paf = paf.loc[paf["mapq"] > lower_cap_mapq, :].copy()
    return paf


def assign_haplotype_to_contigs(paf, tig_assignment):

    contig_haps = col.defaultdict(list)
    for (unitig, contig), alns in paf.groupby(["query_name", "target_name"]):
        utig = unitig.split("_")[0]
        hap = tig_assignment[utig]
        assert len(hap) > 0
        contig_haps[contig].append(
            (
                tuple(hap), utig, alns["align_matching"].sum()
            )
        )
    
    return contig_haps


def summarize_contig_haps(contig_haps, contigs):

    records = []
    
    for contig, length in contigs.items():
        record = {
            "seq": contig,
            "seqlen": length,
            "h1_bp": 0,
            "h1_pct": 0,
            "h2_bp": 0,
            "h2_pct": 0,
            "h0_bp": 0,
            "h0_pct": 0,
            "miss_bp": 0,
            "miss_pct": 0
        }
        if contig in contig_haps:
            hap_assign = contig_haps[contig]
            hap_count = col.Counter()
            hap_count["miss_bp"] = length
            for (hap, utig, assigned_bp) in hap_assign:
                if len(hap) == 2:
                    assert "h0" not in hap
                    hap_label = "MX"
                else:
                    assert len(hap) == 1, (hap, utig, assigned_bp)
                    hap_label = hap[0]
                hap_count[f"{hap_label}_bp"] += assigned_bp
                hap_count["miss_bp"] -= assigned_bp
            for hap_label, hap_bp in hap_count.items():
                record[hap_label] = hap_bp
                pct = round(hap_bp / length * 100, 2)
                record[hap_label.replace("_bp", "_pct")] = pct
            records.append(record)
        else:
            record["miss_bp"] = length
            record["miss_pct"] = 100
            records.append(record)

    df = pd.DataFrame.from_records(records)
    df.fillna(0, inplace=True)
    df.sort_values(["seqlen", "seq"], inplace=True, ascending=[False, True])
    return df


tig_assignment = parse_rukki_paths(root_folder.joinpath("NA19240hifi_rukki_paths.gaf"), "hifiasm")
#tig_assignment = parse_rukki_paths(root_folder.joinpath("verkko", "NA19240_rukki_paths.gaf"), "verkko")
#raise

assembly_infos = col.defaultdict(dict)

for fai_file in root_folder.glob("**/*.fai"):
    assm_key = summarize_assembly(fai_file.name)
    seq_names, total_length, length_lookup = load_assembly_info(fai_file)
    assembly_infos[assm_key] = {
        "seq_names": seq_names,
        "total_length": total_length,
        "seq_to_length": length_lookup
    }

records = []

for a, b in itt.combinations(sorted(assembly_infos.keys()), 2):
    a_to_b, b_to_a = check_assembly_consistency(a, assembly_infos[a]["seq_to_length"], b, assembly_infos[b]["seq_to_length"])
    records.append(a_to_b)
    records.append(b_to_a)
    
summary = pd.DataFrame.from_records(records, columns=["asm1", "asm2", "missed_seq_pct", "matched_bp_pct", "missed_seq_len_Mbp"])

paf_files = [
    root_folder.joinpath("NA19240.utg-to-bp.p_ctg.norm-paf.tsv.gz"),
    root_folder.joinpath("NA19240.utg-to-dip.p_ctg.norm-paf.tsv.gz")
]

aln_to_key = {
    "utg-to-bp": ('NA19240', 'gfa-partial', 'contig'),
    "utg-to-dip": ('NA19240', 'gfa-phased', 'contig'),
}

for paf_file in paf_files:
    paf = read_alignments(paf_file)
    contig_haps = assign_haplotype_to_contigs(paf, tig_assignment)
    assm_key = aln_to_key[paf_file.name.split(".")[1]]
    contig_lengths = assembly_infos[assm_key]["seq_to_length"]
    hap_contig_summary = summarize_contig_haps(contig_haps, contig_lengths)

    if assm_key[1] == "gfa-phased":
        dip_support = read_alignments(root_folder.joinpath(
            "NA19240.bp-to-dip.p_ctg.norm-paf.tsv.gz"
        ))
        collect_support = col.defaultdict(list)
        for (bp_ctg, dip_ctg), alns in dip_support.groupby(["query_name", "target_name"]):
            support = alns["align_matching"].sum()
            qry_support = round(support / alns["query_length"].iloc[0] * 100, 2)
            trg_support = round(support / alns["target_length"].iloc[0] * 100, 2)
            collect_support[dip_ctg].append(
                f"{bp_ctg}:Q-{qry_support}:T-{trg_support}"
            )
        add_support = dict()
        for dip_ctg, support_info in collect_support.items():
            add_support[dip_ctg] = "|".join(support_info)
        hap_contig_summary["bp_support"] = hap_contig_summary["seq"].apply(lambda seq: "no-support" if seq not in add_support else add_support[seq])
                   
    
    outfile = root_folder.joinpath(
        paf_file.name.replace(".tsv.gz", ".ctg-hap-summary.tsv")
    )
    wg_length = hap_contig_summary["seqlen"].sum()
    with open(outfile, "w") as dump:
        _ = dump.write(f"# wg: {wg_length} bp\n")
        hap_contig_summary.to_csv(dump, sep="\t", header=True, index=False)
    
    
