In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import selfies as sf
import os
from rdkit import Chem
from rdkit.Chem import MolFromSmiles
from rdkit.Chem import MolToSmiles
from rdkit.Chem import DataStructs
from rdkit.Chem import AllChem

In [2]:
# Define the CHEMENV class (same as in your training script)
class CHEMENV:
    def __init__(self, max_actions=150):
        self.tokens = ['[#Branch1]', '[#Branch2]', '[=Branch1]', '[=Branch2]', '[=Ring1]', '[=Ring2]', '[Branch1]', '[Branch2]', '[Ring1]', '[Ring2]']
        self.tokens += ['[P]', '[O]', '[N]', '[F]', '[Cl]', '[C]', '[Br]', '[#C]', '[S]']
        self.tokens += ['[=C]', '[=N]', '[=O]', '[=P]', '[=S]', '[#N]'] #, '[=S+1]', '[=PH1]']
        self.action_space = ['<', '>'] + self.tokens
        self.atoi = {a: i for i, a in enumerate(self.action_space)}
        self.current_state = [0]
        self.max_actions = max_actions
        self.action_count = 0

    def get_num_action(self):
        return len(self.action_space)

    def reset(self):
        self.current_state = [0]
        self.action_count = 0
        return self.current_state

    def to_index_list(self, s):
        return [self.atoi[a] for a in s]

    def to_selfies(self, s):
        selfi = "".join([self.action_space[i] for i in s])
        selfi = selfi.replace(self.action_space[1], "").replace(self.action_space[0], "")
        return selfi

    def calc_reward(self, selfies1):
        try:
            smiles1 = sf.decoder(selfies1)
        except sf.DecoderError as e:
            return False
        try:
            mol1 = MolFromSmiles(smiles1)
            if mol1 is not None:
                ms = [Chem.MolFromSmiles(smiles1), Chem.MolFromSmiles('COCCOC1=C(C=C2C(=C1)C(=NC=N2)NC3=CC=CC(=C3)C#C)OCCOC')]
                fpgen = AllChem.GetRDKitFPGenerator()
                fps = [fpgen.GetFingerprint(x) for x in ms]
                tanimoto = DataStructs.TanimotoSimilarity(fps[0],fps[1])
                return tanimoto * 3
            else:
                return 0
        except Exception as e:
            print(f"Error in calc_reward: {e}")  # Debug print
            return 0

    def step(self, a):
        self.current_state = self.current_state + [a]
        self.action_count += 1
        selfies_string = self.to_selfies(self.current_state)
        r = self.calc_reward(selfies_string)
        terminated = (a == 1) and (len(self.current_state) > 5)  # 최소 5자 이상 생성하는 조건 추가
        terminated = terminated or self.action_count >= self.max_actions
        return self.current_state, r, terminated, False, None

    def close(self):
        pass

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# Actor 모델 정의 (same as in your training script)
class Actor(nn.Module):
    def __init__(self, num_action, embedding_dim, hidden_size):
        super(Actor, self).__init__()
        self.emb = nn.Embedding(num_action, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_size, num_layers=4, batch_first=True, dropout=0.2)  # 여러 레이어 추가
        self.fc = nn.Linear(hidden_size, num_action)

    def forward(self, state):
        ## type(state): tensor
        x = self.emb(state.unsqueeze(0))  # 배치 차원 추가
        o, h = self.rnn(x)  # o.shape: (batch, seqlen, hidden_size), h.shape: (num_layers, batch, hidden_size)
        o_last = self.fc(o[:, -1, :])  # 배치의 마지막 시퀀스 아이템
        probabilities = torch.softmax(o_last, dim=-1)
        return probabilities

# Critic 모델 정의 (same as in your training script)
class Critic(nn.Module):
    def __init__(self, num_action, embedding_dim, hidden_size):
        super(Critic, self).__init__()
        self.emb = nn.Embedding(num_action, embedding_dim)
        self.rnn = nn.GRU(embedding_dim, hidden_size, num_layers=4, batch_first=True, dropout=0.2)  # 여러 레이어 추가
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, state):
        ## type(state): tensor
        x = self.emb(state.unsqueeze(0))  # 배치 차원 추가
        o, h = self.rnn(x)  # o.shape: (batch, seqlen, hidden_size), h.shape: (num_layers, batch, hidden_size)
        o_last = self.fc(o[:, -1, :])  # 배치의 마지막 시퀀스 아이템
        return o_last

In [5]:
# Sampling function
def sample_from_model(actor, env, num_samples=10):
    actor.eval()
    samples = []
    for _ in range(num_samples):
        state = env.reset()
        done = False
        while not done:
            state_tensor = torch.IntTensor(state).to(device)
            probs = actor(state_tensor)
            dist = torch.distributions.Categorical(probs)
            action = dist.sample()

            next_state, reward, terminated, truncated, _ = env.step(action.item())
            done = terminated or truncated

            state = next_state

        selfies_string = env.to_selfies(state)
        try:
            smiles = sf.decoder(selfies_string)
            samples.append(smiles)
        except sf.DecoderError:
            samples.append(None)
    
    return samples

In [6]:
# Load checkpoint
checkpoint_path = 'checkpoints/checkpoint_3000.pth'
checkpoint = torch.load(checkpoint_path, map_location=device)
# Initialize environment and actor model
env = CHEMENV(max_actions=150)
actor = Actor(env.get_num_action(), 16, 32).to(device)
actor.load_state_dict(checkpoint['actor_state_dict'])

<All keys matched successfully>

In [7]:
# Reference molecule for Tanimoto similarity calculation
reference_smiles = 'COCCOC1=C(C=C2C(=C1)C(=NC=N2)NC3=CC=CC(=C3)C#C)OCCOC'
reference_mol = Chem.MolFromSmiles(reference_smiles)
fpgen = AllChem.GetRDKitFPGenerator()
ref_fp = fpgen.GetFingerprint(reference_mol)

In [8]:
num_samples = 500  # Number of samples to generate
samples = sample_from_model(actor, env, num_samples=num_samples)

In [9]:
tanimoto_scores = []

for i, smi in enumerate(samples):
    if smi is not None:
        sample_mol = Chem.MolFromSmiles(smi)
        if sample_mol is not None:
            sample_fp = fpgen.GetFingerprint(sample_mol)
            tanimoto = DataStructs.TanimotoSimilarity(ref_fp, sample_fp)
            tanimoto_scores.append(tanimoto)
            print(f"Sample {i + 1}: {smi}, Tanimoto score: {tanimoto}")
        else:
            tanimoto_scores.append(None)
            print(f"Sample {i + 1}: {smi}, Tanimoto score: None (Invalid molecule)")
    else:
        tanimoto_scores.append(None)
        print(f"Sample {i + 1}: {smi}, Tanimoto score: None (Decoding error)")

# Calculate max, min, and average Tanimoto scores
valid_scores = [score for score in tanimoto_scores if score is not None]
max_tanimoto = max(valid_scores) if valid_scores else None
min_tanimoto = min(valid_scores) if valid_scores else None
avg_tanimoto = sum(valid_scores) / len(valid_scores) if valid_scores else None

print()
print(f"Max Tanimoto score: {max_tanimoto}")
print(f"Min Tanimoto score: {min_tanimoto}")
print(f"Average Tanimoto score: {avg_tanimoto}")

# Print samples sorted by Tanimoto score in descending order
sorted_samples = sorted(zip(samples, tanimoto_scores), key=lambda x: (x[1] is not None, x[1]), reverse=True)
print("\nSamples sorted by Tanimoto score (in descending order):")
for i, (smi, score) in enumerate(sorted_samples):
    print(f"Sample {i + 1}: {smi}, Tanimoto score: {score}")

Sample 1: C=S=S=S=S=S=C=C=C=C=C=SOC=C=C=C=C=S=C(NS=C=C=C=C=COCOC=C=C=C)C=C=S=C=CP=C=C=C=C=C=C=S(=C=S=COS=S=CNNC)(CP=C=C=C=C=COC=C=C=C=C)N=C=S=S=C=S=C=S=CP=SP(=C=CC=CC=S=CN=C=COC=C=CP=C=C)=C=C=C=C=S=CC=CP=C=C=C=C=C=C=C=C=C=C(C=C=S=C=C)OC=SN=C=C=C=C=C=C=C=C, Tanimoto score: 0.3074391046741277
Sample 2: C=CC=C=C=C=C=C=SC=CP=C=C=CP=COC=C=C(C=COO)C=S=CC=C=C=C=C=C=CNOC=C=C=C=C=C=C=C=COC=C=C=C=C=S=COC=C=CPC=C=C=C=C=CCC=C=CP=C=CC=S=S=C=S=CC=C=C=C=C=CC=CP=C=C=C=C=COS=CP=C=C=C=C=C=CC=S=C=C=C=C=C=C=S=C=CP=C=C=COC=C=C=C=C=C=C=C=SC=CPOC=C=CC=COC=C=C=C, Tanimoto score: 0.23322932917316694
Sample 3: C=C=C=CP=CC=C=C(C=C=C=C=CC=COCC=C=C=C)C=C=CP=C=C=COC=SC=C=C=CNOC=C=C(S)C=C=C=C=CC=S=S(=S=C=CC=C=CC=C=C=CCC=C=CN=C)ON=C(C=C=C=CP=C=C=C=CP=C=C=C=C)S=C=SN=C=C=C=CC=C=C=C=C=COON=C=C=C=C=CP=C=C=C=C=CNOC=C=C=C=CP=C=C(O)C=C=C=C=S=C=C=C=C=C=S, Tanimoto score: 0.27245508982035926
Sample 4: CP=SP=C=C=S=S=COCP=CC=C=C=C=C=SP=COC=C=C=C(C=C=C=COC=C=C=C=C=C=S=C)CC=C=C=C=CCC=C=C=C=CPPCC=C=C=S=C=C=COCPC=CN=SP=C=S=SP=C=S=C