In [15]:
import pathlib as pl
import pandas as pd
import pickle as pck
import re as re
import collections as col
import os

mount = pl.Path("/mounts/hilbert/project")
remote = pl.Path("/gpfs/project")

sample_folder = pl.Path("projects/medbioinf/data/00_RESTRUCTURE/sample-centric")

data_root = pl.Path("projects/medbioinf/data/00_RESTRUCTURE")

cache_mapping = pl.Path(".").joinpath(".cache", "file_cell_map.pck")
cache_mapping.parent.mkdir(exist_ok=True, parents=True)

hifi_cell = re.compile("(m[0-9a-z_U]{16,24}|[ABCDEFSPL0-9_\-]{22,28})")
ont_cell = re.compile("(P|G)[A-Z0-9_\-]{8,16}")

contains_date = re.compile("20[0-9]{2}[0-9]{4}")
contains_sample = re.compile("(HG|NA|GM)[0-9]{5}")

cells = re.compile(f"((?P<hifi>{hifi_cell})|(?P<ont>{ont_cell}))")

samples_out = pl.Path("../samples/check_kmer.tsv")

if cache_mapping.is_file():
    with open(cache_mapping, "rb") as dump:
        mapping = pck.load(dump)
else:
    mapping = col.defaultdict(set)
    for fofn in mount.joinpath(sample_folder).glob("**/*.fofn"):
        if "strandseq" in fofn.name:
            continue
        with open(fofn, "r") as listing:
            sample = fofn.name.split("_")[0]
            for line in listing:
                is_ont = False
                is_hifi = False
                if not line.strip() or line.startswith("#"):
                    continue
                fofn_name = fofn.name
                
                file_rel_path = pl.Path(line.strip())
                file_name = file_rel_path.name
                file_path = mount.joinpath(data_root, file_rel_path)
                remote_path = remote.joinpath(data_root, file_rel_path)
                if "nanopore" in str(remote_path):
                    is_ont = True
                elif "pacbio_hifi" in str(remote_path):
                    is_hifi = True
                else:
                    raise
                file_size = os.stat(file_path).st_size
                stripped_name = file_name
                
                mobj = contains_date.search(stripped_name)
                if mobj is not None:
                    stripped_name = stripped_name.replace(mobj.group(0), "")
                mobj = contains_sample.search(stripped_name)
                if mobj is not None:
                    stripped_name = stripped_name.replace(mobj.group(0), "")
                
                if is_hifi:
                    mobj = hifi_cell.search(stripped_name)
                elif is_ont:
                    mobj = ont_cell.search(stripped_name)
                else:
                    raise
                if mobj is None:
                    raise ValueError(f"None: {file_name}")
                else:
                    cell_id = mobj.group(0)
                    cell_id = cell_id.strip("-_.")
                    mapping[("size", file_name)] = file_size
                    mapping[("remote", file_name)] = remote_path
                    read_type = "hifi" if is_hifi else "ont"
                    mapping[(read_type, cell_id, fofn_name)].add(file_name)
                    mapping[(read_type, file_name, fofn_name)].add(cell_id)
                    
    with open(cache_mapping, "wb") as dump:
        pck.dump(mapping, dump)


def to_gb(size_in_byte):
    return round(size_in_byte / 1e9, 1)


def determine_mean_file_size(data_files, read_type):
    
    total_size = 0
    total_files = 0
    for (smp, rtype), files in data_files.items():
        if rtype != read_type:
            continue
        total_size += sum(t[0] for t in files)
        total_files += len(files)
    return total_size / total_files


def group_files(data_files, mean_size, read_type):
    grouping = []
    for (smp, rtype), files in data_files.items():
        sample_grouping = []
        if rtype != read_type:
            continue
        if len(files) < 3:
            gnum = 1
            for size, file_name, remote_path in sorted(files, reverse=True):
                group_name = f"{smp}-{read_type}-G{gnum}"
                group_size = size
                group_files = str(remote_path)
                sample_grouping.append((group_name, 1, group_size, group_files))
                gnum += 1
        else:
            gnum = 1
            gsize = 0
            gcard = 0
            gfiles = []
            sample_grouping = []
            assert len(set(files)) == len(files)
            for size, file_name, remote_path in sorted(files, reverse=False):
                gsize += size
                gcard += 1
                gfiles.append(str(remote_path))
                if gsize > mean_size:
                    group_name = f"{smp}-{read_type}-G{gnum}"
                    group_files = ",".join(gfiles)
                    sample_grouping.append((group_name, gcard, gsize, group_files))
                    gnum += 1
                    gsize = 0
                    gcard = 0
                    gfiles = []
            if gfiles:
                group_name = f"{smp}-{read_type}-G{gnum}"
                group_files = ','.join(sorted(gfiles))
                assert gsize > mean_size, gsize
                sample_grouping.append((group_name, gcard, gsize, group_files))
        grouping.extend(sample_grouping)

    df = pd.DataFrame.from_records(
        grouping,
        columns=["sample", "cardinality", "size_byte", "input"]
    )
    return df

group_by_sample = col.defaultdict(list)
for k,v in mapping.items():
    if k[0] in ["size", "remote"]:
        continue
    read_type, cell_or_file, fofn_name = k
    sample = fofn_name.split("_")[0]
    if cell_or_file.endswith(".fastq.gz"):
        continue
    for file_name in v:
        file_size = mapping[("size", file_name)]
        remote_path = mapping[("remote", file_name)]
        group_by_sample[(sample, read_type)].append((file_size, file_name, remote_path))
        
mean_hifi_size = determine_mean_file_size(group_by_sample, "hifi")
print("HiFi ", mean_hifi_size / 1e9)
mean_ont_size = determine_mean_file_size(group_by_sample, "ont")
print("ONT ", mean_ont_size / 1e9)

hifi_groups = group_files(group_by_sample, int(mean_hifi_size * 0.5), "hifi")
ont_groups = group_files(group_by_sample, int(mean_ont_size * 0.5), "ont")
file_groups = pd.concat([hifi_groups, ont_groups], axis=0, ignore_index=False)

# preparing a subset of the samples to establish a baseline:
# select all samples that have a |group| > 1
samples = file_groups.loc[file_groups["cardinality"] > 1, "sample"]
samples = set(s.split("-")[0] for s in samples)

def select_sample(sample):
    return any(sample.startswith(s) for s in samples)

file_groups["select"] = file_groups["sample"].apply(select_sample)
print(file_groups.shape)
subset = file_groups.loc[file_groups["select"], :].copy()
subset.drop(["select"], axis=1, inplace=True)
subset.sort_values(["sample", "cardinality"], inplace=True)
subset.to_csv(samples_out, header=True, index=False, sep="\t")
print(subset.shape)

print(subset.head(20))


HiFi  22.71897417145091
ONT  23.479925036283827
(474, 5)
(275, 4)
              sample  cardinality    size_byte   
223  HG00096-hifi-G1            1  27056024726  \
224  HG00096-hifi-G2            1  27894899294   
225  HG00096-hifi-G3            1  28440271135   
191   HG00096-ont-G1            2  30962974958   
192   HG00096-ont-G2            1  31844157844   
182  HG00171-hifi-G1            1  29252258578   
183  HG00171-hifi-G2            1  30387867787   
184  HG00171-hifi-G3            1  30608884961   
149   HG00171-ont-G1            4  13708644482   
150   HG00171-ont-G2            1  13052043593   
151   HG00171-ont-G3            1  15349687373   
152   HG00171-ont-G4            1  15767547329   
153   HG00171-ont-G5            1  16176464864   
154   HG00171-ont-G6            1  18188089102   
155   HG00171-ont-G7            1  20460243472   
156   HG00171-ont-G8            1  22892175968   
24   HG00512-hifi-G1            2  23097046267   
25   HG00512-hifi-G2            1 