In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split
from IPython.display import clear_output
import numpy as np

In [2]:
def parse_domain(dom_str):

    def sort_domains(domain_list):
        idx = sorted(range(len(domain_list)), key=lambda k: domain_list[k][1][0])
        return [domain_list[i] for i in idx]

    unsorted = [[it[0], list(map(int, it[1].split(':')))] for it in (dom.split(';') for dom in dom_str.split('|'))]

    return sort_domains(unsorted)

def dom_arc_remove_consec_disorder(dom_list):
    doms = [dom[0] for dom in dom_list]
    doms = [dom for i, dom in enumerate(doms) if (i==0) or not (dom == doms[i-1] == 'd')]
    return '|'.join(doms)


def count_real_domains(dom_list):
    return len([dom for dom in dom_list if dom[0][0] not in ['d', 'm']])


def train_seq_redundant(dom_list, test_doms):
    return set([dom[0] for dom in dom_list]).isdisjoint(test_doms)

In [3]:
from collections import Counter

train_dom_arcs = []
with open('train_unbalanced.csv') as f:
    for idx, line in enumerate(f):
        train_dom_arcs.append('|'.join([dom[0] for dom in parse_domain(line.split(',')[1]) if dom[0] != 'd']))
        if idx % 1000000 == 0:
            clear_output()
            print(f'Processing Line {idx}')

train_dom_arcs = Counter(train_dom_arcs)
print(train_dom_arcs.most_common(100))
print(len(train_dom_arcs))

Processing Line 24000000
[('PF00126|PF03466', 256240), ('PF00072', 159496), ('PF00903', 145433), ('PF00892|PF00892', 135718), ('PF00115', 131139), ('PF00248', 130834), ('PF00171', 122629), ('PF00378', 121962), ('PF01850', 88321), ('PF02771|PF02770|PF00441', 86910), ('PF00294', 85240), ('PF00873', 84865), ('PF00392|PF07729', 81142), ('PF00293', 80483), ('PF00155', 78640), ('PF00561', 75949), ('PF00266', 70701), ('PF00583', 70581), ('PF00069', 70193), ('PF01266', 67347), ('PF01925', 67095), ('PF00202', 66828), ('PF01261', 65744), ('PF12697', 64858), ('PF00067', 63450), ('PF04542|PF08281', 63051), ('PF02653', 62884), ('PF13439|PF00534', 59233), ('PF01381', 58815), ('PF00291', 58559), ('PF02515', 56072), ('PF01041', 55489), ('PF01810', 55202), ('PF00300', 54011), ('PF16363', 53986), ('PF01042', 53654), ('PF01053', 52961), ('PF00582', 52931), ('PF00108|PF02803', 52148), ('PF01425', 51554), ('PF03372', 50722), ('PF00701', 49522), ('PF07992|PF02852', 49516), ('PF00085', 48545), ('PF00459', 48

In [4]:
train_doms = []
with open('train_unbalanced.csv') as f:
    for idx, line in enumerate(f):
        train_doms += [dom[0] for dom in parse_domain(line.split(',')[1]) if dom[0]!='d']
        if idx % 1000000 == 0:
            clear_output()
            print(f'Processing Line {idx}')
train_doms = Counter(train_doms)
print(train_doms.most_common(20))
print(len(train_doms))

Processing Line 24000000
[('PF00892', 287541), ('PF03466', 270771), ('PF00126', 261369), ('PF00072', 245043), ('PF00903', 156893), ('PF00440', 146417), ('PF00248', 136379), ('PF00115', 131201), ('PF00571', 129879), ('PF00171', 125051), ('PF00378', 123724), ('PF07992', 121499), ('PF00392', 116539), ('PF00582', 104824), ('PF00441', 95936), ('PF02771', 94335), ('PF02770', 93739), ('PF07690', 93247), ('PF00153', 92688), ('PF01850', 88387)]
14036


In [5]:
print(len([dom for dom in train_dom_arcs if dom.count('PF')==1]))
print(len([dom for dom in train_dom_arcs if dom.count('PF')>1]))

13912
20280


In [6]:
pd.read_csv('train_unbalanced.csv', header=None).sample(frac=1).reset_index(drop=True).to_csv('train_unbalanced_shuffled.csv', header=None, index=None)

In [11]:
to_keep = 0
dom_arc_count = {dom_arc: 0 for dom_arc in train_dom_arcs}
dom_count = {dom: 0 for dom in train_doms}
with open('train_unbalanced_shuffled.csv') as f:
    with open('train_balanced_alt.csv', 'w') as fOut:
        for idx, line in enumerate(f):
            dom_list = [dom[0] for dom in parse_domain(line.split(',')[1])]
            dom_arc = '|'.join(dom_list)
            doms = [dom for dom in dom_list if dom!='d']
            dom_arc_no_disorder = '|'.join(doms)
            if ('d' not in dom_arc) and (dom_arc.count('PF') == 1):
                if max([dom_count[dom] for dom in doms]) < 1 and np.random.rand()<0.03:
                    fOut.write(line)
                    to_keep += 1
                    dom_arc_count[dom_arc_no_disorder] += 1
                    for dom in doms:
                        dom_count[dom] += 1
            elif ('d' not in dom_arc) and (dom_arc.count('PF') > 1):
                if np.random.rand() <= (0.5**dom_arc_count[dom_arc_no_disorder]):
                    fOut.write(line)
                    to_keep += 1
                    dom_arc_count[dom_arc_no_disorder] += 1
                    for dom in doms:
                        dom_count[dom] += 1
            elif ('d' in dom_arc) and (dom_arc.count('PF') == 1):
                if np.random.rand() <= (0.03**dom_arc_count[dom_arc_no_disorder]):
                    fOut.write(line)
                    to_keep += 1
                    dom_arc_count[dom_arc_no_disorder] += 1
                    for dom in doms:
                        dom_count[dom] += 1
            else:
                if np.random.rand() <= (0.7**dom_arc_count[dom_arc_no_disorder]):
                    fOut.write(line)
                    to_keep += 1
                    dom_arc_count[dom_arc_no_disorder] += 1
                    for dom in doms:
                        dom_count[dom] += 1
            
print(to_keep)

59403


In [12]:
with open('train_balanced_alt.csv') as f:
    sing_no_idr_counter = 0
    mult_no_idr_counter = 0
    sing_idr_counter = 0
    mult_idr_counter = 0
    for line in f:
        dom_arc_this_prot = '|'.join([dom[0] for dom in parse_domain(line.split(',')[1])])
        if '|' not in dom_arc_this_prot:
            sing_no_idr_counter += 1
        elif dom_arc_this_prot.count('PF') > 1 and not 'd' in dom_arc_this_prot:
            mult_no_idr_counter += 1
        elif dom_arc_this_prot.count('PF') == 1 and 'd' in dom_arc_this_prot:
            sing_idr_counter += 1
        else:
            mult_idr_counter += 1
print(sing_no_idr_counter)
print(mult_no_idr_counter)
print(sing_idr_counter)
print(mult_idr_counter)

5939
43234
6002
4228


In [13]:
from collections import Counter

train_balanced_dom_arcs = []
with open('train_balanced_alt.csv') as f:
    for idx, line in enumerate(f):
        train_balanced_dom_arcs.append('|'.join([dom[0] for dom in parse_domain(line.split(',')[1]) if dom[0]!='d']))
        if idx % 1000000 == 0:
            clear_output()
            print(f'Processing Line {idx}')

train_balanced_dom_arcs = Counter(train_balanced_dom_arcs)
print(train_balanced_dom_arcs.most_common(100))
print(len(train_balanced_dom_arcs))

Processing Line 0
[('PF08032|PF00588', 21), ('PF00076|PF00076', 21), ('PF00892|PF00892', 20), ('PF00126|PF03466', 19), ('PF02881|PF00448', 19), ('PF01281|PF03948', 19), ('PF05697|PF00254|PF05698', 19), ('PF01386|PF14693', 18), ('PF00364|PF02779|PF02780', 18), ('PF02771|PF02770|PF00441', 17), ('PF03453|PF00994|PF03454', 17), ('PF00781|PF19279', 17), ('PF07992|PF02852', 17), ('PF13499|PF13499', 17), ('PF00392|PF07729', 17), ('PF05198|PF00707', 17), ('PF14849|PF02096', 17), ('PF00333|PF03719', 17), ('PF00479|PF02781', 16), ('PF13439|PF00534', 16), ('PF01300|PF03481', 16), ('PF00153|PF00153|PF00153', 16), ('PF14622|PF00035', 16), ('PF02321|PF02321', 16), ('PF04542|PF08281', 16), ('PF13089|PF02503|PF17941|PF13090', 16), ('PF02881|PF00448|PF02978', 16), ('PF00440|PF02909', 16), ('PF01225|PF08245|PF02875', 15), ('PF00108|PF02803', 15), ('PF05201|PF01488|PF00745', 15), ('PF00081|PF02777', 15), ('PF03721|PF00984|PF03720', 15), ('PF19300|PF00528', 15), ('PF01321|PF00557', 15), ('PF02737|PF00725'