In [15]:
from alphadock import utils
from alphadock import features
from tqdm import tqdm
from path import Path
from collections import Counter

In [28]:
train = utils.read_json('data/train_split/train_12k.json')
valid = utils.read_json('data/train_split/valid_12k.json')

In [29]:
len(train['cases']), len(valid['cases'])

(11491, 305)

In [14]:
case_dicts = {x.dirname().basename(): utils.read_json(x) for x in Path('data/cases').glob('*/case.json')}

In [30]:
junk_chemids = set(['SO4', 'GOL', 'EDO', 'PO4', 'ACT', 'HEM', 'PEG', 'DMS', 'MPD', 'ACE', 'NAG'])

cases_filt = []
for case in tqdm(train['cases']):
    if set(case['group_name'].split('_')).issubset(junk_chemids):
        continue
    case_dict = case_dicts[case['case_name']]
    group_dict = {x['name']: x for x in case_dict['ligand_groups']}[case['group_name']]
    group_size = sum([x['num_heavy_atoms'] for x in group_dict['ligands']])
    if group_size < 6:
        continue
    cases_filt.append(case)
    
train['cases'] = cases_filt
seqclus_counts = Counter([case_dicts[x['case_name']]['seqclus30'] for x in cases_filt])
for case in train['cases']:
    case['seqclus_size'] = seqclus_counts[case_dicts[case['case_name']]['seqclus30']]

100%|██████████| 11491/11491 [00:00<00:00, 299328.93it/s]


In [31]:
cases_filt = []
for case in tqdm(valid['cases']):
    if set(case['group_name'].split('_')).issubset(junk_chemids):
        continue
    case_dict = case_dicts[case['case_name']]
    group_dict = {x['name']: x for x in case_dict['ligand_groups']}[case['group_name']]
    group_size = sum([x['num_heavy_atoms'] for x in group_dict['ligands']])
    if group_size < 6:
        continue
    cases_filt.append(case)
    
valid['cases'] = cases_filt
seqclus_counts = Counter([case_dicts[x['case_name']]['seqclus30'] for x in cases_filt])
for case in valid['cases']:
    case['seqclus_size'] = seqclus_counts[case_dicts[case['case_name']]['seqclus30']]

100%|██████████| 305/305 [00:00<00:00, 442804.68it/s]


In [32]:
utils.write_json(train, 'data/train_split/train_12k_cleaned.json')
utils.write_json(valid, 'data/train_split/valid_12k_cleaned.json')

In [33]:
len(train['cases'])

8904

In [34]:
train['cases'][:100]

[{'case_name': '10GS_A', 'group_name': 'VWW', 'seqclus_size': 44},
 {'case_name': '11GS_A', 'group_name': 'EAA_GSH', 'seqclus_size': 44},
 {'case_name': '12AS_A', 'group_name': 'AMP', 'seqclus_size': 1},
 {'case_name': '12GS_B', 'group_name': '0HH', 'seqclus_size': 44},
 {'case_name': '13GS_B', 'group_name': 'GSH_SAS', 'seqclus_size': 44},
 {'case_name': '17GS_A', 'group_name': 'GTX', 'seqclus_size': 44},
 {'case_name': '18GS_A', 'group_name': 'GDN', 'seqclus_size': 44},
 {'case_name': '19GS_A', 'group_name': 'GSH', 'seqclus_size': 44},
 {'case_name': '1A28_A', 'group_name': 'STR', 'seqclus_size': 110},
 {'case_name': '1A2B_A', 'group_name': 'GSP_MG', 'seqclus_size': 60},
 {'case_name': '1A42_A', 'group_name': 'BZU_ZN', 'seqclus_size': 256},
 {'case_name': '1A4L_A', 'group_name': 'DCF_ZN', 'seqclus_size': 10},
 {'case_name': '1A4M_A', 'group_name': 'PRH_ZN', 'seqclus_size': 10},
 {'case_name': '1A52_A', 'group_name': 'EST', 'seqclus_size': 495},
 {'case_name': '1A69_A', 'group_name': '