<a href="https://colab.research.google.com/github/fourmodern/toc_tutorial_colab/blob/main/teachopencadd/t060_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install rdkit-pypi

Collecting rdkit-pypi
  Downloading rdkit_pypi-2022.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.9 kB)
Downloading rdkit_pypi-2022.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m30.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit-pypi
Successfully installed rdkit-pypi-2022.9.5


In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from rdkit import Chem
from rdkit.Chem import BondType
from sklearn.model_selection import train_test_split

In [3]:
# 하이퍼파라미터 설정
MAX_MOLSIZE = 120  # 최대 원자 수
ATOM_DIM = 11  # 원자 종류 수
BOND_DIM = 5  # 결합 종류 수 (단일, 이중, 삼중, 방향족, 없음)
LATENT_DIM = 435  # 잠재 공간 크기
BATCH_SIZE = 100
EPOCHS = 10
LEARNING_RATE = 5e-4

# SMILES 문자 집합 정의
SMILE_CHARSET = ["C", "B", "F", "I", "H", "O", "N", "S", "P", "Cl", "Br"]
SMILE_TO_INDEX = {char: idx for idx, char in enumerate(SMILE_CHARSET)}
INDEX_TO_SMILE = {idx: char for idx, char in enumerate(SMILE_CHARSET)}

# 결합 종류 매핑
BOND_MAPPING = {
    "SINGLE": 0,
    "DOUBLE": 1,
    "TRIPLE": 2,
    "AROMATIC": 3,
    0: BondType.SINGLE,
    1: BondType.DOUBLE,
    2: BondType.TRIPLE,
    3: BondType.AROMATIC,
}

In [4]:
# SMILES를 그래프로 변환하는 함수
def smiles_to_graph(smiles):
    molecule = Chem.MolFromSmiles(smiles)
    adjacency = np.zeros((BOND_DIM, MAX_MOLSIZE, MAX_MOLSIZE), dtype=np.float32)
    features = np.zeros((MAX_MOLSIZE, ATOM_DIM), dtype=np.float32)

    for atom in molecule.GetAtoms():
        i = atom.GetIdx()
        atom_type = SMILE_TO_INDEX.get(atom.GetSymbol(), ATOM_DIM - 1)
        features[i, atom_type] = 1.0

        for neighbor in atom.GetNeighbors():
            j = neighbor.GetIdx()
            bond = molecule.GetBondBetweenAtoms(i, j)
            bond_type_idx = BOND_MAPPING[bond.GetBondType().name]
            adjacency[bond_type_idx, i, j] = 1
            adjacency[bond_type_idx, j, i] = 1

    # 결합이 없는 부분에 대해 마지막 채널을 1로 설정
    adjacency[-1, np.sum(adjacency, axis=0) == 0] = 1
    # 원자가 없는 부분에 대해 마지막 열을 1로 설정
    features[np.sum(features, axis=1) == 0, -1] = 1

    return adjacency, features

In [5]:
# 데이터셋 클래스 정의
class MoleculeDataset(Dataset):
    def __init__(self, smiles_list):
        self.smiles_list = smiles_list

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        adjacency, features = smiles_to_graph(smiles)
        return torch.tensor(adjacency), torch.tensor(features)

In [6]:
# VAE 모델 정의
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(ATOM_DIM * MAX_MOLSIZE + BOND_DIM * MAX_MOLSIZE * MAX_MOLSIZE, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, LATENT_DIM * 2),  # mean과 logvar를 위한 2배 크기
        )
        self.decoder = nn.Sequential(
            nn.Linear(LATENT_DIM, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, ATOM_DIM * MAX_MOLSIZE + BOND_DIM * MAX_MOLSIZE * MAX_MOLSIZE),
            nn.Sigmoid(),
        )

    def encode(self, x):
        h = self.encoder(x)
        mean, logvar = torch.chunk(h, 2, dim=-1)
        return mean, logvar

    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        recon_x = self.decode(z)
        return recon_x, mean, logvar

In [7]:
# 손실 함수 정의
def loss_function(recon_x, x, mean, logvar):
    BCE = nn.BCELoss(reduction='sum')(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
    return BCE + KLD

In [8]:
# 데이터 로드 및 전처리
import os
import requests

# 데이터 다운로드
url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv"
filename = "250k_rndm_zinc_drugs_clean_3.csv"

if not os.path.exists(filename):
    response = requests.get(url)
    with open(filename, 'wb') as file:
        file.write(response.content)

# 데이터 로드
df = pd.read_csv(filename)
smiles_list = df['smiles'].values
train_smiles, test_smiles = train_test_split(smiles_list, test_size=0.1, random_state=42)

train_dataset = MoleculeDataset(train_smiles)
test_dataset = MoleculeDataset(test_smiles)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [9]:
# 모델, 옵티마이저 및 장치 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
# 모델 학습
model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for adjacency, features in train_loader:
        adjacency = adjacency.to(device)
        features = features.to(device)
        optimizer.zero_grad()
        x = torch.cat([features.view(features.size(0), -1), adjacency.view(adjacency.size(0), -1)], dim=1)
        recon_x, mean, logvar = model(x)
        loss = loss_function(recon_x, x, mean, logvar)
        loss.backward()
        total_loss += loss.item()
        optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {total_loss / len(train_loader)}')

Epoch 1, Loss: 35950.35744448464
Epoch 2, Loss: 25028.74730630061
Epoch 3, Loss: 23371.717327984126
Epoch 4, Loss: 21603.0854497079
Epoch 5, Loss: 20250.13860855714
Epoch 6, Loss: 18777.277863119503
Epoch 7, Loss: 17344.578938131766
Epoch 8, Loss: 16499.90376636373
Epoch 9, Loss: 15569.668880961988


In [None]:
# 모델 평가
model.eval()
with torch.no_grad():
    total_loss = 0
    for adjacency, features in test_loader:
        adjacency = adjacency.to(device)
        features = features.to(device)
        x = torch.cat([features.view(features.size(0), -1), adjacency.view(adjacency.size(0), -1)], dim=1)
        recon_x, mean, logvar = model(x)
        loss = loss_function(recon_x, x, mean, logvar)
        total_loss += loss.item()
    print(f'Test Loss: {total_loss / len(test_loader)}')

In [None]:
# 새로운 분자 생성
def generate_molecule():
    model.eval()
    with torch.no_grad():
        z = torch.randn(1, LATENT_DIM).to(device)
        generated = model.decode(z)
        generated = generated.cpu().numpy()
        features = generated[:, :ATOM_DIM * MAX_MOLSIZE].reshape(-1, MAX_MOLSIZE, ATOM_DIM)
        adjacency = generated[:, ATOM_DIM * MAX_MOLSIZE:].reshape(-1, BOND_DIM, MAX_MOLSIZE, MAX_MOLSIZE)
        return adjacency, features

In [None]:

# 생성된 분자를 SMILES로 변환하는 함수
def graph_to_smiles(adjacency, features):
    molecule = Chem.RWMol()
    atom_indices = []
    for i in range(MAX_MOLSIZE):
        atom_feature = features[0, i]
        if atom_feature[-1] == 1:  # '없음' 원자
            continue
        atom_type_idx = np.argmax(atom_feature)
        atom_type = INDEX_TO_SMILE.get(atom_type_idx, 'C')
        atom = Chem.Atom(atom_type)
        idx = molecule.AddAtom(atom)
        atom_indices.append(idx)

    for i, idx_i in enumerate(atom_indices):
        for j, idx_j in enumerate(atom_indices):
            if i >= j:
                continue
            bond_type_idx = np.argmax(adjacency[0, :, i, j])
            if bond_type_idx == BOND_DIM - 1:  # '없음' 결합
                continue
            bond_type = BOND_MAPPING.get(bond_type_idx, BondType.SINGLE)
            molecule.AddBond(idx_i, idx_j, bond_type)

    smiles = Chem.MolToSmiles(molecule)
    return smiles

In [None]:
# 새로운 분자 생성 및 SMILES 출력
adjacency, features = generate_molecule()
smiles = graph_to_smiles(adjacency, features)
print(f'Generated SMILES: {smiles}')