# Ligand embedding generation

In this notebook, we use Transformer-M to generate embeddings for small molecule ligands. First, we obtain canonical SMILES (2D) and 3D structure data for those ligands from DrugBank. We then wrap them into dataset format compatible with Transformer-M, before running inference with pretrained Transformer-M.

In [None]:
!uv pip install rdkit omegaconf hydra-core bitarray cython python-algos ogb

In [16]:
import pandas as pd
import os
from torch_geometric.data import Data
from rdkit import Chem

from procyon.data.data_utils import DATA_DIR

## SMILES and 3D structure data

If you are interested in generating Transformer-M embeddings for your own data, please format them into a SMILES store and a 3D structure store, similarly as below. Inference can also be performed with only SMILES.

In [18]:
drugbank_drugs = pd.read_pickle(DATA_DIR + "integrated_data/v1/drugbank/drugbank_info_filtered.pkl")
drugbank_df = pd.read_pickle(DATA_DIR + "integrated_data/v1/drugbank/raw_drugbank_df.pkl")
drugbank_drugs = drugbank_drugs.merge(drugbank_df[['drugbank_id', 'smiles']], on='drugbank_id')

def canon(sm):
    try:
        return Chem.MolToSmiles(Chem.MolFromSmiles(sm, sanitize=True))
    except:
        return sm
drugbank_drugs['canonical_smiles'] = drugbank_drugs['smiles'].apply(lambda x: canon(x))
print((~drugbank_drugs['canonical_smiles'].isna()).sum())
print(drugbank_drugs.shape[0])

2399
3018


In [19]:
drugbank_3d_data = Chem.SDMolSupplier(DATA_DIR + "integrated_data/v1/drugbank/drugbank_3d_structures.sdf")
print(len(drugbank_3d_data))

9468


In [20]:
drugbank_drugs_3d = [None] * len(drugbank_drugs)
drugbank_id_to_index = drugbank_drugs[['index', 'drugbank_id']].set_index('drugbank_id').to_dict()['index']

for mol_3d in drugbank_3d_data:
    drugbank_id = mol_3d.GetProp('DATABASE_ID')
    if drugbank_id in drugbank_id_to_index.keys():
        index = drugbank_id_to_index[drugbank_id]
        drugbank_drugs_3d[index] = mol_3d

## Formatting into dataset and inference

Please clone the [Transformer-M fork](https://github.com/jasperhyp/Transformer-M/tree/main) before executing the code below.

The L18 pretrained weights tensor was downloaded following the Transformer-M instructions. A copy was saved in the corresponding directory as specified in the code below.

------------------------------------------------------------------

As a reference, we made the following changes to the original repo:
1. Renamed the `Transformer-M` source subdirectory (not the project directory) as `Transformer_M`.
2. Transformer-M was developed with a very different environment than ProCyon. Instead of setting up a new environment, we patched the Transformer-M codebase so that we can use ProCyon environment to perform inference there. To use the codebase without a new environment, the following files were edited:
    1. `Transformer-M/Transformer_M/data/algos.pyx`: Replace all `astype(long,` with `astype(int,`.
    2. `Transformer-M/fairseq/modules/__init__.py`: Comment out lines 39, 77, 78
    3. `Transformer-M/Transformer_M/tasks/graph_prediction.py`:
        - Comment out lines 33, 35-45, 161-end
        - Add `from fairseq.dataclass import FairseqDataclass`
    4. `Transformer-M/fairseq/data/indexed_dataset.py`: Replace all `np.float` with `float`
    5. `Transformer-M/fairseq/__init__.py`: Comment out lines 32-end.
    6. `Transformer-M/fairseq/dataclass/initialize.py`
        - Add `import dataclasses`
        - Replace `v = FairseqConfig.__dataclass_fields__[k].default` with 
            ```python
            field = FairseqConfig.__dataclass_fields__[k]
            v = field.default
            if v is dataclasses.MISSING and field.default_factory is not dataclasses.MISSING:
                v = field.default_factory()
            ```
    7. `Transformer-M/fairseq/dataclass/configs.py`: Replace all definitions such as `common: CommonConfig = CommonConfig()` in class `FairseqConfig` with `common: CommonConfig = field(default_factory=CommonConfig)`. I.e., instead of defining the default to be an instance, define the default as a `field` of `default_factory`.

In [None]:
!cd /path/to/Transformer-M; python setup_cython.py build_ext --inplace

In [None]:
from pathlib import Path
import sys
PROJECT_ROOT = Path("/path/to/Transformer-M")

# 2. Add this OUTER folder to sys.path
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

# 3. Import
# Now Python finds 'Transformer_M' inside 'PROJECT_ROOT'
from Transformer_M.modules.transformer_m_encoder import TransformerMEncoder
from Transformer_M.tasks.graph_prediction import GraphPredictionConfig
from Transformer_M.data.wrapper import (
    smiles2graph,
    mol2graph,
    preprocess_item,
)
from Transformer_M.data.collator import collator_3d

print("Success: Source folder imported correctly.")

Success: Source folder imported correctly.


In [None]:
import torch

from functools import lru_cache
import numpy as np
from tqdm import tqdm
import multiprocess as mp
from torch_geometric.data import InMemoryDataset

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Prepare dataset

In [21]:
drugbank_drugs_w_smiles = drugbank_drugs[~drugbank_drugs['canonical_smiles'].isna()][['drugbank_id', 'canonical_smiles']]
drugbank_drugs_w_smiles_3d = dict(zip(drugbank_drugs_w_smiles['drugbank_id'].values, np.array(drugbank_drugs_3d)[~drugbank_drugs['canonical_smiles'].isna()].tolist()))

In [23]:
class PyGDataset(InMemoryDataset):
    def __init__(self, root_dir, drugbank_drugs_2d, drugbank_drugs_3d, use_3d=True, smiles2graph=smiles2graph, transform=None, pre_transform=None):
        self.drugbank_drugs_2d = drugbank_drugs_2d
        self.drugbank_drugs_3d = drugbank_drugs_3d
        self.use_3d = use_3d
        self.smiles2graph = smiles2graph

        super(PyGDataset, self).__init__(root_dir, transform, pre_transform)

        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return None

    @property
    def raw_dir(self) -> str:
        return self.root
    
    @property
    def processed_file_names(self):
        return f'all_molecules_transformer_m.pt'
    
    @property
    def processed_dir(self) -> str:
        return self.root

    def process(self):
        smiles_dict = self.drugbank_drugs_2d.set_index('drugbank_id').to_dict()['canonical_smiles']
        graph_pos_dict = self.drugbank_drugs_3d

        print('Converting SMILES strings into 2D graphs...')
        data_2d_list = []
        with mp.Pool(processes=16) as pool:
            iter = pool.imap(smiles2graph, list(smiles_dict.values()))

            for i, graph in tqdm(enumerate(iter)):
                try:
                    data = Data()

                    assert (len(graph['edge_feat']) == graph['edge_index'].shape[1])
                    assert (len(graph['node_feat']) == graph['num_nodes'])

                    data.__num_nodes__ = int(graph['num_nodes'])
                    data.edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
                    data.edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
                    data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
                    data.pos = torch.zeros(data.__num_nodes__, 3).to(torch.float32)

                    data_2d_list.append(data)
                
                except:
                    data_2d_list.append(None)
                    continue
        
        data_2d_dict = dict(zip(smiles_dict.keys(), data_2d_list))
        
        print('Extracting 3D positions...')
        data_3d_list = []
        
        with mp.Pool(processes=120) as pool:
            iter = pool.imap(mol2graph, list(graph_pos_dict.values()))

            for i, graph in tqdm(enumerate(iter), total=len(graph_pos_dict)):
                if graph is None:
                    data_3d_list.append(None)
                    continue
                
                try:
                    data = Data()

                    assert (len(graph['edge_feat']) == graph['edge_index'].shape[1])
                    assert (len(graph['node_feat']) == graph['num_nodes'])

                    data.__num_nodes__ = int(graph['num_nodes'])
                    data.edge_index = torch.from_numpy(graph['edge_index']).to(torch.int64)
                    data.edge_attr = torch.from_numpy(graph['edge_feat']).to(torch.int64)
                    data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
                    data.pos = torch.from_numpy(graph['position']).to(torch.float32)

                    data_3d_list.append(data)
                    
                except:
                    data_3d_list.append(None)
                    continue
        
        data_3d_dict = dict(zip(graph_pos_dict.keys(), data_3d_list))
        data_dict = dict((k, v) if v is not None else (k, data_2d_dict[k]) for k, v in data_3d_dict.items())
        for k, v in list(data_dict.items()):
            if v is None:
                data_dict.pop(k)
                data_2d_dict.pop(k)
                data_3d_dict.pop(k)
        
        self.data_3d_dict = data_3d_dict
        self.data_2d_dict = data_2d_dict
        self.data_dict = data_dict

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_dict.values()]
        else:
            data_list = list(data_dict.values())
            
        data, slices = self.collate(data_list)

        print('Saving...')
        torch.save((data, slices), self.processed_paths[0])

    @lru_cache(maxsize=2048)
    def __getitem__(self, idx):
        # item = self.get(self.indices()[idx])
        item = self.get(idx)
        item.idx = idx
        return preprocess_item(item)

In [24]:
all_mols_store = PyGDataset(
    root_dir=DATA_DIR+"integrated_data/v1/drugbank/", 
    drugbank_drugs_2d=drugbank_drugs_w_smiles, 
    drugbank_drugs_3d=drugbank_drugs_w_smiles_3d, 
    use_3d=True,
)

Processing...


Converting SMILES strings into 2D graphs...


2399it [00:00, 3145.25it/s]

Extracting 3D positions...



100%|██████████| 2399/2399 [00:00<00:00, 3024.10it/s]


Saving...


Done!


### Load model and inference

In [26]:
config = GraphPredictionConfig  # do not change since the pretrained models are trained with this config
multi_hop_max_dist = 5
spatial_pos_max = 1024

# config.num_atoms = 512*9
# config.num_in_degree = 512
# config.num_out_degree = 512
# config.num_edges = 512*3
# config.num_spatial = 512
# config.num_edge_dis = 128
# config.edge_type = "multihop"
config.multi_hop_max_dist = multi_hop_max_dist
config.encoder_layers = 18
config.encoder_embed_dim = 768
config.encoder_ffn_embed_dim = 768 
config.encoder_attention_heads = 32 
config.dropout = 0.0
config.attention_dropout = 0.1 
config.act_dropout = 0.1
# config.max_positions = 512
config.num_segment = 2
config.no_token_positional_embeddings = False
config.encoder_normalize_before = True
config.apply_init = True
config.activation_fn = "gelu"
config.encoder_learned_pos = True
config.sandwich_ln = False
config.droppath_prob = 0.1
config.add_3d = True
config.num_3d_bias_kernel = 128
config.mode_prob = "0.2,0.2,0.6"
config.no_2d = False
# config.noise_scale = 0.2  # 0.01
# config.criterion = "graph_prediction"
# config.arch = "transformer_m_base"

mol_encoder = TransformerMEncoder(
    num_atoms=config.num_atoms,
    num_in_degree=config.num_in_degree,
    num_out_degree=config.num_out_degree,
    num_edges=config.num_edges,
    num_spatial=config.num_spatial,
    num_edge_dis=config.num_edge_dis,
    edge_type=config.edge_type,
    multi_hop_max_dist=config.multi_hop_max_dist,
    num_encoder_layers=config.encoder_layers,
    embedding_dim=config.encoder_embed_dim,
    ffn_embedding_dim=config.encoder_ffn_embed_dim,
    num_attention_heads=config.encoder_attention_heads,
    dropout=config.dropout,
    attention_dropout=config.attention_dropout,
    activation_dropout=config.act_dropout,
    max_seq_len=config.max_positions,
    num_segments=config.num_segment,
    use_position_embeddings=not config.no_token_positional_embeddings,
    encoder_normalize_before=config.encoder_normalize_before,
    apply_init=config.apply_init,
    activation_fn=config.activation_fn,
    learned_pos_embedding=config.encoder_learned_pos,
    sandwich_ln=config.sandwich_ln,
    droppath_prob=config.droppath_prob,
    add_3d=config.add_3d,
    num_3d_bias_kernel=config.num_3d_bias_kernel,
    no_2d=config.no_2d,
    mode_prob=config.mode_prob,
)

mol_encoder_state_dict = torch.load(os.path.join(DATA_DIR, "model_weights/L18"))["model"]
remove_keys = []
for k, v in list(mol_encoder_state_dict.items()):  # create a copy of the items to avoid changing the original dict size while iterating itself
    if 'molecule_encoder' not in k:
        pass
    else:
        mol_encoder_state_dict[k[len('encoder.molecule_encoder.'):]] = v
    remove_keys.append(k)

for k in remove_keys: 
    mol_encoder_state_dict.pop(k)
mol_encoder.load_state_dict(mol_encoder_state_dict)

<All keys matched successfully>

In [None]:
batch_size = 4

all_embeddings = []

for start in range(0, len(all_mols_store), batch_size):
    raw_batch_mols = []

    # Dummy label to avoid crash in collator
    for i in range(start, start + batch_size):
        item = all_mols_store[i]
        item.y = torch.tensor([0.0]) 
        raw_batch_mols.append(item)
    raw_batch_mols = tuple(raw_batch_mols)
    
    # raw_batch_mols = tuple(all_mols_store[i] for i in range(start, start+batch_size))
    batch_mols = collator_3d(raw_batch_mols, max_node=100000, multi_hop_max_dist=multi_hop_max_dist, spatial_pos_max=spatial_pos_max)
    temp, _ = mol_encoder(batch_mols, last_state_only=True)
    assert len(temp) == 1
    all_embeddings.append(temp[0][0, :, :].detach().cpu())  # [SEQ, BATCH, FEAT]

all_embeddings = torch.cat(all_embeddings, dim=0)

In [45]:
drugbank_ids = list(all_mols_store.data_dict.keys())
drug_indices = [drugbank_id_to_index[drugbank_id] for drugbank_id in drugbank_ids]
input_3d = [all_mols_store.data_3d_dict[drugbank_id] is not None for drugbank_id in drugbank_ids]

In [None]:
torch.save({
    "embeds": all_embeddings,
    "drugbank_ids": drugbank_ids,
    "drugbank_indices": drug_indices,
    "input_3d": input_3d,
}, DATA_DIR+"integrated_data/v1/drugbank/drugbank_compound_embeddings_transformer_m_18.pt")