In [None]:
# 设置python的工作路径
import os

os.chdir("/home/huabei/projects/SMTarRNA")

In [None]:
import logging
import pickle
import time
from collections import Counter, defaultdict
from functools import partial
from queue import Queue
from threading import Thread

import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from pandas import HDFStore
from tqdm import tqdm

matplotlib.use("Agg")
logging.basicConfig(level=logging.INFO)
os.listdir()

In [None]:
ZINC_DATA_PATH = "../data/dataset/ligand/index/"
# ZINC_HDF5_PATH = '../data/dataset/ZINC20-test.h5'

# ZINC_DATA_PATH = 'ZINC-DrugLike-3D-20230407'
# ZINC_DATA_PATH = 'test'
# ZINC_HDF5_PATH = 'ZINC20-DrugLike-3D-20230402.h5'

# 构建h5数据集

In [None]:
# from create_total_dataset_hdf5 import create_total_dataset_hdf5
import logging

from dataset.create_total_dataset_hdf5 import create_total_dataset_hdf5

# logging.basicConfig(level=logging.INFO)

In [None]:
create_total_dataset_hdf5(ZINC_DATA_PATH)
print("Construct Done! Good Job!")

In [None]:
store = HDFStore("../data/dataset/ZINC20-test/BA_index.h5")
store.keys()

# 筛选缺失的数据

In [None]:
# 已下载的数据
total_file = []
for _, _, files in os.walk(ZINC_DATA_PATH):
    total_file.extend(files)
total_file = [file for file in total_file if not file.endswith(".h5")]

In [None]:
# 需要下载的数据
with open("ZINC-downloader-3D-pdbqt.gz.curl", "r") as f:
    need_download_cmd = f.readlines()
need_download = {file.split(" ")[-2].split("/")[-1]: file for file in need_download_cmd}

In [None]:
# 验证已下载的数据包含于需要下载的数据
total_file_set = set(total_file)
need_download_set = set(need_download.keys())
assert total_file_set.issubset(
    need_download_set
), "total_file_set is not subset of need_download_set"
not_download = need_download_set - total_file_set

In [None]:
# 生成下载命令
download_cmd = []
for file in not_download:
    download_cmd.append(need_download[file])
len(download_cmd)
download_cmd[:5]

# 获取数据集的统计信息

In [None]:
hdf_index_file = [
    os.path.join(ZINC_DATA_PATH, i) for i in os.listdir(ZINC_DATA_PATH) if i.endswith("_index.h5")
]
hdf_coor_file = [
    os.path.join(ZINC_DATA_PATH, i) for i in os.listdir(ZINC_DATA_PATH) if i.endswith("_coor.h5")
]
hdf_index_file.sort()
hdf_coor_file.sort()
all_hdf_file = list(zip(hdf_index_file, hdf_coor_file))
all_hdf_file[:5]

### 单线程完成分子个数和原子个数的统计

In [None]:
dataset_info = dict()
for index_file, coor_file in tqdm(all_hdf_file):
    index_store = HDFStore(index_file)
    coor_store = HDFStore(coor_file)
    total_molecule = 0
    total_atom = defaultdict(lambda: 0)
    for k in index_store.keys():
        total_molecule += index_store[k].shape[0]
        atom_num_tmp = Counter(coor_store[k]["atom"].to_list())
        for atom, num in atom_num_tmp.items():
            total_atom[atom] += num
    dataset_info[os.path.basename(index_store._path)] = {
        "total_molecule": total_molecule,
        "total_atom": dict(total_atom),
    }
    index_store.close()
    coor_store.close()
with open("zinc20_druglike_dataset_info.pkl", "wb") as f:
    pickle.dump(dataset_info, f)
print("All work completed")

### 检查数据集的基本统计信息（文件数，分子数）

In [None]:
total_file = 0
total_molecule = 0
for index_file, _ in tqdm(all_hdf_file):
    # 计算时间
    t1 = time.time()
    index_store = HDFStore(index_file)
    t2 = time.time()
    # print(index_store.keys())
    total_file += len(index_store.keys())
    for k in index_store.keys():
        total_molecule += index_store[k].shape[0]
    t3 = time.time()
    logging.info(
        f"\nfile: {index_file}, keys: {len(index_store.keys())}, time: {t2-t1} s, {t3-t2} s"
    )
    index_store.close()
total_file, total_molecule

### 多线程完成分子个数和原子个数的统计（待优化）

In [None]:
# 创建两个线程，一个线程用于读取数据，一个线程用于处理数据


def read_data(q: Queue, file_list: list):
    for index_file, coor_file in file_list:
        logging.info(f"index_file: {index_file}, coor_file: {coor_file}")
        q.put((HDFStore(index_file, mode="r"), HDFStore(coor_file, mode="r")))


def process_data(index_store, coor_store):
    logging.info("process_data is running")
    total_molecule_num = 0
    total_atom = defaultdict(lambda: 0)
    for path, sub_group, datasetes in tqdm(index_store.walk()):
        for dataset in datasetes:
            d = os.path.join(path, dataset)
            data = index_store.get(d)
            data: pd.DataFrame
            total_molecule_num += len(data)
            data = coor_store.get(d)
            data: pd.DataFrame
            # 统计原子数量
            atom_num_tmp = Counter(data["atom"].to_list())
            # 统计原子总数
            for key in atom_num_tmp.keys():
                total_atom[key] += atom_num_tmp[key]
    index_store.close()
    coor_store.close()
    return {"total_molecule_num": total_molecule_num, "total_atom": dict(total_atom)}


# 创建生产者线程类
class Producer(Thread):
    def __init__(self, func, q: Queue, file_list: list):
        super().__init__()
        self.q = q
        self.file_list = file_list
        self.func = func
        self.daemon = True

    def run(self):
        self.func(self.q, self.file_list)


# 创建消费者线程类
class Consumer(Thread):
    def __init__(self, q: Queue, loops: int):
        super().__init__()
        self.q = q
        self.daemon = True
        self.loops = loops

    def run(self):
        self.results = dict()
        for i in range(self.loops):
            logging.info("Consumer is running")
            index_store, coor_store = self.q.get()
            logging.info("Consumer get data")
            result = process_data(index_store, coor_store)
            self.results[os.path.basename(index_store._path)] = result
            self.q.task_done()

### 多线程

In [None]:
q = Queue()
producer = Producer(read_data, q, all_hdf_file)
consumer = Consumer(q, len(all_hdf_file))
producer.start()
consumer.start()
q.join()
producer.join()
consumer.join()
with open("zinc20_druglike_dataset_info.pkl", "wb") as f:
    pickle.dump(consumer.results, f)
print("All work completed")

### 随机抽取数据集中的分子
抽取zinc id然后分组，通过路径找到对应的pdbqt.gz文件抽取分子。

In [None]:
import random

random_sample_molecule = dict()
ratio = 1 / 600
for index_file, _ in tqdm(all_hdf_file):
    index_store = HDFStore(index_file, mode="r")
    for path, sub_group, datasetes in tqdm(index_store.walk()):
        for dataset in datasetes:
            d = os.path.join(path, dataset)
            zinc_id_tmp = index_store.get(d).index.to_list()
            random_sample_molecule[d] = [zi for zi in zinc_id_tmp if random.random() < ratio]
    index_store.close()

with open("zinc20_druglike_random_sample_molecule_1f600.pkl", "wb") as f:
    pickle.dump(random_sample_molecule, f)

In [None]:
import gzip
import io
import re


class ZincPdbqt:
    """
    A class for pdbqt or pdbqt.gz file, this class could transfer str dict to some friendly format.
    """

    def __init__(self, pdbqt_file):
        # 读取.pdbqt.gz文件，转换为str
        self.f_str = gzip.open(pdbqt_file, mode="rb").read().decode()
        # 读取.pdbqt.gz文件中的zinc_id
        self.zinc_id = re.findall("Name = (.*?)\n", self.f_str)
        # 读取.pdbqt.gz文件中的分子结构
        if self.f_str.startswith("MODEL"):
            self.molecules = re.findall("MODEL.*?\n(.*?)ENDMDL\n", self.f_str, re.S)
        else:
            self.molecules = [self.f_str]
        # 生成一个list，包含zinc_id和分子结构
        self.data = list(zip(self.zinc_id, self.molecules))

    @property
    def data_dict(self):
        return dict(zip(self.zinc_id, self.molecules))


def gz_writer(file_name: str) -> io.TextIOWrapper:
    """get a file name, return a gz file api with wb mode"""
    output = gzip.open(file_name, "wb")
    ecn = io.TextIOWrapper(output, encoding="utf-8")
    return ecn


def write_pdbqt_to_gz(pdbqt_list, gz_file):
    """write a list of pdbqt to gz file"""
    with gz_writer(gz_file) as f:
        for pdbqt in tqdm(pdbqt_list, desc="write to gz"):
            f.writelines("MODEL\n" + pdbqt[1] + "ENDMDL\n")


def ele_filter(zinc_pdbqt_item, elements_list=None):
    """
    if pdbqt item have element that not in elements_list, return False, else return True.
    Use in filter() function.
    :param zinc_pdbqt_item: [..., pdbqt_str]
    :param elements_list: ['H', 'C', 'O']
    :return: True or False.
    """
    assert elements_list is not None, "elements_list is None"
    lines = zinc_pdbqt_item[1].strip().split("\n")
    elements_list = [i.upper() for i in elements_list]
    for line in lines:
        if line.startswith(("ATOM", "HETATM")):
            # ele = line[12:16].strip()
            # 去除元素符号中的非字母字符
            ele = line[12:14].strip()
            if ele.upper() in elements_list:
                continue
            else:
                return False
        else:
            continue
    return True

In [None]:
pdbqt_list = []
for k, v in tqdm(random_sample_molecule.items(), desc="read pdbqt.gz"):
    # 跳过空的list
    if not v:
        continue
    # 根据k生成文件路径
    file = ZINC_DATA_PATH + k.replace("_", ".") + ".pdbqt.gz"
    # 分析文件，生成一个dict，key为zinc_id，value为分子结构
    zinc_pdbqt = ZincPdbqt(file).data_dict
    # 根据zinc_id，从dict中提取分子结构
    try:
        for zinc_id in v:
            pdbqt_list.append((zinc_id, zinc_pdbqt[zinc_id]))
    except KeyError:
        print(f"{zinc_id} not in {file}")
        continue
    # break

In [None]:
elements_dict = dict(C=0, N=1, O=2, H=3, F=4, S=5, CL=6, Br=7, I=8, SI=9, P=10)
elements_list = list(elements_dict.keys())
ele_filter_ = partial(ele_filter, elements_list=elements_list)
pdbqt_list = list(filter(ele_filter_, pdbqt_list))

In [None]:
len(pdbqt_list)

In [None]:
# 将pdbqt_list写入gz文件
write_pdbqt_to_gz(pdbqt_list, "outputs/zinc20_druglike_random_sample_molecule_1f600.pdbqt.gz")

### 从pdbqt.gz文件中抽取分子

In [None]:
pdbqt_list_100k = random.sample(pdbqt_list, 100_000)
write_pdbqt_to_gz(
    pdbqt_list_100k, "outputs/zinc20_druglike_random_sample_molecule_1f600_100k.pdbqt.gz"
)
pdbqt_list_10k = random.sample(pdbqt_list_100k, 10000)
write_pdbqt_to_gz(
    pdbqt_list_10k, "outputs/zinc20_druglike_random_sample_molecule_1f600_10k.pdbqt.gz"
)

In [None]:
random_sample = ZincPdbqt(
    "outputs/zinc20_druglike_random_sample_molecule_1f600.pdbqt.gz"
).data_dict

# 绘图

In [None]:
dataset_info = pickle.load(open("../data/dataset/zinc20_druglike_dataset_info.pkl", "rb"))

In [None]:
total_molecule_num = 0
total_atom_count = defaultdict(lambda: 0)
for k, v in dataset_info.items():
    total_molecule_num += v["total_molecule"]
    for atom, num in v["total_atom"].items():
        total_atom_count[atom] += num

In [None]:
total_atom_count = dict(total_atom_count)

# 归一化
total_atom_num = sum(total_atom_count.values())
total_atom_count_p = {
    k: v / total_atom_num
    for k, v in sorted(total_atom_count.items(), key=lambda item: item[1], reverse=True)
    if k not in ["HD", "HH", "HE"]
}
# 绘制漂亮的条形图
fig, ax = plt.subplots(figsize=(10, 6))
ax.bar(total_atom_count_p.keys(), total_atom_count_p.values())
ax.set_xlabel("Atom")
ax.set_ylabel("Atom Count")
ax.set_title("ZINC20 Druglike Subset Atom Count")

# 保存为无背景的图片
plt.savefig("zinc20_druglike_atom_count.png")

## 提取信息

In [None]:
# 待提取的分子
molecules = pickle.load(
    open("../data/dataset/outputs/zinc20_druglike_random_sample_molecule_1f600.pkl", "rb")
)

In [None]:
molecule_size = dict()
for k, v in tqdm(molecules.items()):
    if not v:
        continue
    file = os.path.join(ZINC_DATA_PATH, k[1:3] + "_index.h5")
    with HDFStore(file, "r") as store:
        for zinc_id in v:
            try:
                mol = store[k].loc[zinc_id]
                atom_num = mol["end"] - mol["start"]
                molecule_size[zinc_id] = atom_num
                # print(f'{zinc_id} has {atom_num} atoms')
                # break
            except KeyError:
                print(f"{zinc_id} not in {file}")
                continue
        # break
with open(
    "../data/dataset/outputs/zinc20_druglike_random_sample_molecule_1f600_molecule_size.pkl", "wb"
) as f:
    pickle.dump(molecule_size, f)

In [None]:
molecule_size = pickle.load(
    open(
        "../data/dataset/outputs/zinc20_druglike_random_sample_molecule_1f600_molecule_size.pkl",
        "rb",
    )
)
list(molecule_size.items())[-1]