In [1]:
import torch
from torch import nn
import selfies as sf
# Print torch version
print(torch.__version__)

2.3.0.post301


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cuda")
device

device(type='cuda')

In [3]:
# dmu9 dataset downloaded from https://gdb.unibe.ch/downloads/
# read in all the smiles
smiles_sf = open('data/1to6.dmu.selfies', 'r').read().splitlines()
print(len(smiles_sf))
print(max(len(w) for w in smiles_sf))
print(smiles_sf[-8:])

35466
120
['[N][C][C][C][C][Ring1][Ring2][Ring1][Ring1][C][Ring1][Branch1][Ring1][Ring2][Ring1][Ring1]', '[N][C][C][N][C][Ring1][Ring2][Ring1][Ring1][C][Ring1][Branch1][Ring1][Ring2][Ring1][Ring1]', '[O][C][C][C][C][Ring1][Ring2][Ring1][Ring1][C][Ring1][Branch1][Ring1][Ring2][Ring1][Ring1]', '[O][C][C][N][C][Ring1][Ring2][Ring1][Ring1][C][Ring1][Branch1][Ring1][Ring2][Ring1][Ring1]', '[C][C][C][C][Ring1][Ring2][C][Ring1][Branch1][Ring1][Ring2][C][Ring1][Branch1][Ring1][Ring2][Ring1][Ring1]', '[C][C][N][C][Ring1][Ring2][C][Ring1][Branch1][Ring1][Ring2][C][Ring1][Branch1][Ring1][Ring2][Ring1][Ring1]', '[N][C][N][C][Ring1][Ring2][C][Ring1][Branch1][Ring1][Ring2][C][Ring1][Branch1][Ring1][Ring2][Ring1][Ring1]', '[C][C][C][Ring1][Ring1][C][Ring1][Ring2][C][Ring1][Branch1][Ring1][Ring2][C][Ring1][Branch1][Ring1][Ring2][Ring1][Ring1]']


In [4]:
# build the vocabulary of selfies tokens and mappings to/from integers
alphabet = sf.get_alphabet_from_selfies(smiles_sf)
#chars = sorted(list(set(''.join(smiles))))
stoi = {s:i+1 for i,s in enumerate(alphabet)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
VOCAB_SIZE = len(itos)
print(itos)
print(VOCAB_SIZE)

{1: '[Branch1]', 2: '[N]', 3: '[=O]', 4: '[#C]', 5: '[=C]', 6: '[Ring2]', 7: '[=Branch1]', 8: '[O]', 9: '[=N]', 10: '[Ring1]', 11: '[=Ring1]', 12: '[#N]', 13: '[C]', 0: '.'}
14


In [5]:
# shuffle the smiles
import random
random.seed(42)
random.shuffle(smiles_sf)

In [6]:
SEQ_LEN = 120
EMB_DIM = 20
LAT_DIM = 10

# build the dataset
def build_dataset(smiles_sf):
    X = []
    for s in smiles_sf:
        x = []
        for vocab in list(sf.split_selfies(s)) + ['.']:
            ix = stoi[vocab]
            x.append(ix)
        while len(x) < SEQ_LEN:
            x.append(0)
        X.append(x)
    X = torch.tensor(X)
    print(X.shape)
    return X
n1 = int(0.8 * len(smiles_sf))
n2 = int(0.9 * len(smiles_sf))
Xtr = build_dataset(smiles_sf[:n1])
Xdev = build_dataset(smiles_sf[n1:n2])
Xte = build_dataset(smiles_sf[n2:])

torch.Size([28372, 120])
torch.Size([3547, 120])
torch.Size([3547, 120])


In [7]:
for x in Xtr[-20:]:
    print(''.join(itos[ix.item()] for ix in x), '-->', x.tolist())

[C][C][C][C][=Branch1][C][=C][C][Ring1][Branch1].............................................................................................................. --> [13, 13, 13, 13, 7, 13, 5, 13, 10, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
[C][#C][N][O][C][O][Ring1][Ring2]................................................................................................................ --> [13, 4, 2, 8, 13, 8, 10, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [8]:
# Create Variational Autoencoder model
class VAE_selfies(nn.Module):
    def __init__(self, seq_len = SEQ_LEN, vocab_size=VOCAB_SIZE, emb_dim = EMB_DIM, hidden_dim=100, latent_dim=LAT_DIM):
        super().__init__()
    
        # ecoder    
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=0)  
        self.rnn_emb2hid = nn.GRU(emb_dim, hidden_dim, batch_first=True)   
        self.fc_hid2mean = nn.Linear(hidden_dim, latent_dim)   
        self.fc_hid2logvar = nn.Linear(hidden_dim, latent_dim)  
        
        # decoder
        self.fc_lat2hid = nn.Linear(latent_dim, hidden_dim)  
        self.rnn_hid2emb = nn.GRU(emb_dim, hidden_dim, batch_first=True)  
        self.fc_emb2out = nn.Linear(hidden_dim, vocab_size)   
    
    def encode(self, x):
        x = self.emb(x)   # (B,20,20)
        _, hn_e = self.rnn_emb2hid(x)  # (1, B, 200) 
        hn_e = hn_e.squeeze(0)  # (B, 200)
        mean = self.fc_hid2mean(hn_e)  #   (B,10)
        logvar = self.fc_hid2logvar(hn_e)  #  (B,10)
        return mean, logvar
    
    def reparameterization(self, mean, logvar):
        epsilon = torch.randn_like(logvar)
        z = mean + logvar * epsilon
        return z
    
    def decode(self, z):  # (B, 10)
        hn_d = self.fc_lat2hid(z) # (B, 200)
        hn_d = hn_d.unsqueeze(0)  # (1, B, 200)
        h0 = torch.zeros(z.size(0), SEQ_LEN, EMB_DIM).to(device)  # (B, 20, 20)
        z, _ = self.rnn_hid2emb(h0, hn_d) # (B, 20, 20) 
        x = self.fc_emb2out(z) # (B, 20, 13)
        return x
    
    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterization(mean, logvar)
        x_hat = self.decode(z)
        return x_hat, mean, logvar   

# Loss function
def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.cross_entropy(x_hat, x)
    KLD = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())
    return reproduction_loss + KLD                

In [10]:
batch_size = 64
Xb = torch.randint(0,VOCAB_SIZE,(batch_size,SEQ_LEN)).to(device)
print(Xb.shape)
model = VAE_selfies(seq_len=SEQ_LEN).to(device)
x, mean, logvar = model(Xb)
total_params = sum(p.numel() for p in model.parameters())
print(x.shape, total_params)

torch.Size([64, 120])
torch.Size([64, 120, 14]) 78014


In [12]:
model = VAE_selfies().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
epochs = 210000
batch_size = 64
lossi = []
for epoch in range(epochs):
    # Sample batch
    idx = torch.randint(0, Xtr.shape[0], (batch_size,))
    Xb = Xtr[idx].to(device)
    
    # Train the model
    model.train()
    optimizer.zero_grad()
    x_hat, mean, log_var = model(Xb)
    loss = loss_function(Xb.view(-1), x_hat.view(-1, VOCAB_SIZE), mean, log_var)
    lossi.append(loss.item())
    loss.backward()
    optimizer.step()
    if epoch % 10000 == 0:
        print("\tEpoch", epoch, "\tLoss: ", loss.item())
    
    modelName = 'VAE_selfies'
    if epoch % 50000 == 0 and epoch > 0:
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss.item(),
            'iteration': epoch
        }
        torch.save(checkpoint, f'models/{modelName}_checkpoint_{epoch}.pt')
        print(f'Checkpoint saved at iteration {epoch}')
    if epoch > 100000:
        lr = 1e-4
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr


	Epoch 0 	Loss:  4.053007125854492
	Epoch 10000 	Loss:  0.1456565111875534
	Epoch 20000 	Loss:  0.1493128091096878
	Epoch 30000 	Loss:  0.1483093500137329
	Epoch 40000 	Loss:  0.14494749903678894
	Epoch 50000 	Loss:  0.14369513094425201
Checkpoint saved at iteration 50000
	Epoch 60000 	Loss:  0.1498894840478897
	Epoch 70000 	Loss:  0.14054633677005768
	Epoch 80000 	Loss:  0.14324727654457092
	Epoch 90000 	Loss:  0.14776615798473358
	Epoch 100000 	Loss:  0.1476602554321289
Checkpoint saved at iteration 100000
	Epoch 110000 	Loss:  0.15285947918891907
	Epoch 120000 	Loss:  0.13754811882972717
	Epoch 130000 	Loss:  0.1402951329946518
	Epoch 140000 	Loss:  0.14642149209976196
	Epoch 150000 	Loss:  0.15036843717098236
Checkpoint saved at iteration 150000
	Epoch 160000 	Loss:  0.13958072662353516
	Epoch 170000 	Loss:  0.1451232135295868
	Epoch 180000 	Loss:  0.14306317269802094
	Epoch 190000 	Loss:  0.14286252856254578
	Epoch 200000 	Loss:  0.1486712247133255
Checkpoint saved at iteration 20

In [33]:
from rdkit import Chem
model = VAE_selfies().to(device)
checkpoint = torch.load('models/VAE_selfies_checkpoint_200000.pt')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
latent_dim = LAT_DIM
def generate_selfies(z):
    z_sample = torch.tensor([z], dtype=torch.float).to(device)
    logits = model.decode(z_sample)
    logits = logits.view(-1, VOCAB_SIZE)
    prob = nn.functional.softmax(logits, dim=1)
    #indices = torch.argmax(prob, dim=-1)
    indices = torch.multinomial(prob, num_samples=1).squeeze(-1)
    #print(indices)
    return ''.join(itos[ix.item()] for ix in indices).replace('.','')
    #return indices
samp = torch.randn(latent_dim)
#samp = torch.tensor([0,0])
print(samp)
gen_selfies = generate_selfies(samp.tolist())
print(gen_selfies)
gen_smiles = sf.decoder(gen_selfies)
print(gen_smiles, Chem.MolToSmiles(Chem.MolFromSmiles(gen_smiles)))

tensor([-0.9199, -0.7488,  0.5164, -0.0688, -0.5017, -0.9367, -1.6985,  0.0731,
         0.5540, -0.8736])
[O][C][C][C][=N][C][C][=N][C][Ring2][Ring1][Ring2]
O1CCC=NCC=NC1 C1=NCC=NCOCC1


In [36]:
# Suppress RDKit warnings
from rdkit import RDLogger
lg = RDLogger.logger()
lg.setLevel(RDLogger.CRITICAL)

uniqueList = []
validCount = 0
for i in range(10000):
    samp = torch.rand(LAT_DIM)
    smi_sf = generate_selfies(samp.tolist())
    smi = sf.decoder(smi_sf)
    if Chem.MolFromSmiles(smi):
        validCount+=1
        canon_smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi))
        if canon_smi not in uniqueList:
            uniqueList.append(canon_smi)

print(f'Valid SMILES: {validCount}/10000')
print(f'Unique SMILES: {len(uniqueList)}/{validCount}')

Valid SMILES: 10000/10000
Unique SMILES: 6348/10000


In [35]:
print(uniqueList)

['CN=NC#N', 'C1=CNOC1', 'C1=CN=CONC1', 'O=CC=CCC=NCC=O', 'C=NN=NN', 'C=CN=C=O', 'NNN=CN=O', '', 'CC=NNCOCO', '[nH]1oo1', 'C#N', 'C=CONC#CC', 'CN=CC1=NC1', 'N=CNCN', 'C1#CN=C1', 'NNCCON', 'NC1NN1', 'CON=C=O', 'NN1CC=N1', 'C1=CC=NNN=1', 'CC=CN=C=N', 'N#CNCN=N', 'CN1CN1', 'C1=CNC=1', 'N=CC=O', 'C1NNO1', 'CCOC=O', 'C1C2OC12', 'CNNCNNCNN=N', 'NC=O', 'NNCN1C=N1', 'C1=NNN2NC12', 'C1CONO1', 'C1=NOCNN1', 'C=NC=CONN', 'C=NNNCN', 'C#CC=NCNN=N', 'NC1=C2CC12', 'C=CC', 'C1=NNC=NO1', 'C=NC=CO', 'C1=CNNN=N1', 'N#CCCC=N', 'C1=CNON=C1', 'C1=NCC=NC1', 'ONCC1=C=CNO1', 'C1=CCC=NC=1', 'c1c2[nH]n1-2', 'C1#COCN1', 'CN=C=CO', 'CCN=C=O', 'O=CN1CC1', 'C=O', 'C1=NNCCCCC1', 'OON=C1OC1O', 'N#CON', 'OCC=NOO', 'N=CC=NN', 'C1CNNON=NN1', 'C1=COC1', 'NN=CNCNC=O', 'C1=CNOC=NNCN=1', 'N=C=CON=O', 'C1#CCOC1', 'C1C2CN12', 'C1N=NN=N1', 'O=c1c#c1', 'CCC(C)O', 'CC=C=NN=COO', 'NN1CN1', 'N=C1CC1OO', 'C=NCOO', 'C#CCCNN', 'CCOCOC', 'N=NC=NN', 'C#CN', 'NOCCCN=O', 'NC1CC=N1', 'C1NCO1', 'C=NCCCNC', 'C=CN=CNNC', 'COONOONN=CN', 'N#N', '