In [25]:
import os
import numpy as np
from collections import defaultdict

In [7]:
class UnionFind:
    def __init__(self):
        self.parent = {}

    def find(self,x):
        if self.parent[x] != x:
            self.parent[x] = self.find(self.parent[x])
        return self.parent[x]

    def union(self,x,y):
        root_x = self.find(x)
        root_y = self.find(y)
        if root_x != root_y:
            self.parent[root_x] = root_y

In [11]:
def get_disjoint_sets(setlist):
    uf = UnionFind()
    for s in setlist:
        for x in s:
            if x not in uf.parent:
                uf.parent[x] = x

    for reverse_id in [x for x in uf.parent if x.endswith("_Reversed")]:
        forward_id = reverse_id.replace("_Reversed","")
        if forward_id in uf.parent:
            uf.union(reverse_id,forward_id)

    for s in setlist:
        s_list = list(s)
        for i in range(1,len(s_list)):
            uf.union(s_list[0],s_list[i])

    groups = defaultdict(set)
    for x in uf.parent:
        groups[uf.find(x)].add(x)

    return list(groups.values())

In [12]:
setlist = [
    {"A", "B", "C"},
    {"C", "D", "E"},
    {"E", "F"},
    {"X", "Y"},
    {"Y_Reversed"},  # Should be merged with {"X", "Y"}
    {"G", "H_Reversed"},
]

In [13]:
groups = get_disjoint_sets(setlist)
groups

[{'A', 'B', 'C', 'D', 'E', 'F'}, {'X', 'Y', 'Y_Reversed'}, {'G', 'H_Reversed'}]

In [68]:
import pandas as pd

In [39]:
pairs[:10]

[{'BCL11A_1542', 'RBM38_4242'},
 {'BCL11A_6910_Reversed', 'peak34207_Reversed'},
 {'BCL11A_6910_Reversed', 'peak11111_Reversed'},
 {'BCL11A_6910_Reversed', 'peak79481'},
 {'BCL11A_6910_Reversed', 'peak3064'},
 {'BCL11A_6910_Reversed', 'HBA2_1011_Reversed'},
 {'BCL11A_6910_Reversed', 'peak69368_Reversed'},
 {'BCL11A_6910_Reversed', 'peak27018_Reversed'},
 {'BCL11A_6910_Reversed', 'peak451_Reversed'},
 {'BCL11A_6910_Reversed', 'HBA2_4317_Reversed'}]

In [36]:
work_dir = "/home/brett/work/OrthogonalTrainValSplits/hashFrag/data/tutorial.create_orthogonal_splits.work"
path = os.path.join(work_dir,"hashFrag.blastn.processed.tsv")


In [71]:
for chunk_df in pd.read_csv(path,sep="\t",chunksize=250_000,header=None,names=["qseqid","sseqid","score"]):
    display(chunk_df)
    break

Unnamed: 0,qseqid,sseqid,score
0,peak83127_Reversed,peak83127_Reversed,200
1,peak83127_Reversed,peak70672,18
2,peak83127_Reversed,peak3437,18
3,peak83127_Reversed,peak70885,18
4,peak83127_Reversed,peak17162_Reversed,16
...,...,...,...
249995,peak49129_Reversed,BCL11A_2470_Reversed,16
249996,peak49129_Reversed,peak1432,18
249997,peak49129_Reversed,peak52511_Reversed,19
249998,peak49129_Reversed,peak52938_Reversed,16


In [73]:
for qseqid,sseqid,score in zip(chunk_df["qseqid"],chunk_df["sseqid"],chunk_df["score"]):
    print(score,type(score))
    break

200 <class 'int'>


In [52]:
threshold = 80
setlist = []
idset = set()
homology_set = set()
covered_ids = set()
with open(path,"r") as handle:
    handle.readline()
    for line in handle:
        id_i,id_j,score = line.strip().split("\t")
        idset.add(id_i)
        idset.add(id_j)
        if float(score) >= threshold:
            pairset = {id_i,id_j}
            setlist.append(pairset)
            homology_set.update(pairset)
        else:
            setlist.append({id_i})
            setlist.append({id_j})
len(setlist),len(idset),len(homology_set)

(1453331, 10000, 9999)

In [53]:
groups = get_disjoint_sets(setlist)
len(groups)

4700

In [54]:
n_folds = 5

In [55]:
import random
random.seed(1)
# random.shuffle(groups)
groups = sorted(groups, key=len, reverse=True)

In [56]:
N = np.sum([len(g) for g in groups])
N

10000

In [57]:
max_group_size = np.max([len(g) for g in groups])

In [61]:
import random
random.seed(1)
# random.shuffle(groups)
groups = sorted(groups, key=len, reverse=True)
folds = [[] for _ in range(n_folds)]
foldsizes = [0]*n_folds
for group in groups:
    i = foldsizes.index(min(foldsizes))
    folds[i].extend(group)
    foldsizes[i] += len(group)

foldsizes

[2000, 2000, 2000, 2000, 2000]

In [66]:
for i,group_i in enumerate(groups):
    for j,group_j in enumerate(groups):
        if i == j: continue
        assert group_i.isdisjoint(group_j)

In [62]:
foldsizes

[2000, 2000, 2000, 2000, 2000]

In [63]:
groups_path = os.path.join(work_dir,"homologous_groups.pure.csv")
group_dict = defaultdict(set)
with open(groups_path,"r") as handle:
    for line in handle:
        id_i,group = line.strip().split("\t")
        group_dict[group].add(id_i)
previous_groups = list(group_dict.values())
len(previous_groups)

92