In [3]:
# # Making the tokenizer
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
# tokenizer = Tokenizer(WordLevel(vocab = dictionary, unk_token="[UNK]"))
# tokenizer.save("tokenizer.json")

tokenizer = Tokenizer.from_file('tokenizer_from_USPTO_MIT.json')

# Modifying the tokenizer
from tokenizers.pre_tokenizers import WhitespaceSplit
tokenizer.pre_tokenizer = WhitespaceSplit()
tokenizer.enable_padding(pad_id=2)

In [4]:
SRC_VOCAB_SIZE = 298
TGT_VOCAB_SIZE = 298
EMB_SIZE = 256
NHEAD = 8
FFN_HID_DIM = 2048
BATCH_SIZE = 64
NUM_ENCODER_LAYERS = 4
NUM_DECODER_LAYERS = 4
NUM_EPOCHS = 150
LEARNING_RATE = 0.0001
DEVICE_NAME = 'cuda:0'

In [97]:
## Functions

# In the data_files={...}, specify train, val, and test
from datasets import load_dataset
from torch.utils.data import Dataset
from typing import NamedTuple, List


class Batch(NamedTuple):
    src_text: List[str]
    tgt_text: List[str]

class InputDataset(Dataset):
    def __init__(self, train_data, target_data):
        self.train_data = train_data
        self.target_data = target_data

    def __len__(self):
        return len(self.target_data)
    
    def __getitem__(self, index):
        return Batch(
            src_text=self.train_data[index]["text"], 
            tgt_text=self.target_data[index]["text"]
        )


# Now we have the tokenizer, we need to define the collate function
# The collate function will be called by the DataLoader

import torch
from torch.utils.data.dataloader import default_collate
#from torch.utils.data import default_collate

class EncodedBatch(NamedTuple):
    src_sequence_ids: torch.tensor
    src_attention_mask: torch.tensor
    tgt_sequence_ids: torch.tensor
    tgt_attention_mask: torch.tensor


def collate_fn(batch:Batch):
    batch = default_collate(batch)
    encoded_src = tokenizer.encode_batch(batch.src_text)
    encoded_tgt = tokenizer.encode_batch(batch.tgt_text)
    src_sequence_ids = torch.tensor([elem.ids for elem in encoded_src]).T
    tgt_sequence_ids = torch.tensor([elem.ids for elem in encoded_tgt]).T
    src_attention_mask = torch.tensor([elem.attention_mask for elem in encoded_src]).T
    tgt_attention_mask = torch.tensor([elem.attention_mask for elem in encoded_tgt]).T
    encoded_batch = EncodedBatch(src_sequence_ids, src_attention_mask, tgt_sequence_ids, tgt_attention_mask)
    return encoded_batch


from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device(DEVICE_NAME if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# MechTransformer Network
class MechTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(MechTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor = None):
        src_emb = self.src_tok_emb(src)
        tgt_emb = self.tgt_tok_emb(trg)
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(
                            self.src_tok_emb(src), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(
                          self.tgt_tok_emb(tgt), memory,
                          tgt_mask)




# Make the masks

def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    PAD_IDX = 2

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask


torch.manual_seed(0)


transformer = MechTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE,
                                 NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)




# Define Train,Val,Decode,Test

from torch.utils.data import DataLoader



def greedy_decode(model, input_sequence):
    model.eval()
    with torch.no_grad():
        # Encode the input sequence
        src = tokenizer.encode(input_sequence).ids

        src = torch.tensor(src).to(DEVICE)
        src = src.unsqueeze(1)
        memory = model.encode(src, src_mask=None)

        # Prepare the initial input for the decoder
        decoder_input = torch.tensor([0]).to(DEVICE)
        decoder_input = decoder_input.unsqueeze(1)

        output_sequence = []
        MAX_LENGTH = 100
        for _ in range(MAX_LENGTH):
            # Generate the next token
            output = model.decode(decoder_input, memory, tgt_mask=None)
            output = model.generator(output)
            next_token = output.argmax(2)[-1, :].item()
            output_sequence.append(next_token)

            if next_token == 1:
                break

            # Prepare the next input for the decoder
            decoder_input = torch.cat([decoder_input, torch.tensor([[next_token]]).to(DEVICE)], 0)

    return memory, tokenizer.decode(output_sequence)




def beam_search(model, input_sequence, beam_size):
    model.eval()

    k = beam_size
    iters = 100

    with torch.no_grad():
        # Encode the input sequence
        src = tokenizer.encode(input_sequence).ids
        src = torch.tensor(src).to(DEVICE)
        src = src.unsqueeze(1)
        memory = transformer.encode(src, src_mask=None)

    decoder_input = [(0, torch.tensor([0]).to(DEVICE))]
    decoder_output = []

    for _ in range(iters):
        decoder_input = sorted(decoder_input, reverse = True)
        decoder_input = decoder_input[:k]

        incomplete_sequence_exists = False
        for _, tokens in decoder_input:
            if torch.all(tokens != 1):
                incomplete_sequence_exists = True
        
        if not incomplete_sequence_exists:
            break

        for log_prob, inp in decoder_input:
            # print(inp.shape)
            output = transformer.decode(inp[:, None], memory, tgt_mask=None)
            output = transformer.generator(output)
            output = nn.functional.log_softmax(output,dim=-1)
            top_values, top_indices = output.topk(k)
            # print(top_indices.shape)
            top_indices = top_indices[-1].T
            top_values = top_values[-1].T
            for new_log_p, new_index in zip(top_values, top_indices):
                decoder_output.append((
                    (log_prob+new_log_p).item(),
                    torch.cat([inp,new_index], dim=0)
                ))
        decoder_input = decoder_output
        decoder_output = []

    final_outputs = []
    for log_prob, tokens in sorted(decoder_input, reverse = True):
        tokens = tokenizer.decode(tokens.tolist())
        final_outputs.append(tokens)

    return final_outputs



In [98]:
transformer.load_state_dict(torch.load('trained_models/USPTO_MIT_OCT_4.pth', map_location = torch.device('cpu')))
transformer.eval()

MechTransformer(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerDecoderLayer(
          (self_attn): MultiheadAttenti

In [99]:
import rdkit
from rdkit import Chem
from tqdm import tqdm
# from rpCHEM.Common.Util import smi_to_unique_smi_fast, smi_to_unique_smi_map, exact_mass

def canonicalize_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return Chem.MolToSmiles(mol, isomericSmiles=True)
    else:
        return ''
    
# def amin_canonicalize_smiles(smiles):
#     return smi_to_unique_smi_fast(smiles)






def compute_accuracy_greedy(source,target):

    with open(source) as source_file:
        source_lines = source_file.readlines()

    with open(target) as target_file:
        target_lines = target_file.readlines()

    matches = 0

    for source_line, target_line in zip(source_lines, target_lines[:100]):
        tgt = target_line.strip()
        tgt = tgt.replace(" ", "")
        tgt = canonicalize_smiles(tgt)

        source_line = '[START] ' + source_line.strip() + ' [END]'
        prediction = greedy_decode(transformer, source_line)
        prediction = prediction.replace("[END]", "")
        prediction = prediction.replace("[START]", "")
        prediction = prediction.strip()
        prediction = prediction.replace(" ", "")
        prediction = canonicalize_smiles(prediction)

        if target == prediction:
                matches = matches + 1
                #print("target is " + target)
                #print("prediction is " + prediction)

    return matches / float(len(target_lines))



def compute_accuracy_beam(source,target,beam_size):

    with open(source) as source_file:
        source_lines = source_file.readlines()

    with open(target) as target_file:
        target_lines = target_file.readlines()

    matches = 0

    for source_line, target_line in zip(source_lines, target_lines[:100]):
        tgt = target_line.strip()
        tgt = tgt.replace(" ", "")
        tgt = canonicalize_smiles(tgt)

        source_line = '[START] ' + source_line.strip() + ' [END]'
        predictions = beam_search(transformer, source_line,beam_size)
        correct = False

        for prediction in predictions:
            prediction = prediction.replace("[END]", "")
            prediction = prediction.replace("[START]", "")
            prediction = prediction.strip()
            prediction = prediction.replace(" ", "")
            prediction = canonicalize_smiles(prediction)

            if target == prediction:
                    correct = True
        if correct:
            matches = matches + 1

    return matches / float(len(target_lines))

In [100]:
m1, v1 = greedy_decode(transformer,'[START] C O [END]')

# C C C C C ( C O ) c 1 c c c ( C ) c c 1 C

In [101]:
m2, v2 = greedy_decode(transformer,'[START] O C [END]')

In [102]:
m1, m2

(tensor([[[-0.1278,  0.7598,  1.1653,  ..., -0.0565,  0.5364,  0.7407]],
 
         [[-1.3394, -0.1681,  0.7162,  ...,  1.1068, -0.0721, -1.1558]],
 
         [[-0.3345, -0.2282,  0.5544,  ...,  1.7588, -0.4460, -0.1879]],
 
         [[-0.6921,  0.7790, -0.1657,  ...,  1.6773,  0.0471, -0.1246]]]),
 tensor([[[-0.1278,  0.7598,  1.1653,  ..., -0.0565,  0.5364,  0.7407]],
 
         [[-0.3345, -0.2282,  0.5544,  ...,  1.7588, -0.4460, -0.1879]],
 
         [[-1.3394, -0.1681,  0.7162,  ...,  1.1068, -0.0721, -1.1558]],
 
         [[-0.6921,  0.7790, -0.1657,  ...,  1.6773,  0.0471, -0.1246]]]))

In [25]:
beam_search(transformer,'[START] C C C C C 1 C O 1 . C c 1 c c c c ( C ) c 1 > Cl . [Cl-] [END]',5)

# C C C C C ( C O ) c 1 c c c ( C ) c c 1 C

KeyboardInterrupt: 

In [9]:
compute_accuracy_greedy("/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/src-test.txt","/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/tgt-test.txt")

[12:54:49] SMILES Parse Error: extra close parentheses while parsing: CO)=O.c1cc2[nH]ccc2cc1CCO
[12:54:49] SMILES Parse Error: Failed parsing SMILES 'CO)=O.c1cc2[nH]ccc2cc1CCO' for input: 'CO)=O.c1cc2[nH]ccc2cc1CCO'
[12:54:49] SMILES Parse Error: unclosed ring for input: 'O=c1cnc2ccc(F)cn2CC(CN3CCC4O)ccc3c(cn2)CN(C2)CO2'
[12:54:49] SMILES Parse Error: unclosed ring for input: 'C=C(C)Oc1ccc(N2C(=O)C(=C(C)C)sc2-c2ccccc2)C(=O)N2'
[12:54:50] Can't kekulize mol.  Unkekulized atoms: 3 4 5 6 12 13 14 15 25 26 28
[12:54:50] SMILES Parse Error: ring closure 1 duplicates bond between atom 1 and atom 2 for input: 'Fc1-c1c(CN2CCCCC2)oc(-c3ccc(O)c2)N1CCCCC2'
[12:54:50] SMILES Parse Error: extra open parentheses for input: 'Cc1ccc(NC2CS2)c1nc(NCc1ccc(C)o1'
[12:54:50] SMILES Parse Error: extra close parentheses while parsing: CCCCCC(C)C1CCC2CC=C2CC(OCCCCCCC(=N)C)C)CCC1
[12:54:50] SMILES Parse Error: Failed parsing SMILES 'CCCCCC(C)C1CCC2CC=C2CC(OCCCCCCC(=N)C)C)CCC1' for input: 'CCCCCC(C)C1CCC2CC=C2CC

0.0

In [8]:
compute_accuracy_beam("/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/src-test.txt","/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/tgt-test.txt",5)

[12:53:37] SMILES Parse Error: extra close parentheses while parsing: Nc1cc(F)ccc1C#NO)cc(
[12:53:37] SMILES Parse Error: Failed parsing SMILES 'Nc1cc(F)ccc1C#NO)cc(' for input: 'Nc1cc(F)ccc1C#NO)cc('
[12:53:37] SMILES Parse Error: extra close parentheses while parsing: Nc1cc(F)ccc1C#NF)cc(
[12:53:37] SMILES Parse Error: Failed parsing SMILES 'Nc1cc(F)ccc1C#NF)cc(' for input: 'Nc1cc(F)ccc1C#NF)cc('
[12:53:37] SMILES Parse Error: extra close parentheses while parsing: Nc1cc(F)ccc1C#NO)cc(
[12:53:37] SMILES Parse Error: Failed parsing SMILES 'Nc1cc(F)ccc1C#NO)cc(' for input: 'Nc1cc(F)ccc1C#NO)cc('
[12:53:38] SMILES Parse Error: extra close parentheses while parsing: CO)=C.c1cn(-c2ccc(C(F)(F)F)cc2
[12:53:38] SMILES Parse Error: Failed parsing SMILES 'CO)=C.c1cn(-c2ccc(C(F)(F)F)cc2' for input: 'CO)=C.c1cn(-c2ccc(C(F)(F)F)cc2'
[12:53:38] SMILES Parse Error: extra close parentheses while parsing: COC)c1cn(C(F)(F)Fccc2n(C)c1cc
[12:53:38] SMILES Parse Error: Failed parsing SMILES 'COC)c1cn(C(F

0.0