In [None]:
# 导入必要的库
import sys
# 将自定义模块路径添加到系统路径中
sys.path.append('../DLEPS/')
import numpy as np
import pandas as pd
import h5py
from rdkit.Chem import MolFromSmiles, MolToSmiles, Draw
import molecule_vae
import zinc_grammar
import nltk
from functools import reduce
from multiprocessing import Pool, cpu_count



In [None]:
# 读取训练集和测试集的 SMILES 数据
train_smiles_df = pd.read_csv('../../results/train_SMILES_demo.csv')
test_smiles_df = pd.read_csv('../../results/test_SMILES_demo.csv')

# 合并训练集和测试集的 SMILES 数据
combined_smiles = np.concatenate([train_smiles_df['smiles'].values, test_smiles_df['smiles'].values], axis=0)

print("Number of SMILES from train and test datasets:", len(combined_smiles))


In [None]:
# 读取 L1000 基因表达数据
gene_expression_df = pd.read_csv('../../results/L1000_landmark.csv')

print("Number of SMILES in L1000_landmark.csv:", len(gene_expression_df))
print("Columns in gene expression data:", gene_expression_df.columns.tolist())


In [None]:
# 规范化合并后的 SMILES 数据为标准 SMILES
def normalize_smiles(smiles_list):
    normalized_smiles = []
    for smi in smiles_list:
        try:
            molecule = MolFromSmiles(smi)
            if molecule:
                canonical_smi = MolToSmiles(molecule)
                normalized_smiles.append(canonical_smi)
            else:
                normalized_smiles.append(None)
        except Exception as e:
            normalized_smiles.append(None)
    return normalized_smiles

# 规范化 SMILES
canonical_smiles = normalize_smiles(combined_smiles)

# 创建包含原始和规范化 SMILES 的 DataFrame
smiles_df = pd.DataFrame({
    'original_smiles': combined_smiles,
    'canonical_smiles': canonical_smiles
})

# 显示规范化结果
print(smiles_df.head())


In [None]:
# 规范化基因表达数据中的 SMILES
gene_expression_canonical_smiles = normalize_smiles(gene_expression_df['smiles'])

# 将规范化后的 SMILES 添加为新列
gene_expression_df['canonical_smiles'] = gene_expression_canonical_smiles

# 显示规范化结果
print(gene_expression_df.head())


In [None]:
# 重置索引以便合并
smiles_df.reset_index(inplace=True, drop=True)
gene_expression_df.reset_index(inplace=True, drop=True)

# 合并两个数据集，基于规范化后的 SMILES
merged_df = pd.merge(
    smiles_df,
    gene_expression_df,
    on='canonical_smiles',
    how='inner',
    suffixes=('_smiles', '_gene')
)

print("Number of matched SMILES after merging:", len(merged_df))


In [None]:
# 提取匹配的 SMILES 索引
matched_smiles_indices = merged_df.index.values

# 假设基因表达数据从列名 '780' 开始，提取基因表达数据
gene_expression_columns = merged_df.columns[merged_df.columns.get_loc('780'):]
gene_expression_data = merged_df[gene_expression_columns].values

# 提取需要处理的 SMILES
smiles_to_process = merged_df['original_smiles'].values

print("Extracted gene expression data shape:", gene_expression_data.shape)
print("Number of SMILES to process:", len(smiles_to_process))


In [None]:
# 使用 RDKit 进一步处理 SMILES，记录有效的索引
processed_smiles = []
valid_indices = []

for idx, smi in enumerate(smiles_to_process):
    try:
        molecule = MolFromSmiles(smi)
        if molecule:
            canonical_smi = MolToSmiles(molecule)
            processed_smiles.append(canonical_smi)
            valid_indices.append(idx)
        else:
            print(f"Invalid molecule at index {idx}")
    except Exception as e:
        print(f"Error processing SMILES at index {idx}: {e}")

print("Number of valid SMILES after RDKit processing:", len(processed_smiles))


In [None]:
# 过滤基因表达数据，仅保留有效 SMILES 对应的数据
valid_gene_expression_data = gene_expression_data[valid_indices]

print("Filtered gene expression data shape:", valid_gene_expression_data.shape)


In [10]:
# 定义计算对象长度的辅助函数
def calculate_length(sequence):
    return reduce(lambda total, _: total + 1, sequence, 0)

# 定义 ZINC 分词器函数
def get_zinc_tokenizer(cfg):
    # 提取长度大于1的长分子符号
    long_tokens = [token for token in cfg._lexical_index.keys() if calculate_length(token) > 1]
    
    # 定义替换字符
    replacements = ['$', '%', '^']
    
    # 确保替换字符数量与长分子符号数量一致
    assert calculate_length(long_tokens) == len(replacements), "Mismatch between long tokens and replacements."
    
    # 确保替换字符未被占用
    for token in replacements:
        assert token not in cfg._lexical_index, f"Replacement token {token} already in lexical index."
    
    def tokenize(smiles):
        # 替换长分子符号为单字符
        for i, token in enumerate(long_tokens):
            smiles = smiles.replace(token, replacements[i])
        
        tokens = []
        for char in smiles:
            try:
                # 尝试还原长分子符号
                index = replacements.index(char)
                tokens.append(long_tokens[index])
            except ValueError:
                # 保留原字符
                tokens.append(char)
        return tokens
    
    return tokenize


In [None]:
# 初始化分词器和解析器
tokenizer = get_zinc_tokenizer(zinc_grammar.GCFG)
parser = nltk.ChartParser(zinc_grammar.GCFG)
productions = zinc_grammar.GCFG.productions()

# 创建产生式规则到索引的映射
production_map = {prod: idx for idx, prod in enumerate(productions)}

# 设置编码参数
MAX_SMILES_LENGTH = 277
NUM_PRODUCTIONS = len(productions)

print(f"Number of productions: {NUM_PRODUCTIONS}")
print(f"Maximum SMILES length: {MAX_SMILES_LENGTH}")


In [12]:
# 定义函数，用于解析 SMILES 并返回解析树
def parse_smiles(args):
    index, tokens = args
    try:
        parse_tree = next(parser.parse(tokens))
        return (index, parse_tree, None)
    except Exception as e:
        return (index, None, str(e))


In [None]:
# 对所有 SMILES 进行分词
tokenized_smiles = list(map(tokenizer, processed_smiles))

# 使用多进程并行解析 SMILES
with Pool(cpu_count()) as pool:
    parse_results = pool.map(parse_smiles, enumerate(tokenized_smiles))

# 处理解析结果，收集成功的解析树和失败的索引
parse_trees = []
failed_indices = []

for idx, tree, error in parse_results:
    if tree is not None:
        parse_trees.append(tree)
    else:
        print(f"Parse tree error at index {idx}: {error}")
        failed_indices.append(idx)

print("Number of successfully parsed SMILES:", len(parse_trees))
print("Number of failed parses:", len(failed_indices))


In [None]:
# 过滤掉解析失败的 SMILES 对应的索引
valid_parse_indices = [i for i in range(len(valid_indices)) if i not in failed_indices]

# 更新有效的基因表达数据和 SMILES
final_gene_expression_data = valid_gene_expression_data[valid_parse_indices]
final_smiles = [processed_smiles[i] for i in valid_parse_indices]

print("Final gene expression data shape:", final_gene_expression_data.shape)
print("Number of final SMILES to encode:", len(final_smiles))


In [None]:
# 提取每个解析树的产生式规则序列
production_sequences = [tree.productions() for tree in parse_trees]

# 将产生式规则映射为索引序列
production_indices = [
    np.array([production_map[prod] for prod in seq], dtype=int) 
    for seq in production_sequences
]

# 初始化 One-Hot 编码矩阵
one_hot_encoded = np.zeros((len(production_indices), MAX_SMILES_LENGTH, NUM_PRODUCTIONS), dtype=np.float32)

# 填充 One-Hot 编码矩阵
for i, indices in enumerate(production_indices):
    num_productions = len(indices)
    if num_productions > MAX_SMILES_LENGTH:
        print(f"SMILES at index {i} exceeds max length, truncating.")
        one_hot_encoded[i, :MAX_SMILES_LENGTH, indices[:MAX_SMILES_LENGTH]] = 1.0
    else:
        one_hot_encoded[i, :num_productions, indices] = 1.0
        one_hot_encoded[i, num_productions:, -1] = 1.0  # 填充空白位置

print("One-Hot encoded SMILES shape:", one_hot_encoded.shape)


In [None]:
# 设置随机种子以确保可重复性
np.random.seed(42)

# 获取样本数量
num_samples = final_gene_expression_data.shape[0]

# 生成随机排列的索引
shuffled_indices = np.random.permutation(num_samples)

# 打乱基因表达数据和 One-Hot 编码数据
shuffled_gene_expression = final_gene_expression_data[shuffled_indices]
shuffled_one_hot = one_hot_encoded[shuffled_indices]

# 定义测试集大小
TEST_SET_SIZE = 3000

# 划分测试集和训练集
gene_expression_test = shuffled_gene_expression[:TEST_SET_SIZE]
gene_expression_train = shuffled_gene_expression[TEST_SET_SIZE:]

one_hot_test = shuffled_one_hot[:TEST_SET_SIZE]
one_hot_train = shuffled_one_hot[TEST_SET_SIZE:]

print("Training set gene expression shape:", gene_expression_train.shape)
print("Training set One-Hot shape:", one_hot_train.shape)
print("Test set gene expression shape:", gene_expression_test.shape)
print("Test set One-Hot shape:", one_hot_test.shape)


In [None]:
# 定义保存数据的函数
def save_to_h5(file_path, dataset_name, data):
    with h5py.File(file_path, 'w') as h5f:
        h5f.create_dataset(dataset_name, data=data)
    print(f"Data saved to {file_path} with dataset name '{dataset_name}'.")

# 保存基因表达训练集和测试集
save_to_h5('../../results/L1000_train.h5', 'data', gene_expression_train)
save_to_h5('../../results/L1000_test.h5', 'data', gene_expression_test)

# 保存 One-Hot 编码的 SMILES 训练集和测试集
save_to_h5('../../results/SMILES_train_demo.h5', 'data', one_hot_train)
save_to_h5('../../results/SMILES_test_demo.h5', 'data', one_hot_test)
