In [1]:
import argparse
import csv
import logging
import os
import sys
import time
from random import random
import traceback
import rdkit
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem

import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.utils.data import DataLoader
import math, random, sys
import numpy as np
import argparse
import os
from tqdm import tqdm
import pandas as pd

sys.path.append('/home/lnptest/SB_jin/git_test/hgraph2graph_wandb/')
from hgraph import *
from hgraph import vocab2
import rdkit

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

def setup_logging():
    logging.basicConfig(
        level=logging.DEBUG,
        format='%(asctime)s [%(levelname)s] %(message)s',
        handlers=[logging.FileHandler('output.log'), logging.StreamHandler()]
    )
    lg = rdkit.RDLogger.logger()
    lg.setLevel(rdkit.RDLogger.CRITICAL)

# argparse 대신 직접 args 설정
class Args:
    vocab = "vocab.txt"
    atom_vocab = common_atom_vocab
    model = "/home/lnptest/SB_jin/git_test/hgraph2graph_wandb/ckpt/chembl-240627-2/model_95232.ckpt"
    seed = 7
    nsample = 10000000
    rnn_type = 'LSTM'
    hidden_size = 512
    embed_size = 512
    batch_size = 50
    latent_size = 32
    depthT = 15
    depthG = 15
    diterT = 1
    diterG = 3
    dropout = 0.0

args = Args()

In [None]:
def parse_arguments():
    parser = argparse.ArgumentParser(description='Description of your program')
    parser.add_argument("-s", "--sample", required=True, help="Number of samples to generate", type=int)
    parser.add_argument("-i", "--ideacode", required=True, help="Ideacode for the query", type=str)
    parser.add_argument("-o", "--output_dir", default="/home/dwpdp/ideadb/data/", help="Directory to save the result", type=str)
    parser.add_argument("-t", "--temp_suffix", default="", help="Temporary suffix for the output file", type=str)
    return parser.parse_args(

In [None]:
def write_results(sampled_dict, args, execution_time):
    output_dir_path = os.path.join(args.output_dir, args.ideacode, 'hgraph')
    os.makedirs(output_dir_path, exist_ok=True)
    result_file_name = f'Gen_result_{args.ideacode}{args.temp_suffix}.csv'
    result_file_path = os.path.join(output_dir_path, result_file_name)
    try:
        with open(result_file_path, 'w', newline='') as csvfile:
            csv_writer = csv.writer(csvfile)
            csv_writer.writerow(["Query_ID", "Gen_SMILE", "Gen_ID", "Sim(ECFP4)"])
            for idx, (gen_smi, ecfp_sim) in enumerate(sampled_dict.items()):
                gen_id = f"{args.ideacode}-S{str(idx + 1).zfill(3)}"
                csv_writer.writerow([args.ideacode, gen_smi, gen_id, f"{ecfp_sim:.3f}"])
            
            # 실행 시간을 기록합니다.
            hours = int(execution_time / 3600)
            minutes = int((execution_time % 3600) / 60)
            seconds = int((execution_time % 3600) % 60)
            csv_writer.writerow([])
            csv_writer.writerow([f"Execution time: {hours}h {minutes}m {seconds}s"])
        logging.info('Generation complete')
    except Exception as e:
        logging.error(f"Failed to write results: {e}")

def main():
    setup_logging()
    args = parse_arguments()
    model = initialize_model()

    query_path = f'{args.ideacode}.smi'
    with open(query_path, 'r') as f:
        query_smi = f.readline().strip()

    query_mol = Chem.MolFromSmiles(query_smi)
    os.makedirs(args.output_dir, exist_ok=True)
    start_time = time.time()

    sampled_dict = {}
    sampled_list = []

    while len(sampled_list) < args.sample:
        try:
            gen_smi = model.reconstruct_scaled(query_smi, scale=random(), prob_decode=False, debug=False)
        except IndexError as e:
            logging.error(f"IndexError: {e}")
            continue
        except Exception as e:
            logging.error(f"Unexpected error: {e}, Traceback: {traceback.format_exc()}")
            continue

        if not gen_smi:
            continue

        gen_smi = Chem.MolToSmiles(Chem.MolFromSmiles(gen_smi))
        gen_mol = Chem.MolFromSmiles(gen_smi)
        if gen_mol is None:  
            continue

        # ECFP4 계산
        fps_query_ecfp = AllChem.GetMorganFingerprintAsBitVect(query_mol, 2, nBits=1024)
        fps_gen_ecfp = AllChem.GetMorganFingerprintAsBitVect(gen_mol, 2, nBits=1024)

        similarity_ecfp = DataStructs.FingerprintSimilarity(fps_query_ecfp, fps_gen_ecfp)

        # 유사도가 기준치 이상인 SMILES만 저장
        if similarity_ecfp >= 0.1:  # 기준치는 필요에 따라 조정하시면 됩니다.
            if gen_smi not in sampled_dict:
                sampled_dict[gen_smi] = similarity_ecfp
                sampled_list.append(gen_smi)

                # 생성 과정시 모니터링 용도로 화면에 출력하는 부분으로 추후 삭제 가능합니다.             
                print(f'succeed >> {gen_smi}, {len(sampled_list)}, {similarity_ecfp:.3f}')

        if len(sampled_list) >= args.sample:
            break

    end_time = time.time()
    execution_time = end_time - start_time

    write_results(sampled_dict, args, execution_time)

if __name__ == '__main__':
    main()


# rd_python JTNN_generator.py -s 10 -i IDEA0138008