# Finetuning after getting an iteration of experiment results

In [1]:
import os
import pickle
from pathlib import Path

import lmdb
import pandas as pd

In [2]:
class LMDBDataset:
    def __init__(self, db_path):
        self.db_path = db_path
        assert os.path.isfile(self.db_path), "{} not found".format(self.db_path)
        self.env = self.connect_db(self.db_path)
        with self.env.begin() as txn:
            self._keys = list(txn.cursor().iternext(values=False))

    def connect_db(self, lmdb_path, save_to_self=False):
        env = lmdb.open(
            lmdb_path,
            subdir=False,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
            max_readers=256,
        )
        if not save_to_self:
            return env
        else:
            self.env = env
            
    def __len__(self):
        return len(self._keys)

    def __getitem__(self, idx):
        # TDOO:
        # idx = 1
        datapoint_pickled = self.env.begin().get(bytes(str(idx), 'utf-8'))
        data = pickle.loads(datapoint_pickled)
        return data

### 1. making new lmdb dataset containing the lipids and results of the previous iterations

This can takes a file of such result data, that usually can be obtained from the dashboard visualize_result page

The format of the lmdb to be created can be found in `model/pretrain/notebooks/test_lmdb.ipynb`. Essentially, every entry should contains the following fields: 'atoms', 'coordinates', 'mol', 'smi', 'target'


In [3]:
exp_res_file = Path("exp1001_export.csv")
library_folder = Path("/datasets/cellxgene/3d_molecule_data/220k-lib/lmdb")
# the 220k lmdbs are split into 10 subfolders library_folder/0, library_folder/1, ..., library_folder/9
library_lmdb_paths = [library_folder / str(i) / "test.lmdb" for i in range(10)]
library_smi_paths = [library_folder / str(i) / "smi_name_list.txt" for i in range(10)]
assert all([p.exists() for p in library_lmdb_paths])
assert all([p.exists() for p in library_smi_paths])

In [4]:
exp_res_df = pd.read_csv(exp_res_file, index_col=0)

exp_res_df

Unnamed: 0,smiles,amine,isocyanide,aldehyde,carboxylic_acid,max,mean,std,reading.66f1cc436f02a6b43df5b162,reading.66ef37371a176fc8bea75a5e,reading.66f1cc436f02a6b43df5b160,reading.66ef37371a176fc8bea75a5c,reading.66ef2feca83176d144d82d36,reading.66d32a929d365f235a91b104,reading.66cef8347d9aecc0ed02fa79,reading.66c61458d27c8b5806cea711,reading.66d0ee9c819f784009d33374,reading.66ccf7214f75e1ea9c980576,reading.66c61458d27c8b5806cea70f,reading.66bc532a880da25d86ed778d
A4B12C5D25,CCCCCCCCCCCCCCCCCC(=O)N(CCN1CCNCC1)C(CCCCCCCCC...,A4,B12,C5,D25,16.314794,11.075759,4.798635,,,,,,,,,4.776386,10.772577,12.439281,16.314794
A17B10C6D26,CCCCCCCC/C=C\CCCCCCCC(=O)N(CCCN1CCCC1)C(CCCCCC...,A17,B10,C6,D26,13.555879,12.170939,1.245765,,,,,,13.555879,11.14174,11.815198,,,,
A8B4C3D36,C=CCCCCCCCCCOC(=O)CCCCC(=O)N(CCN(CCN)CCN)C(CCC...,A8,B4,C3,D36,13.140386,7.533560,6.418891,,,13.140386,0.531959,8.928336,,,,,,,
A4B7C6D1,CCCCCCCCCCCC(C(=O)NC1CCCCC1)N(CCN1CCNCC1)C(=O)...,A4,B7,C6,D1,12.901808,5.470778,5.953271,,,,,,,,,7.656558,0.456265,0.868483,12.901808
A11B6C12D15,CCCCCCCCCCCCCCCC(C(=O)NC1CCCC1)N(CCCN(C)CCCN)C...,A11,B6,C12,D15,12.706683,10.412705,2.052349,,,9.780870,8.750561,12.706683,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
A17B12C9D9,CCCCCCCCC(CCCCCC)C(=O)OCCCCCC(C(=O)NC12C[C@H]3...,A17,B12,C9,D9,0.323954,0.179873,0.164916,,,0.215665,0.323954,0.000000,,,,,,,
A20B5C10D5,CCCCCCC(CCCC)C(=O)OCCCCCC(C(=O)NCC(=O)OCC)N(C(...,A20,B5,C10,D5,0.228568,0.114284,0.161622,0.000000,0.228568,,,,,,,,,,
A1B9C1D31,CCCCCCC(CCC)OC(=O)CCCCC(=O)N(CCN(C)C)C(CCCCC)C...,A1,B9,C1,D31,0.213804,0.071268,0.123440,,,0.000000,0.213804,0.000000,,,,,,,
A24B8C5D21,C=CCCCCCCCCC(=O)N(c1nc2ccccc2[nH]1)C(CCCCCCCCC...,A24,B8,C5,D21,0.206409,0.193546,0.018192,0.180682,0.206409,,,,,,,,,,


In [5]:
exp_res = [
    {
        "smiles": row["smiles"],
        "target": row["max"],
    }
    for _, row in exp_res_df.iterrows()
]

# load the library lmdbs
library_datasets = [LMDBDataset(str(p)) for p in library_lmdb_paths]

# make smiles to idx mapping, 10 mappings for 10 subfolders
smi_to_idx_mappings = []
for smi_path in library_smi_paths:
    with open(smi_path, "r") as f:
        smi_to_idx = {smi.strip(): idx for idx, smi in enumerate(f)}
    smi_to_idx_mappings.append(smi_to_idx)
print(f"number of smiles in mappings: {sum([len(m) for m in smi_to_idx_mappings])}")


number of smiles in mappings: 221184


In [6]:
# find the folder num and idx in the folder for each exp_res
# and get the corresponding molecule info
for exp_res_item in exp_res:
    smi = exp_res_item["smiles"]
    for folder_num, smi_to_idx in enumerate(smi_to_idx_mappings):
        if smi in smi_to_idx:
            idx = smi_to_idx[smi]
            break
    else:
        raise ValueError(f"smiles {smi} not found in any of the mappings")
    dataset = library_datasets[folder_num]
    data = dataset[idx]
    assert data["smi"].strip() == smi, f"folder {folder_num} idx {idx}: {data['smi']} != {smi}"
    exp_res_item["folder_num"] = folder_num
    exp_res_item["idx"] = idx
    exp_res_item["atoms"] = data["atoms"]
    exp_res_item["coordinates"] = data["coordinates"]
    # exp_res_item["mol"] = data["mol"]

In [7]:
exp_res_to_save = [
    {
        "atoms": item["atoms"],
        "coordinates": item["coordinates"],
        # "mol": item["mol"],
        "smi": item["smiles"],
        "target": item["target"],
    }
    for item in exp_res
]

# save the results to new lmdb
output_lmdb_path = exp_res_file.with_suffix(".lmdb")
if output_lmdb_path.exists():
    raise ValueError(f"{output_lmdb_path} already exists")
env = lmdb.open(
    str(output_lmdb_path),
    subdir=False,
    map_size=1099511627776 * 2,
    readonly=False,
    meminit=False,
    map_async=True,
    max_dbs=0,
    lock=False,
    max_readers=1,
)
with env.begin(write=True) as txn:
    for idx, exp_res_item in enumerate(exp_res_to_save):
        txn.put(str(idx).encode(), pickle.dumps(exp_res_item))
print(f"results saved to {output_lmdb_path}")

results saved to exp1001_export.lmdb


2. Finetuning the model with k-fold cross validation