In [5]:
import yaml
import torch
from pathlib import Path
from src.data import load

with open("config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

    # Update config
    config['data_path'] = Path('data/' + config['data_set'])
    if not torch.cuda.is_available():
        config['device'] = 'cpu'

train_data, val_data, tokenizer = load(config['data_path'], config['batch_size'])

config['src_vocab_size'] = config['tgt_vocab_size'] = tokenizer.get_vocab_size()
print(config['src_vocab_size'])


DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 818070
    })
    val: Dataset({
        features: ['text'],
        num_rows: 30000
    })
})
298


In [6]:
from src.models.transformer import TransformerModel
transformer = TransformerModel(
    config['num_encoder_layers'],
    config['num_decoder_layers'],
    config['emb_size'],
    config['nhead'],
    config['src_vocab_size'],
    config['tgt_vocab_size'],
    config['ffn_hid_dim'],
)



In [10]:
transformer.load_state_dict(torch.load('src/models/transformer/pretrained/<modelname>.pth', map_location = torch.device(config['device'])))
transformer.eval()

FileNotFoundError: [Errno 2] No such file or directory: 'src/models/transformer/pretrained/<modelname>.pth'

In [29]:
import rdkit
from rdkit import Chem
from tqdm import tqdm
# from rpCHEM.Common.Util import smi_to_unique_smi_fast, smi_to_unique_smi_map, exact_mass

def canonicalize_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return Chem.MolToSmiles(mol, isomericSmiles=True)
    else:
        return ''
#     
# def canonicalize_smiles(smiles):
#     return smi_to_unique_smi_fast(smiles)


def compute_accuracy_greedy(source,target):

    with open(source) as source_file:
        source_lines = source_file.readlines()

    with open(target) as target_file:
        target_lines = target_file.readlines()

    matches = 0

    for source_line, target_line in zip(source_lines, target_lines[:100]):
        tgt = target_line.strip()
        tgt = tgt.replace(" ", "")
        tgt = canonicalize_smiles(tgt)

        source_line = '[START] ' + source_line.strip() + ' [END]'
        prediction = greedy_decode(transformer, source_line)
        prediction = prediction.replace("[END]", "")
        prediction = prediction.replace("[START]", "")
        prediction = prediction.strip()
        prediction = prediction.replace(" ", "")
        prediction = canonicalize_smiles(prediction)

        if target == prediction:
                matches = matches + 1
                #print("target is " + target)
                #print("prediction is " + prediction)

    return matches / float(len(target_lines))



def compute_accuracy_beam(source,target,beam_size):

    with open(source) as source_file:
        source_lines = source_file.readlines()

    with open(target) as target_file:
        target_lines = target_file.readlines()

    matches = 0

    for source_line, target_line in zip(source_lines, target_lines[:100]):
        tgt = target_line.strip()
        tgt = tgt.replace(" ", "")
        tgt = canonicalize_smiles(tgt)

        source_line = '[START] ' + source_line.strip() + ' [END]'
        predictions = beam_search(transformer, source_line,beam_size)
        correct = False

        for prediction in predictions:
            prediction = prediction.replace("[END]", "")
            prediction = prediction.replace("[START]", "")
            prediction = prediction.strip()
            prediction = prediction.replace(" ", "")
            prediction = canonicalize_smiles(prediction)

            if target == prediction:
                    correct = True
        if correct:
            matches = matches + 1

    return matches / float(len(target_lines))

In [30]:
def greedy_decode(model, input_sequence):
    model.eval()
    with torch.no_grad():
        # Encode the input sequence
        src = tokenizer.encode(input_sequence).ids
        src = torch.tensor(src).to(config['device'])
        src = src.unsqueeze(1)
        memory = model.encode(src, src_mask=None)

        # Prepare the initial input for the decoder
        decoder_input = torch.tensor([0]).to(config['device'])
        decoder_input = decoder_input.unsqueeze(1)

        output_sequence = []
        MAX_LENGTH = 100
        for _ in range(MAX_LENGTH):
            # Generate the next token
            output = model.decode(decoder_input, memory, tgt_mask=None)
            output = model.generator(output)
            next_token = output.argmax(2)[-1, :].item()
            output_sequence.append(next_token)

            if next_token == 1:
                break

            # Prepare the next input for the decoder
            decoder_input = torch.cat([decoder_input, torch.tensor([[next_token]]).to(config['device'])], 0)

    return tokenizer.decode(output_sequence)

In [31]:
canonicalize_smiles("CNCC(C1=CC(=C(C=C1)O)O)O")

'CNCC(O)c1ccc(O)c(O)c1'

In [26]:
greedy_decode(transformer,'[START] C O [END]')

# C C C C C ( C O ) c 1 c c c ( C ) c c 1 C

'[B+] [B+] [B+] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [B+] [B+] [B+] [B+] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [B+] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2] [Zn+2]'

In [101]:
m2, v2 = greedy_decode(transformer,'[START] O C [END]')

In [102]:
m1, m2

(tensor([[[-0.1278,  0.7598,  1.1653,  ..., -0.0565,  0.5364,  0.7407]],
 
         [[-1.3394, -0.1681,  0.7162,  ...,  1.1068, -0.0721, -1.1558]],
 
         [[-0.3345, -0.2282,  0.5544,  ...,  1.7588, -0.4460, -0.1879]],
 
         [[-0.6921,  0.7790, -0.1657,  ...,  1.6773,  0.0471, -0.1246]]]),
 tensor([[[-0.1278,  0.7598,  1.1653,  ..., -0.0565,  0.5364,  0.7407]],
 
         [[-0.3345, -0.2282,  0.5544,  ...,  1.7588, -0.4460, -0.1879]],
 
         [[-1.3394, -0.1681,  0.7162,  ...,  1.1068, -0.0721, -1.1558]],
 
         [[-0.6921,  0.7790, -0.1657,  ...,  1.6773,  0.0471, -0.1246]]]))

In [27]:
beam_search(transformer,'[START] C C C C C 1 C O 1 . C c 1 c c c c ( C ) c 1 > Cl . [Cl-] [END]',5)

# C C C C C ( C O ) c 1 c c c ( C ) c c 1 C

NameError: name 'beam_search' is not defined

In [9]:
compute_accuracy_greedy("/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/src-test.txt","/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/tgt-test.txt")

[12:54:49] SMILES Parse Error: extra close parentheses while parsing: CO)=O.c1cc2[nH]ccc2cc1CCO
[12:54:49] SMILES Parse Error: Failed parsing SMILES 'CO)=O.c1cc2[nH]ccc2cc1CCO' for input: 'CO)=O.c1cc2[nH]ccc2cc1CCO'
[12:54:49] SMILES Parse Error: unclosed ring for input: 'O=c1cnc2ccc(F)cn2CC(CN3CCC4O)ccc3c(cn2)CN(C2)CO2'
[12:54:49] SMILES Parse Error: unclosed ring for input: 'C=C(C)Oc1ccc(N2C(=O)C(=C(C)C)sc2-c2ccccc2)C(=O)N2'
[12:54:50] Can't kekulize mol.  Unkekulized atoms: 3 4 5 6 12 13 14 15 25 26 28
[12:54:50] SMILES Parse Error: ring closure 1 duplicates bond between atom 1 and atom 2 for input: 'Fc1-c1c(CN2CCCCC2)oc(-c3ccc(O)c2)N1CCCCC2'
[12:54:50] SMILES Parse Error: extra open parentheses for input: 'Cc1ccc(NC2CS2)c1nc(NCc1ccc(C)o1'
[12:54:50] SMILES Parse Error: extra close parentheses while parsing: CCCCCC(C)C1CCC2CC=C2CC(OCCCCCCC(=N)C)C)CCC1
[12:54:50] SMILES Parse Error: Failed parsing SMILES 'CCCCCC(C)C1CCC2CC=C2CC(OCCCCCCC(=N)C)C)CCC1' for input: 'CCCCCC(C)C1CCC2CC=C2CC

0.0

In [8]:
compute_accuracy_beam("/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/src-test.txt","/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/tgt-test.txt",5)

[12:53:37] SMILES Parse Error: extra close parentheses while parsing: Nc1cc(F)ccc1C#NO)cc(
[12:53:37] SMILES Parse Error: Failed parsing SMILES 'Nc1cc(F)ccc1C#NO)cc(' for input: 'Nc1cc(F)ccc1C#NO)cc('
[12:53:37] SMILES Parse Error: extra close parentheses while parsing: Nc1cc(F)ccc1C#NF)cc(
[12:53:37] SMILES Parse Error: Failed parsing SMILES 'Nc1cc(F)ccc1C#NF)cc(' for input: 'Nc1cc(F)ccc1C#NF)cc('
[12:53:37] SMILES Parse Error: extra close parentheses while parsing: Nc1cc(F)ccc1C#NO)cc(
[12:53:37] SMILES Parse Error: Failed parsing SMILES 'Nc1cc(F)ccc1C#NO)cc(' for input: 'Nc1cc(F)ccc1C#NO)cc('
[12:53:38] SMILES Parse Error: extra close parentheses while parsing: CO)=C.c1cn(-c2ccc(C(F)(F)F)cc2
[12:53:38] SMILES Parse Error: Failed parsing SMILES 'CO)=C.c1cn(-c2ccc(C(F)(F)F)cc2' for input: 'CO)=C.c1cn(-c2ccc(C(F)(F)F)cc2'
[12:53:38] SMILES Parse Error: extra close parentheses while parsing: COC)c1cn(C(F)(F)Fccc2n(C)c1cc
[12:53:38] SMILES Parse Error: Failed parsing SMILES 'COC)c1cn(C(F

0.0