In [1]:
import pandas as pd
import numpy as np
import os
import sys
import polars as pl
import json
from joblib import Parallel, delayed
import deepchem
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

# Configure Polars 
cfg = pl.Config()
cfg.set_tbl_rows(20)
cfg.set_tbl_cols(50)
from sklearn.model_selection import StratifiedGroupKFold, GroupKFold, KFold, GroupShuffleSplit, ShuffleSplit, StratifiedKFold
import gc
import random
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
import lap
from typing import List, Tuple, Union
import lap
from matplotlib import pyplot as plt
from rdkit.Chem import MACCSkeys
from deepchem.splits.splitters import _generate_scaffold

No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
Skipped loading some Tensorflow models, missing a dependency. No module named 'tensorflow'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'deepchem.models.torch_models' (/opt/conda/lib/python3.10/site-packages/deepchem/models/torch_models/__init__.py)
Skipped loading some Jax models, missing a dependency. No module named 'jax'
Skipped loading some PyTorch models, missing a dependency. No module named 'tensorflow'


In [2]:
with open('/home/dangnh36/datasets/competitions/leash_belka/processed/meta/building_blocks.json', 'r') as f:
    bbs_meta = json.load(f)
    
for k, v in bbs_meta.items():
    print(k, '-->', len(v)) 

train_bbs --> 1145
train_bb1s --> 271
train_bb2s --> 693
train_bb3s --> 872
test_bb1s --> 341
test_bb2s --> 1140
test_bb3s --> 1389
test_bbs --> 2110
all_bbs --> 2110


In [3]:
train_scaffolds = pl.scan_csv('/home/dangnh36/datasets/competitions/leash_belka/processed/train_scaffold.csv').collect()
train_df = pl.scan_csv('/home/dangnh36/datasets/competitions/leash_belka/processed/train_v2.csv')\
    .select(
        pl.col('molecule'),
        pl.col('bb1', 'bb2', 'bb3').cast(pl.UInt16),
        pl.col('BRD4', 'HSA', 'sEH').cast(pl.UInt8),
        scaffold_idx = train_scaffolds['scaffold_idx'],
    )\
    .collect()
print(train_df.shape, train_df.estimated_size('mb'))
train_df

(98415610, 8) 8601.91998577118


molecule,bb1,bb2,bb3,BRD4,HSA,sEH,scaffold_idx
str,u16,u16,u16,u8,u8,u8,i64
"""C#CCOc1ccc(CNc…",1640,1653,765,0,0,0,4283326
"""C#CCOc1ccc(CNc…",1640,1653,205,0,0,0,4486646
"""C#CCOc1ccc(CNc…",1640,1653,1653,0,0,0,1015728
"""C#CCOc1ccc(CNc…",1640,1653,146,0,0,0,5301385
"""C#CCOc1ccc(CNc…",1640,1653,439,0,0,0,5301385
"""C#CCOc1ccc(CNc…",1640,1653,196,0,0,0,5301385
"""C#CCOc1ccc(CNc…",1640,1653,253,0,0,0,5301385
"""C#CCOc1ccc(CNc…",1640,1653,1219,0,0,0,5301385
"""C#CCOc1ccc(CNc…",1640,1653,604,0,0,0,543172
"""C#CCOc1ccc(CNc…",1640,1653,121,0,0,0,2571428


In [4]:
test_scaffolds = pl.scan_csv('/home/dangnh36/datasets/competitions/leash_belka/processed/test_scaffold.csv').collect()
mol_groups = pl.scan_csv('/home/dangnh36/datasets/competitions/leash_belka/processed/test_v4.csv').select(pl.col('mol_group').cast(pl.UInt8)).collect()
test_df = pl.scan_csv('/home/dangnh36/datasets/competitions/leash_belka/processed/test_v2.csv')\
    .select(
        pl.col('id','molecule'),
        pl.col('bb1', 'bb2', 'bb3').cast(pl.UInt16),
        pl.col('protein'),
        scaffold_idx = test_scaffolds['scaffold_idx']
    ).group_by('molecule').first().sort('id').with_columns(mol_group = mol_groups['mol_group']).collect()
print(test_df.shape, test_df.estimated_size('mb'))
test_df

(878022, 8) 85.05668830871582


molecule,id,bb1,bb2,bb3,protein,scaffold_idx,mol_group
str,i64,u16,u16,u16,str,i64,u8
"""C#CCCC[C@H](Nc…",295246830,1989,409,409,"""BRD4""",2217250,2
"""C#CCCC[C@H](Nc…",295246833,1989,409,1012,"""BRD4""",602641,2
"""C#CCCC[C@H](Nc…",295246836,1989,409,1722,"""BRD4""",4502748,2
"""C#CCCC[C@H](Nc…",295246839,1989,409,1078,"""BRD4""",3936208,2
"""C#CCCC[C@H](Nc…",295246842,1989,409,605,"""BRD4""",4550856,2
"""C#CCCC[C@H](Nc…",295246845,1989,409,521,"""BRD4""",4414349,2
"""C#CCCC[C@H](Nc…",295246848,1989,409,41,"""BRD4""",5367715,2
"""C#CCCC[C@H](Nc…",295246851,1989,409,1826,"""BRD4""",1422452,2
"""C#CCCC[C@H](Nc…",295246854,1989,409,1970,"""BRD4""",4752663,2
"""C#CCCC[C@H](Nc…",295246857,1989,409,598,"""BRD4""",5758930,2


In [5]:
# for i in range(10000):
#     s = _generate_scaffold(train_df[0, 'molecule'])
#     assert 'y' not in s

In [6]:
all_bbs = bbs_meta['all_bbs']
train_bb1s = bbs_meta['train_bb1s']
train_bb2s = bbs_meta['train_bb2s']
train_bb3s = bbs_meta['train_bb3s']
train_bb23s = sorted(list(set(train_bb2s + train_bb3s)))
len(train_bb1s), len(train_bb23s)

(271, 874)

In [7]:
train_bb1s_scaffolds = [_generate_scaffold(bb) for bb in train_bb1s]
train_bb23s_scaffolds = [_generate_scaffold(bb) for bb in train_bb23s]
print('Molecule:', len(train_bb1s), len(train_bb23s))
print('Scaffold:', len(set(train_bb1s_scaffolds)), len(set(train_bb23s_scaffolds)))

Molecule: 271 874
Scaffold: 62 270


In [36]:
def split_to_chunk(arr, num_chunks):
    num_per_chunk = len(arr) // num_chunks
    residual = len(arr) % num_chunks
    ret = []
    cur_start = 0
    for i in range(num_chunks):
        cur_end = cur_start + num_per_chunk
        if residual > 0:
            cur_end += 1
            residual -= 1
        ret.append(arr[cur_start:cur_end])
        cur_start = cur_end
    return ret


def make_combination_idxs(n, num_combine):
    init_arr = list(range(n))
    arrs = []
    for i in range(num_combine):
        arrs.append(init_arr[i:] + init_arr[:i])
    return [[arrs[i][j] for i in range(num_combine)] for j in range(n)]

def make_grid(bb1_splits, bb23_splits, bb1_cell_num_splits = 2, bb23_cell_num_splits = 2, shuffle = True):
    """
    Args:
        bb1_splits: List[List[int]]
    """
    if shuffle:
        random.seed(42)
        from copy import deepcopy
        bb1_splits = deepcopy(bb1_splits)
        bb23_splits = deepcopy(bb23_splits)
        random.shuffle(bb1_splits)
        random.shuffle(bb23_splits)
    
    bb1_grid_idxs = make_combination_idxs(len(bb1_splits), bb1_cell_num_splits)
    bb23_grid_idxs = make_combination_idxs(len(bb23_splits), bb23_cell_num_splits)
    
    print('BB1 GRID:', bb1_grid_idxs)
    print('BB23 GRID:', bb23_grid_idxs)
    
    ret = []
    grid_idx = -1
    for i, bb1_split_idxs in enumerate(bb1_grid_idxs):
        for j, bb23_split_idxs in enumerate(bb23_grid_idxs):
            grid_idx += 1
            bb1_idxs = []
            bb23_idxs = []
            for _split_idx in bb1_split_idxs:
                bb1_idxs.extend(bb1_splits[_split_idx])
            for _split_idx in bb23_split_idxs:
                bb23_idxs.extend(bb23_splits[_split_idx])
            ret.append({
                'grid_idx': grid_idx,
                'bb1_grid_idx': i,
                'bb23_grid_idx': j,
                'bb1_split_idxs': bb1_split_idxs,
                'bb23_split_idxs': bb23_split_idxs,
                'bb1_idxs': bb1_idxs,
                'bb23_idxs': bb23_idxs,
                'num_bb1s': len(bb1_idxs),
                'num_bb23s': len(bb23_idxs),
                'expected_num_samples': len(bb1_idxs) * len(bb23_idxs) * (len(bb23_idxs) + 1) / 2
            })
                     
    return ret
    
print(split_to_chunk(list(range(19)), 4))
print(make_combination_idxs(4, 2))
print(make_combination_idxs(4, 3))

[[0, 1, 2, 3, 4], [5, 6, 7, 8, 9], [10, 11, 12, 13, 14], [15, 16, 17, 18]]
[[0, 1], [1, 2], [2, 3], [3, 0]]
[[0, 1, 2], [1, 2, 3], [2, 3, 0], [3, 0, 1]]


BB1 scaffold distribution: [87, 47, 20, 13, 9, 8, 5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]

In [57]:
# scaffold_gkf, scaffold, 
BB_SPLIT_METHOD = 'scaffold_gkf'
BB1_SPLITS = 5
BB23_SPLITS = 8
BB1_CELL_LEN = 2
BB23_CELL_LEN= 3
SHUFFLE = True

ret = {
    'split_method': BB_SPLIT_METHOD,
    'bb1_splits': BB1_SPLITS,
    'bb23_splits': BB23_SPLITS,
    'bb1_cell_len': BB1_CELL_LEN,
    'bb23_cell_len': BB23_CELL_LEN,
    'shuffle': SHUFFLE
}

bb1s = []
bb23s = []
if BB_SPLIT_METHOD == 'scaffold':
    from collections import Counter
    bb1_scaf_counter = dict(Counter(train_bb1s_scaffolds))
    bb23_scaf_counter = dict(Counter(train_bb23s_scaffolds))
    bb1_list = [(train_bb1s[i], train_bb1s_scaffolds[i], bb1_scaf_counter[train_bb1s_scaffolds[i]]) for i in range(len(train_bb1s))]
    bb23_list = [(train_bb23s[i], train_bb23s_scaffolds[i], bb23_scaf_counter[train_bb23s_scaffolds[i]]) for i in range(len(train_bb23s))]
    bb1_list.sort(key = lambda x: (x[2], x[0]), reverse=True)
    bb23_list.sort(key = lambda x: (x[2], x[0]), reverse=True)
#     print(bb1_list)
#     print(bb23_list)
    bb1s = split_to_chunk([e[0] for e in bb1_list], BB1_SPLITS)
    bb23s = split_to_chunk([e[0] for e in bb23_list], BB23_SPLITS)
elif BB_SPLIT_METHOD == 'scaffold_gkf':
    from sklearn.model_selection import GroupKFold   
    splitter1 = GroupKFold(n_splits=BB1_SPLITS)
    for i, (train_idxs, val_idxs) in enumerate(splitter1.split(train_bb1s, train_bb1s, train_bb1s_scaffolds)):
        bb1s.append([train_bb1s[_j] for _j in val_idxs])
        
    splitter23 = GroupKFold(n_splits=BB23_SPLITS)
    for i, (train_idxs, val_idxs) in enumerate(splitter23.split(train_bb23s, train_bb23s, train_bb23s_scaffolds)):
        bb23s.append([train_bb23s[_j] for _j in val_idxs])
elif BB_SPLIT_METHOD == 'random':
    pass
else:
    raise ValueError

split_bb1_idxs = []
split_bb23_idxs = []

for split in bb1s:
    split_bb1_idxs.append([all_bbs.index(e) for e in split])
    
for split in bb23s:
    split_bb23_idxs.append([all_bbs.index(e) for e in split])

print([len(split) for split in split_bb1_idxs], sum([len(split) for split in split_bb1_idxs]))
print([len(split) for split in split_bb23_idxs], sum([len(split) for split in split_bb23_idxs]))

# print(split_bb1_idxs)
# print(split_bb23_idxs)

ret['splits'] = {
    'bb1': split_bb1_idxs,
    'bb23': split_bb23_idxs
}
ret['grid'] = make_grid(split_bb1_idxs, split_bb23_idxs, BB1_CELL_LEN, BB23_CELL_LEN, shuffle=SHUFFLE)
print('GRID LEN:', len(ret['grid']))
expected_samples_per_cell = min([e['expected_num_samples'] for e in ret['grid']])
print('Expected len:', round(expected_samples_per_cell / 1_000_000, 1), 'M')
# print(ret)

[87, 47, 46, 46, 45] 271
[126, 117, 106, 105, 105, 105, 105, 105] 874
BB1 GRID: [[0, 1], [1, 2], [2, 3], [3, 4], [4, 0]]
BB23 GRID: [[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6], [5, 6, 7], [6, 7, 0], [7, 0, 1]]
GRID LEN: 40
Expected len: 4.5 M


In [58]:
SAVE_DIR = '/home/dangnh36/datasets/competitions/leash_belka/processed/cv/bb_grid/'
os.makedirs(SAVE_DIR, exist_ok=True)
SAVE_PATH = os.path.join(SAVE_DIR,
  f'{BB_SPLIT_METHOD}_{BB1_SPLITS}_{BB23_SPLITS}_{BB1_CELL_LEN}_{BB23_CELL_LEN}{"_shuffle" if SHUFFLE else ""}_{round(expected_samples_per_cell / 1_000_000, 1)}M.json')

print(SAVE_PATH)
with open(SAVE_PATH, 'w') as f:
    json.dump(ret, f, indent=4)

/home/dangnh36/datasets/competitions/leash_belka/processed/cv/bb_grid/scaffold_gkf_5_8_2_3_shuffle_4.5M.json
