In [1]:
import torch
from torch import nn
# Print torch version
print(torch.__version__)

2.3.0.post301


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cuda")
device

device(type='cuda')

In [3]:
# Load SMILES data
smiles = open('data/1to6.dmu.smi','r').read().splitlines()
smiles
print(len(smiles))
max_len = max(len(w) for w in smiles)
print(max_len)
print(smiles[:8])

35466
20
['C', 'N', 'O', 'C#C', 'C#N', 'N#N', 'C=C', 'C=N']


In [4]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(smiles))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: '#', 2: '(', 3: ')', 4: '1', 5: '2', 6: '3', 7: '4', 8: '5', 9: '=', 10: 'C', 11: 'N', 12: 'O', 0: '.'}
13


In [5]:
# shuffle the smiles
import random
random.seed(42)
random.shuffle(smiles)

In [6]:
SEQ_LEN = 20
EMB_DIM = 20
LAT_DIM = 10

# build the dataset
def build_dataset(smiles):
    X = []
    for s in smiles:
        x = []
        for ch in s:
            ix = stoi[ch]
            x.append(ix)
        while len(x) < SEQ_LEN:
            x.append(0)
        X.append(x)
    X = torch.tensor(X)
    print(X.shape)
    return X
n1 = int(0.8 * len(smiles))
n2 = int(0.9 * len(smiles))
Xtr = build_dataset(smiles[:n1])
Xdev = build_dataset(smiles[n1:n2])
Xte = build_dataset(smiles[n2:])

torch.Size([28372, 20])
torch.Size([3547, 20])
torch.Size([3547, 20])


In [7]:
for x in Xtr[-20:]:
    print(''.join(itos[ix.item()] for ix in x), '-->', x.tolist())

CC1CC(=C)C1......... --> [10, 10, 4, 10, 10, 2, 9, 10, 3, 10, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0]
C#CN1OCO1........... --> [10, 1, 10, 11, 4, 12, 10, 12, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
N=CC=CN=O........... --> [11, 9, 10, 10, 9, 10, 11, 9, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
N=C=C1ON=C1......... --> [11, 9, 10, 9, 10, 4, 12, 11, 9, 10, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0]
O=C1NN=NO1.......... --> [12, 9, 10, 4, 11, 11, 9, 11, 12, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
CCN(N)NN............ --> [10, 10, 11, 2, 11, 3, 11, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
C1NN2ONC12.......... --> [10, 4, 11, 11, 5, 12, 11, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
ONC1=C(O)N1......... --> [12, 11, 10, 4, 9, 10, 2, 12, 3, 11, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0]
NN1C2NC12........... --> [11, 11, 4, 10, 5, 11, 10, 4, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
NC(=O)N=C=N......... --> [11, 10, 2, 9, 12, 3, 11, 9, 10, 9, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0]
CC#CNCN............. --> [10, 10, 1, 10, 11, 10, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [8]:
# Create Variational Autoencoder model
class VAE_smiles(nn.Module):
    def __init__(self, seq_len = SEQ_LEN, vocab_size=13, emb_dim = EMB_DIM, hidden_dim=100, latent_dim=LAT_DIM):
        super().__init__()
    
        # ecoder    
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)  
        self.rnn_emb2hid = nn.GRU(emb_dim, hidden_dim, batch_first=True)   
        self.fc_hid2mean = nn.Linear(hidden_dim, latent_dim)   
        self.fc_hid2logvar = nn.Linear(hidden_dim, latent_dim)  
        
        # decoder
        self.fc_lat2hid = nn.Linear(latent_dim, hidden_dim)  
        self.rnn_hid2emb = nn.GRU(emb_dim, hidden_dim, batch_first=True)  
        self.fc_emb2out = nn.Linear(hidden_dim, vocab_size)   
    
    def encode(self, x):
        x = self.emb(x)   # (B,20,20)
        _, hn_e = self.rnn_emb2hid(x)  # (1, B, 200) 
        hn_e = hn_e.squeeze(0)  # (B, 200)
        mean = self.fc_hid2mean(hn_e)  #   (B,10)
        logvar = self.fc_hid2logvar(hn_e)  #  (B,10)
        return mean, logvar
    
    def reparameterization(self, mean, logvar):
        epsilon = torch.randn_like(logvar)
        z = mean + logvar * epsilon
        return z
    
    def decode(self, z):  # (B, 10)
        hn_d = self.fc_lat2hid(z) # (B, 200)
        hn_d = hn_d.unsqueeze(0)  # (1, B, 200)
        h0 = torch.zeros(z.size(0), SEQ_LEN, EMB_DIM).to(device)  # (B, 20, 20)
        z, _ = self.rnn_hid2emb(h0, hn_d) # (B, 20, 20) 
        x = self.fc_emb2out(z) # (B, 20, 13)
        return x
    
    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterization(mean, logvar)
        x_hat = self.decode(z)
        return x_hat, mean, logvar   

# Loss function
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.cross_entropy(x_hat, x)
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
    return reproduction_loss + KLD                

In [9]:
batch_size = 64
Xb = torch.randint(0,vocab_size,(batch_size,SEQ_LEN)).to(device)
print(Xb.shape)
#model = VAE_smiles_WaveNet(seq_len=SEQ_LEN)
model = VAE_smiles(seq_len=SEQ_LEN).to(device)
x, _, _ = model(Xb)
total_params = sum(p.numel() for p in model.parameters())
print(x.shape, total_params)

torch.Size([64, 20])
torch.Size([64, 20, 13]) 77893


In [10]:
model = VAE_smiles().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
epochs = 310000
batch_size = 64
lossi = []
for epoch in range(epochs):
    # Sample batch
    idx = torch.randint(0, Xtr.shape[0], (batch_size,))
    Xb = Xtr[idx].to(device)
    
    # Train the model
    model.train()
    optimizer.zero_grad()
    x_hat, mean, log_var = model(Xb)
    loss = loss_function(Xb.view(-1), x_hat.view(-1, vocab_size), mean, log_var)
    lossi.append(loss.item())
    loss.backward()
    optimizer.step()
    if epoch % 10000 == 0:
        print("\tEpoch", epoch, "\tLoss: ", loss.item())
    
    modelName = 'VAE_smiles'
    if epoch % 50000 == 0 and epoch > 0:
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
            'iteration': epoch
        }
        torch.save(checkpoint, f'models/{modelName}_checkpoint_{epoch}.pt')
        print(f'Checkpoint saved at iteration {epoch}')
    if epoch > 150000:
        lr = 1e-4
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


  from .autonotebook import tqdm as notebook_tqdm


	Epoch 0 	Loss:  5.747844696044922
	Epoch 10000 	Loss:  0.8063761591911316
	Epoch 20000 	Loss:  0.6139052510261536
	Epoch 30000 	Loss:  0.6027733087539673
	Epoch 40000 	Loss:  0.5145543813705444
	Epoch 50000 	Loss:  0.48320817947387695
Checkpoint saved at iteration 50000
	Epoch 60000 	Loss:  0.41425931453704834
	Epoch 70000 	Loss:  0.4497325122356415
	Epoch 80000 	Loss:  0.3373427093029022
	Epoch 90000 	Loss:  0.3248225152492523
	Epoch 100000 	Loss:  0.2829974293708801
Checkpoint saved at iteration 100000
	Epoch 110000 	Loss:  0.2786678075790405
	Epoch 120000 	Loss:  0.2688663601875305
	Epoch 130000 	Loss:  0.23895548284053802
	Epoch 140000 	Loss:  0.5180830955505371
	Epoch 150000 	Loss:  0.2092927247285843
Checkpoint saved at iteration 150000
	Epoch 160000 	Loss:  0.20241442322731018
	Epoch 170000 	Loss:  0.1529400795698166
	Epoch 180000 	Loss:  0.14882642030715942
	Epoch 190000 	Loss:  0.13575328886508942
	Epoch 200000 	Loss:  0.13754743337631226
Checkpoint saved at iteration 200000


In [72]:
from rdkit import Chem
model = VAE_smiles().to(device)
checkpoint = torch.load('models/VAE_smiles_checkpoint_300000.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
latent_dim = LAT_DIM
def generate_smiles(z):
    z_sample = torch.tensor([z], dtype=torch.float).to(device)
    logits = model.decode(z_sample)
    logits = logits.view(-1, vocab_size)
    prob = nn.functional.softmax(logits, dim=1)
    indices = torch.argmax(prob, dim=-1)
    #indices = torch.multinomial(prob, num_samples=1).squeeze(-1)
    #print(indices)
    return ''.join(itos[ix.item()] for ix in indices).replace('.','')
    #return indices
samp = torch.randn(latent_dim)
#samp = torch.tensor([0,0])
print(samp)
gen_smiles = generate_smiles(samp.tolist())
print(gen_smiles, Chem.MolToSmiles(Chem.MolFromSmiles(gen_smiles)))

tensor([-0.4861,  1.5379,  0.5911, -0.6273, -0.3451,  0.9384, -0.6507,  0.9511,
         1.0580,  2.8108])
CCCCCCCCCCCCCCCCCCCC CCCCCCCCCCCCCCCCCCCC


In [78]:
# Suppress RDKit warnings
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

uniqueList = []
validCount = 0
for i in range(10000):
    samp = torch.rand(LAT_DIM)
    smi = generate_smiles(samp.tolist())
    if Chem.MolFromSmiles(smi):
        validCount+=1
        canon_smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi))
        if canon_smi not in uniqueList:
            uniqueList.append(canon_smi)

print(f'Valid SMILES: {validCount}/10000')
print(f'Unique SMILES: {len(uniqueList)}/{validCount}')

Valid SMILES: 1728/10000
Unique SMILES: 248/1728


In [79]:
print(uniqueList)

['CCCCCCCCC', 'CC#CCCC', 'CCCC', 'C#CCCCCCCCCCCCCC', 'CCCCC', 'CCCCCCCCCCCCCCCCCCCC', 'C#CC', 'CC#CC', 'CCCCNNNNNNNNNNNNNNNN', 'C#CCC', 'CCNCC', 'CC', 'C1CCC1', 'C1CC1', 'C=CCC', 'CC=CCCCCCC', 'CCOCC', 'CCC', 'CCCCCCCC', 'CCCCCC', 'CC#CCC', 'CCNC', 'CCC1CCCC1', 'CCCCCCC', 'CCCC1CC1', 'CN', 'C1CO1', 'CCCCCNN', 'CCCCCCCNCC', 'C=CC#CC', 'CC#CCCCCCC', 'CCCOC', 'C=CCNON', 'C#CCCC', 'CC=CCC', 'CCCCCCNC', 'C#CCCCCC', 'CCCCCCNNNNNNNNNNNNNC', 'CCCCCNNC', 'CC#CCCCCCCC', 'CCCCCCNNC', 'CCCCNC', 'C=CCNNCC', 'CCCCCNC', 'C#CCCCC', 'CCC1CCC1', 'C1COONO1', 'CC1CC1', 'C=CCCCCCCCCCCCCCCCCC', 'CCCCCCCCNC', 'CCC1CCCCNC1', 'CC#N', 'C1=CCCCC1', 'CC#CCCCCC', 'CC#CCCCC', 'CCNNCC1OO1', 'CCCNN', 'CCCC1CCC1', 'C', 'CCCNCCNCC', 'CC#CCCCCCCCCCCCC', 'C#CCCCCCC', 'C#CCCCCCCCCCCCCCCCCC', 'CCCCCCCNN', 'C#C', 'CCC1C=CC1', 'CCCC#CN', 'CCCCCCCCCCCCCCCCCCC', 'CCOO', 'OOOOOOOOOOOOOOOOOOOO', 'C1CC12CC2', 'CCCCNN', 'CCOC', 'CCCCCC1CC1', 'CCCCCCCCOC', 'CCCCCNNNNNNCCCCNNNNC', 'CC1CCN1', 'C#N', 'C#CCCC=C', 'C=CCCCCC', 'CON', 'CC