In [132]:
import yaml
import torch
from pathlib import Path
from src.data import load

with open("config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

    # Update config
    config['data_path'] = Path('data/' + config['data_set'])
    if not torch.cuda.is_available():
        config['device'] = 'cpu'

train_data, val_data, tokenizer = load(config['data_path'], config['batch_size'])

config['src_vocab_size'] = config['tgt_vocab_size'] = tokenizer.get_vocab_size()
print(config['src_vocab_size'])


298


In [133]:
from src.models.transformer import TransformerModel
transformer = TransformerModel(
    config['num_encoder_layers'],
    config['num_decoder_layers'],
    config['emb_size'],
    config['nhead'],
    config['src_vocab_size'],
    config['tgt_vocab_size'],
    config['ffn_hid_dim'],
)

In [194]:
transformer.load_state_dict(torch.load('src/models/transformer/pretrained/USPTO_MIT_OCT_18/Batch_140.pth', map_location = torch.device(config['device'])))
transformer.eval()

TransformerModel(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=512, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (decoder): TransformerDecoder(
      (layers): ModuleList(
        (0-3): 4 x TransformerDecoderLayer(
          (self_attn): MultiheadAttentio

In [195]:
import rdkit
from rdkit import Chem
from tqdm import tqdm
# from rpCHEM.Common.Util import smi_to_unique_smi_fast, smi_to_unique_smi_map, exact_mass

def canonicalize_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is not None:
        return Chem.MolToSmiles(mol, isomericSmiles=True)
    else:
        return ''
#     
# def canonicalize_smiles(smiles):
#     return smi_to_unique_smi_fast(smiles)


def compute_accuracy_greedy(source,target):

    with open(source) as source_file:
        source_lines = source_file.readlines()

    with open(target) as target_file:
        target_lines = target_file.readlines()

    matches = 0

    for source_line, target_line in zip(source_lines, target_lines[:100]):
        tgt = target_line.strip()
        tgt = tgt.replace(" ", "")
        tgt = canonicalize_smiles(tgt)

        source_line = '[START] ' + source_line.strip() + ' [END]'
        prediction = greedy_decode(transformer, source_line)
        prediction = prediction.replace("[END]", "")
        prediction = prediction.replace("[START]", "")
        prediction = prediction.strip()
        prediction = prediction.replace(" ", "")
        prediction = canonicalize_smiles(prediction)

        if target == prediction:
                matches = matches + 1
                #print("target is " + target)
                #print("prediction is " + prediction)

    return matches / float(len(target_lines))



def compute_accuracy_beam(source,target,beam_size):

    with open(source) as source_file:
        source_lines = source_file.readlines()

    with open(target) as target_file:
        target_lines = target_file.readlines()

    matches = 0

    for source_line, target_line in zip(source_lines, target_lines[:100]):
        tgt = target_line.strip()
        tgt = tgt.replace(" ", "")
        tgt = canonicalize_smiles(tgt)

        source_line = '[START] ' + source_line.strip() + ' [END]'
        predictions = beam_search(transformer, source_line,beam_size)
        correct = False

        for prediction in predictions:
            prediction = prediction.replace("[END]", "")
            prediction = prediction.replace("[START]", "")
            prediction = prediction.strip()
            prediction = prediction.replace(" ", "")
            prediction = canonicalize_smiles(prediction)

            if target == prediction:
                    correct = True
        if correct:
            matches = matches + 1

    return matches / float(len(target_lines))

In [201]:
def greedy_decode(model, input_sequence):
    model.eval()
    with torch.no_grad():
        # Encode the input sequence
        src = tokenizer.encode(input_sequence).ids
        src = torch.tensor(src).to(config['device'])
        src = src.unsqueeze(1).T
        
        print(src)
        memory = model.encode(src, src_mask=None)
        
        
        print(f'{memory.shape=}')
        print(memory)
        

        # Prepare the initial input for the decoder
        decoder_input = torch.tensor([0]).to(config['device'])
        decoder_input = decoder_input.unsqueeze(1)

        output_sequence = []
        MAX_LENGTH = 100
        for _ in range(MAX_LENGTH):
            # Generate the next token
            output = model.decode(decoder_input, memory, tgt_mask=None)
            output = model.generator(output)
            print(f'{output[:, -1, :]=}')
            next_token = output.argmax(2)[:, -1].item()
            output_sequence.append(next_token)
            
            if next_token == 1:
                break

            # Prepare the next input for the decoder
            decoder_input = torch.cat([decoder_input, torch.tensor([[next_token]]).to(config['device'])], 1)

    return tokenizer.decode(output_sequence)

In [202]:
canonicalize_smiles("CNCC(C1=CC(=C(C=C1)O)O)O")

'CNCC(O)c1ccc(O)c(O)c1'

In [203]:
greedy_decode(transformer,'C ( C ( C ) C ) [Mg+] . O c 1 n c c ( c c 1 ) C ( = O ) N ( C ) O C > C 1 C O C C 1 . [Cl-]')

# C C C C C ( C O ) c 1 c c c ( C ) c c 1 C

tensor([[ 23,   6,  23,   6,  23,   7,  23,   7, 147,   9,  28, 293,  10, 294,
         293, 293,   6, 293, 293,  10,   7,  23,   6,  19,  28,   7,  27,   6,
          23,   7,  28,  23,  20,  23,  10,  23,  28,  23,  23,  10,   9,  86]])
memory.shape=torch.Size([1, 42, 512])
tensor([[[-1.8711, -0.9703,  0.1766,  ...,  2.4295,  0.6192, -0.1062],
         [-1.9460, -1.4832,  0.0630,  ...,  2.2756,  0.6844,  0.0259],
         [-1.7488, -1.0970,  0.2565,  ...,  2.4035,  0.5092, -0.0955],
         ...,
         [-1.6326, -1.1198,  0.1559,  ...,  2.3430,  1.1389, -0.0879],
         [-1.6902, -0.7421,  0.5550,  ...,  2.3674,  0.5808,  0.4074],
         [-1.7978, -1.1126,  0.3300,  ...,  2.1973,  1.0896, -0.1157]]])
output[:, -1, :]=tensor([[-1.1612e+00,  6.5887e+00, -1.6040e+00, -2.1401e+00,  2.9669e+00,
          2.0579e+00,  4.1972e+00,  3.8274e+00,  1.4095e+00,  2.2577e+00,
          4.2941e+00,  3.1080e+00,  3.8029e+00,  1.2807e+00,  1.6363e+00,
          3.5328e-01, -3.4849e-01, -2.3038

'C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C'

In [208]:
greedy_decode(transformer,'[START] C C C C C C . Br Br [END]')

tensor([[ 0, 23, 23, 23, 23, 23, 23,  9, 22, 22,  1]])
memory.shape=torch.Size([1, 11, 512])
tensor([[[-1.5966, -0.3690, -0.1153,  ...,  1.6717,  1.4651, -0.1527],
         [-1.9144, -0.5769,  0.1755,  ...,  1.6654,  0.5826, -0.1472],
         [-1.8879, -0.6965,  0.1894,  ...,  1.6578,  0.5188, -0.1269],
         ...,
         [-1.6104, -0.2887,  0.0813,  ...,  1.5339,  0.9499, -0.1984],
         [-1.6761, -0.3766,  0.0956,  ...,  1.5275,  0.9534, -0.1569],
         [-2.1971, -0.7145,  0.5731,  ...,  1.7209,  0.9559, -0.2519]]])
output[:, -1, :]=tensor([[-1.1349e+00,  7.0345e+00, -1.8560e+00, -2.4043e+00,  3.0779e+00,
          1.8756e+00,  4.6675e+00,  4.4139e+00,  1.4033e+00,  2.0329e+00,
          4.6979e+00,  3.1770e+00,  3.9421e+00,  1.3324e+00,  1.6384e+00,
          6.7196e-02, -8.4919e-01, -2.1185e+00, -1.6655e+00,  5.8106e+00,
         -3.9200e-01, -7.5913e-01,  4.8296e+00,  7.2039e+00,  4.7624e+00,
          1.5139e+00,  1.3495e+00,  3.3154e+00,  6.0748e+00,  1.0323e+00,
    

'C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C C'

In [205]:
m1, m2

NameError: name 'm1' is not defined

In [185]:
beam_search(transformer,'[START] C C C C C 1 C O 1 . C c 1 c c c c ( C ) c 1 > Cl . [Cl-] [END]',5)

# C C C C C ( C O ) c 1 c c c ( C ) c c 1 C

NameError: name 'beam_search' is not defined

In [9]:
compute_accuracy_greedy("/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/src-test.txt","/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/tgt-test.txt")

[12:54:49] SMILES Parse Error: extra close parentheses while parsing: CO)=O.c1cc2[nH]ccc2cc1CCO
[12:54:49] SMILES Parse Error: Failed parsing SMILES 'CO)=O.c1cc2[nH]ccc2cc1CCO' for input: 'CO)=O.c1cc2[nH]ccc2cc1CCO'
[12:54:49] SMILES Parse Error: unclosed ring for input: 'O=c1cnc2ccc(F)cn2CC(CN3CCC4O)ccc3c(cn2)CN(C2)CO2'
[12:54:49] SMILES Parse Error: unclosed ring for input: 'C=C(C)Oc1ccc(N2C(=O)C(=C(C)C)sc2-c2ccccc2)C(=O)N2'
[12:54:50] Can't kekulize mol.  Unkekulized atoms: 3 4 5 6 12 13 14 15 25 26 28
[12:54:50] SMILES Parse Error: ring closure 1 duplicates bond between atom 1 and atom 2 for input: 'Fc1-c1c(CN2CCCCC2)oc(-c3ccc(O)c2)N1CCCCC2'
[12:54:50] SMILES Parse Error: extra open parentheses for input: 'Cc1ccc(NC2CS2)c1nc(NCc1ccc(C)o1'
[12:54:50] SMILES Parse Error: extra close parentheses while parsing: CCCCCC(C)C1CCC2CC=C2CC(OCCCCCCC(=N)C)C)CCC1
[12:54:50] SMILES Parse Error: Failed parsing SMILES 'CCCCCC(C)C1CCC2CC=C2CC(OCCCCCCC(=N)C)C)CCC1' for input: 'CCCCCC(C)C1CCC2CC=C2CC

0.0

In [8]:
compute_accuracy_beam("/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/src-test.txt","/baldig/chemistry/2023_rp/USPTO_MIT_TOKENIZED/MIT_separated_augm/tgt-test.txt",5)

[12:53:37] SMILES Parse Error: extra close parentheses while parsing: Nc1cc(F)ccc1C#NO)cc(
[12:53:37] SMILES Parse Error: Failed parsing SMILES 'Nc1cc(F)ccc1C#NO)cc(' for input: 'Nc1cc(F)ccc1C#NO)cc('
[12:53:37] SMILES Parse Error: extra close parentheses while parsing: Nc1cc(F)ccc1C#NF)cc(
[12:53:37] SMILES Parse Error: Failed parsing SMILES 'Nc1cc(F)ccc1C#NF)cc(' for input: 'Nc1cc(F)ccc1C#NF)cc('
[12:53:37] SMILES Parse Error: extra close parentheses while parsing: Nc1cc(F)ccc1C#NO)cc(
[12:53:37] SMILES Parse Error: Failed parsing SMILES 'Nc1cc(F)ccc1C#NO)cc(' for input: 'Nc1cc(F)ccc1C#NO)cc('
[12:53:38] SMILES Parse Error: extra close parentheses while parsing: CO)=C.c1cn(-c2ccc(C(F)(F)F)cc2
[12:53:38] SMILES Parse Error: Failed parsing SMILES 'CO)=C.c1cn(-c2ccc(C(F)(F)F)cc2' for input: 'CO)=C.c1cn(-c2ccc(C(F)(F)F)cc2'
[12:53:38] SMILES Parse Error: extra close parentheses while parsing: COC)c1cn(C(F)(F)Fccc2n(C)c1cc
[12:53:38] SMILES Parse Error: Failed parsing SMILES 'COC)c1cn(C(F

0.0