# COATI NLP model for encoding molecules

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from pathlib import Path
from typing import List, Optional, Tuple

import hydra
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from lightning import Callback, LightningDataModule, LightningModule, Trainer
from lightning.pytorch.loggers import Logger
from omegaconf import DictConfig, OmegaConf, open_dict

from src import utils
from src.coati.models.io import load_e3gnn_smiles_clip_e2e
from src.modules.collate_fn import default_collate
from src.modules.losses import InfoNCE
from src.modules.molecules.coati import COATI

In [3]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

cpjump1 already mounted.
cpjump2 already mounted.
cpjump3 already mounted.


## Developping the model

In [2]:
smiles = [
    "CC1CC2=CCOC2O1",
    "OC1CC1(O)CC1CC1",
    "CC1N2C=NCC12C#C",
    "CC1COC11C(O)C1O",
    "CC12OCC(CO1)C2=O",
    "CC12CC(CO1)CC2=O",
    "CCN=COC",
    "CC1(CO)CO1",
    "C(C#N)C(=O)N",
    "CC(=O)OC=N",
]

In [60]:
model = COATI(
    pretrained_name="grande_closed",
    out_dim=128,
    padding_length=250,
    freeze=False,
    device="cuda",
)

Loading model from s3://terray-public/models/grande_closed.pkl
Loading tokenizer may_closedparen from s3://terray-public/models/grande_closed.pkl
number of parameters: 12.64M
number of parameters Total: 2.44M xformer: 17.92M Total: 20.36M 
vocab_name not found in tokenizer_vocabs, trying to load from file


In [63]:
model2 = model.to("cpu")

In [64]:
model2.device

'cpu'

In [66]:
model(smiles).shape

KeyboardInterrupt: 

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
encoder, tokenizer = load_e3gnn_smiles_clip_e2e(
    # model parameters to load.
    doc_url="s3://terray-public/models/grande_closed.pkl",
    freeze=False,
    # device=torch.device("cpu"),
    # print_debug=True,
)

Loading model from s3://terray-public/models/grande_closed.pkl
Loading tokenizer may_closedparen from s3://terray-public/models/grande_closed.pkl
number of parameters: 12.64M
number of parameters Total: 2.44M xformer: 17.92M Total: 20.36M 
vocab_name not found in tokenizer_vocabs, trying to load from file


In [48]:
tokenizer.vocab

{'[PAD]': 0,
 '[STOP]': 1,
 '[SMILES]': 2,
 '[MASK]': 3,
 '[PREFIX]': 4,
 '[SUFFIX]': 5,
 '[MIDDLE]': 6,
 '[UNK]': 7,
 '[CLIP]': 8,
 '[FORMULA]': 9,
 '[GRAPH]': 10,
 '[EDGES]': 11,
 '[EDGE1]': 12,
 '[EDGEC]': 13,
 '[EDGE2]': 14,
 '[EDGE3]': 15,
 '[SET]': 16,
 '[ISOMORPHIC]': 17,
 '[VALID]': 18,
 '[TRUE]': 19,
 '[FALSE]': 20,
 '[geom_drugs]': 21,
 '[mcule]': 22,
 '[tspace_real]': 23,
 '[tensormol]': 24,
 '[chembl_mols]': 25,
 '[bbspace]': 26,
 '[zinc22]': 27,
 '[tspace_enum]': 28,
 '[ELM1]': 29,
 '[ELM2]': 30,
 '[ELM3]': 31,
 '[ELM4]': 32,
 '[ELM5]': 33,
 '[ELM6]': 34,
 '[ELM7]': 35,
 '[ELM8]': 36,
 '[ELM9]': 37,
 '[ELM10]': 38,
 '[ELM11]': 39,
 '[ELM12]': 40,
 '[ELM13]': 41,
 '[ELM14]': 42,
 '[ELM15]': 43,
 '[ELM16]': 44,
 '[ELM17]': 45,
 '[ELM18]': 46,
 '[ELM19]': 47,
 '[ELM20]': 48,
 '[ELM21]': 49,
 '[ELM22]': 50,
 '[ELM23]': 51,
 '[ELM24]': 52,
 '[ELM25]': 53,
 '[ELM26]': 54,
 '[ELM27]': 55,
 '[ELM28]': 56,
 '[ELM29]': 57,
 '[ELM30]': 58,
 '[ELM31]': 59,
 '[ELM32]': 60,
 '[ELM33]': 

In [47]:
tokenizer.__dir__()

['n_seq',
 'special_tokens',
 'smiles_tokens',
 'keys',
 'n_token',
 'vocab',
 'stop_token',
 'pad_token',
 'clip_token',
 'unk_token',
 'smiles_token',
 'suffix_token',
 'middle_token',
 'graph_token',
 'formula_token',
 'set_token',
 'smiles_trie',
 'special_trie',
 '__module__',
 '__doc__',
 '__init__',
 'pre_tokenize',
 'tokenize_text',
 'batch_smiles',
 'decode',
 '__dict__',
 '__weakref__',
 '__new__',
 '__repr__',
 '__hash__',
 '__str__',
 '__getattribute__',
 '__setattr__',
 '__delattr__',
 '__lt__',
 '__le__',
 '__eq__',
 '__ne__',
 '__gt__',
 '__ge__',
 '__reduce_ex__',
 '__reduce__',
 '__subclasshook__',
 '__init_subclass__',
 '__format__',
 '__sizeof__',
 '__dir__',
 '__class__']

In [28]:
encoder.xformer.transformer.ln_f.__dir__()

['T_destination',
 '__annotations__',
 '__call__',
 '__class__',
 '__constants__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattr__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_apply',
 '_backward_hooks',
 '_backward_pre_hooks',
 '_buffers',
 '_call_impl',
 '_forward_hooks',
 '_forward_hooks_with_kwargs',
 '_forward_pre_hooks',
 '_forward_pre_hooks_with_kwargs',
 '_get_backward_hooks',
 '_get_backward_pre_hooks',
 '_get_name',
 '_is_full_backward_hook',
 '_load_from_state_dict',
 '_load_state_dict_post_hooks',
 '_load_state_dict_pre_hooks',
 '_maybe_warn_non_full_backward_hook',
 '_modules',
 '_named_members',
 '_non_persistent_buffers_set',
 '_parameters',
 '_register_load_state_dict_pre_hook

In [39]:
batch_tokens = torch.tensor(
    [
        tokenizer.tokenize_text("[SMILES]" + s + "[STOP]", pad=True)
        if s != "*"
        else tokenizer.tokenize_text("[SMILES]C[STOP]", pad=True)
        for s in smiles
    ],
    device="cpu",
    dtype=torch.int,
)

In [40]:
batch_tokens.shape

torch.Size([10, 250])

In [38]:
tokenizer.n_seq = 250

In [32]:
tokenizer.__dir__()

['n_seq',
 'special_tokens',
 'smiles_tokens',
 'keys',
 'n_token',
 'vocab',
 'stop_token',
 'pad_token',
 'clip_token',
 'unk_token',
 'smiles_token',
 'suffix_token',
 'middle_token',
 'graph_token',
 'formula_token',
 'set_token',
 'smiles_trie',
 'special_trie',
 '__module__',
 '__doc__',
 '__init__',
 'pre_tokenize',
 'tokenize_text',
 'batch_smiles',
 'decode',
 '__dict__',
 '__weakref__',
 '__new__',
 '__repr__',
 '__hash__',
 '__str__',
 '__getattribute__',
 '__setattr__',
 '__delattr__',
 '__lt__',
 '__le__',
 '__eq__',
 '__ne__',
 '__gt__',
 '__ge__',
 '__reduce_ex__',
 '__reduce__',
 '__subclasshook__',
 '__init_subclass__',
 '__format__',
 '__sizeof__',
 '__dir__',
 '__class__']

In [9]:
batch_embeds = encoder.encode_tokens(batch_tokens, tokenizer)

In [10]:
encoder.encode_tokens?

[0;31mSignature:[0m [0mencoder[0m[0;34m.[0m[0mencode_tokens[0m[0;34m([0m[0mtoken_indices[0m[0;34m:[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m,[0m [0mtokenizer[0m[0;34m)[0m [0;34m->[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m Embeds the tokens, and projects into the latent space.
[0;31mFile:[0m      /mnt/2547d4d7-6732-4154-b0e1-17b0c1e0c565/Document-2/Projet2/Stage/workspace/jump_models/src/coati/models/encoding/clip_e2e.py
[0;31mType:[0m      method

In [56]:
encoder.to("cpu")

e3gnn_smiles_clip_e2e(
  (point_encoder): e3gnn_clip(
    (act_fn): SiLU()
    (embedding): Linear(in_features=28, out_features=256, bias=True)
    (embedding_norm): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (node_dec): Sequential(
      (0): Linear(in_features=256, out_features=256, bias=True)
      (1): SiLU()
      (2): Identity()
      (3): Linear(in_features=256, out_features=256, bias=True)
    )
    (gcl_0): e_gcl_sparse(
      (instance_norm): InstanceNorm1d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (edge_mlp): Sequential(
        (0): Linear(in_features=513, out_features=256, bias=True)
        (1): SiLU()
        (2): Identity()
        (3): Linear(in_features=256, out_features=256, bias=True)
        (4): SiLU()
        (5): Identity()
      )
      (node_mlp): Sequential(
        (0): Linear(in_features=512, out_features=256, bias=True)
        (1): SiLU()
        (2): Identity()
        (3):

In [57]:
encoder.batch_smiles_to_s2s_likelihood(smiles, tokenizer)

(tensor([ 5.1528,  3.8842, 41.5446,  4.8210,  7.4189,  2.2844,  2.4983,  0.7494,
         21.5638,  5.3194], grad_fn=<SumBackward1>),
 tensor([True, True, True, True, True, True, True, True, True, True]))

In [63]:
encoder.xformer.forward(batch_tokens)

tensor([[[ -4.7370,   0.3078,   1.6261,  ...,  -0.0516,  -3.1638,   0.5547],
         [-11.4987,  -0.3192,  -1.5464,  ...,  -6.1698, -10.2218,  -3.1609],
         [ -8.4435,   0.8021,  -0.8089,  ...,  -2.4167,  -7.5098,  -2.4759],
         ...,
         [-10.3736,   2.5105,  -0.7439,  ...,  -4.9863,  -8.6617,  -3.3883],
         [-10.2333,   2.5897,  -0.6997,  ...,  -4.8988,  -8.5454,  -3.3025],
         [-10.0865,   2.5194,  -0.6686,  ...,  -4.8681,  -8.4047,  -3.2709]],

        [[ -4.7370,   0.3078,   1.6261,  ...,  -0.0516,  -3.1638,   0.5547],
         [-11.6163,  -2.5803,  -1.9344,  ...,  -5.1848, -10.3400,  -5.3778],
         [ -9.6973,   3.3149,  -1.9919,  ...,  -2.9876,  -7.5202,  -5.2770],
         ...,
         [-10.5352,   4.5065,  -0.2985,  ...,  -4.9456,  -8.7423,  -3.3941],
         [-10.7862,   4.4251,  -0.2186,  ...,  -5.1446,  -8.9914,  -3.7423],
         [-10.8848,   4.3139,  -0.1930,  ...,  -5.3303,  -9.0748,  -3.8258]],

        [[ -4.7370,   0.3078,   1.6261,  ...

## Check config

In [4]:
# GlobalHydra.instance().clear()

In [5]:
initialize(version_base=None, config_path="../configs")

hydra.initialize()

In [6]:
cfg = compose(
    config_name="train.yaml",
    overrides=[
        "evaluate=true",
        "eval=hint",
        "paths.projects_dir=..",
        "paths.output_dir=./tmp/21312FS12A",
        "trainer.devices=1",
        "seed=22123",
        "experiment=coati/med",
        "trainer=gpu",
        "trainer.devices=[1]",
        "trainer.max_epochs=200",
        "data.num_workers=12",
        "data.transform.size=224",
        "data.batch_size=4",
        "model.embedding_dim=256",
        "model/image_encoder=vit_base_16_224",
        "model/criterion=ntxent_reg",
        "model.criterion.alpha=0.2",
        "model.criterion.mse_reg=0.5",
        "model.criterion.variance_reg=1",
        "model.criterion.covariance_reg=0.25",
        "model.criterion.temperature=10",
        "model.criterion.temperature_requires_grad=True",
    ],
)
# print(OmegaConf.to_yaml(cfg))

In [7]:
dm = instantiate(cfg.data)

vocab_name not found in tokenizer_vocabs, trying to load from file




In [10]:
if cfg.get("load_first_bacth"):
    dm.prepare_data()
    dm.setup("fit")
    dl = dm.train_dataloader(batch_size=2)
    b = next(iter(dl))
    example_input = b
else:
    example_input = None

In [11]:
model = instantiate(cfg.model, example_input=example_input)

vocab_name not found in tokenizer_vocabs, trying to load from file
number of parameters: 12.64M
number of parameters Total: 2.44M xformer: 17.92M Total: 20.36M 


In [15]:
cfg.trainer.devices = [0]

In [25]:
model: LightningModule = hydra.utils.instantiate(cfg.model)

callbacks: List[Callback] = utils.instantiate_callbacks(cfg.get("callbacks"))

logger: List[Logger] = utils.instantiate_loggers(cfg.get("logger"))

trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)

  rank_zero_warn(
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.rich_model_summary.RichModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


vocab_name not found in tokenizer_vocabs, trying to load from file
number of parameters: 12.64M
number of parameters Total: 2.44M xformer: 17.92M Total: 20.36M 


In [26]:
model.example_input_array = example_input

In [27]:
model(**model.example_input_array)

{'image_emb': tensor([[ 1.6610e-02, -4.1819e-01, -4.4254e-02,  1.2422e+00,  6.1211e-01,
          -2.6903e-01, -2.6012e-01,  4.2256e-01, -1.7027e-01,  3.6687e-01,
           8.9104e-01,  3.1795e-01,  2.8448e-01,  4.2938e-01, -2.4188e-01,
          -5.4162e-01,  4.1095e-01,  3.4153e-01, -4.0051e-01,  3.8637e-01,
          -3.0858e-01, -7.8681e-02, -1.7590e-01,  4.3911e-01, -5.6159e-01,
          -9.1884e-01,  7.0309e-02, -4.9634e-01,  7.9358e-01,  1.6343e-01,
          -8.4752e-01,  9.3032e-01, -3.9207e-01,  2.7210e-01, -5.9688e-02,
           9.2509e-02, -5.6339e-01,  4.5488e-01, -5.6979e-01, -8.1515e-02,
           4.2378e-01,  1.8099e-01, -3.3752e-01, -6.2349e-01, -8.4275e-02,
           3.3699e-02,  1.6890e-01, -4.9607e-02,  6.8864e-01, -3.5666e-01,
           4.0561e-01, -3.0659e-01,  2.1460e-01, -9.7925e-01,  3.7538e-03,
          -9.8630e-01,  2.2101e-01, -1.4444e-01, -2.0538e-01,  2.4341e-01,
           4.0621e-01, -6.1124e-01,  9.6070e-01, -7.7939e-01,  5.6610e-01,
           6

In [28]:
trainer.fit(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Output()