In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os
import dill
import numpy as np
import multiprocessing_on_dill as mp

from collections import defaultdict, Counter
from tqdm import tqdm

from f723.tools.urs.extraction import assemble_chains, get_sec_struct_model
from f723.tools.dataset.entities import NucleotideFeatures, PairFeatures, PairMeta, PairData, make_pair, Pair

In [3]:
NRLIST_PATH = '/home/mikhail/bioinformatics/data/nrlist_3.76_3.0A.csv'
JSON_DIR = '/home/mikhail/bioinformatics/data/NR_3.0/json'
CIF_DIR = '/home/mikhail/bioinformatics/data/NR_3.0/cif/'
OUT_DIR = '/home/mikhail/bioinformatics/data/NR_3.0/out/'
SEC_STRUCT_DIR = '/home/mikhail/bioinformatics/data/sec_struct'
DATASET_DIR = '/home/mikhail/bioinformatics/data/dataset_all_60'

MAX_PAIR_DIST = 60
NUM_NEIGHBOURS = 5

In [4]:
all_chains = {(chain.pdb_id, chain.id): chain 
              for chain in assemble_chains(NRLIST_PATH, CIF_DIR, OUT_DIR, SEC_STRUCT_DIR)}

100%|██████████| 1074/1074 [00:42<00:00, 25.04it/s]


In [5]:
def make_nucleotide_features(urs_model, chain):
    nucleotide_features = {}
    
    for nt in chain.nts:
        chain_entry = urs_model.chains[chain.id]['RES'][nt.index]
        
        secondary_structure = urs_model.NuclSS(nt.id)
        base = nt.base.lower()
        
        if chain_entry['WING']:
            wing = urs_model.wings['LU'][chain_entry['WING'] - 1]
            fragment_length = wing['LEN']
            fragment_index = nt.index - wing['START'][2]
        elif chain_entry['THREAD']:
            thread = urs_model.threads[chain_entry['THREAD'] - 1]
            fragment_length = thread['LEN']
            fragment_index = nt.index - thread['START'][2]
        else:
            fl, fed = None, None
        
        nucleotide_features[nt] = NucleotideFeatures(
            secondary_structure=secondary_structure, 
            base=base, 
            fragment_length=fragment_length, 
            fragment_index=fragment_index)
    
    return nucleotide_features


def make_neighbours(chain, nt, nucleotide_features, num_neighbours):
    begin_index, end_index = nt.index - num_neighbours - 1, nt.index + num_neighbours
    
    left_padding = [None] * max(0, -begin_index)
    right_padding = [None] * max(0, end_index - len(chain.nts))
    middle = chain.nts[max(0, begin_index):min(end_index, len(chain.nts))]
    nts = left_padding + middle + right_padding
    
    assert len(nts) == 2 * num_neighbours + 1
    
    return [nucleotide_features.get(nt) for nt in nts]


def make_relation(model, pair):
    return model.NuclRelation(pair.nt_left.id, pair.nt_right.id)


def make_features(model, chain, pair, num_neighbours):
    nucleotide_features = make_nucleotide_features(model, chain)
    
    return PairFeatures(
        neighbours_left=make_neighbours(chain, pair.nt_left, nucleotide_features, num_neighbours),
        neighbours_right=make_neighbours(chain, pair.nt_right, nucleotide_features, num_neighbours),
        relation=make_relation(model, pair))


def make_pair_type(pair, pair_types):
    for key, value in pair_types.items():
        if pair in value:
            return key
    
    return 'random'

    
def make_meta(chain, pair, pair_types):
    return PairMeta(pdb_id=chain.pdb_id, pair=pair, type=make_pair_type(pair, pair_types))
    

def make_pair_data(model, chain, pair, num_neighbours, pair_types):
    return PairData(
        features=make_features(model, chain, pair, num_neighbours), 
        meta=make_meta(chain, pair, pair_types))

In [6]:
with open('/home/mikhail/bioinformatics/data/chains_for_classification.json', 'r') as infile:
    chains_for_classification = json.load(infile)

In [7]:
chains_by_pdb_id = defaultdict(list)
for pdb_id, chain_id in chains_for_classification:
    chains_by_pdb_id[pdb_id].append(all_chains[pdb_id, chain_id])

In [8]:
def make_samples(batch_index, chains_by_pdb_id):
    samples = []
    
    for pdb_id, chains in chains_by_pdb_id.items():
        urs_model = get_sec_struct_model(SEC_STRUCT_DIR, pdb_id)

        for chain in chains:
            pair_types = {
                attr: [make_pair(bp.nt_left, bp.nt_right) for bp in getattr(chain, attr)]
                for attr in ['ss_bps', 'noncanonical_bps']
            }

            print(len(chain.nts))

            for index, nt_left in enumerate(chain.nts):
                for nt_right in chain.nts[index + 1:index + 1 + MAX_PAIR_DIST]:
                    pair_data = make_pair_data(
                        urs_model, chain, make_pair(nt_left, nt_right), NUM_NEIGHBOURS, pair_types)
                    samples.append(pair_data)

    with open(os.path.join(DATASET_DIR, 'batch_{}'.format(batch_index)), 'wb') as outfile:
        dill.dump(samples, outfile, protocol=dill.HIGHEST_PROTOCOL)

In [9]:
sum_nts = sum(len(chain.nts) for chains in chains_by_pdb_id.values() for chain in chains)
num_batches = 30

In [10]:
batches = [{}]
current_nts = 0

chains_by_pdb_id_items = list(chains_by_pdb_id.items())
np.random.shuffle(chains_by_pdb_id_items)

for pdb_id, chains in chains_by_pdb_id_items:
    pdb_nts = sum(len(chain.nts) for chain in chains)
    
    if pdb_nts + current_nts > sum_nts / num_batches:
        print(current_nts)
        current_nts = 0
        batches.append({})
    
    current_nts += pdb_nts
    batches[-1][pdb_id] = chains

batches = list(enumerate(batches))

316
3032
653
2254
2167
771
1800
3038
409
4515
1394
4558
2253
909
4191
883
4436
2254
1614
3003
2223
1009
5864
1129
1889
3428
1045
2866
2211


In [11]:
pool = mp.Pool(4)
result = pool.starmap(make_samples, batches)

32
36
54
75
35
159
1555
47
47
18
61
1800
117
57
137
122
2910
94
2203
67
74
54
56
15
53
62
69
34
24
69
35
96
71
126
28
28
101
75
39
39
80
14
74
77
94
16
17
58
14
63
70
118
2810
2923
29
22
76
20
20
24
42
92
161
33
77
144
34
74
390
94
70
17
26
120
2904
48
12
34
112
92
38
29
75
49
76
76
55
21
40
57
56
65
59
188
107
21
34
86
16
87
28
125
78
71
18
84
26
75
123
52
20
20
77
75
34
94
76
77
78
67
37
86
16
47
47
38
22
38
1778
106
1481
115
171
211
1526
1521
84
27
87
57
20
84
141
22
22
39
38
28
28
77
76
76
72
36
16
76
107
75
35
35
57
34
52
36
77
2881
1534
2915
71
133
183
118
22
22
68
61
95
55
40
58
107
37
71
92
35
120
112
75
29
40
36
76
25
38
161
21
23
71
92
16
28
94
26
26
71
102
27
40
16
76
32
174
58
37
55
77
76
11
71
71
75
18
93
54
64
35
71
75
21
29
107
188
35
76
88
17
76
76
51
3773
1465
84
75
65
75
37
88
3149
122
88
75
73
17
192
47
74
52
76
21
21
53
22
30
30
76
81
17
123
91
146
118
941
1278
36
28
16
80
77
24
47
89
36
55
34
34
69
125
113
25
46
85
53
64
35
25
86
76
43
73
78
55
77
74
133
213
25
41
