In [1]:
import json
import os
import pickle
import matplotlib.pyplot as plt
import numpy as np
import pickle

from collections import namedtuple, defaultdict, Counter
from itertools import groupby, chain, count
from operator import itemgetter
from scipy.spatial.distance import cdist
from tqdm import tqdm

In [2]:
MODELS_DIR = '/home/mikhail/bioinformatics/data/sec_struct'
DATA_DIR = '/home/mikhail/bioinformatics/data/fragment_data_1'

In [3]:
def load_model(pdb_id):
    with open(os.path.join(MODELS_DIR, '{}.pickle'.format(pdb_id)), 'rb') as infile:
        return pickle.load(infile)

In [4]:
with open('/home/mikhail/bioinformatics/data/nonredundant.txt', 'r') as infile:
    nonredundant_chain_ids = {tuple(chain_id.split('.cif1_')) for chain_id in infile.read().splitlines()}

In [5]:
urs_models = {pdb_id: load_model(pdb_id) for pdb_id, _ in nonredundant_chain_ids}

In [6]:
Fragment = namedtuple('Fragment', 'pdb_id ch type index global_index members')

In [7]:
def get_fragment_chains(chain_ids):
    fragment_chains = []
    
    for pdb_id, chain_id in chain_ids:
        fragment_id_res = []
        
        for res in urs_models[pdb_id].chains[chain_id]['RES']:
            assert sum(res[key] is not None for key in ['WING', 'THREAD']) == 1
            type_ = 'WING' if res['THREAD'] is None else 'THREAD'
            fragment_id_res.append(((type_, res[type_]), res))
        
        fragment_chains.append([
            Fragment(pdb_id, chain_id, type_, index, global_index, [res for _, res in group]) 
            for global_index, ((type_, index), group) in enumerate(groupby(fragment_id_res, itemgetter(0)))])
    
    return fragment_chains


def print_fragment_chain(fragment_chain):
    print('meta:', fragment_chain[0].pdb_id, fragment_chain[0].ch)
    for fragment in fragment_chain:
        print(fragment.type, fragment.index)

In [8]:
def get_coords(res):
    return [[atom[key] for key in 'XYZ'] for atom in res['ATOMS']]


def get_fragment_atoms(fragment):
    return list(chain.from_iterable([get_coords(res) for res in fragment.members]))


def fragment_distance(left, right):
    return np.min(cdist(get_fragment_atoms(left), get_fragment_atoms(right), 'euclidean'))


def fragment_sequence_distance(left, right):
    indices = [[urs_models[fragment.pdb_id].dssrnucls[res['DSSR']][2] 
                for res in fragment.members] 
               for fragment in [left, right]]
    
    return min(abs(min(indices[r]) - max(indices[l])) for l, r in [(0, 1), (1, 0)])


def fragment_relation(left, right):
    if left.type == 'WING' and right.type == 'WING' and fragment_sequence_distance(left, right) == 1:
        return 'LC'
    
    return urs_models[left.pdb_id].NuclRelation(left.members[0]['DSSR'], right.members[0]['DSSR'])


def fragment_type(fragment):
    types = urs_models[fragment.pdb_id].NuclSS(fragment.members[0]['DSSR']).split(';')
    
    if types == ['S']:
        return [{'S'}, {}]
    else:
        return [{item[i] for item in types} for i in range(2)]

    
def fragment_length(fragment):
    return len(fragment.members)


def fragment_sequence(fragment, fill_length):
    sequence = []
    
    for res in fragment.members:
        name = res['NAME']
        if len(name) > 1 or name not in 'AUGC':
            name = 'M'
        
        sequence.append(name)
    
    while len(sequence) < fill_length:
        sequence.append('N')
    
    return sequence


def fragment_fragment_distance(left, right):
    return right.global_index - left.global_index

In [9]:
fragment_chains = get_fragment_chains(nonredundant_chain_ids)
len(fragment_chains)

44

In [10]:
chain_lengths = {'{}_{}'.format(fragment_chain[0].pdb_id, fragment_chain[0].ch): len(fragment_chain)
                 for fragment_chain in fragment_chains}

with open('chain_lengths.pickle', 'wb') as outfile:
    pickle.dump(chain_lengths, outfile)

In [11]:
fragments = list(chain.from_iterable(fragment_chains))
len(fragments)

3194

Проверим, что для LR пар фрагментов, таких что пара не (WING, WING), расстояние по сиквенсу между фрагментами не менее 2.

In [12]:
for fragment_chain in fragment_chains:
    for i, left in enumerate(fragment_chain):
        for right in fragment_chain[i + 1:]:
            rel = urs_models[left.pdb_id].NuclRelation(left.members[0]['DSSR'], right.members[0]['DSSR'])
            
            if rel == 'LR' and {left.type, right.type} != {'WING'}:
                assert not fragment_sequence_distance(left, right) <= 2

Посчитаем количество фрагментов с MISS нуклеотидами

In [13]:
sum([any(res['MISS'] for res in fragment.members) for fragment in fragments])

35

In [14]:
sum([all(res['MISS'] for res in fragment.members) for fragment in fragments])

9

Так как таких фрагментов мало, просто выкинем их из выборки

Посчитаем среднее евклидово расстояние между нуклеотидами в цепочках в зависимости от расстояния между нуклеотидами по сиквенсу

In [15]:
def get_distances_on_sequence(sequence_distance, chain_ids):
    distances = []
    
    for pdb_id, chain_id in chain_ids:
        atoms = [get_coords(res) for res in urs_models[pdb_id].chains[chain_id]['RES']]

        for left_i, left in enumerate(atoms):
            right_i = left_i + sequence_distance
            if right_i < len(atoms):
                right = atoms[right_i]

                if len(right) > 0 and len(left) > 0:
                    distances.append(np.min(cdist(left, right, 'euclidean')))
    
    return distances

In [16]:
for i in range(1, 30):
    distances = get_distances_on_sequence(i, nonredundant_chain_ids)
    mean = np.mean(distances)
    std = np.std(distances)
    
    print(i, mean, std)

1 1.6089178747660788 0.18143464158571168
2 5.805022637977844 1.139063344573624
3 8.826294468867898 2.1702721456212735
4 11.375384871050775 3.0911026326345694
5 13.580455084587362 3.9571849052277495
6 15.464713741164214 4.795388865068102
7 17.078913211242185 5.581391402615642
8 18.510487217879657 6.331259718240907
9 19.8926131849052 7.0879635534351415
10 21.27094212767807 7.852585143333483
11 22.634894209667436 8.647659194179155
12 23.986928225656513 9.455061807827107
13 25.30002633220102 10.2850526545864
14 26.548126949172993 11.102171647714874
15 27.701893580894154 11.898910820453164
16 28.778787287934346 12.640643528946892
17 29.777659563660933 13.32678626005696
18 30.6903085631811 13.964265321218692
19 31.51841637571828 14.582341197661746
20 32.28292385420625 15.173370705588935
21 33.008679956457186 15.753404230238889
22 33.713414484486165 16.31071002194597
23 34.410222311584924 16.8511306588473
24 35.09236218716058 17.38296834350296
25 35.75480877975653 17.913647820995998
26 36.385

In [17]:
for i in range(1, 20):
    distances = list(sorted(get_distances_on_sequence(i, nonredundant_chain_ids)))
    print(i, distances[-100])

1 1.6260765664629682
2 7.785276681017853
3 13.03325262549608
4 17.807200622220197
5 22.11585447139675
6 26.0258733571037
7 29.683053987081585
8 33.02704390344376
9 36.30172731427528
10 39.25754657387546
11 41.81045661075709
12 44.10900386542413
13 45.99871032974728
14 47.97641897849397
15 50.27048322823245
16 52.605720687012735
17 54.74848725764026
18 57.127149955165805
19 59.54524326090204


Видим, что по-хорошему, если ставить positive threshold в 8 ангстрем, то нужно брать выборку из фрагментов на расстоянии не менее 10, чтобы ловить реально LR взаимодействия.

Посмотрим теперь на расстояния по сиквенсу в LR фрагментах 

In [18]:
lr_sequence_distances = []

for fragment_chain in tqdm(fragment_chains):
    for i, left in enumerate(fragment_chain):
        for right in fragment_chain[i + 1:]:
            if fragment_relation(left, right) == 'LR':
                lr_sequence_distances.append(fragment_sequence_distance(left, right))

100%|██████████| 44/44 [00:04<00:00, 10.88it/s]


In [19]:
len(lr_sequence_distances)

701042

In [20]:
counter = Counter(lr_sequence_distances)

sum([counter[d] for d in range(2, 15)])

7376

Видим, что пар на малом расстоянии мало, не будем их включать в выборку.

Теперь самая проблемная часть: сиквенсы.

In [21]:
Counter([fragment_length(fragment) for fragment in fragments]).most_common()

[(2, 617),
 (3, 615),
 (4, 570),
 (5, 356),
 (1, 289),
 (6, 257),
 (7, 181),
 (8, 93),
 (9, 66),
 (10, 39),
 (11, 27),
 (12, 23),
 (13, 15),
 (15, 9),
 (14, 6),
 (20, 5),
 (16, 4),
 (19, 3),
 (22, 3),
 (17, 3),
 (18, 2),
 (21, 1),
 (28, 1),
 (23, 1),
 (48, 1),
 (47, 1),
 (25, 1),
 (122, 1),
 (37, 1),
 (26, 1),
 (31, 1),
 (35, 1)]

Видим, что фрагментов длинее 10 мало. Предлагается пока их просто выкинуть, а не обрезать, потому что скорее всего у них какая-то отличная от коротких геометрия.

In [22]:
FragmentData = namedtuple('FragmentData', 'type sequence length')
FragmentPairData = namedtuple('FragmentPairData', 'lefts rights relation fragment_distance sequence_distance')

FragmentPair = namedtuple('FragmentPair', 'data distance pdb_id chain_id left right')

In [23]:
def check_fragment(fragment, fragment_length_threshold):
    return all([
        not any(res['MISS'] for res in fragment.members),
        fragment_length(fragment) <= fragment_length_threshold
    ])


def check_fragment_pair(left, right, fragment_sequence_distance_threshold):
    return all([
        fragment_relation(left, right) == 'LR',
        fragment_sequence_distance(left, right) >= fragment_sequence_distance_threshold
    ])


def make_fragment_data(fragment, fragment_length_threshold):
    return FragmentData(
        type=fragment_type(fragment), 
        sequence=fragment_sequence(fragment, fragment_length_threshold), 
        length=fragment_length(fragment))


def make_fragment_pair_data(lefts, rights, fragment_length_threshold):
    center = len(lefts) // 2
    
    return FragmentPairData(
        lefts=[make_fragment_data(item, fragment_length_threshold) for item in lefts],
        rights=[make_fragment_data(item, fragment_length_threshold) for item in rights],
        relation=[[fragment_relation(l, r) for l in lefts] for r in rights], 
        fragment_distance=fragment_fragment_distance(lefts[center], rights[center]),
        sequence_distance=fragment_sequence_distance(lefts[center], rights[center]))


def make_data(fragment_chains, fragment_length_threshold, fragment_sequence_distance_threshold, num_neighbors):
    data = []
    
    for fragment_chain in tqdm(fragment_chains):
        for i, left in enumerate(fragment_chain):
            for j, right in enumerate(fragment_chain):
                ij_ok = all(num_neighbors <= index < len(fragment_chain) - num_neighbors for index in [i, j])
                if i != j and ij_ok:
                    lefts, rights = [fragment_chain[index - num_neighbors:index + num_neighbors + 1] 
                                     for index in [i, j]]
                    fragments_ok = all(check_fragment(fragment, fragment_length_threshold) 
                                       for fragment in chain.from_iterable([lefts, rights]))
                    fragment_pairs_ok = check_fragment_pair(left, right, fragment_sequence_distance_threshold)
                    
                    if fragments_ok and fragment_pairs_ok:
                        data.append(FragmentPair(
                            data=make_fragment_pair_data(lefts, rights, fragment_length_threshold),
                            distance=fragment_distance(left, right),
                            pdb_id=left.pdb_id, 
                            chain_id=left.ch, 
                            left=i, 
                            right=j))
    
    return data

In [24]:
data = make_data(fragment_chains, 10, 15, 1)

100%|██████████| 44/44 [04:36<00:00,  6.28s/it]


In [25]:
Counter([(item.pdb_id, item.chain_id) for item in data])

Counter({('2qwy', 'C'): 32,
         ('4wfl', 'A'): 368,
         ('4rge', 'C'): 110,
         ('6dtd', 'C'): 2,
         ('1f1t', 'A'): 4,
         ('3rw6', 'H'): 66,
         ('4enc', 'A'): 24,
         ('4rmo', 'B'): 2,
         ('5x2g', 'B'): 14,
         ('6qzp', 'L5'): 446592,
         ('1kh6', 'A'): 4,
         ('4jf2', 'A'): 68,
         ('4y4o', '2A'): 328676,
         ('1l9a', 'B'): 392,
         ('6qzp', 'L8'): 760,
         ('6dlr', 'A'): 506,
         ('4yaz', 'R'): 266,
         ('3la5', 'A'): 108,
         ('3pdr', 'X'): 962,
         ('5fjc', 'A'): 310,
         ('5u3g', 'B'): 148,
         ('4qlm', 'A'): 104,
         ('2z75', 'B'): 326,
         ('4far', 'A'): 4484,
         ('3k1v', 'A'): 2,
         ('4y4o', '1a'): 110206,
         ('4frg', 'B'): 180,
         ('4prf', 'B'): 64,
         ('4p95', 'A'): 1400,
         ('3f2q', 'X'): 316,
         ('3e5c', 'A'): 40,
         ('1u9s', 'A'): 1014,
         ('5kpy', 'A'): 20,
         ('4lvw', 'A'): 168,
         ('3npq'

Так как теперь у нас нет верхней границы на расстояние по сиквенсу для пар, то получили сильный дисбаланс по цепочкам и pdb_id. В цепочках ('6qzp', 'L5'), ('4y4o', '2A'), ('6qzp', 'S2'), ('4y4o', '1a') лежат почти все пары. Из-за этого получаются проблемы с кроссвалидацией по количеству фолдов, превышающему 2, если мы хотим чтобы все пары одной цепочки попадали в один и тот же фолд.

In [26]:
np.mean([item.distance < 8 for item in data if item.pdb_id not in ['4y4o', '6qzp']])

0.09967647511939609

In [27]:
np.mean([item.distance < 8 for item in data])

0.008621173773196515

Дисбаланс в коротких цепочках тоже ожидаемо меньше.

Сохраним датасет, потом преобразуем в матрицу для random forest.

In [28]:
BATCH_SIZE = 100000

In [29]:
for batch_index, i in enumerate(range(0, len(data), BATCH_SIZE)):
    with open(os.path.join(DATA_DIR, 'batch_{}'.format(batch_index)), 'wb') as outfile:
        pickle.dump(data[i:i + BATCH_SIZE], outfile, protocol=pickle.HIGHEST_PROTOCOL)

In [30]:
def generate_data():
    for i in count():
        fpath = os.path.join(DATA_DIR, 'batch_{}'.format(i))
        
        if os.path.exists(fpath):
            with open(fpath, 'rb') as infile:
                yield from pickle.load(infile)
        else:
            break

In [31]:
fragment_pair = next(generate_data())
fragment_pair

FragmentPair(data=FragmentPairData(lefts=[FragmentData(type=[{'H'}, {'I'}], sequence=['G', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N'], length=1), FragmentData(type=[{'S'}, {}], sequence=['C', 'G', 'C', 'G', 'G', 'C', 'N', 'N', 'N', 'N'], length=6), FragmentData(type=[{'H'}, {'P'}], sequence=['G', 'A', 'U', 'U', 'U', 'A', 'A', 'N', 'N', 'N'], length=7)], rights=[FragmentData(type=[{'S'}, {}], sequence=['U', 'U', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N'], length=2), FragmentData(type=[{'S'}, {}], sequence=['G', 'C', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N'], length=2), FragmentData(type=[{'H'}, {'P'}], sequence=['A', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N', 'N'], length=1)], relation=[['LR', 'LR', 'LC'], ['LR', 'LR', 'LC'], ['LR', 'LC', 'SM']], fragment_distance=5, sequence_distance=15), distance=7.462208118244893, pdb_id='2qwy', chain_id='C', left=1, right=6)

In [32]:
class Extractor:
    def extract(self, fragment_pair):
        pass
        
    def describe(self, fragment_pair):
        pass


def extract(fragment_pair, extractors, method='extract'):
    return list(chain.from_iterable([getattr(extractor, method)(fragment_pair) for extractor in extractors]))

In [33]:
def onehot(items, possible_items):
    result = [0] * len(possible_items)
    for item in items:
        result[possible_items.index(item)] = 1

    return result


class FragmentExtractor(Extractor):
    BASES = 'AUGCMN'
    FIRST_TYPES = 'SHBIJ'
    SECOND_TYPES = 'CIP'
    
    def extract(self, fragment_pair):
        result = []
        
        for fragment_data in chain.from_iterable([fragment_pair.data.lefts, fragment_pair.data.rights]):
            result.append(fragment_data.length)
            
            for base in fragment_data.sequence:
                result.extend(onehot([base], self.BASES))
            
            first, second = fragment_data.type
                
            result.extend(onehot(first, self.FIRST_TYPES))
            result.extend(onehot(second, self.SECOND_TYPES))
        
        return result
    
    def describe(self, fragment_pair):
        result = []
        
        for lr in ['lefts', 'rights']:
            for fragment_index, fragment_data in enumerate(fragment_pair.data.lefts):
                result.append('sequence {}[{}] length'.format(lr, fragment_index))
                
                for sequence_index, _ in enumerate(fragment_data.sequence):
                    for base in self.BASES:
                        result.append('sequence {}[{}][{}] == {}'.format(lr, fragment_index, sequence_index, base))
                
                result.extend(['fragment type {}[{}] for first {}'.format(lr, fragment_index, letter) 
                               for letter in self.FIRST_TYPES])
                result.extend(['fragment type {}[{}] for second {}'.format(lr, fragment_index, letter) 
                               for letter in self.SECOND_TYPES])
        
        return result


class TinyExtractor(Extractor):
    def extract(self, fragment_pair):
        return [fragment_pair.data.fragment_distance, fragment_pair.data.sequence_distance]
    
    def describe(self, fragment_pair):
        return ['fragment_distance', 'sequence_distance']


class RelationExtractor(Extractor):
    RELATIONS = ['SM', 'LC', 'LR']
    
    def extract(self, fragment_pair):
        result = []
        
        for rel in chain.from_iterable(fragment_pair.data.relation):
            result.extend(onehot([rel], self.RELATIONS))
        
        return result
    
    def describe(self, fragment_pair):
        result = []
        
        for i, rels in enumerate(fragment_pair.data.relation):
            for j, _ in enumerate(rels):
                for rel_type in self.RELATIONS:
                    result.append('rel_matrix[{}][{}] == {}'.format(i, j, rel_type))
                
        return result


class BowFragmentExtractor(Extractor):
    BASES = 'AUGCMN'
    FIRST_TYPES = 'SHBIJ'
    SECOND_TYPES = 'CIP'
    
    def extract(self, fragment_pair):
        result = []
        
        for fragment_data in chain.from_iterable([fragment_pair.data.lefts, fragment_pair.data.rights]):
            result.append(fragment_data.length)
            
            bow = [0] * len(self.BASES)
            for base in fragment_data.sequence:
                bow[self.BASES.index(base)] += 1
            result.extend(bow)
            
            first, second = fragment_data.type
            first_onehot, second_onehot = [0] * len(self.FIRST_TYPES), [0] * len(self.SECOND_TYPES)
            
            for item in first:
                first_onehot[self.FIRST_TYPES.index(item)] = 1
                
            for item in second:
                second_onehot[self.SECOND_TYPES.index(item)] = 1
                
            result.extend(first_onehot)
            result.extend(second_onehot)
        
        return result
    
    def describe(self, fragment_pair):
        result = []
        
        for lr in ['lefts', 'rights']:
            for fragment_index, _ in enumerate(fragment_pair.data.lefts):
                result.append('sequence {}[{}] length'.format(lr, fragment_index))
                for base in self.BASES:
                    result.append('sequence {}[{}] count of {}'.format(lr, fragment_index, base))
                
                result.extend(['fragment type {}[{}] for first {}'.format(lr, fragment_index, letter) 
                               for letter in self.FIRST_TYPES])
                result.extend(['fragment type {}[{}] for second {}'.format(lr, fragment_index, letter) 
                               for letter in self.SECOND_TYPES])
        
        return result


class BowRelationExtractor(Extractor):
    RELATIONS = ['SM', 'LC', 'LR']
    
    def extract(self, fragment_pair):
        bow = [0] * 3
        for rel in chain.from_iterable(fragment_pair.data.relation):
            bow[self.RELATIONS.index(rel)] += 1
        
        return bow
    
    def describe(self, fragment_pair):
        return ['relation count {}'.format(rel) for rel in self.RELATIONS]

In [41]:
extractors = [BowFragmentExtractor(), TinyExtractor(), BowRelationExtractor()]
description = extract(fragment_pair, extractors, method='describe')

In [42]:
len(description)

95

In [43]:
digitized_features = [extract(fragment_pair, extractors) for fragment_pair in tqdm(generate_data())]

1032806it [00:54, 18901.12it/s]


In [44]:
target_distance = []
pdb_ids = []
chain_ids = []

for fragment_pair in tqdm(generate_data()):
    target_distance.append(fragment_pair.distance)
    pdb_ids.append(fragment_pair.pdb_id)
    chain_ids.append(fragment_pair.chain_id)

1032806it [00:39, 26441.88it/s]


In [45]:
digitized_features = np.array(digitized_features, dtype=np.int16)
target_distance = np.array(target_distance)
pdb_ids = np.array(pdb_ids)
chain_ids = np.array(chain_ids)

In [46]:
def save_dataset(digitized_features, target_distance, pdb_ids, chain_ids, description, prefix=''):
    np.save(os.path.join(DATA_DIR, '{}_features.npy'.format(prefix)), digitized_features)
    np.save(os.path.join(DATA_DIR, '{}_distance.npy'.format(prefix)), target_distance)
    np.save(os.path.join(DATA_DIR, '{}_pdb_ids.npy'.format(prefix)), pdb_ids)
    np.save(os.path.join(DATA_DIR, '{}_chain_ids.npy'.format(prefix)), chain_ids)
    
    with open(os.path.join(DATA_DIR, '{}_description.pickle'.format(prefix)), 'wb') as outfile:
        pickle.dump(description, outfile)

In [47]:
save_dataset(digitized_features, target_distance, pdb_ids, chain_ids, description, prefix='bow')