In [12]:
# DO NOT CHANGE!
!rm -r CH*
!rm -r assignment*
!rm -r practice*
!git clone https://github.com/oneoftwo/KAIST_CH453_AI_chemistry/
!mv ./KAIST_CH453_AI_chemistry/assignments/assignment_4/ ./
!rm -r KAIST_CH453*
!pip install rdkit
import rdkit
print(rdkit.__version__)
!ls

rm: cannot remove 'CH*': No such file or directory
rm: cannot remove 'practice*': No such file or directory
Cloning into 'KAIST_CH453_AI_chemistry'...
remote: Enumerating objects: 1981, done.[K
remote: Counting objects: 100% (1981/1981), done.[K
remote: Compressing objects: 100% (1971/1971), done.[K
remote: Total 1981 (delta 14), reused 1910 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (1981/1981), 5.56 MiB | 11.63 MiB/s, done.
Resolving deltas: 100% (14/14), done.
2024.03.6
assignment_4  README.md  sample_data


In [13]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from rdkit import Chem
from rdkit.Chem.Descriptors import ExactMolWt
from rdkit.Chem.Crippen import MolLogP
from functools import reduce # optional
from tqdm import tqdm # optional
import numpy as np
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

# **1. Define Dataset**

In [14]:
class MolDataset(Dataset):

    def __init__(self, smi_list):
        super().__init__()
        self.smi_list = smi_list

        self._set_c_to_i()
        self._set_i_to_c()
        self.vec_dim = self._get_num_char()

    def __len__(self):
        return len(self.smi_list)

    def __getitem__(self, idx):
        '''
        return a dict of {"input": input, "output": output, "length", length},
        where input is a long tensor of seq. encoded smiles,
        output is a float tensor of corresponding logp value, and
        length is a long tensor of "a length of smiles".
        use self._encode_smi and self._get_logp.
        '''
        sample = dict()
        smi = self.smi_list[idx]

        input = self._encode_smi(smi)
        logp = self._get_logp(smi)

        sample = {
                "input": torch.LongTensor(input),
                "length": torch.LongTensor([len(smi) + 1]), ### +1 due to <EOS>
                "logp": torch.Tensor(logp)
        }
        return sample

    def _set_c_to_i(self):
        '''
        Obtain c_to_i dictionary from smi_list.
        We'll use the characters in self.smi_list, and auxiliary character 'X'.
        '''
        whole_char = ['X'] ##### auxiliary token, padding_value = 0
        whole_char += list(reduce(lambda x, y: x | y, \
                [set(smi) for smi in self.smi_list]))
        c_to_i = {c: i for i, c in enumerate(whole_char)}
        self.c_to_i = c_to_i

    def _set_i_to_c(self):
        self.i_to_c = {v:k for k, v in self.c_to_i.items()}

    def _get_c_to_i(self):
        return self.c_to_i

    def _get_i_to_c(self):
        return self.i_to_c

    def _encode_smi(self, smi):
        return np.array([self.c_to_i[c] for c in smi + 'X'])

    def _get_num_char(self):
        return len(getattr(self, "c_to_i", dict()))

    def _get_logp(self, smi):
        '''
        return a numpy array of logP of given smiles.
        '''
        mol = Chem.MolFromSmiles(smi)
        return np.array([MolLogP(mol)])

In [15]:
def random_splitter(dataset, train_ratio, validation_ratio, test_ratio):
    import random
    import copy
    assert train_ratio + validation_ratio + test_ratio == 1.0
    N = len(dataset)
    all_idx = list(range(N))
    random.shuffle(all_idx)

    train_idx = all_idx[:int(train_ratio * N)]
    valid_idx = all_idx[int(train_ratio * N):int(validation_ratio * N) \
                        + int(train_ratio * N)]
    test_idx = all_idx[int(validation_ratio * N) + int(train_ratio * N):]
    train_dataset = copy.deepcopy(dataset)
    valid_dataset = copy.deepcopy(dataset)
    test_dataset  = copy.deepcopy(dataset)
    train_dataset.smi_list = [dataset.smi_list[i] for i in train_idx]
    valid_dataset.smi_list = [dataset.smi_list[i] for i in valid_idx]
    test_dataset.smi_list =  [dataset.smi_list[i] for i in  test_idx]
    return train_dataset, valid_dataset, test_dataset

In [16]:
def sample_collate_fn(samples):
    '''
    Dataloader will make a list of samples with a len(samples) = batch_size.
    Collate function should pad all the tensors in every sample at maximum size,
    and stack them on a batch dimension.

    Example)
    if four tensors of shape (3, 7), (2, 7), (6, 7), (4, 7) is given,
    collated tensor will have a shape of (4, 6, 7) where 4 is a batch size.
    '''
    inputs = pad_sequence([sample["input"] for sample in samples], \
            batch_first=True, padding_value=0)
    lengths = torch.cat([sample["length"] for sample in samples], dim=0)
    logps = torch.cat([sample["logp"] for sample in samples], dim=0)

    sample_batch = {
            "input": inputs,
            "length": lengths,
            "logp": logps
    }
    return sample_batch

In [17]:
sample_smi_list = ["c1ccccc1", "COCC", "CCCCCCCCCCCCN"]

sample_dataset = MolDataset(sample_smi_list)
print(sample_dataset.c_to_i)
print(sample_dataset.i_to_c)
print(sample_dataset[1])
print(sample_dataset._get_num_char())

{'X': 0, 'c': 1, 'C': 2, 'N': 3, 'O': 4, '1': 5}
{0: 'X', 1: 'c', 2: 'C', 3: 'N', 4: 'O', 5: '1'}
{'input': tensor([2, 4, 2, 2, 0]), 'length': tensor([5]), 'logp': tensor([0.6527])}
6


# **2. Define VAE Model**

In [18]:
class VariationalAutoEncoder(nn.Module):

    def __init__(
            self,
            n_char,
            n_hidden,
            n_rnn_layer
            ):
        super().__init__()

        self.n_char = n_char
        self.n_hidden = n_hidden
        self.n_rnn_layer = n_rnn_layer

        self.sos = nn.Parameter(torch.ones(1, 1, n_hidden))
        self.sos.requires_grad = False

        self.embedding = nn.Embedding(n_char, n_hidden)
        self.encoder = nn.GRU(input_size=n_hidden, hidden_size=n_hidden, \
                          num_layers=n_rnn_layer, batch_first=True)
        self.decoder = nn.GRU(input_size=n_hidden, hidden_size=n_hidden, \
                          num_layers=n_rnn_layer, batch_first=True)
        self.mu_layer = nn.Linear(n_hidden, n_hidden)
        self.logvar_layer = nn.Linear(n_hidden, n_hidden)
        self.final_layer = nn.Linear(n_hidden, n_char)

        self.rec_loss_fn = nn.CrossEntropyLoss()
        self.vae_loss_fn = self.vae_loss

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.rand_like(std)
        return mu + eps * std

    def vae_loss(self, mu, logvar):
        return -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())

    def forward(self, x, length):
        '''
        x (torch.Tensor): [B, L]
        '''
        B = x.shape[0] # batch size

        # 1. Embedding x[B L] -> h[B L F]
        h = self.embedding(x) # [B L F]

        # 2. Encoder h[B L F] -> enc[B L F] -> z[B F]
        enc, _ = self.encoder(h) # [B L F]
        z = torch.stack([enc[i, length[i] - 1, :] for i in range(B)], dim=0) # [B F]

        # 3. Get mu and logvar from z and reparameterize
        mu = self.mu_layer(z)
        logvar = self.logvar_layer(z)
        rep_z = self.reparameterize(mu, logvar)

        # 4. Decoder z[B F], dec_in[B 1+L F] -> dec[B 1+L F]
        sos_vec = self.sos.repeat(B, 1, 1) # [B 1 F]
        dec_in = torch.cat([sos_vec, h], dim=1) # [B 1+L F]
        dec, _ = self.decoder(dec_in, rep_z.unsqueeze(0).repeat(self.n_rnn_layer, 1, 1))

        # 5. Predict the probability of character dec[B 1+L F] -> dec_final[B 1+L N_CHAR]
        dec_final = self.final_layer(dec) # [B 1+L N_CHAR]

        # 6. Calculate loss
        rec_losses = torch.stack(
            [self.rec_loss_fn(dec_final[i, :length[i], :], x[i, :length[i]]) \
             / length[i] for i in range(B)], dim=0
        ) # length due to <SOS>
        rec_loss = rec_losses.mean()
        vae_loss = self.vae_loss_fn(mu, logvar)

        loss = rec_loss + vae_loss
        return loss, rec_loss, vae_loss

    def generate(self, max_length, batch_size=1):
        z = torch.randn(self.n_rnn_layer, batch_size, self.n_hidden).to(self.sos.device)
        dec_in = self.sos.repeat(batch_size, 1, 1) # [B 1 F]

        gen_smis = []
        i = 0
        while i < max_length:
            dec_out, z = self.decoder(dec_in, z)
            dec_final = nn.functional.softmax(self.final_layer(dec_out), dim=-1)
            index = torch.argmax(dec_final)

            if index == 0: break # <eos>
            gen_smis.append(int(index))
            dec_in = self.embedding(index.view(1, -1).long())
            i += 1
        if i == max_length: return [] # generation failed
        return gen_smis

In [19]:
dataloader = DataLoader(sample_dataset, batch_size=3, shuffle=False, collate_fn=sample_collate_fn)
dataiter = iter(dataloader)
batch = next(dataiter)
print(batch)

sample_model = VariationalAutoEncoder(6, 64, 1)
loss, rec_loss, vae_loss = sample_model(batch["input"], batch["length"])
print(loss)
print(rec_loss)
print(vae_loss)

{'input': tensor([[1, 5, 1, 1, 1, 1, 1, 5, 0, 0, 0, 0, 0, 0],
        [2, 4, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 0]]), 'length': tensor([ 9,  5, 14]), 'logp': tensor([1.6866, 0.6527, 3.8660])}
tensor(0.2547, grad_fn=<AddBackward0>)
tensor(0.2255, grad_fn=<MeanBackward0>)
tensor(0.0292, grad_fn=<MulBackward0>)


# **3. Hyperparameter Settings**

In [20]:
########## DO NOT CHANGE ##########
NUM_EPOCH = 25
LR = 1e-3
N_HIDDEN = 256
N_RNN_LAYER = 2
BATCH_SIZE = 256
DATA_DIR = "./assignment_4/PubChem_30K.txt"
####################################

# **4. Build the Dataset**

In [21]:
with open(DATA_DIR, 'r') as f:
    smi_list = [l.strip().split()[1] for l in f.readlines()]
    smi_list = [x for x in smi_list if len(x) > 10 and len(x)< 60]

dataset = MolDataset(smi_list)
train_dataset, valid_dataset, test_dataset = \
        random_splitter(dataset, 0.8, 0.2, 0.0)

N_CHAR = dataset._get_num_char()

print(len(train_dataset))
print(train_dataset[0])

20667
{'input': tensor([57, 45, 14, 57,  6, 54, 43, 10, 15, 10, 10, 10, 14, 45,  6, 45, 10, 13,
        10, 10, 10, 10, 10, 13, 43, 10, 10, 15,  0]), 'length': tensor([29]), 'logp': tensor([3.6946])}


# **5. Build the DataLoader**

In [22]:
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=sample_collate_fn)
valid_dataloader = DataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=sample_collate_fn)
tr_N = len(train_dataset)
va_N = len(valid_dataset)

# **6. Set Model and Optimizer**

In [23]:
model = VariationalAutoEncoder(
            N_CHAR,
            N_HIDDEN,
            N_RNN_LAYER
        )
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
model = model.cuda()

# **7. Train with Mini-batches**

In [None]:
save_dir = "./save/"
if not os.path.exists(save_dir):
    os.mkdir(save_dir)

train_loss_history, valid_loss_history = [], []
best_loss = 1e6
for i in range(1, NUM_EPOCH + 1):

    model.train()
    train_batch_losses = []
    for batch_idx, batch in tqdm(enumerate(train_dataloader), total=tr_N // BATCH_SIZE):
        x_batch = batch["input"].long().cuda()
        l_batch = batch["length"].long().cuda()

        loss, rec_loss, vae_loss = model(x_batch, l_batch)
        train_batch_losses.append(loss.data.cpu().numpy())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        valid_batch_losses = []
        for batch_idx, batch in tqdm(enumerate(valid_dataloader), total=va_N // BATCH_SIZE):
            x_batch = batch["input"].long().cuda()
            l_batch = batch["length"].long().cuda()

            loss, rec_loss, vae_loss = model(x_batch, l_batch)
            valid_batch_losses.append(loss.data.cpu().numpy())

    train_avg_loss = np.mean(np.array(train_batch_losses))
    valid_avg_loss = np.mean(np.array(valid_batch_losses))
    train_loss_history.append(train_avg_loss)
    valid_loss_history.append(valid_avg_loss)

    if valid_avg_loss < best_loss:
        best_epoch = i
        best_loss = valid_avg_loss

    print(f"\t{i}th EPOCH --- TRAIN LOSS: {train_avg_loss:.4f} || VALIDATION LOSS: {valid_avg_loss:.4f} || BEST EPOCH: {best_epoch}", flush=True)

    torch.save(model.state_dict(), os.path.join(save_dir, f"save_{i}.pt"))

81it [00:28,  2.87it/s]
21it [00:04,  4.56it/s]                        

	1th EPOCH --- TRAIN LOSS: 0.0621 || VALIDATION LOSS: 0.0446 || BEST EPOCH: 1



 71%|███████▏  | 57/80 [00:17<00:06,  3.67it/s]

# **8. Plot the Loss Histories**

In [None]:
import matplotlib.pyplot as plt
x_axis = np.arange(NUM_EPOCH)
fig, ax = plt.subplots()
ax.plot(x_axis, train_loss_history, label='train loss')
ax.plot(x_axis, valid_loss_history, label='validation loss')
ax.set_xlabel('Num epoch')
ax.set_ylabel('Loss')
ax.set_title('Loss History')
ax.legend()
fig.show()

# **9. Generate SMILES with VAE**

In [None]:
def index_list_to_smiles(ind_list, i_to_c):
    return "".join([i_to_c[x] for x in ind_list])


save_state_dict = torch.load(os.path.join(save_dir, f"save_{best_epoch}.pt"))
model.load_state_dict(save_state_dict)
model.eval()

max_length = 64
num_sample = 200
i_to_c = dataset._get_i_to_c()

gen_smis = []
for _ in tqdm(range(num_sample), total=num_sample):
    gen_smi = model.generate(max_length)
    gen_smis.append(index_list_to_smiles(gen_smi, i_to_c))

print(gen_smis)

In [None]:
# code for plotting molecules (just for sanitiy check)

print(gen_smis)


filtered_gen_smis = filter_smiles(gen_smis)[0]
print("Number of valid SMILES:", len(filtered_gen_smis))
mols = [Chem.MolFromSmiles(smi) for smi in gen_smis][:10]

from rdkit import Chem
from rdkit.Chem import Draw
from IPython.display import display

mol_list = [Chem.MolFromSmiles(smiles) for smiles in gen_smis]
image_size = (150, 150)  # Adjust the dimensions as needed

# Plot and display the molecules with the specified size
for mol in mol_list[:10]:
    try:
        display(Draw.MolToImage(mol, size=image_size))
    except:
        pass


In [None]:
from rdkit.Chem import Draw, AllChem
from rdkit.Chem.Draw import rdMolDraw2D
from rdkit.Chem.Draw.MolDrawing import DrawingOptions
from IPython.display import SVG

def filter_smiles(smi_list, ref_smi_list=None):

    # 1. Valid smiles?
    valid_smi_list = []
    for smi in smi_list:
        if len(smi) == 0: continue
        try:
            mol = Chem.MolFromSmiles(smi)
            assert mol is not None
        except:
            continue
        valid_smi_list.append(smi)

    # 2. Unique smiles?
    unique_smi_list = []
    for smi in valid_smi_list:
        pass
        ## Implement Here ## TODO

    # 3. Novel smiles?
    novel_smi_list = []
    for smi in unique_smi_list:
        pass
        ## Implement Here ##

    return valid_smi_list, unique_smi_list, novel_smi_list

filtered_gen_smis = filter_smiles(gen_smis)[0]
print("Number of valid SMILES:", len(filtered_gen_smis))
mols = [Chem.MolFromSmiles(smi) for smi in filtered_gen_smis][:10]


from rdkit import Chem
from rdkit.Chem import Draw
from IPython.display import display
image_size = (150, 150)  # Adjust the dimensions as needed
for mol in mols:
    try:
        display(Draw.MolToImage(mol, size=image_size))
    except:
        pass