In [19]:
'''
ENV PARAMS
'''
train_new_model = False
continue_training = None

if not train_new_model:
    continue_training = False


#PATH LOCAL
base_path = None #set path to folder containing this notebook here


if not train_new_model:
    load_path = f'{base_path}best_model.pth'
else:
    load_path = None
path_to_data = f'{base_path}data/USPTO_50K/'

        

In [20]:
'''
IMPORTS
'''

import sys
from models import LocalRetro

import torch


import torch.nn as nn
from torch.optim import Adam, lr_scheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import os
import numpy as np
from datetime import date
import pandas as pd
import dgl
from dgllife.utils import WeaveAtomFeaturizer, \
            CanonicalBondFeaturizer, smiles_to_bigraph, EarlyStopping
from dgl.data.utils import save_graphs, load_graphs, Subset

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = torch.device('cpu')
print(f'Training on: {device}')


Training on: cpu


In [21]:
'''
HYPERPARAMETER
'''
batch_size = 16
num_workers = 0
lr = 1e-4
#lr scheduler
lr_step_size = 10
#early stopping
patience = 5
weight_decay = 1e-6
num_epochs = 1 # defual 50
print_every = 50
max_grad_clip = 20

In [22]:
'''
PREPROCESSING
'''
from collections import defaultdict
import re
import rdkit
from rdkit import Chem, RDLogger 
from rdkit.Chem import rdChemReactions
RDLogger.DisableLog('rdApp.*')
from LocalTemplate.template_extractor import extract_from_reaction
from extract_utils import get_reaction_template, reduce_template

'''
PREPROCESSING OF TRAINING DATA

creates reaction templates based on training data

IN:  
    raw_train.csv            train dataset

NOTE: 
    reaction class is being used!
    seperate class numbers are used for atom and bond templates
        

OUT: 
    template_rxnclass.csv    templates per reaction USPTO-class
    smiles2smarts.csv        contains smarts reaction templates
    atom_templates.csv       contains templates for atom change
    bond_templates.csv       contains templates for bond change

'''
if not os.path.exists(f'{base_path}data/USPTO_50K/atom_templates.csv'):
    
    from extract_utils import remove_reagents, dearomatic, fix_arom, destereo, \
                    clean_smarts, demap

    #extract templates
    print('extracting templates')
    rxns = pd.read_csv(f'{path_to_data}raw_train.csv')['reactants>reagents>production']
    class_train = (f'{path_to_data}class_train.csv')
    if os.path.exists(class_train):
        RXNHASCLASS = True
        rxn_class = pd.read_csv(class_train)['class']
        template_rxnclass = {i+1:set() for i in range(10)}
    else:
        RXNHASCLASS = False

    smiles2smarts = {}
    smiles2edit = {}
    smiles2Hs = {}
    atom_templates = defaultdict(int)
    bond_templates = defaultdict(int)
    unique_templates = set()
    
    for i, rxn in enumerate(rxns):
        if RXNHASCLASS:
            template_class = rxn_class[i]
        try:
            rxn, result = get_reaction_template(rxn, i)
            if 'reaction_smarts' not in result.keys():
                continue
            local_template = result['reaction_smarts']
            smi_template, sma_template = reduce_template(local_template)
            if smi_template not in smiles2smarts.keys(): # first come first serve
                template = sma_template
                smiles2smarts[smi_template] = sma_template
                smiles2edit[smi_template] = result['edit_sites'][2] # keep the map of changing idx
                smiles2Hs[smi_template] = result['H_change']
            else:
                template = smiles2smarts[smi_template]
                
            edit_sites = result['edit_sites'][0]
            
            atom_edit = False
            bond_edit = False
            
            for e in edit_sites:
                if type(e) == type(1):
                    atom_edit = True
                else:
                    bond_edit = True
            
            if atom_edit:
                atom_templates[template] += 1
            if bond_edit:
                bond_templates[template] += 1
            
            unique_templates.add(template)
            if RXNHASCLASS:
                template_rxnclass[template_class].add(template)
                
        except KeyboardInterrupt:
            print('Interrupted')
            try:
                sys.exit(0)
            except SystemExit:
                os._exit(0)
        except ValueError as e:
            print (i, e)
        if i % 100 == 0:
            print ('\r i = %s, # of atom template: %s, # of bond template: %s' \
                   % (i, len(atom_templates), len(bond_templates)), end='', flush=True)
    print (f'\n total # of template: {len(unique_templates)}')
    derived_templates = {'atom':atom_templates, 'bond': bond_templates}
    
    if RXNHASCLASS:
        pd.DataFrame.from_dict(template_rxnclass, orient = 'index').T.to_csv(
            f'{path_to_data}template_rxnclass.csv', index = None)
        print(f'written {path_to_data}template_rxnclass.csv')
    smiles2smarts = pd.DataFrame({'Smiles_template': k, 'Smarts_template': t,
                                  'edit_site':smiles2edit[k], 'change_H': \
                                  smiles2Hs[k]} for k, t in smiles2smarts.items())
    smiles2smarts.to_csv(f'{path_to_data}smiles2smarts.csv')
    print(f'written {path_to_data}smiles2smarts.csv')
    #export template
    print(f'exporting templates')
    for k in derived_templates.keys():
        local_templates = derived_templates[k]
        templates = []
        template_class = []
        template_freq = []
        sorted_tuples = sorted(local_templates.items(), key=lambda item: item[1])
        c = 1
        for t in sorted_tuples:
            templates.append(t[0])
            template_freq.append(t[1])
            template_class.append(c)
            c += 1
        template_dict = {templates[i]:i+1  for i in range(len(templates)) }
        template_df = pd.DataFrame({'Template' : templates, 'Frequency' : template_freq,
                                    'Class': template_class})

        template_df.to_csv(f'{path_to_data}{k}_templates.csv')
        print(f'written {path_to_data}{k}_templates.csv')

'''
CREATING LABELS FOR WHOLE DATASET

IN:  
    smiles2smarts.csv        contains smarts reaction templates
    atom_templates.csv       contains templates for atom change
    bond_templates.csv       contains templates for bond change


OUT: 
    preprocessed_train.csv   contains reaction, products, atom label and bond label
    preprocessed_valid.csv   contains reaction, products, atom label and bond label
    preprocessed_test.csv   contains reaction, products, atom label and bond label
    labeled_data.csv         combines all data of the 3 splits
'''    
if not os.path.exists(f'{base_path}data/USPTO_50K/labeled_data.csv'):
    
    #IMPORTS
    from preprocessing_utils import matchwithtemp, match_num, get_idx_map, get_edit_site
    num_max_edits = 8
    threshold = 1
    
    # load_template_dict
    template_dicts = {}
    for site in ['atom', 'bond']:
        template_df = pd.read_csv(f'{path_to_data}{site}_templates.csv')
        template_dict = {template_df['Template'][i]:template_df['Class'][i] 
                         for i in template_df.index  if template_df['Frequency'][i] >= threshold}
        print (f'loaded {len(template_dict)} {site} templates')
        template_dicts[site] = template_dict

    # load_smi2sma_dict
    template_df = pd.read_csv(f'{path_to_data}smiles2smarts.csv')
    smiles2smarts = {template_df['Smiles_template'][i]:template_df['Smarts_template'][i] \
                     for i in template_df.index}
    smiles2edit = {template_df['Smiles_template'][i]:template_df['edit_site'][i] \
                   for i in template_df.index}

    
    
    pre_sets = []
    split_names = ['train', 'val', 'test']
    # labeling_dataset   (args, split,
    for split in split_names:
        atom_templates = template_dicts['atom']
        bond_templates = template_dicts['bond']
        rxns = pd.read_csv(f'{path_to_data}raw_{split}.csv')['reactants>reagents>production']
        products = []
        atom_labels = []
        bond_labels = []
        masks = []
        success = 0
        for n, rxn in enumerate(rxns):
            product = rxn.split('>>')[1]
            try:
                rxn, result = get_reaction_template(rxn, n)
                local_template = result['reaction_smarts']
                smi_template, sma_template = reduce_template(local_template)
                if smi_template not in smiles2smarts.keys():
                    products.append(product)
                    atom_labels.append(0)
                    bond_labels.append(0)
                    masks.append(0)
                    continue
                else:
                    replace_dict = matchwithtemp(sma_template, smiles2smarts[smi_template], 
                                                 result['replacement_dict']) 
                    replace_dict = get_idx_map(product, replace_dict)
                    edit_sites = eval(match_num(smiles2edit[smi_template], replace_dict))
                    local_template = smiles2smarts[smi_template]

            except KeyboardInterrupt:
                print('Interrupted')
                try:
                    sys.exit(0)
                except SystemExit:
                    os._exit(0)
            except Exception as e:
                products.append(product)
                atom_labels.append(0)
                bond_labels.append(0)
                masks.append(0)
                continue

            if len(edit_sites) <= num_max_edits:
                atom_sites, bond_sites = get_edit_site(product)
                try:
                    if local_template not in atom_templates.keys() and \
                            local_template not in bond_templates.keys():
                        products.append(product)
                        atom_labels.append(0)
                        bond_labels.append(0)
                        masks.append(0)
                    else:
                        atom_label = [0] * len(atom_sites)
                        bond_label = [0] * len(bond_sites)
                        for edit_site in edit_sites:
                            if type(edit_site) == type(1):
                                atom_label[atom_sites.index(edit_site)] = atom_templates[local_template]
                            else:
                                bond_label[bond_sites.index(edit_site)] = bond_templates[local_template]
                        products.append(product)
                        atom_labels.append(atom_label)
                        bond_labels.append(bond_label)
                        masks.append(1)
                        success += 1
                except Exception as e:
                    products.append(product)
                    atom_labels.append(0)
                    bond_labels.append(0)
                    masks.append(0)
                    continue

                if n % 100 == 0:
                    print ('\r Processing USPTO_50K %s data..., success %s data (%s/%s)' \
                           % (split, success, n, len(rxns)), end='', flush=True)
            else:
                print ('\nReaction # %s has too many (%s) edits... may be wrong mapping!'
                       % (n, len(edit_sites)))
                products.append(product)
                atom_labels.append(0)
                bond_labels.append(0)
    # begin edit
                masks.append(0)
    # end edit
            
        print ('\nDerived tempalates cover %.3f of %s data reactions' % ((success/len(rxns)), split))
    # begin cheap trick
        if len(masks) != len(rxns):
            diff = len(rxns) - len(masks)
            count_diff = 0
            for _ in range(diff):
                count_diff += 1
                masks.append(0)
            print(f'cheap trick applied {count_diff} times')
    # end cheap trick
        df = pd.DataFrame({'Reaction': rxns, 'Products': products, 'Atom_label': atom_labels, \
                           'Bond_label': bond_labels, 'Mask': masks})
        df.to_csv(f'{base_path}data/USPTO_50K/preprocessed_{split}.csv')
        print(f'written {base_path}data/USPTO_50K/preprocessed_{split}.csv')
        pre_sets.append(df)


    # combine_preprocessed_data
    train_valid = pre_sets[0][pre_sets[0]['Mask'] != 0].reset_index()
    val_valid = pre_sets[1][pre_sets[1]['Mask'] != 0].reset_index()
    test_valid = pre_sets[2][pre_sets[2]['Mask'] != 0].reset_index()
    
    train_valid['Split'] = ['train'] * len(train_valid)
    val_valid['Split'] = ['val'] * len(val_valid)
    test_valid['Split'] = ['test'] * len(test_valid)

    all_valid = train_valid.append(val_valid, ignore_index=True)
    all_valid = all_valid.append(test_valid, ignore_index=True)
    print (f'Valid data size: {len(all_valid)}')
    all_valid.to_csv(f'{path_to_data}labeled_data.csv', index = None)
    print(f'written {path_to_data}labeled_data.csv')

In [23]:
'''
DATASET

loading graphs or create them if they don't exist

IN:  
    (labeled_data.csv)         combines all data of the 3 splits

OUT: 
    USPTO_50K_dglgraph.bin     contains graphs of labeled_data.csv
    train_loader
    val_loader
    test_loader
'''    
def flatten_list(t):
    return torch.LongTensor([item for sublist in t for item in sublist])

def collate_molgraphs(data):
    smiles, graphs, atom_labels, bond_labels = map(list, zip(*data))
    atom_labels = flatten_list(atom_labels)
    bond_labels = flatten_list(bond_labels)
    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    return smiles, bg, atom_labels, bond_labels

class USPTODataset(object):
    def __init__(self, base_path, load=True, log_every=1000):
        df = pd.read_csv(f'{path_to_data}labeled_data.csv')
        self.train_ids = df.index[df['Split'] == 'train'].values
        self.val_ids = df.index[df['Split'] == 'val'].values
        self.test_ids = df.index[df['Split'] == 'test'].values
        self.smiles = df['Products'].tolist()
        self.atom_labels = [eval(t) for t in df['Atom_label']]
        self.bond_labels = [eval(t) for t in df['Bond_label']]
        self.cache_file_path = f'{base_path}/data/saved_graphs/USPTO_50K_dglgraph.bin'
        self._pre_process(load, log_every)

    def _pre_process(self, load, log_every):
        if os.path.exists(self.cache_file_path) and load:
            print('Loading previously saved dgl graphs...')
            self.graphs, label_dict = load_graphs(self.cache_file_path)
        else:
            print('Processing dgl graphs from scratch...')
            self.graphs = []
            for i, s in enumerate(self.smiles):
                if (i + 1) % log_every == 0:
                    print('\rProcessing molecule %d/%d' % (i+1, 
                                                    len(self.smiles)), end='', flush=True)
                self.graphs.append(smiles_to_bigraph(s, add_self_loop=True, 
                            node_featurizer=WeaveAtomFeaturizer(), 
                            edge_featurizer=CanonicalBondFeaturizer(self_loop=True), 
                                                     canonical_atom_order=False))
            print ()
            save_graphs(self.cache_file_path, self.graphs)

    def __getitem__(self, item):
        return self.smiles[item], self.graphs[item], self.atom_labels[item], self.bond_labels[item]

    def __len__(self):
            return len(self.smiles)
        
        
dataset = USPTODataset(base_path)
print(f'Dataset loaded with len {len(dataset)}, creating subsets...')
train_set = Subset(dataset, dataset.train_ids)
val_set   = Subset(dataset, dataset.val_ids)
test_set  = Subset(dataset, dataset.test_ids)
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True,
                          collate_fn=collate_molgraphs, num_workers=num_workers)
val_loader = DataLoader(dataset=val_set, batch_size=batch_size, shuffle=True,
                          collate_fn=collate_molgraphs, num_workers=num_workers)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=True,
                          collate_fn=collate_molgraphs, num_workers=num_workers)
print(f'''created loaders: train {len(train_loader)*batch_size} \
val {len(val_loader)*batch_size} test {len(test_loader)*batch_size}''')

Loading previously saved dgl graphs...
Dataset loaded with len 49944, creating subsets...
created loaders: train 40000 val 4976 test 4992


In [24]:
'''
CREATE/LOAD MODEL

IN:
    (load_path)

OUT: 
    model
'''    
def get_configure():
    config = {
              "attention_heads": 8,
              "attention_layers": 1,
              "edge_hidden_feats": 64,
              "node_out_feats": 320,
              "num_step_message_passing": 6
            }
    config['AtomTemplate_n'] = len(pd.read_csv(f'{path_to_data}atom_templates.csv'))
    config['BondTemplate_n'] = len(pd.read_csv(f'{path_to_data}bond_templates.csv'))
    config['in_node_feats'] = WeaveAtomFeaturizer().feat_size()
    config['in_edge_feats'] = CanonicalBondFeaturizer(self_loop=True).feat_size()
    config['GRA'] = True
    return config

def get_model():
    exp_config = get_configure()
    model = LocalRetro(
        node_in_feats=exp_config['in_node_feats'],
        edge_in_feats=exp_config['in_edge_feats'],
        node_out_feats=exp_config['node_out_feats'],
        edge_hidden_feats=exp_config['edge_hidden_feats'],
        num_step_message_passing=exp_config['num_step_message_passing'],
        attention_heads = exp_config['attention_heads'],
        attention_layers = exp_config['attention_layers'],
        AtomTemplate_n = exp_config['AtomTemplate_n'],
        BondTemplate_n = exp_config['BondTemplate_n'],
        GRA = exp_config['GRA'])
    model = model.to(device)
    return model

def load_model(load_path, base_path):

    return model, optimizer, epoch, loss

if train_new_model == False:
    assert os.path.exists(load_path)
    model = get_model()
    if torch.cuda.is_available():
        checkpoint = torch.load(load_path)
    else:
        checkpoint = torch.load(load_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])

    if continue_training:
        optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epoch = checkpoint['epoch']
        loss = checkpoint['loss']     
        print(f'loaded model data, ready for training')
        
    else:
        print(f'loaded model data, ready for inference')


elif os.path.exists(f'{base_path}models/model_{date.today()}.pth'):
    print(f'this path already exists: {base_path}models/model_{date.today()}.pth')
    
else:
    print(f'no model found, created this new one:')
    model = get_model()
    print(model)
    loss_criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

loaded model data, ready for inference


In [25]:
'''
TRAIN MODEL

IN:
    

OUT: 
    model from early stopping
    additionally: model_state_dict to continue training
'''    

if continue_training or train_new_model:
    scheduler = lr_scheduler.StepLR(optimizer, step_size=lr_step_size)
    stopper = EarlyStopping(mode='lower', patience=patience)
    for epoch in range(num_epochs):
        ### TRAIN MODEL
        model.train()
        train_loss = 0
        train_acc = 0
        for batch_id, batch_data in enumerate(tqdm(train_loader)):
            smiles, bg, atom_labels, bond_labels = batch_data
            if len(smiles) == 1:
                continue
            atom_labels, bond_labels = atom_labels.to(device), bond_labels.to(device)

            #predict
            bg = bg.to(device)
            node_feats = bg.ndata.pop('h').to(device)
            edge_feats = bg.edata.pop('e').to(device)
            atom_logits, bond_logits, _ = model(bg, node_feats, edge_feats)

            loss_a = loss_criterion(atom_logits, atom_labels).mean()
            loss_b = loss_criterion(bond_logits, bond_labels).mean()
            total_loss = loss_a + loss_b

            train_loss += total_loss.item()

            optimizer.zero_grad()      
            total_loss.backward() 
            nn.utils.clip_grad_norm_(model.parameters(), max_grad_clip)
            optimizer.step()

        print('\nepoch %d/%d, training loss: %.4f' % (epoch + 1, num_epochs, train_loss/batch_id))

        ### VALIDATE MODEL    
        model.eval()
        val_loss = 0
        val_acc = 0
        with torch.no_grad():
            for batch_id, batch_data in enumerate(tqdm(val_loader)):
                smiles, bg, atom_labels, bond_labels = batch_data
                atom_labels, bond_labels = atom_labels.to(device), bond_labels.to(device)

                bg = bg.to(device)
                node_feats = bg.ndata.pop('h').to(device)
                edge_feats = bg.edata.pop('e').to(device)
                atom_logits, bond_logits, _ = model(bg, node_feats, edge_feats)

                loss_a = loss_criterion(atom_logits, atom_labels).mean()
                loss_b = loss_criterion(bond_logits, bond_labels).mean()
                total_loss = loss_a + loss_b
                val_loss += total_loss.item()

        val_loss = val_loss/batch_id            
        early_stop = stopper.step(val_loss, model) #maybe detach loss here?
        scheduler.step()
        print('epoch %d/%d, validation loss: %.4f' %  (epoch + 1, num_epochs, val_loss))
        print('epoch %d/%d, Best val loss: %.4f' % (epoch + 1, num_epochs, stopper.best_score))
        if early_stop:
            print ('Early stopped!!')
            break

    stopper.load_checkpoint(model)

    ### TEST MODEL
    model.eval()
    test_loss = 0
    test_acc = 0
    with torch.no_grad():
        for batch_id, batch_data in enumerate(tqdm(test_loader)):
            smiles, bg, atom_labels, bond_labels = batch_data
            atom_labels, bond_labels = atom_labels.to(device), bond_labels.to(device)

            bg = bg.to(device)
            node_feats = bg.ndata.pop('h').to(device)
            edge_feats = bg.edata.pop('e').to(device)
            atom_logits, bond_logits, _ = model(bg, node_feats, edge_feats)

            loss_a = loss_criterion(atom_logits, atom_labels).mean()
            loss_b = loss_criterion(bond_logits, bond_labels).mean()
            total_loss = loss_a + loss_b
            test_loss += total_loss.item()

    test_loss = test_loss/batch_id 
    print('test loss: %.4f' % test_loss)

    torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss' : train_loss
                }, f'{base_path}models/model_train_{date.today()}.pth')
    print(f'best model additionally saved at: {base_path}models/model_train_{date.today()}.pth')

test_loader = None
test_set = None
print(f'prepared for inference: cleared test dataset variables')

prepared for inference: cleared test dataset variables


In [42]:
'''
PREPROCESSING TEST DATA

create graphs from raw test data

IN:
    raw_test.csv

OUT: 
    USPTO_50K_test_dglgraph.bin
'''    
canonicalize_data = True
use_reduced_testset = False
overwrite_inf = True


if use_reduced_testset:
    result_path = f'{base_path}outputs/reduced_testset/'
else:
    result_path = f'{base_path}outputs/'



def collate_molgraphs_test(data):
    smiles, graphs, rxns = map(list, zip(*data))
    bg = dgl.batch(graphs)
    bg.set_n_initializer(dgl.init.zero_initializer)
    bg.set_e_initializer(dgl.init.zero_initializer)
    return smiles, bg, rxns

def canonicalize_smi(smi, is_smarts=False, remove_atom_mapping=True):
    """
    Canonicalize SMARTS from https://github.com/rxn4chemistry/rxnfp/blob/master/rxnfp/tokenization.py#L249
    """
    mol = Chem.MolFromSmiles(smi)
    if not mol:
        raise ValueError("Molecule not canonicalizable")
    if remove_atom_mapping:
        for atom in mol.GetAtoms():
            if atom.HasProp("molAtomMapNumber"):
                atom.ClearProp("molAtomMapNumber")
    return Chem.MolToSmiles(mol)


class USPTOTestDataset(object):
    def __init__(self, canonicalize_data, use_reduced_testset, load=True, log_every=1000):
        self.canonicalize = canonicalize_data
        if use_reduced_testset:
            df = pd.read_csv(f'{path_to_data}raw_test_reduced.csv')
        else:
            df = pd.read_csv(f'{path_to_data}raw_test.csv')
        self.rxns = df['reactants>reagents>production'].tolist()
        self.smiles = [rxn.split('>>')[-1] for rxn in self.rxns]
        if self.canonicalize and not use_reduced_testset:
            for _ in range(2):
                self.smiles = [canonicalize_smi(smi) for smi in self.smiles]
            self.cache_file_path = f'{base_path}data/saved_graphs/USPTO_50K_test_dglgraph_can.bin'
        elif not self.canonicalize and not use_reduced_testset:
            self.cache_file_path = f'{base_path}data/saved_graphs/USPTO_50K_test_dglgraph.bin'
        elif self.canonicalize and use_reduced_testset:
            for _ in range(2):
                self.smiles = [canonicalize_smi(smi) for smi in self.smiles]
            self.cache_file_path = f'{base_path}data/saved_graphs/USPTO_50K_test_reduced_dglgraph_can.bin'
        elif not self.canonicalize and use_reduced_testset:
            self.cache_file_path = f'{base_path}data/saved_graphs/USPTO_50K_test_reduced_dglgraph.bin' 
        self._pre_process(load, log_every)
        

    def _pre_process(self,load, log_every):
        if os.path.exists(self.cache_file_path) and load:
            print('Loading previously saved test dgl graphs...')
            self.graphs, label_dict = load_graphs(self.cache_file_path)
        else:
            print('Processing test dgl graphs from scratch...')
            self.graphs = []
            for i, s in enumerate(self.smiles):
                if (i + 1) % log_every == 0:
                    print('Processing molecule %d/%d' % (i+1, len(self.smiles)))
                self.graphs.append(smiles_to_bigraph(s, add_self_loop=True, 
                                                   node_featurizer=WeaveAtomFeaturizer(),
                                                   edge_featurizer=CanonicalBondFeaturizer(self_loop=True), 
                                                   canonical_atom_order=False))
            save_graphs(self.cache_file_path, self.graphs)
            print(f'saved graphs to: {self.cache_file_path}')

    def __getitem__(self, item):
            return self.smiles[item], self.graphs[item], self.rxns[item]

    def __len__(self):
            return len(self.smiles)



test_set = USPTOTestDataset(canonicalize_data, use_reduced_testset)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size,
                         collate_fn=collate_molgraphs_test, num_workers=num_workers)

Loading previously saved test dgl graphs...


In [43]:
'''
INFERENCE


IN:
    test_loader

OUT: 
    raw_predictions.txt

''' 
#params for inference
top_num = 100 #default 100


configs = get_configure()

if use_reduced_testset and canonicalize_data:
    raw_path = f'{result_path}reduced_testset_raw_predictions_can.txt'
elif use_reduced_testset and not canonicalize_data:
    raw_path = f'{result_path}reduced_testset_raw_predictions.txt'
elif not use_reduced_testset and canonicalize_data:
    raw_path = f'{result_path}raw_predictions_can.txt'
elif not use_reduced_testset and not canonicalize_data:
    raw_path = f'{result_path}raw_predictions.txt'

def get_id_template(a, CLASS_NUM):
    CLASS_NUM = CLASS_NUM + 1 # no template
    edit_idx = a//CLASS_NUM
    template = a%CLASS_NUM
    return (edit_idx, template)

def output2edit(out, CLASS_NUM, top_num):
    readout = out.cpu().detach().numpy()
    readout = readout.reshape(-1)
    output_rank = np.flip(np.argsort(readout))
    output_rank = [r for r in output_rank if get_id_template(r, CLASS_NUM)[1] != 0][:top_num]
    
    selected_edit = [get_id_template(a, CLASS_NUM) for a in output_rank]
    selected_proba = [readout[a] for a in output_rank]
     
    return selected_edit, selected_proba
    
def combined_edit(graph, atom_out, bond_out, ATOM_CLASS, BOND_CLASS, top_num):
    edit_id_a, edit_proba_a = output2edit(atom_out, ATOM_CLASS, top_num)
    edit_id_b, edit_proba_b = output2edit(bond_out, BOND_CLASS, top_num)
    atom_pair_list = torch.transpose(graph.adjacency_matrix().coalesce().indices(), 0, 1).numpy()
    edit_id_b = [(list(atom_pair_list[edit_id[0]]), edit_id[1])  for edit_id in edit_id_b]
    edit_id_c = edit_id_a + edit_id_b
    edit_proba_c = edit_proba_a + edit_proba_b
    edit_rank_c = np.flip(np.argsort(edit_proba_c))[:top_num]
    edit_id_c = [edit_id_c[r] for r in edit_rank_c]
    edit_proba_c = [edit_proba_c[r] for r in edit_rank_c]
    
    return edit_id_c, edit_proba_c

if overwrite_inf:
    model.eval()
    with open(raw_path, 'w') as f:
        f.write('Test_id\tReaction\t%s\n' % '\t'.join(['Edit %s\tProba %s' % (i+1, i+1) \
                                                       for i in range(top_num)]))
        with torch.no_grad():
            for batch_id, data in enumerate(test_loader):
                _, bg, rxns = data

                #predict
                bg = bg.to(device)
                node_feats = bg.ndata.pop('h').to(device)
                edge_feats = bg.edata.pop('e').to(device)
                batch_atom_logits, batch_bond_logits, _ = model(bg, node_feats, edge_feats)

                batch_atom_logits = nn.Softmax(dim=1)(batch_atom_logits)
                batch_bond_logits = nn.Softmax(dim=1)(batch_bond_logits) 

                sg = bg.remove_self_loop()
                graphs = dgl.unbatch(sg, sg.batch_num_nodes(), (sg.batch_num_edges() - sg.batch_num_nodes()))
                nodes_sep = [0]
                edges_sep = [0]
                for g in graphs:
                    nodes_sep.append(nodes_sep[-1] + g.num_nodes())
                    edges_sep.append(edges_sep[-1] + g.num_edges())
                nodes_sep = nodes_sep[1:]
                edges_sep = edges_sep[1:]


                start_node = 0
                start_edge = 0
                print('\rWriting test molecule batch %s/%s' % (batch_id+1, len(test_loader)), end='', flush=True)
                for single_id, (graph, end_node, end_edge) in enumerate(zip(graphs, nodes_sep, edges_sep)):
                    rxn = rxns[single_id]
                    test_id = (batch_id * batch_size) + single_id
                    edit_id, edit_proba = combined_edit(graph, batch_atom_logits[start_node:end_node], 
                                    batch_bond_logits[start_edge:end_edge], configs['AtomTemplate_n'], 
                                                        configs['BondTemplate_n'], top_num)
                    start_node = end_node
                    start_edge = end_edge
                    f.write('%s\t%s\t%s\n' % (test_id, rxn, '\t'.join(['%s\t%.3f' % \
                                            (edit_id[i], edit_proba[i]) for i in range(top_num)])))

    print (f'\n written results to: {raw_path}')

Writing test molecule batch 313/313
 written results to: /home/dominik/AI_Master/Project/my_solution/outputs/raw_predictions_can.txt


In [44]:
'''
DECODING PREDICTIONS

applying templates to raw predictions

IN:
    raw_predictions.txt
    class_test.csv           raw_test.csv with class labels
    smiles2smarts.csv        contains smarts reaction templates
    atom_templates.csv       contains templates for atom change
    bond_templates.csv       contains templates for bond change

OUT: 
    decoded_predictions.txt  containing predicted reactants
    decoded_predictions_class.txt
                             decoded_predictions filtered by class value from USPTO database

''' 
top_k = 100 #default 50

from LocalTemplate.template_decoder import *



atom_templates = pd.read_csv(f'{path_to_data}atom_templates.csv')
bond_templates = pd.read_csv(f'{path_to_data}bond_templates.csv')
smiles2smarts = pd.read_csv(f'{path_to_data}smiles2smarts.csv')
class_test = f'{path_to_data}class_test.csv'
if os.path.exists(class_test):
    rxn_class_given = True
    templates_class = pd.read_csv(f'{path_to_data}template_rxnclass.csv')
    test_rxn_class = pd.read_csv(f'{path_to_data}class_test.csv')['class']
else:
    rxn_class_given = False 

    
if overwrite_inf:
    counter = 0
    atom_templates = {atom_templates['Class'][i]: atom_templates['Template'][i] 
                                                for i in atom_templates.index}
    bond_templates = {bond_templates['Class'][i]: bond_templates['Template'][i] 
                                                for i in bond_templates.index}
    smarts2E = {smiles2smarts['Smarts_template'][i]: eval(smiles2smarts['edit_site'][i]) 
                                                for i in smiles2smarts.index}
    smarts2H = {smiles2smarts['Smarts_template'][i]: eval(smiles2smarts['change_H'][i])
                                                for i in smiles2smarts.index}

    prediction = pd.read_csv(raw_path, sep = '\t')

    if use_reduced_testset and canonicalize_data:
        output_path = f'{result_path}reduced_testset_decoded_prediction_can.txt'
        output_path_class = f'{result_path}reduced_testset_decoded_prediction_class_can.txt'
    elif use_reduced_testset and not canonicalize_data:
        output_path = f'{result_path}reduced_testset_decoded_prediction.txt'
        output_path_class = f'{result_path}reduced_testset_decoded_prediction_class.txt'
    elif not use_reduced_testset and canonicalize_data:
        output_path = f'{result_path}decoded_prediction_can.txt'
        output_path_class = f'{result_path}decoded_prediction_class_can.txt'
    elif not use_reduced_testset and not canonicalize_data:
        output_path = f'{result_path}decoded_prediction.txt'
        output_path_class = f'{result_path}decoded_prediction_class.txt'    


    with open(output_path, 'w') as f1, open(output_path_class, 'w') as f2:
        for i in prediction.index:
            all_prediction = []
            class_prediction = []
            rxn = prediction['Reaction'][i]
            products = rxn.split('>>')[1]
            idx_map = get_idx_map(products)
            for K_prediciton in prediction.columns:
                if 'Edit' not in K_prediciton:
                    continue
                edition = eval(prediction[K_prediciton][i])
                edit_idx = edition[0]
                template_class = edition[1]
                if type(edit_idx) == type(0):
                    template = atom_templates[template_class]
                    if len(template.split('>>')[0].split('.')) > 1:
                        edit_idx = idx_map[edit_idx]
                else:
                    template = bond_templates[template_class]
                    edit_idx = tuple(edit_idx)
                    if len(template.split('>>')[0].split('.')) > 1:
                        edit_idx = (idx_map[edit_idx[0]], idx_map[edit_idx[1]])

                template_idx = smarts2E[template]
                H_change = smarts2H[template]
                try:
                    pred_reactants, _, _ = apply_template(products, template, edit_idx,
                                                          template_idx, H_change)
                except Exception as e:
                    # print (e)
                    counter += 1
                    pred_reactants = []

                if len(pred_reactants) == 0:
                    try:
                        template = dearomatic(template)
                        pred_reactants, _, _ = apply_template(products, template, edit_idx,
                                                              template_idx, H_change)
                    except:
                        pred_reactants = []

                all_prediction += [p for p in pred_reactants if p not in all_prediction]

                if rxn_class_given:
                    rxn_class = test_rxn_class[i]
                    if template in templates_class[str(rxn_class)].values:
                        class_prediction += [p for p in pred_reactants if p 
                                             not in class_prediction]
                    if len (class_prediction) >= top_k:
                        break

                elif len (all_prediction) >= top_k:
                    break

            f1.write('\t'.join(all_prediction) + '\n')
            f2.write('\t'.join(class_prediction) + '\n')
            print('\rDecoding LocalRetro predictions %d/%d' % \
                  (i+1, len(prediction)), end='', flush=True)
    print(f'\n num of excteptions when applying templates for decoding: {counter}')
    with open('log.txt', 'a') as f:
        f.write(f'\nnum of excteptions when applying templates for decoding: {counter}')

Decoding LocalRetro predictions 5007/5007
 num of excteptions when applying templates for decoding: 39


In [45]:
'''
TOP K

calculate top k accuracy


IN:
    decoded_predictions.txt  containing predicted reactants
    decoded_predictions_class.txt
                             decoded_predictions filtered by class value from USPTO database

OUT: 

''' 


from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers


def demap(smi):
    mol = Chem.MolFromSmiles(smi)
    for atom in mol.GetAtoms():
        atom.SetAtomMapNum(0)
    smi = Chem.MolToSmiles(mol)
    return Chem.MolToSmiles(Chem.MolFromSmiles(smi))
    
def get_isomers(smi):
    mol = Chem.MolFromSmiles(smi)
    isomers = tuple(EnumerateStereoisomers(mol))
    isomers_smi = [Chem.MolToSmiles(x, isomericSmiles=True) for x in isomers]
    return isomers_smi
    
def get_MaxFrag(smiles):
    return max(smiles.split('.'), key=len)

def isomer_match(preds, reac, MaxFrag = False):
    try:
        if MaxFrag:
            reac = get_MaxFrag(reac)
        reac_isomers = get_isomers(reac)
        for k, pred in enumerate(preds):
            if MaxFrag:
                pred = get_MaxFrag(pred)
            pred_isomers = get_isomers(pred)
            if(set(pred_isomers).issubset(set(reac_isomers))):
                return k+1
        return -1
    except:
        return -1
if use_reduced_testset:
    test_file = pd.read_csv(f'{path_to_data}raw_test_reduced.csv')    
else:
    test_file = pd.read_csv(f'{path_to_data}raw_test.csv')

rxn_ps = [rxn.split('>>')[1] for rxn in 
          test_file['reactants>reagents>production']]
ground_truth = [demap(rxn.split('>>')[0]) for rxn in 
                test_file['reactants>reagents>production']]


class_given = False

if class_given:
    if use_reduced_testset and canonicalize_data:
        result_file = f'{result_path}reduced_testset_decoded_prediction_class_can.txt'
    elif use_reduced_testset and not canonicalize_data:
        result_file = f'{result_path}reduced_testset_decoded_prediction_class.txt'
    elif not use_reduced_testset and canonicalize_data:
        result_file = f'{result_path}decoded_prediction_class_can.txt'
    elif not use_reduced_testset and not canonicalize_data:
        result_file = f'{result_path}decoded_prediction_class.txt'  
else:
    if use_reduced_testset and canonicalize_data:
        result_file = f'{result_path}reduced_testset_decoded_prediction_can.txt'
    elif use_reduced_testset and not canonicalize_data:
        result_file = f'{result_path}reduced_testset_decoded_prediction.txt'
    elif not use_reduced_testset and canonicalize_data:
        result_file = f'{result_path}decoded_prediction_can.txt'
    elif not use_reduced_testset and not canonicalize_data:
        result_file = f'{result_path}decoded_prediction.txt'

results = {}
with open(result_file, 'r') as f:
    for i, line in enumerate(f.readlines()):
        results[i] = line.split('\n')[0].split('\t')

        
        
Exact_matches = []
MaxFrag_matches = [] # Description in Supporting Information

Exact_matches_multi = []
MaxFrag_matches_mumlti = [] 
for i in range(len(results)):
    match_exact = isomer_match(results[i], ground_truth[i], False)
    match_maxfrag = isomer_match(results[i], ground_truth[i], True)
    if len(rxn_ps[i].split('.')) > 1:
        Exact_matches_multi.append(match_exact)
        MaxFrag_matches_mumlti.append(match_maxfrag)
    Exact_matches.append(match_exact)
    MaxFrag_matches.append(match_maxfrag)
    if i % 100 == 0:
        print ('\rCalculating accuracy... %s/%s' % (i, len(results)), end='', flush=True)
        
ks = [1, 3, 5, 10, 50]
exact_k = {k:0 for k in ks}
MaxFrag_k = {k:0 for k in ks}

print(len(Exact_matches))
for i in range(len(Exact_matches)):
    for k in ks:
        if Exact_matches[i] <= k and Exact_matches[i] != -1:
            exact_k[k] += 1
        if MaxFrag_matches[i] <= k and MaxFrag_matches[i] != -1:
            MaxFrag_k[k] += 1

with open('log.txt', 'a') as f:
    f.write(f'\n\n{date.today()} using canonicalized test data: {canonicalize_data}; len testset: {len(results)}')
    for k in ks:
        print ('Top-%d Exact accuracy: %.3f, MaxFrag accuracy: %.3f' % \
               (k, exact_k[k]/len(Exact_matches), MaxFrag_k[k]/len(MaxFrag_matches)))
        f.write(f'\n\tTop-{k} Exact accuracy: {exact_k[k]/len(Exact_matches):.3f}, MaxFrag accuracy: {MaxFrag_k[k]/len(MaxFrag_matches):.3f}')

Calculating accuracy... 5000/50075007
Top-1 Exact accuracy: 0.064, MaxFrag accuracy: 0.067
Top-3 Exact accuracy: 0.088, MaxFrag accuracy: 0.093
Top-5 Exact accuracy: 0.097, MaxFrag accuracy: 0.103
Top-10 Exact accuracy: 0.107, MaxFrag accuracy: 0.115
Top-50 Exact accuracy: 0.118, MaxFrag accuracy: 0.131
