<a href="https://colab.research.google.com/github/fourmodern/toc_tutorial_colab/blob/main/teachopencadd/t043_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install rdkit

Collecting rdkit
  Downloading rdkit-2024.3.6-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (4.0 kB)
Downloading rdkit-2024.3.6-cp310-cp310-manylinux_2_28_x86_64.whl (32.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m32.8/32.8 MB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2024.3.6


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import pandas as pd
import numpy as np
from rdkit import Chem, RDLogger
from rdkit.Chem import BondType
import ast
import requests
import os
from pathlib import Path

RDLogger.DisableLog("rdApp.*")

# 상수 정의
SMILE_CHARSET = '["C", "B", "F", "I", "H", "O", "N", "S", "P", "Cl", "Br"]'
SMILE_CHARSET = ast.literal_eval(SMILE_CHARSET)

bond_mapping = {"SINGLE": 0, "DOUBLE": 1, "TRIPLE": 2, "AROMATIC": 3}
bond_mapping.update(
    {0: BondType.SINGLE, 1: BondType.DOUBLE, 2: BondType.TRIPLE, 3: BondType.AROMATIC}
)

SMILE_to_index = dict((c, i) for i, c in enumerate(SMILE_CHARSET))
index_to_SMILE = dict((i, c) for i, c in enumerate(SMILE_CHARSET))
atom_mapping = dict(SMILE_to_index)
atom_mapping.update(index_to_SMILE)

BATCH_SIZE = 100
EPOCHS = 10
NUM_ATOMS = 120
ATOM_DIM = len(SMILE_CHARSET)
BOND_DIM = 4 + 1
LATENT_DIM = 435

def download_file(url, filename):
    """파일 다운로드 함수"""
    response = requests.get(url)
    if response.status_code == 200:
        with open(filename, 'wb') as f:
            f.write(response.content)
        return True
    return False

# 데이터 파일 경로 설정
data_dir = Path('data')
data_dir.mkdir(exist_ok=True)
csv_path = data_dir / '250k_rndm_zinc_drugs_clean_3.csv'

# 데이터 다운로드 (파일이 없는 경우에만)
if not csv_path.exists():
    url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv"
    print("데이터 다운로드 중...")
    success = download_file(url, csv_path)
    if success:
        print(f"데이터가 성공적으로 다운로드되었습니다: {csv_path}")
    else:
        raise Exception("데이터 다운로드 실패")

데이터 다운로드 중...
데이터가 성공적으로 다운로드되었습니다: data/250k_rndm_zinc_drugs_clean_3.csv

데이터 샘플:
                                              smiles     logP       qed  \
0            CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1  5.05060  0.702012   
1       C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1  3.11370  0.928975   
2  N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)...  4.96778  0.599682   
3  CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c...  4.00022  0.690944   
4  N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])[C@H](C#...  3.60956  0.789027   

        SAS  
0  2.084095  
1  3.432004  
2  2.470633  
3  2.822753  
4  4.035182  


In [3]:
def smiles_to_graph(smiles):
    # Converts SMILES to molecule object
    molecule = Chem.MolFromSmiles(smiles)

    # Initialize adjacency and feature tensor
    adjacency = np.zeros((BOND_DIM, NUM_ATOMS, NUM_ATOMS), "float32")
    features = np.zeros((NUM_ATOMS, ATOM_DIM), "float32")

    # loop over each atom in molecule
    for atom in molecule.GetAtoms():
        i = atom.GetIdx()
        atom_type = atom_mapping[atom.GetSymbol()]
        features[i] = np.eye(ATOM_DIM)[atom_type]
        # loop over one-hop neighbors
        for neighbor in atom.GetNeighbors():
            j = neighbor.GetIdx()
            bond = molecule.GetBondBetweenAtoms(i, j)
            bond_type_idx = bond_mapping[bond.GetBondType().name]
            adjacency[bond_type_idx, [i, j], [j, i]] = 1

    # Where no bond, add 1 to last channel (indicating "non-bond")
    # Notice: channels-first
    adjacency[-1, np.sum(adjacency, axis=0) == 0] = 1

    # Where no atom, add 1 to last column (indicating "non-atom")
    features[np.where(np.sum(features, axis=1) == 0)[0], -1] = 1

    return adjacency, features


def graph_to_molecule(graph):
    # Unpack graph
    adjacency, features = graph

    # RWMol is a molecule object intended to be edited
    molecule = Chem.RWMol()

    # Remove "no atoms" & atoms with no bonds
    keep_idx = np.where(
        (np.argmax(features, axis=1) != ATOM_DIM - 1)
        & (np.sum(adjacency[:-1], axis=(0, 1)) != 0)
    )[0]
    features = features[keep_idx]
    adjacency = adjacency[:, keep_idx, :][:, :, keep_idx]

    # Add atoms to molecule
    for atom_type_idx in np.argmax(features, axis=1):
        atom = Chem.Atom(atom_mapping[atom_type_idx])
        _ = molecule.AddAtom(atom)

    # Add bonds between atoms in molecule; based on the upper triangles
    # of the [symmetric] adjacency tensor
    (bonds_ij, atoms_i, atoms_j) = np.where(np.triu(adjacency) == 1)
    for (bond_ij, atom_i, atom_j) in zip(bonds_ij, atoms_i, atoms_j):
        if atom_i == atom_j or bond_ij == BOND_DIM - 1:
            continue
        bond_type = bond_mapping[bond_ij]
        molecule.AddBond(int(atom_i), int(atom_j), bond_type)

    # Sanitize the molecule; for more information on sanitization, see
    # https://www.rdkit.org/docs/RDKit_Book.html#molecular-sanitization
    flag = Chem.SanitizeMol(molecule, catchErrors=True)
    # Let's be strict. If sanitization fails, return None
    if flag != Chem.SanitizeFlags.SANITIZE_NONE:
        return None

    return molecule

class MoleculeDataset(Dataset):
    def __init__(self, df, num_samples=8000):
        self.adjacency_tensor = []
        self.feature_tensor = []
        self.qed_tensor = []

        # 데이터 수집
        for idx in range(min(num_samples, len(df))):
            adjacency, features = smiles_to_graph(df.loc[idx]["smiles"])
            qed = df.loc[idx]["qed"]
            self.adjacency_tensor.append(adjacency)
            self.feature_tensor.append(features)
            self.qed_tensor.append(qed)

        # numpy 배열로 먼저 변환
        self.adjacency_tensor = np.array(self.adjacency_tensor, dtype=np.float32)
        self.feature_tensor = np.array(self.feature_tensor, dtype=np.float32)
        self.qed_tensor = np.array(self.qed_tensor, dtype=np.float32)

        # numpy 배열을 torch 텐서로 변환
        self.adjacency_tensor = torch.from_numpy(self.adjacency_tensor)
        self.feature_tensor = torch.from_numpy(self.feature_tensor)
        self.qed_tensor = torch.from_numpy(self.qed_tensor)

    def __len__(self):
        return len(self.qed_tensor)

    def __getitem__(self, idx):
        return (self.adjacency_tensor[idx],
                self.feature_tensor[idx],
                self.qed_tensor[idx])

In [4]:
class RelationalGraphConvLayer(nn.Module):
    def __init__(self, in_features, out_features, activation=F.relu, use_bias=False):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.activation = activation
        self.use_bias = use_bias

        self.weight = nn.Parameter(torch.FloatTensor(BOND_DIM, in_features, out_features))
        if use_bias:
            self.bias = nn.Parameter(torch.FloatTensor(BOND_DIM, 1, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.weight)
        if self.use_bias:
            nn.init.zeros_(self.bias)

    def forward(self, adjacency, features):
        # [batch, bond_dim, num_atoms, num_atoms] x [batch, num_atoms, in_features]
        support = torch.matmul(adjacency, features.unsqueeze(1))
        # Apply weights
        output = torch.matmul(support, self.weight)
        if self.use_bias:
            output = output + self.bias
        # Sum over bond types
        output = torch.sum(output, dim=1)
        return self.activation(output)

class Sampling(nn.Module):
    def forward(self, z_mean, z_log_var):
        batch_size = z_mean.size(0)
        latent_dim = z_mean.size(1)
        epsilon = torch.randn(batch_size, latent_dim).to(z_mean.device)
        return z_mean + torch.exp(0.5 * z_log_var) * epsilon

In [5]:
class Encoder(nn.Module):
    def __init__(self, gconv_units, dense_units, dropout_rate=0.0):
        super().__init__()

        # Graph Convolution Layers
        self.gconv_layers = nn.ModuleList()
        in_features = ATOM_DIM
        for units in gconv_units:
            self.gconv_layers.append(
                RelationalGraphConvLayer(in_features, units)
            )
            in_features = units

        # Dense Layers
        self.dense_layers = nn.ModuleList()
        in_features = gconv_units[-1]
        for units in dense_units:
            self.dense_layers.append(nn.Sequential(
                nn.Linear(in_features, units),
                nn.ReLU(),
                nn.Dropout(dropout_rate)
            ))
            in_features = units

        self.z_mean = nn.Linear(dense_units[-1], LATENT_DIM)
        self.z_log_var = nn.Linear(dense_units[-1], LATENT_DIM)

    def forward(self, adjacency, features):
        x = features
        for gconv in self.gconv_layers:
            x = gconv(adjacency, x)

        # Global pooling
        x = torch.mean(x, dim=1)

        for dense in self.dense_layers:
            x = dense(x)

        return self.z_mean(x), self.z_log_var(x)

class Decoder(nn.Module):
    def __init__(self, dense_units, dropout_rate=0.2):
        super().__init__()

        # Dense layers
        self.dense_layers = nn.ModuleList()
        in_features = LATENT_DIM
        for units in dense_units:
            self.dense_layers.append(nn.Sequential(
                nn.Linear(in_features, units),
                nn.Tanh(),
                nn.Dropout(dropout_rate)
            ))
            in_features = units

        # Output layers
        self.adjacency_layer = nn.Sequential(
            nn.Linear(dense_units[-1], BOND_DIM * NUM_ATOMS * NUM_ATOMS),
            nn.Softmax(dim=1)
        )

        self.features_layer = nn.Sequential(
            nn.Linear(dense_units[-1], NUM_ATOMS * ATOM_DIM),
            nn.Softmax(dim=1)
        )

    def forward(self, z):
        x = z
        for dense in self.dense_layers:
            x = dense(x)

        # Reshape outputs
        adjacency = self.adjacency_layer(x)
        adjacency = adjacency.view(-1, BOND_DIM, NUM_ATOMS, NUM_ATOMS)
        # Ensure symmetry
        adjacency = (adjacency + adjacency.transpose(2, 3)) / 2

        features = self.features_layer(x)
        features = features.view(-1, NUM_ATOMS, ATOM_DIM)

        return adjacency, features

In [6]:
class MoleculeVAE(nn.Module):
    def __init__(self, gconv_units=[9], dense_units=[512]):
        super().__init__()
        self.encoder = Encoder(gconv_units, dense_units)
        self.sampling = Sampling()
        self.decoder = Decoder([128, 256, 512])
        self.property_predictor = nn.Linear(LATENT_DIM, 1)

    def forward(self, adjacency, features):
        z_mean, z_log_var = self.encoder(adjacency, features)
        z = self.sampling(z_mean, z_log_var)
        gen_adjacency, gen_features = self.decoder(z)
        property_pred = self.property_predictor(z_mean)

        return z_mean, z_log_var, property_pred, gen_adjacency, gen_features

    def compute_loss(self, z_mean, z_log_var, qed_true, qed_pred,
                    adjacency_real, features_real,
                    adjacency_gen, features_gen):
        # Reconstruction loss
        adjacency_loss = F.cross_entropy(
            adjacency_gen.view(-1, BOND_DIM),
            adjacency_real.view(-1, BOND_DIM)
        )
        features_loss = F.cross_entropy(
            features_gen.view(-1, ATOM_DIM),
            features_real.view(-1, ATOM_DIM)
        )

        # KL divergence
        kl_loss = -0.5 * torch.sum(
            1 + z_log_var - z_mean.pow(2) - z_log_var.exp(),
            dim=1
        ).mean()

        # Property prediction loss
        property_loss = F.binary_cross_entropy_with_logits(
            qed_pred.squeeze(), qed_true
        )

        return adjacency_loss + features_loss + kl_loss + property_loss

def train_model(model, train_loader, optimizer, device, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch_idx, (adjacency, features, qed) in enumerate(train_loader):
            adjacency = adjacency.to(device)
            features = features.to(device)
            qed = qed.to(device)

            optimizer.zero_grad()

            z_mean, z_log_var, qed_pred, gen_adjacency, gen_features = model(
                adjacency, features
            )

            loss = model.compute_loss(
                z_mean, z_log_var, qed, qed_pred,
                adjacency, features,
                gen_adjacency, gen_features
            )

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}')

In [None]:
def main():
    # 데이터 로드
    # 데이터 로드 및 전처리
    df = pd.read_csv(csv_path)
    df["smiles"] = df["smiles"].apply(lambda s: s.replace("\n", ""))
    print("\n데이터 샘플:")
    print(df.head())
    train_df = df.sample(frac=0.75, random_state=42)
    train_df.reset_index(drop=True, inplace=True)

    # 데이터셋 생성
    dataset = MoleculeDataset(train_df)
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

    # 디바이스 설정
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 모델 초기화
    model = MoleculeVAE().to(device)
    optimizer = optim.Adam(model.parameters(), lr=5e-4)

    # 학습
    train_model(model, train_loader, optimizer, device, EPOCHS)

    # 분자 생성
    def generate_molecules(model, n_samples=1000):
        model.eval()
        with torch.no_grad():
            z = torch.randn(n_samples, LATENT_DIM).to(device)
            adjacency, features = model.decoder(z)

            # Convert to numpy for RDKit processing
            adjacency = adjacency.cpu().numpy()
            features = features.cpu().numpy()

            molecules = []
            for i in range(n_samples):
                mol = graph_to_molecule([adjacency[i], features[i]])
                if mol is not None:
                    molecules.append(mol)

            return molecules

if __name__ == "__main__":
    main()


데이터 샘플:
                                              smiles     logP       qed  \
0            CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1  5.05060  0.702012   
1       C[C@@H]1CC(Nc2cncc(-c3nncn3C)c2)C[C@@H](C)C1  3.11370  0.928975   
2  N#Cc1ccc(-c2ccc(O[C@@H](C(=O)N3CCCC3)c3ccccc3)...  4.96778  0.599682   
3  CCOC(=O)[C@@H]1CCCN(C(=O)c2nc(-c3ccc(C)cc3)n3c...  4.00022  0.690944   
4  N#CC1=C(SCC(=O)Nc2cccc(Cl)c2)N=C([O-])[C@H](C#...  3.60956  0.789027   

        SAS  
0  2.084095  
1  3.432004  
2  2.470633  
3  2.822753  
4  4.035182  


  self.adjacency_tensor = torch.FloatTensor(self.adjacency_tensor)


Epoch 1/10, Average Loss: 109.0409
Epoch 2/10, Average Loss: 5.2328
Epoch 3/10, Average Loss: 4.7633
Epoch 4/10, Average Loss: 4.7126
Epoch 5/10, Average Loss: 4.6857
