In [192]:
import re
import math
import json
from collections import defaultdict

import numpy as np
import torch
from sklearn.model_selection import train_test_split

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

Using cpu device


In [3]:
scope_file = "../../data/astral-scopedom-seqres-gd-sel-gs-bib-95-2.08.fa"
scope_pattern = re.compile("[abcdefghijkl]\.[0-9]{1,3}\.[0-9]{1,3}\.[0-9]{1,3}")

In [4]:
test_header = ">d5lqwy_ g.101.1.1 (Y:) Pre-mRNA splicing factor Phf5 / Rds3 {Baker's yeast (Saccharomyces cerevisiae) [TaxId: 4932]}"
scope_pattern.search(test_header).group(0)

'g.101.1.1'

In [14]:
def load_data(data_file):
    sequences = []
    scope_codes = []

    with open(data_file, "r") as fin:
        scope_code = ""
        sequence = ""
        for line in fin:
            if line.startswith(">"):
                sequences.append(sequence)
                scope_codes.append(scope_code)

                sequence = ""
                scope_code = ""

                header = line.strip()
                m = scope_pattern.search(header)
                if m is not None:
                    scope_code = m.group(0)
            else:
                sequence += line.strip()

    sequences = sequences[1:]
    scope_codes = scope_codes[1:]
    
    return sequences, scope_codes

In [15]:
def generate_buckets_(scope_codes):
    scope_levels = defaultdict(list)
    for i, scope_code in enumerate(scope_codes):
        splitted_code = scope_code.split(".")
        for j in range(1,5):
            subcode = ".".join(splitted_code[0:j])
            scope_levels[subcode].append(i)
            
    #filtered_scope_levels = dict()
    #for scope, sequence_ids in scope_levels.items():
    #    if len(sequence_ids) > bucket_size:
    #        filtered_scope_levels[scope] = sequence_ids
            
    return list(scope_levels.items())

In [187]:
def get_scope_similarity_level(code1, code2):
    split_code1 = code1.split(".")
    split_code2 = code2.split(".")
    for i in range(4,0,-1):
        if split_code1[:i] == split_code2[:i]:
            return ".".join(split_code1[:i])
    return ""
        

def generate_buckets(scope_codes):
    scope_levels = defaultdict(list)
    for i, scope_codei in enumerate(scope_codes):
        for j, scope_codej in enumerate(scope_codes[i+1:]):
            similarity_code = get_scope_similarity_level(scope_codei, scope_codej)
            scope_levels[similarity_code].append((i, j+i+1))
    return scope_levels
    


In [188]:
sequences, scope_codes = load_data(scope_file)
sequences_train, sequences_test, scope_codes_train, scope_codes_test = train_test_split(sequences, scope_codes, test_size=0.1)
scope_levels_train = generate_buckets(scope_codes_train)

In [234]:
del scope_levels_train[1]

In [235]:
level_distribution = defaultdict(int)
for k, v in scope_levels_train.items():
    print(k,len(v))
    if k == "":
        level_distribution[0] += len(v)
    else:
        level_distribution[len(k.split("."))] += len(v)
normalization_constant = sum(level_distribution.values())
for k, v in level_distribution.items():
    level_distribution[k] /= normalization_constant
level_distribution

 396664504
b 32691720
b.7.1.1 91
b.7.1 685
b.7 585
c 33251823
c.95.1.0 12880
c.95.1 9235
g 1111111
g.40.1.1 45
c.47.1 81164
c.47.1.13 28
c.47 504
g.37.1.1 11781
g.37.1 3734
f 180727
f.4 1889
f.4.1 122
f.4.1.0 36
c.37.1 242455
c.37.1.19 1596
b.19.1.2 630
b.19.1 1250
b.40 35841
b.40.4 16472
b.40.4.1 21
f.23 7022
f.23.12.1 3
f.23.12 3
d.118.1.0 45
d 22106211
d.118.1 120
a 11475676
a.4 122172
a.4.1 16577
a.4.1.12 15
c.97.1 762
c.97.1.0 66
c.97 396
b.122.1 973
b.122.1.1 36
d.79 6707
d.79.2.1 28
d.79.2 40
d.14.1 3360
d.14.1.5 55
c.1 615885
c.1.14 119
c.1.14.0 21
c.28.1.0 6
c.28.1 20
a.139.1.0 55
a.139.1 44
a.3.1 4230
a.3.1.1 1711
c.86.1.1 78
c.86.1 65
a.104.1 3393
a.104.1.0 6786
f.10.1.1 6
f.10.1 20
c.33.1 208
c.33.1.3 28
c.36.1 1118
c.36.1.0 120
b.1.1.0 1804050
b.1.1 3836940
b.1 2323709
d.144.1.7 8646
d.144.1 29079
c.2.1 170066
c.2.1.0 84666
c.92.2.0 528
c.92.2 614
c.92 449
a.104.1.1 406
g.46.1.1 153
c.2.1.2 8646
c.6.2 262
c.6 448
c.6.2.6 3
b.1.1.1 1216020
c.37.1.11 595
d.15 53069
d.15.1 15

b.40.6.1 1
d.157.1.5 1
d.162.1.1 465
d.162.1 1699
c.13 49
c.13.2 14
c.13.2.1 6
b.7.1.0 136
d.115.1 28
d.115.1.2 15
c.53.2.0 91
c.53 132
c.53.2 112
d.58.53 93
d.58.53.7 6
d.294.1.2 6
g.41.8 82
g.41.8.5 3
b.18.1.0 1176
d.38.1.9 1
c.1.10.5 1
c.69.1.17 78
g.9.1.1 300
g.9.1 51
g.98.1.1 1
a.208.1 2
a.208.1.1 1
c.8.5 46
c.8.5.1 28
c.111.1 7
a.103.1.1 91
g.69.1.1 36
a.4.5.20 1
a.118.1 3006
a.119.1 7
a.119.1.1 3
a.7.17.1 1
a.73.1 9
a.73.1.1 36
a.11.1 18
a.11.1.1 3
c.2.1.1 351
d.58.10 185
d.58.10.1 36
a.4.5.47 55
a.245.1.1 78
b.181.1 8
b.181.1.0 6
b.2.6.1 10
d.278.1 9
d.278.1.2 1
c.87.1.0 861
c.36.1.9 28
c.119.1 60
c.119.1.2 1
g.39.1.1 28
a.25.1.0 1128
c.56.3 24
c.56.3.0 28
g.41.8.1 3
b.18.1.2 28
c.30.1.0 406
b.47.1.1 210
d.321.1 1
c.23.16.0 231
b.1.6.1 66
b.60.1.2 703
b.60.1 6511
c.25.1.0 136
c.94.1.2 171
a.42.1.1 45
a.42.1 10
d.18.1.2 3
d.190.1.2 45
a.155.1.1 15
c.23.12.1 28
c.23.12 36
c.7.1 48
c.7.1.1 3
d.233.1.1 1
d.224.1 118
d.224.1.0 36
f.71.1.1 6
f.71.1 4
g.43.1 24
g.43.1.1 28
c.14.1.2 3


d.24.1.1 6
d.160.1.0 10
a.13.1.1 3
b.100.1.1 3
b.100.1 3
d.110.2.1 15
d.58.40.1 3
b.40.1.1 36
b.43.4.0 276
d.79.3.1 45
c.1.8.5 105
d.89.1.1 1
d.193.1.1 3
d.68.1.1 6
d.68.1 4
d.77.1 47
d.77.1.1 10
c.15.1 49
c.15.1.3 3
a.89.1.0 15
c.55.1.2 15
b.42.4.1 21
d.201.1.1 1
a.48.4.1 3
a.48.4 3
d.58.10.2 10
c.1.9.3 21
b.18.1.4 21
b.69.4.1 210
a.6.1 181
a.6.1.3 6
b.139.1.1 1
b.33.1.0 66
d.108.1.10 3
d.58.53.3 1
g.14.1.0 15
a.66.1.1 15
c.30.1.1 36
c.1.2.1 66
d.68.4.1 3
d.17.4.2 36
e.13.1 3
c.43.1.4 3
a.27.1 77
a.27.1.0 21
a.25.2.0 45
d.261.1 1
c.25.1.2 3
c.51.4.2 1
g.20.1.1 91
g.20.1 14
g.16.2.1 3
c.61.1.2 15
c.19.1.1 10
c.19.1 11
a.26.1.0 3
a.53.1.0 91
a.53.1 56
c.23.14 33
c.23.14.1 1
f.33.1.1 1
c.92.2.4 3
a.48.1 2
a.48.1.0 1
a.60.1.2 231
b.40.4.0 1540
d.67.3.1 21
d.67 65
d.67.3 14
d.275.1.1 1
c.30.1.8 1
a.12.1.1 1
d.325.1.2 1
d.81.1.5 10
b.21.1.3 1
d.96.1.0 120
d.96.1 408
d.96 68
b.60.1.0 496
b.113.1.1 15
b.113.1 6
d.44.1.0 496
c.55.3.2 45
a.138.1.1 171
b.82.1.10 1
a.159 38
d.13.1.0 36
d.13.1 237

b.21.1.1 78
d.20.1.4 1
a.24.2 9
a.24.2.0 3
a.161.1.1 3
b.82.2.5 6
b.169.1.1 3
f.19.1.0 21
c.55.3.8 3
d.219.1.0 10
d.219.1 10
c.108.1.16 3
b.118.1.1 3
b.123.1.0 1
a.39.3.1 1
d.96.1.3 6
b.29.1.6 21
c.69.1.22 1
c.46.1.0 10
c.66.1.13 6
f.23.17.1 1
b.34.4.4 3
a.60.12.0 6
f.1.1.1 6
d.66.1 31
d.66.1.4 1
d.111.1.0 3
d.264.1.1 3
d.264.1 3
a.41.1.0 3
a.41.1 6
c.146.1 2
c.146.1.0 1
d.96.1.2 6
c.154.1.1 1
f.37.1 6
f.37.1.1 3
b.80.7.1 1
b.82.1.11 6
g.41.5 141
g.41.5.1 153
b.49.3.1 21
e.29.1 23
b.80.5.1 1
f.25.1 3
c.65.1.0 45
d.110.3.9 3
c.1.31.0 3
c.34.1.0 36
f.45.1.1 1
c.1.13.1 10
a.118.8.9 36
b.40.9.1 1
b.34.13.3 6
a.24.16 14
d.80.1.3 45
a.29.2.1 36
c.57.1.2 6
d.79.2.0 10
f.23.43.1 15
c.23.16.9 1
b.68.6 15
c.48.1 89
c.48.1.2 10
d.170.1.0 28
d.170.1 17
d.170 10
d.17.4.16 1
d.391.1.1 3
d.58.4.14 1
g.3.2.1 91
a.86.1.1 6
a.118.1.14 36
a.4.1.18 6
c.47.1.8 1
d.47.1.1 10
d.17.4.10 10
d.91.1 1
d.194.1.2 10
b.88.1.0 3
a.2.3.0 66
c.53.2.1 28
a.24.28 1
d.96.1.4 15
a.192.1 1
b.34.11 18
b.34.11.1 3
a.29.5.0 3

b.87.1 5
d.159.1.2 3
a.2.5.1 1
d.222.1.1 1
d.222.1 2
a.216.1.1 3
a.60.9.1 3
b.1.6.3 1
a.24.22.1 1
c.1.26.1 6
c.37.1.26 1
d.122.1.2 10
g.96.1.1 6
c.37.1.25 1
b.157.1 6
c.26.2.4 21
b.29.1.26 1
c.73.1.2 6
b.77.2.0 10
g.34.1.1 1
b.155.1.1 3
c.111.1.2 3
d.110.2.2 3
d.298.1.0 3
c.1.18.2 1
d.284.1.1 1
a.5.7.0 1
d.58.42.1 3
a.124.1 12
a.124.1.0 3
a.60.2.4 3
f.4.4 1
c.103.1.1 10
a.175.1.1 1
a.175.1 4
b.3.6.1 15
b.179.1.1 3
a.32.1.1 6
a.118.19.1 1
b.143.1.1 1
d.12.1 16
d.12.1.1 1
d.159.1.0 10
d.186 1
f.4.7.1 3
d.150.1.1 1
c.76.1.2 3
c.1.16.3 3
d.33.1.2 1
e.19.1.0 10
d.223.1.2 10
b.40.15.0 1
g.80.1.1 15
a.2.13.1 3
b.40.7.1 1
f.23.26.1 6
d.144.1.6 10
b.29.1.18 3
a.160.1.2 1
a.4.5.32 1
a.102.1.5 1
a.7.7.1 15
d.58.36.0 15
d.41.2.1 15
d.17.4.3 6
a.29.8.2 1
a.4.5.1 1
a.29.10.1 1
a.217.1.1 10
e.62.1 5
e.62.1.1 1
a.28.3.2 1
b.115.1 12
b.115.1.0 3
b.157.1.0 15
d.165.1.0 36
e.76.1.0 3
a.24.10 50
b.36.1.5 1
d.149.1.1 10
d.149.1 5
d.60.1 10
f.5.1.1 3
d.230.5.1 1
g.95.1.1 6
b.20.1.1 1
b.85.6.1 3
b.85.6 3
d.7

a.102.1.4 1
a.118.9.3 1
g.3.19.1 1
e.15.1.1 1
b.92.1.6 1
a.4.13.3 1
c.106.1.1 1
c.107.1.0 1
e.38.1.1 3
b.35.1.0 1
a.118.28.1 1
c.23.16.8 1
d.322.1.2 1
a.202.1.1 1
a.4.5.55 1
b.121.2.3 1
a.118.1.22 1
d.113.1.4 1
d.92.1.14 1
b.85.3 1
d.58.18.8 1
d.17.4.20 1
d.198.3.1 1
b.60.1.4 1
a.209.1.0 1
d.8.1.0 1
c.31.1.2 3
d.13.1.3 1
d.38.1.7 1
a.29.6 2
g.3.11.5 1
a.280.1.0 1
d.52.2.0 3
d.58.33.0 1
g.42.1.1 3
d.301.1.1 3
g.79.1.1 1
c.2.1.13 1
e.22.1.1 1
a.156.1.1 1
b.125.1.1 1
b.1.33 1
a.71.1.1 1
a.171.1.1 3
d.58.18.11 1
d.279.1 3
d.279.1.1 3
a.4.5.67 1
d.280.1.1 1
c.77.1.3 3
a.246.2.1 1
a.142.1.1 1
b.1.27.0 1
c.55.2.1 6
a.2.17.1 1
d.102.1.0 1
d.295.1.1 1
a.4.5.68 1
a.24.10.0 1
g.21.1.0 1
a.301.1.1 3
d.58.15.1 1
b.142.1 2
b.142.1.2 1
c.44.2.0 6
a.257.1.1 1
d.95.2.0 1
b.87.1.1 1
a.6.1.4 1
c.66.1.53 1
a.118.1.29 1
d.232.1.1 1
a.8.6.1 1
f.46.1 2
c.136.1.1 1
c.10.2.4 1
d.83.1 2
d.83.1.1 1
g.44.1.3 1
d.241.1.1 1
c.56.5.3 1
d.142.2.4 1
b.61.7.1 3
g.81.1.1 3
a.213.1.2 3
d.169.1.3 1
d.396.1.0 1
d.41.2.2 6


defaultdict(int,
            {0: 0.7775270870858653,
             1: 0.1979144226271354,
             4: 0.007135514750262893,
             3: 0.009920308534048128,
             2: 0.007502667002688264})

In [206]:
def compute_scope_pair_cdf(scope_level_pairs, smoothing_func):
    scope_pdf = list()
    sum_npairs = 0
    for scope_level, pairs in scope_level_pairs.items():
        npairs = len(pairs)
        sum_npairs += npairs
        scope_pdf.append((scope_level, smoothing_func(npairs)))
        
    normalization_constant = sum([i[1] for i in scope_pdf])
    
    scope_pdf = [(k,v/normalization_constant) for k, v in scope_pdf]
   
    cdf = np.cumsum([i[1] for i in scope_pdf])
    scope_cdf = [(k,v) for k, v in zip([i[0] for i in scope_pdf], cdf)]

    return scope_cdf, scope_pdf

In [222]:
scope_pair_cdf_train, scope_pair_pdf_train = compute_scope_pair_cdf(scope_levels_train, lambda x: np.power(x, 1))

In [223]:

"""
def compute_scope_pair_cdf_(scope_levels, total_pairs, smoothing_func):
    scope_pdf = list()
    sum_npairs = 0
    for scope_level, sequence_ids in scope_levels:
        npairs = len(sequence_ids)*(len(sequence_ids)-1)/2.0
        sum_npairs += npairs
        scope_pdf.append((scope_level, smoothing_func(npairs)))
        
    different_scope_npairs = total_pairs - sum_npairs
    print(f"total pairs {total_pairs} and different pairs {different_scope_npairs}")
    scope_pdf.append(("", smoothing_func(different_scope_npairs)))
        
    normalization_constant = sum([i[1] for i in scope_pdf])
    
    scope_pdf = [(k,v/normalization_constant) for k, v in scope_pdf]
   
    cdf = np.cumsum([i[1] for i in scope_pdf])
    scope_cdf = zip([i[0] for i in scope_pdf], cdf)

    return scope_cdf, scope_pdf
"""

"""
def compute_scope_cdf(scope_levels, smoothing_func):
    scope_pdf = [(scope_level, smoothing_func(len(sequence_ids))) for scope_level, sequence_ids in scope_levels]

    normalization_constant = sum([i[1] for i in scope_pdf])
    
    scope_pdf = [(k,v/normalization_constant) for k, v in scope_pdf]
    
    cdf = np.cumsum([i[1] for i in scope_pdf])
    scope_cdf = zip([i[0] for i in scope_pdf], cdf)
    
    return scope_cdf, scope_pdf
"""

def get_sampled_element(cdf):
    a = np.random.uniform(0, 1)
    return np.argmax(cdf>=a)

def run_sampling(cdf, n=5000):
    for k in np.arange(n):
        yield get_sampled_element(cdf)

In [226]:
sequence_pairs = list()

for i in run_sampling(np.array([i[1] for i in scope_pair_cdf_train]), n=10000):
    k, _ = scope_pair_cdf_train[i]
    #print(f"code {k}")
    pairs = scope_levels_train[k]
    #print(f"number of pairs for code {len(pairs)}")
    pairi = np.random.randint(0,len(pairs), None)
    #print(f"sampled pair index {pairi}: {pairs[pairi]}")
    i1, i2 = pairs[pairi]
    sequence_pairs.append((sequences_train[i1], 
                           sequences_train[i2], 
                           scope_codes_train[i1], 
                           scope_codes_train[i2]))

In [227]:
level_distribution = defaultdict(int)
for p in sequence_pairs:
    common_level = get_scope_similarity_level(p[2],p[3])
    if common_level == "":
        level_distribution[0] += 1/len(sequence_pairs)
    else:
        level_distribution[len(common_level.split("."))] += 1/len(sequence_pairs)
level_distribution

defaultdict(int,
            {0: 0.7783999999999306,
             3: 0.010999999999999989,
             1: 0.19849999999999446,
             2: 0.006600000000000004,
             4: 0.005500000000000001})

In [165]:
sequences, scope_codes = load_data(scope_file)
sequences_train, sequences_test, scope_codes_train, scope_codes_test = train_test_split(sequences, scope_codes, test_size=0.1)
scope_levels_train = generate_buckets(scope_codes_train)
total_pairs_train = len(sequences_train)*(len(sequences_train)-1)/2.0
scope_pair_cdf_train, scope_pair_pdf_train = compute_scope_pair_cdf(scope_levels_train, total_pairs_train, lambda x: np.power(x, 1))
pair_cdf_train = np.array([i[1] for i in scope_pair_cdf_train])
scope_cdf_train, scope_pdf_train = compute_scope_cdf(scope_levels_train, lambda x: x)
cdf_train = np.array([i[1] for i in scope_cdf_train])

total pairs 510161653.0 and different pairs 371609922.0


In [166]:
level_counts = defaultdict(int)
n = 10000
for i in run_sampling(cdf_train, n):
    level = len(scope_pdf_train[i][0].split("."))
    level_counts[level] += 1.0/n
level_counts

defaultdict(int,
            {2: 0.2489999999999889,
             4: 0.24959999999998883,
             1: 0.2433999999999895,
             3: 0.2579999999999879})

In [167]:
scope_pair_pdf_train[-1]

('', 0.7284160222838231)

In [168]:
sequence_pairs = list()

for i in run_sampling(pair_cdf_train, n=100000):
    if i < len(scope_levels_train):
        label, sequenceids = scope_levels_train[i]
        i1, i2 = np.random.randint(0,len(sequenceids), 2)
        seqid1 = sequenceids[i1]
        seqid2 = sequenceids[i2]
        sequence_pairs.append((sequences_train[seqid1], 
                               sequences_train[seqid2], 
                               scope_codes_train[seqid1], 
                               scope_codes_train[seqid2]))
    else:
        scope1i = 0
        scope2i = 0
        while scope1i == scope2i:
            scope1i, scope2i = [i for i in run_sampling(cdf_train, n=2)]
        scope1, sequenceids1 = scope_levels_train[scope1i]
        scope2, sequenceids2 = scope_levels_train[scope2i]
        i1 = np.random.randint(0,len(sequenceids1), 1)[0]
        i2 = np.random.randint(0,len(sequenceids2), 1)[0]
        seqid1 = sequenceids1[i1]
        seqid2 = sequenceids2[i2]
        sequence_pairs.append((sequences_train[seqid1], 
                               sequences_train[seqid2], 
                               scope_codes_train[seqid1], 
                               scope_codes_train[seqid2]))

In [169]:
count = 0
for p in sequence_pairs:
    if p[2] == p[3]:
        count += 1
count

3277

In [144]:
sequences, scope_codes = load_data(scope_file)
sequences_train, sequences_test, scope_codes_train, scope_codes_test = train_test_split(sequences, scope_codes, test_size=0.1)
scope_levels_train = generate_buckets(scope_codes_train)
scope_cdf_train, scope_pdf_train = compute_scope_cdf(scope_levels_train, lambda x: np.power(x, 0.8))
cdf_train = np.array([i[1] for i in scope_cdf_train])

In [145]:
sequence_pairs = list()

for i in range(100000):
        scope1i, scope2i = [i for i in run_sampling(cdf_train, n=2)]
        scope1, sequences1 = scope_levels_train[scope1i]
        scope2, sequences2 = scope_levels_train[scope2i]
        i1 = np.random.randint(0,len(sequences1), 1)[0]
        i2 = np.random.randint(0,len(sequences2), 1)[0]
        sequence_pairs.append((sequences1[i1], sequences2[i2], scope1, scope2))

In [146]:
count = 0
for p in sequence_pairs:
    if p[2] == p[3]:
        count += 1
count

366

In [147]:
sequence_pairs[:100]

[(10865, 11330, 'd.185.1', 'c'),
 (30315, 13272, 'd.309', 'c.55.1.0'),
 (26171, 20988, 'a', 'c'),
 (13827, 21092, 'g.3.6', 'a.32.1.1'),
 (31278, 5936, 'd.110', 'c'),
 (25095, 282, 'b.18.1.4', 'a.38.1'),
 (11874, 27738, 'd.14', 'c.1.14.0'),
 (8526, 22023, 'b.121.4.5', 'b'),
 (7129, 10178, 'd.130.1', 'g.44.1.0'),
 (2954, 13339, 'a.4.1.12', 'c.47'),
 (4361, 20060, 'a.45.1.1', 'g.101.1.1'),
 (13080, 7702, 'd.245.1', 'b'),
 (12919, 18104, 'b.1.2', 'd.58.7'),
 (7390, 28877, 'c.55.1', 'a'),
 (7446, 22349, 'd.185.1.1', 'b.74.1'),
 (18957, 29736, 'd.20.1.0', 'a.22.1'),
 (20399, 26286, 'd.92.1', 'd.139.1.1'),
 (31810, 20457, 'a.24.19.0', 'b'),
 (10645, 19534, 'b.78.1.0', 'd.186.1.1'),
 (22816, 8796, 'd.66.1', 'd.58.7.0'),
 (11626, 29086, 'g.79', 'c.37.1'),
 (154, 11591, 'b.1.1', 'b'),
 (24399, 19071, 'b.82', 'g'),
 (29226, 8688, 'b.147.1', 'd.41.2.0'),
 (27507, 11461, 'b.23', 'd'),
 (20991, 21266, 'b.19.1', 'd.198.1.1'),
 (17573, 20858, 'c.94.1.0', 'b.50.1'),
 (12703, 10519, 'c.67.1', 'c.98.1'),