# CS598 Deep Learning for Healthcare - Reproducibility Project Draft

Darin Zhen, George Vojvodic, and Alan Yee; {darinz2, dvojvo2, alanyee2} @illinois.edu

Group ID: 10

Paper ID: 94

Text2Mol: Cross-Modal Molecule Retrieval with Natural Language Queries


# Relevant Terminology

- **Hits@1**: A metric commonly used in information retrieval and recommendation systems to evaluate their performance. It measures the accuracy of a system in predicting the top recommendation or the correct answer out of a list of options for a given query or user interaction.

- **Mean Reciprocal Rank (MRR)**: A metric used to evaluate the effectiveness of information retrieval systems, particularly in the context of ranked retrieval. It measures the quality of the ranked list of results by considering the position of the first relevant item in the list.

- **Molecule**: An electrically neutral group of atoms bonded together.

- **Compound**: Two or more elements held together by chemical bonds.

- **Chemical fingerprint**: Represents a molecule or substructure using a bitstring. This allows for efficient substructure search and similarity calculation.

- **Morgan fingerprint**: A specific type of chemical fingerprint also known as ECFP.

- **SMILES string**: A character-based sequence representation of a molecule. (for example, C1=CC=CC=C1 is the SMILES string for benzene)

- **Canonical SMILES**: A unique SMILES string for a molecule.


# Introduction

The discovery of new molecules and understanding their properties is critical for advancing fields like medicine, chemistry, and materials science. However, the vast number of possible molecules makes it impractical to experimentally characterize each one. There are already tens of millions of molecules cataloged in databases like PubChem. Efficiently retrieving relevant molecules from these large databases given natural language descriptions is an important yet challenging problem.

Current methods for molecule retrieval typically rely on structured representations like molecular fingerprints or SMILES strings. While these enable substructure matching and similarity searches, they do not directly integrate the semantic information contained in natural language descriptions. Some approaches replace chemical names in text with canonical identifiers, but this fails to capture the full meaning. Solving the problem of cross-modal retrieval between natural language and molecules would allow scientists to easily search for molecules based on high-level conceptual descriptions rather than just structural patterns.

The key challenge lies in bridging the stark difference between the modalities of natural language and molecular structure data. Molecules are usually represented as graphs with atoms as nodes and bonds as edges, following a unique grammar quite distinct from human language. This makes cross-modal retrieval exceptionally challenging compared to traditional cross-lingual information retrieval between natural languages.

In this paper, the authors propose a novel "Text2Mol" task for retrieving molecules directly from natural language descriptions (shown in Fiture 1). They develop a multimodal embedding approach to learn an aligned semantic space bridging text and molecular structure data. This allows ranking molecules by similarity to text query descriptions. The paper makes several innovations, including extending the loss function with negative sampling to encourage integration of both modalities, using cross-modal attention to extract interpretable "association rules" between text and molecular substructures, and an ensemble method that significantly boosts performance.
<br>
<img src="https://drive.google.com/uc?id=1mG9lgWvfpZ2tplVl6xqLZMyzj5ooPhy8" width=500 />
<br>
<b>Figure 1: Given a natural language description of water, we want to rank the corresponding molecule $H_2O$ first among all the possible molecules.</b>
<br>

On a new benchmark dataset of over 33,000 text-molecule pairs, the proposed methods achieve a mean reciprocal rank of 0.499, substantially outperforming baselines. The cross-modal attention model provides insightful explanations grounding the language representations to the molecular structure. Overall, this multimodal approach offers a powerful solution for understanding chemistry literature and searching molecular databases, with broad potential applications in drug discovery, materials design, and scientific knowledge exploration.


# Scope of Reproducibility

The scope of reproducibility in the paper encompasses the following key hypotheses that will reproduce:

1. **Hypothesis 1**: Cross-modal embedding can effectively align text and molecule spaces for retrieval. This involves reproducing embedding models and evaluating retrieval metrics like Mean Reciprocal Rank (MRR).

2. **Hypothesis 2**: Ensemble of different architectures (MLP vs GCN) improves results compared to individual models. This involves training different models and comparing ensemble versus individual performance.

3. **Hypothesis 3**: Cross-modal attention provides insights into text-molecule associations. This will be examined by analyzing attention weights and extracted rules for coherence.

4. **Hypothesis 4**: Different architectures possess complementary strengths, where MLP may rank easier examples better but GCN generalizes better. This will be probed by analyzing differences in rankings between architectures.

5. **Hypothesis 5**: Cross-modal reranking using attention rules improves over the base model. Testing reranking on a holdout set will validate this hypothesis.

# Setup

**NOTE:** This notebook requires **GPU** runtime due to the extensive use of hundreds of millions of parameters in the training model.

This notebook supports accessing data and images:


In [None]:
# Data directory
!mkdir /content/data

# Image directory
!mkdir /content/image

# Input directory
!mkdir /content/input

# Download files into the data folder.

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1XYjz33tWWet6t4QouZkVn3TRYliG5L5F' -O ChEBI_defintions_substructure_corpus.cp
!mv /content/ChEBI_defintions_substructure_corpus.cp /content/data/ChEBI_defintions_substructure_corpus.cp

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1U_0mn1-GZ7NtL8Bk2S8Yr9IQQp51qbpf' -O chem_embeddings_test.npy
!mv /content/chem_embeddings_test.npy /content/data/chem_embeddings_test.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1nVQIExr7toG3Ob6CJzPK51LKq9OsL1Ar' -O chem_embeddings_train.npy
!mv /content/chem_embeddings_train.npy /content/data/chem_embeddings_train.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1ybhCMaMFSETom3PlOGTkXXRpuauCMztH' -O chem_embeddings_val.npy
!mv /content/chem_embeddings_val.npy /content/data/chem_embeddings_val.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1vPKXKtlQx8oX3-SAlgzTGvXHU0SwN1Rr' -O cids_test.npy
!mv /content/cids_test.npy /content/data/cids_test.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1CUvn6lOUbbb7sJkqnS0jQ5eYYnJfDM7F' -O cids_train.npy
!mv /content/cids_train.npy /content/data/cids_train.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1LMfx63hrd6r5pAQWUsK5BYbb5TEcumuc' -O cids_val.npy
!mv /content/cids_val.npy /content/data/cids_val.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1SmKQJPKRePUXyomOBdMwha75-D5O4413' -O test.sdf
!mv /content/test.sdf /content/data/test.sdf

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1dhhTAD3z97yOQYSK0Go-RI-bgjJdSg6e' -O test.txt
!mv /content/test.txt /content/data/test.txt

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=12DlIeAwx_oeJBsuRAoi_t4cr7Rj3G8UN' -O text_embeddings_test.npy
!mv /content/text_embeddings_test.npy /content/data/text_embeddings_test.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1tozq-TbD2avIqwtZkA7TtwXgskUn_dMX' -O text_embeddings_train.npy
!mv /content/text_embeddings_train.npy /content/data/text_embeddings_train.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1C3-4KHwKySBCXri6YYj6IFJ3BppYkpYI' -O text_embeddings_val.npy
!mv /content/text_embeddings_val.npy /content/data/text_embeddings_val.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1gR4B11xGBLGwYQ2-s_k0C19HUXYDFSoa' -O token_embedding_dict.npy
!mv /content/token_embedding_dict.npy /content/data/token_embedding_dict.npy

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1g8a6wg1OC8okFltFJOaNbMC5BsWuQdML' -O training.sdf
!mv /content/training.sdf /content/data/training.sdf

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=17SvDWffLm8Eez7KIIyZt3RvsNOkSDZMM' -O training.txt
!mv /content/training.txt /content/data/training.txt

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1i39CtXI7HbdtnRG4lHMCFKQ7i_AZgSdn' -O val.sdf
!mv /content/val.sdf /content/data/val.sdf

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1sQ7iYAHIRMmq0YePRaiRngcPNwFPzZoO' -O val.txt
!mv /content/val.txt /content/data/val.txt


# Download files into the input folder.

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1sOLCJIDKZCZO-9jSFPsQGYK4h17sDvU7' -O ChEBI_defintions_substructure_corpus.cp
!mv /content/ChEBI_defintions_substructure_corpus.cp /content/input/ChEBI_defintions_substructure_corpus.cp

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1G1iST_JfJTfj1hRBzzqgIBQZDUjzv5oA' -O mol2vec_ChEBI_20_test.txt
!mv /content/mol2vec_ChEBI_20_test.txt /content/input/mol2vec_ChEBI_20_test.txt

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1aSuZUiiM7Bmsg-zONm7RMfAm8MizL1tc' -O mol2vec_ChEBI_20_training.txt
!mv /content/mol2vec_ChEBI_20_training.txt /content/input/mol2vec_ChEBI_20_training.txt

!wget --no-check-certificate 'https://drive.google.com/uc?export=download&id=1BERxF6s-GPDRaEOoSBfhfs-u6RS1UUSx' -O mol2vec_ChEBI_20_val.txt
!mv /content/mol2vec_ChEBI_20_val.txt /content/input/mol2vec_ChEBI_20_val.txt


# Methodology



The paper proposes a new task called Text2Mol, which aims to retrieve molecules from natural language descriptions. The methodology involves the following key steps:

1. Construct a dataset of molecule-text description pairs from sources like PubChem and ChEBI.

2. Learn aligned semantic embeddings for text and molecules using:
    a) A text encoder based on SciBERT
    b) A molecule encoder using either a multi-layer perceptron (MLP) on Mol2vec embeddings or a graph convolutional network (GCN) on the molecular graph with Mol2vec features.

3. Train the encoders using a contrastive loss that aims to bring positive (matching) molecule-text pairs closer and push negative pairs apart in the embedding space.

4. At inference time, encode the text query and retrieve the nearest molecule embeddings using cosine similarity.

5. Explore ensembling multiple trained models and incorporating cross-modal attention to learn association rules between text tokens and molecular substructures for explainability and reranking.

The key novelties are applying contrastive learning across the text and molecule modalities, proposing the Text2Mol retrieval task, and using cross-modal attention for explainable retrieval via association rules.

## Environment

### Python Version



This project uses Python version 3.10.

### Dependencies & Packages Needed

This project relies on several Python libraries and modules for text-to-molecule tasks:

1. **Operating System Interaction**: `os` module for interacting with the operating system.
2. **File Operations**: `shutil` module for file operations.
3. **Time-related Functions**: `time` module for time-related functions.
4. **Mathematical Operations**: `math` module for mathematical operations.
5. **Numerical Computations**: `numpy` (`np` alias) for numerical computations.
6. **Plotting**: `matplotlib.pyplot` (`plt` alias) for plotting.
7. **Cosine Similarity Computation**: `cosine_similarity` from `sklearn.metrics.pairwise` for computing cosine similarity.
8. **Deep Learning Framework**: `torch` for PyTorch, a deep learning framework.
9. **Neural Network Components**: `torch.nn` for neural network modules and `torch.nn.functional` (`F` alias) for functional interfaces.
10. **Data Handling Utilities**: `torch.utils.data` for handling data in PyTorch, including `Dataset` and `DataLoader`.
11. **Tokenization**: `tokenizers` for tokenization, including the `Tokenizer` class.
12. **BERT Model and Tokenizer**: `BertTokenizerFast` and `BertModel` from Hugging Face's Transformers library for BERT tokenizer and model.
13. **CSV File Handling**: `csv` module for reading and writing CSV files.
14. **Graph Convolutional Network (GCN)**: `torch_geometric.nn` for GCN operations, including `GCNConv` and `global_mean_pool`.
15. **Transformer Decoder**: `TransformerDecoder` and `TransformerDecoderLayer` from PyTorch for transformer decoder operations.
16. **Optimization**: `torch.optim` for optimization, including various optimizers.
17. **Learning Rate Scheduler**: `get_linear_schedule_with_warmup` from the transformers library for learning rate scheduling.

Additionally, the code includes an installation command (`!pip install torch_geometric`) to install the `torch_geometric` library.

In [None]:
# This code imports various Python libraries and modules that are used in this
# notebook for Text2Mol
# Importing necessary libraries/modules
import os                   # Module for interacting with the operating system
import shutil               # Module for file operations
import time                 # Module for time-related functions

import math                # Module for mathematical operations

import numpy as np         # NumPy, a library for numerical computations

import matplotlib.pyplot as plt  # Matplotlib, a plotting library
from sklearn.metrics.pairwise import cosine_similarity  # Module for cosine similarity computation

import torch               # PyTorch, a deep learning framework
from torch import nn       # Neural network module from PyTorch
import torch.nn.functional as F  # Functional interface to neural network components in PyTorch
from torch.utils.data import Dataset, DataLoader  # Utilities for handling data in PyTorch

import tokenizers         # Tokenizers library for tokenization
from tokenizers import Tokenizer  # Tokenizer class for tokenization
from transformers import BertTokenizerFast, BertModel  # BERT tokenizer and model from Hugging Face's
                                                       # Transformers library

import csv                 # Module for reading and writing CSV files

## Data

### Data Download Instruction

To download the data for the project, follow these instructions:

1. **ChEBI Annotations of Compounds from [PubChem](https://pubchem.ncbi.nlm.nih.gov/)**: (This part is not necessary as the raw ChEBI dataset is not directly used in the model, and the actual dataset is already directly provided and organized in `2.`. Nonetheless, we provide steps here to obtain the raw data.)
   1. Visit the [ChEBI](https://www.ebi.ac.uk/chebi/) website.
   1. Under `Downloads`, click on `SDF files`.
   1. Click on `ChEBI_complete.sdf.gz`
   1. Choose where to download to zipped file and click `Save`

2. **ChEBI-20 Dataset**: (This part is required.)
   - Access the [ChEBI-20 dataset repository](https://github.com/cnedwards/text2mol/tree/master/data).

### Data Descriptions

The paper makes use of the following datasets:

1. [ChEBI](https://www.ebi.ac.uk/chebi/) (Chemical Entities of Biological Interest) annotations of compounds scraped from [PubChem](https://pubchem.ncbi.nlm.nih.gov/).
   - This contains 102,980 compound-description pairs.

2. [ChEBI-20](https://github.com/cnedwards/text2mol/tree/master/data) dataset
   - Constructed from the ChEBI/PubChem data by filtering for descriptions longer than 20 words.
   - Contains 33,010 text-compound pairs.
   - Split into 80/10/10% train/validation/test sets.

The ChEBI-20 dataset forms the main benchmark used to evaluate the proposed cross-modal molecule retrieval methods.

For representing the molecular structures, the paper uses:

1. Mol2vec representations
   - Molecular graphs are converted to "sentences" using the Morgan fingerprinting (shown in Figure 2) algorithm which generates substructure identifiers.
   <br>
   <img src="https://drive.google.com/uc?id=1lT5fla5UnQXW60-utzHYWDFdg62sMgp0" width="400">
   <br>
   <b>Figure 2: Example of Morgan Fingerprinting from (Rogers and Hahn, 2010) for Butyramide. The algorithm updates the identifiers from radius r = 0 to r = 1, as shown by the green circles.</b>
   <br>
   - The Mol2vec algorithm applies Word2vec on these substructure sentences to produce molecule embeddings.
   - Default radius of 1 is used, giving two substructure tokens per atom.

2. SMILES strings
   - Character-based representation of molecules that can be parsed into molecular graphs.

The text descriptions are encoded using the pre-trained SciBERT language model.

The key data is the new ChEBI-20 dataset of paired text descriptions and molecular structures, with molecules represented by Mol2vec embeddings or SMILES strings, and descriptions encoded by SciBERT.

#### The ChEBI-20 dataset is contained in 6 files:


(1,2,3) The mol2vec_ChEBI_20_X.txt files have lines in the following form:
```
CID	mol2vec embedding	Description
```

(4) mol_graphs.zip contain {cid}.graph files. These are formatted first with the edgelist of the graph and then substructure tokens for each node.
For example,
edgelist:
```
0 1
1 0
1 2
2 1
1 3
3 1
```
idx to identifier:
```
0 3537119515
1 2059730245
2 3537119515
3 1248171218
```

(5) ChEBI_defintions_substructure_corpus.cp contains the molecule token "sentences". It is formatted:
```
cid: tokenid1 tokenid2 tokenid3 ... tokenidn
```

(6) token_embedding_dict.npy is a dictionary mapping molecule tokens to their embeddings.

It can be loaded with the following code:
```python
import numpy as np
token_embedding_dict = np.load("token_embedding_dict.npy", allow_pickle=True)[()]
```

### Preprocessing Code + Command

#### GenerateData Class

This class is designed to handle the generation of examples for training,
validation, and testing sets, where each example contains both text data
(processed using a BERT tokenizer) and molecule data. The methods within
the class prepare the necessary data structures, tokenize text inputs,
and yield examples in the desired format.

In [None]:
#Need a special generator for random sampling:

class GenerateData():
    def __init__(self, path_train, path_val, path_test, path_molecules, path_token_embs):
        # Constructor method initializing paths and parameters
        self.path_train = path_train  # Path to the training data file
        self.path_val = path_val  # Path to the validation data file
        self.path_test = path_test  # Path to the test data file
        self.path_molecules = path_molecules  # Path to the file containing molecule data
        self.path_token_embs = path_token_embs  # Path to the file containing token embeddings

        self.text_trunc_length = 256  # Maximum length for text input tokens

        # Initialize text tokenizer
        self.prep_text_tokenizer()

        # Load substructures from molecule data
        self.load_substructures()

        self.batch_size = 32  # Batch size for data processing

        # Store descriptions
        self.store_descriptions()

    def load_substructures(self):
        # Method to load substructures from molecule data
        self.molecule_sentences = {}  # Dictionary to store molecule sentences
        self.molecule_tokens = {}  # Dictionary to store molecule tokens

        total_tokens = set()  # Set to store unique tokens
        self.max_mol_length = 0  # Variable to store maximum molecule length
        with open(self.path_molecules) as f:
            for line in f:
                spl = line.split(":")
                cid = spl[0]  # Compound ID
                tokens = spl[1].strip()  # Tokens for the compound
                self.molecule_sentences[cid] = tokens
                t = tokens.split()
                total_tokens.update(t)  # Add tokens to the set
                size = len(t)
                if size > self.max_mol_length:
                    self.max_mol_length = size  # Update maximum molecule length

        # Load token embeddings
        self.token_embs = np.load(self.path_token_embs, allow_pickle=True)[()]

    def prep_text_tokenizer(self):
        # Method to prepare text tokenizer (using BERT)
        self.text_tokenizer = BertTokenizerFast.from_pretrained("allenai/scibert_scivocab_uncased")

    def store_descriptions(self):
        # Method to store descriptions from training, validation, and test sets
        self.descriptions = {}  # Dictionary to store descriptions
        self.mols = {}  # Dictionary to store molecule data

        self.training_cids = []  # List to store training set compound IDs
        # Get training set compound IDs
        with open(self.path_train) as f:
            reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE,
                                    fieldnames=['cid', 'mol2vec', 'desc'])
            for n, line in enumerate(reader):
                self.descriptions[line['cid']] = line['desc']
                self.mols[line['cid']] = line['mol2vec']
                self.training_cids.append(line['cid'])

        self.validation_cids = []  # List to store validation set compound IDs
        # Get validation set compound IDs
        with open(self.path_val) as f:
            reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE,
                                    fieldnames=['cid', 'mol2vec', 'desc'])
            for n, line in enumerate(reader):
                self.descriptions[line['cid']] = line['desc']
                self.mols[line['cid']] = line['mol2vec']
                self.validation_cids.append(line['cid'])

        self.test_cids = []  # List to store test set compound IDs
        # Get test set compound IDs
        with open(self.path_test) as f:
            reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE,
                                    fieldnames=['cid', 'mol2vec', 'desc'])
            for n, line in enumerate(reader):
                self.descriptions[line['cid']] = line['desc']
                self.mols[line['cid']] = line['mol2vec']
                self.test_cids.append(line['cid'])

    def generate_examples_train(self):
        # Method to generate examples for training set
        np.random.shuffle(self.training_cids)  # Shuffle training compound IDs

        for cid in self.training_cids:
            text_input = self.text_tokenizer(self.descriptions[cid], truncation=True,
                                             max_length=self.text_trunc_length,
                                             padding='max_length', return_tensors='np')  # Tokenize text input

            yield {
                'cid': cid,
                'input': {
                    'text': {
                        'input_ids': text_input['input_ids'].squeeze(),
                        'attention_mask': text_input['attention_mask'].squeeze(),
                    },
                    'molecule': {
                        'mol2vec': np.fromstring(self.mols[cid], sep=" "),  # Convert molecule data to NumPy array
                        'cid': cid
                    },
                },
            }

    def generate_examples_val(self):
        # Method to generate examples for validation set
        np.random.shuffle(self.validation_cids)  # Shuffle validation compound IDs

        for cid in self.validation_cids:
            text_input = self.text_tokenizer(self.descriptions[cid], truncation=True,
                                             padding='max_length',
                                             max_length=self.text_trunc_length, return_tensors='np')  # Tokenize text input

            yield {
                'cid': cid,
                'input': {
                    'text': {
                        'input_ids': text_input['input_ids'].squeeze(),
                        'attention_mask': text_input['attention_mask'].squeeze(),
                    },
                    'molecule': {
                        'mol2vec': np.fromstring(self.mols[cid], sep=" "),  # Convert molecule data to NumPy array
                        'cid': cid
                    }
                },
            }

    def generate_examples_test(self):
        # Method to generate examples for test set
        np.random.shuffle(self.test_cids)  # Shuffle test compound IDs

        for cid in self.test_cids:
            text_input = self.text_tokenizer(self.descriptions[cid], truncation=True, padding='max_length',
                                             max_length=self.text_trunc_length, return_tensors='np')  # Tokenize text input

            yield {
                'cid': cid,
                'input': {
                    'text': {
                        'input_ids': text_input['input_ids'].squeeze(),
                        'attention_mask': text_input['attention_mask'].squeeze(),
                    },
                    'molecule': {
                        'mol2vec': np.fromstring(self.mols[cid], sep=" "),  # Convert molecule data to NumPy array
                        'cid': cid
                    }
                },
            }

In the following code, the paths to various data files are defined. Then, it checks
if a specific token embedding file exists using os.path.exists(). If the
file does not exist, it raises a FileNotFoundError. Finally, an instance of
the GenerateData class is created with the defined

In [None]:
# Define the path to the token embedding file
#mounted_path_token_embs = os.path.join(data_dir, 'token_embedding_dict.npy')
mounted_path_token_embs = "data/token_embedding_dict.npy"

# Check if the token embedding file exists
if not os.path.exists(mounted_path_token_embs):
    # Raise FileNotFoundError if the file does not exist
    raise FileNotFoundError(f"The following token embedding DOES NOT EXIST: {mounted_path_token_embs}")

# Define the path to the molecule data file
parent_dir = os.path.join('/content/drive/', 'My Drive')

mounted_path_molecules = "input/ChEBI_defintions_substructure_corpus.cp"

mounted_path_train = "input/mol2vec_ChEBI_20_training.txt"
mounted_path_val = "input/mol2vec_ChEBI_20_val.txt"
mounted_path_test = "input/mol2vec_ChEBI_20_test.txt"

# Instantiate the GenerateData class with the specified paths
gt = GenerateData(mounted_path_train, mounted_path_val, mounted_path_test,
                  mounted_path_molecules, mounted_path_token_embs)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.txt:   0%|          | 0.00/228k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

#### Dataset Class

This class Dataset is designed to create a custom dataset for PyTorch. It allows generating samples of data on-the-fly using a generator function gen. The `__len__` method returns the total number of samples in the dataset, and the `__getitem__` method generates one sample of data for a given index. If the generator is exhausted (i.e., it reaches the end), it resets the iterator to the beginning. In this specific implementation, the target variable `y` is set to a constant value of 1 for all samples.

In [None]:
class Dataset(Dataset):
    'Characterizes a dataset for PyTorch'

    def __init__(self, gen, length):
        'Initialization'

        self.gen = gen  # Generator function that yields data examples
        self.it = iter(self.gen())  # Iterator over the generator function

        self.length = length  # Length of the dataset

    def __len__(self):
        'Denotes the total number of samples'

        return self.length  # Returns the length of the dataset

    def __getitem__(self, index):
        'Generates one sample of data'

        try:
            ex = next(self.it)  # Get the next example from the iterator
        except StopIteration:
            self.it = iter(self.gen())  # If iterator is exhausted, reset it
            ex = next(self.it)  # Get the next example

        X = ex['input']  # Extract input data from the example
        y = 1  # Placeholder for the target variable (constant value in this case)

        return X, y  # Return input data and target variable for the given index


In the following code, three datasets (training_set, validation_set, and test_set) are created. Each dataset is instantiated with the Dataset class, and they are initialized with different generator functions (gt.generate_examples_train, gt.generate_examples_val, and gt.generate_examples_test, respectively) along with the lengths of their respective compound ID lists. These datasets are intended to be used for training, validation, and testing.







In [None]:
# Create a dataset for the training set
# using the 'generate_examples_train' method of the 'gt' object and the length of training compound IDs
training_set = Dataset(gt.generate_examples_train, len(gt.training_cids))

# Create a dataset for the validation set
# using the 'generate_examples_val' method of the 'gt' object and the length of validation compound IDs
validation_set = Dataset(gt.generate_examples_val, len(gt.validation_cids))

# Create a dataset for the test set
# using the 'generate_examples_test' method of the 'gt' object and the length of test compound IDs
test_set = Dataset(gt.generate_examples_test, len(gt.test_cids))

n_samples = 50
training_set_sample = torch.utils.data.Subset(training_set, list(range(n_samples)))
validation_set_sample = torch.utils.data.Subset(validation_set, list(range(n_samples)))
test_set_sample = torch.utils.data.Subset(test_set, list(range(n_samples)))

params = {'batch_size': gt.batch_size,
          'shuffle': True}

training_generator = DataLoader(training_set_sample, **params)
validation_generator = DataLoader(validation_set_sample, **params)
test_generator = DataLoader(test_set_sample, **params)

## Model

### Citation to the original paper

The following is the citation to the original paper:

Carl Edwards, ChengXiang Zhai, and Heng Ji. 2021. [Text2mol: Cross-modal molecule retrieval with natural language queries](https://aclanthology.org/2021.emnlp-main.47). In Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing, pages 595–607, Online and Punta Cana, Dominican Republic. Association for Computational Linguistics.

### Link to the original paper’s repo


The link to the original paper's repo is as follows:

[Text2Mol Code Repository](https://github.com/cnedwards/text2mol)

### Overview

The paper proposes several models for the cross-modal Text2Mol task of retrieving molecules from natural language descriptions:

Models Parameters:
   - The text encoder uses the large SciBERT model, which has around 110M parameters.
   - The MLP molecule encoder is relatively small, with around 110M parameters.
   - The GCN molecule encoder is slightly larger, with around 112M parameters.
   - The cross-modal attention model is the largest, with around 129M parameters.


1. Base Models:

   a. Text Encoder:
   
   Uses the SciBERT language model to encode the text description, followed by a linear projection to an embedding space and layer normalization.
   
   b. Molecule Encoder:
    - MLP Encoder: Takes the Mol2vec embedding as input, passes it through a multi-layer perceptron (MLP), and projects to the joint embedding space.
      
    - GCN Encoder: Incorporates the molecular graph structure by using a Graph Convolutional Network (GCN) on the Mol2vec token embeddings as node features.

   
   The text and molecule embeddings are mapped to an aligned semantic space, where cosine similarity is used to retrieve/rank molecules given a text query.

2. Cross-Modal Attention Model (shown in Figure 3):
   - Uses a transformer decoder with cross-modal attention between the text (from SciBERT) and molecule representations (from the GCN encoder).
   - Allows learning "association rules" between text tokens and molecular substructures from the attention weights.
   - Association rules are used for explainability and to rerank retrieved molecules.
   <img src="https://drive.google.com/uc?id=1_b4MWeiDDRpDKtY44MJMQEaUSZtWjChI" width=500 />
   <br>
   <b>Figure 3: Model architecture for the cross-modal attention extension and association rules. </b>
   <br>

3. Ensemble Model:
   - Takes a weighted average of the rankings from different base model instances (e.g. MLP1, MLP2, GCN1, etc.) to create an ensemble ranking.
   $$
S(m) = \sum_i w_i R_i(m) \ \ \ \ \ \ \ \ s.t. \sum_i w_i = 1
$$
<b>Score as a weighted average for some molecule m where $R_i$ is the rank assigned to that molecule by model $i$ and $w_i$ is the model weight.</b><br>
   - Improves performance significantly by combining models trained with different initializations.

4. Loss Functions:

   a. Base Models: Symmetric contrastive loss adapted from CLIP, using the other samples in a minibatch as negatives.
   
   b. Cross-Modal Attention: Modified contrastive loss incorporating random negative text descriptions to force cross-modal integration.

The models are trained on the ChEBI-20 dataset, with molecules represented as SMILES strings or Mol2vec embeddings, and evaluated on molecular retrieval metrics like mean reciprocal rank.

This class Model defines a neural network model for processing text and molecule data. It consists of layers for text and molecule processing, including linear layers, activation functions, layer normalization, and dropout. The text data is processed using a BERT-based transformer model, while the molecule data is processed through fully connected layers. The model outputs scaled representations of text and molecule data.

In [None]:
class Model(nn.Module):
    def __init__(self, ntoken, ninp, nout, nhid, dropout=0.5):
        super(Model, self).__init__()

        # Define layers for text processing
        self.text_hidden1 = nn.Linear(ninp, nout)  # Linear layer for text input

        # Define parameters
        self.ninp = ninp  # Dimension of input embeddings
        self.nhid = nhid  # Dimension of hidden layers
        self.nout = nout  # Dimension of output layer

        # Dropout layer
        self.drop = nn.Dropout(p=dropout)

        # Define layers for molecule processing
        self.mol_hidden1 = nn.Linear(nout, nhid)  # First hidden layer for molecule input
        self.mol_hidden2 = nn.Linear(nhid, nhid)  # Second hidden layer for molecule input
        self.mol_hidden3 = nn.Linear(nhid, nout)  # Output layer for molecule input

        # Temperature parameter for scaling logits
        self.temp = nn.Parameter(torch.Tensor([0.07]))
        self.register_parameter('temp', self.temp)

        # Layer normalization for text and molecule representations
        self.ln1 = nn.LayerNorm(nout)  # LayerNorm for molecule representation
        self.ln2 = nn.LayerNorm(nout)  # LayerNorm for text representation

        # Activation functions
        self.relu = nn.ReLU()
        self.selu = nn.SELU()

        # List to store parameters excluding those from the BERT model
        self.other_params = list(self.parameters())  # Get all parameters except those from BERT

        # Load BERT-based transformer model for text representation
        self.text_transformer_model = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')
        self.text_transformer_model.train()

    def forward(self, text, molecule, text_mask=None, molecule_mask=None):
        # Forward pass of the model

        # Process text input using BERT-based transformer model
        text_encoder_output = self.text_transformer_model(text, attention_mask=text_mask)
        text_x = text_encoder_output['pooler_output']  # Extract text representation from BERT pooler output
        text_x = self.text_hidden1(text_x)  # Apply linear transformation to text representation

        # Process molecule input through fully connected layers
        x = self.relu(self.mol_hidden1(molecule))  # First hidden layer with ReLU activation
        x = self.relu(self.mol_hidden2(x))  # Second hidden layer with ReLU activation
        x = self.mol_hidden3(x)  # Output layer for molecule input

        # Apply layer normalization
        x = self.ln1(x)  # LayerNorm for molecule representation
        text_x = self.ln2(text_x)  # LayerNorm for text representation

        # Scale logits using temperature parameter
        x = x * torch.exp(self.temp)  # Apply temperature scaling to molecule representation
        text_x = text_x * torch.exp(self.temp)  # Apply temperature scaling to text representation

        return text_x, x  # Return text and molecule representations

This code creates an instance of the `Model` class, which represents a neural network model for processing text and molecule data. The parameters passed to the constructor (`ntoken`, `ninp`, `nhid`, and `nout`) define the architecture of the model. In this specific instantiation:
- `ntoken` is set to the size of the vocabulary used by the text tokenizer (`gt.text_tokenizer.vocab_size`).
- `ninp` is set to `768`, which is the dimensionality of the input embeddings typically used in BERT-based models.
- `nhid` is set to `600`, representing the dimensionality of the hidden layers.
- `nout` is set to `300`, representing the dimensionality of the output layer.

In [None]:
ninp = 768
nhid = 600
nout = 300

In [None]:
# Instantiate the Model class with the specified parameters
# Parameters:
# - ntoken: Size of the vocabulary for the text tokenizer
# - ninp: Dimensionality of the input embeddings (768 for BERT-based models)
# - nhid: Dimensionality of the hidden layers
# - nout: Dimensionality of the output layer
model = Model(ntoken=gt.text_tokenizer.vocab_size, ninp=ninp, nhid=nhid, nout=nout)

pytorch_model.bin:   0%|          | 0.00/442M [00:00<?, ?B/s]

### Model Descriptions and Implementation Code

#### MLP molecule encoder

Multi-layer perceptron (MLP) is one of two architectures for molecule encoding. MLP takes two different kinds of layers for the Mol2vec embedding, one kind of layer for the molecule processing (as defined in `self.mol_hidden1` and `self.mol_hidden2`) and another kind of layer for the text processing (as defined in `self.text_hidden1`). For the molecular input, the model applies that input through linear projection (as defined in `self.mol_hidden3`) and layer normalization (as defined in `self.ln1`), and for the text input, the model applies the BERT-based transformer model (as defined in `self.text_transformer_model`) and layer normalization (as defined in `self.ln2`). Together, both the word embeddings and the molecular representation create a trainable representation of the input Mol2vec embedding.

In [None]:
class MLPModel(nn.Module):
    def __init__(self, ninp, nout, nhid):
        super(MLPModel, self).__init__()

        # Define layers for text processing
        self.text_hidden1 = nn.Linear(ninp, nout)

        # Define parameters
        self.ninp = ninp  # Dimension of input embeddings
        self.nhid = nhid  # Dimension of hidden layers
        self.nout = nout  # Dimension of output layer

        # Define layers for molecule processing
        self.mol_hidden1 = nn.Linear(nout, nhid)  # First hidden layer for molecule input
        self.mol_hidden2 = nn.Linear(nhid, nhid)  # Second hidden layer for molecule input
        self.mol_hidden3 = nn.Linear(nhid, nout)  # Output layer for molecule input

        # Temperature parameter for scaling logits
        self.temp = nn.Parameter(torch.Tensor([0.07]))
        self.register_parameter('temp', self.temp)

        # Layer normalization for text and molecule representations
        self.ln1 = nn.LayerNorm(nout)  # LayerNorm for molecule representation
        self.ln2 = nn.LayerNorm(nout)  # LayerNorm for text representation

        # Activation functions
        self.relu = nn.ReLU()
        self.selu = nn.SELU()

        # List to store parameters excluding those from the BERT model
        self.other_params = list(self.parameters())

        # Load BERT-based transformer model for text representation
        self.text_transformer_model = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')
        self.text_transformer_model.train()

    def forward(self, text, molecule, text_mask = None):
        """Forward pass of the model"""

        # Process text input using BERT-based transformer model
        text_encoder_output = self.text_transformer_model(text, attention_mask=text_mask)
        text_x = text_encoder_output['pooler_output']  # Extract text representation from BERT pooler output
        text_x = self.text_hidden1(text_x)  # Apply linear transformation to text representation

        # Process molecule input through fully connected layers
        x = self.relu(self.mol_hidden1(molecule))  # First hidden layer with ReLU activation
        x = self.relu(self.mol_hidden2(x))  # Second hidden layer with ReLU activation
        x = self.mol_hidden3(x)  # Output layer for molecule input

        # Apply layer normalization
        x = self.ln1(x)  # LayerNorm for molecule representation
        text_x = self.ln2(text_x)  # LayerNorm for text representation

        # Scale logits using temperature parameter
        x = x * torch.exp(self.temp)  # Apply temperature scaling to molecule representation
        text_x = text_x * torch.exp(self.temp)  # Apply temperature scaling to text representation

        return text_x, x # Return text and molecule representations

#### GCN molecule encoder

Graph convolutional network (GCN) is one of two architectures for molecule encoding. Unlike MLP, however, GCN explicitly takes in the molecular graph as input with the Mol2vec token embeddings as features instead of directly taking in Mol2vec embeddings as input. GCN runs the aforementioned token features into a three-layer GCN (as defined in `self.conv1`, `self.conv2`, and `self.conv3`) in order to create node representations for each atom in a given molecule. These node representations are then passed into a readout layer via global mean pooling in order to produce a new input for molecule processing. Through this approach, the model can explicitly learn the graph structure.

Then, like MLP, GCN takes two different kinds of layers for the Mol2vec embedding, one kind of layer for the molecule processing (as defined in `self.mol_hidden1` and `self.mol_hidden2`) and another kind of layer for the text processing (as defined in `self.text_hidden1`). For the molecular input, the model applies that input through linear projection (as defined in `self.mol_hidden3`) and layer normalization (as defined in `self.ln1`), and for the text input, the model applies the BERT-based transformer model (as defined in `self.text_transformer_model`) and layer normalization (as defined in `self.ln2`). Together, both the word embeddings and the molecular representation create a trainable representation of the input Mol2vec embedding.

In [None]:
class GCNModel(nn.Module):
    def __init__(self, num_node_features, ninp, nout, nhid, graph_hidden_channels):
        super(GCNModel, self).__init__()

        # Define layers for text processing
        self.text_hidden1 = nn.Linear(ninp, nout)

        # Define parameters
        self.ninp = ninp  # Dimension of input embeddings
        self.nhid = nhid  # Dimension of hidden layers
        self.nout = nout  # Dimension of output layer

        # Temperature parameter for scaling logits
        self.temp = nn.Parameter(torch.Tensor([0.07]))
        self.register_parameter('temp', self.temp)

        # Layer normalization for text and molecule representations
        self.ln1 = nn.LayerNorm(nout)  # LayerNorm for molecule representation
        self.ln2 = nn.LayerNorm(nout)  # LayerNorm for text representation

        # Activation functions
        self.relu = nn.ReLU()
        self.selu = nn.SELU()

        # GCN Convolution layers
        self.conv1 = GCNConv(num_node_features, graph_hidden_channels)
        self.conv2 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.conv3 = GCNConv(graph_hidden_channels, graph_hidden_channels)

        # Define layers for molecule processing
        self.mol_hidden1 = nn.Linear(graph_hidden_channels, nhid)  # First hidden layer for molecule input
        self.mol_hidden2 = nn.Linear(nhid, nhid)  # Second hidden layer for molecule input
        self.mol_hidden3 = nn.Linear(nhid, nout)  # Output layer for molecule input

        # List to store parameters excluding those from the BERT model
        self.other_params = list(self.parameters())

        # Load BERT-based transformer model for text representation
        self.text_transformer_model = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')
        self.text_transformer_model.train()

    def forward(self, text, graph_batch, text_mask=None, molecule_mask=None):
        """Forward pass of the model"""

        # Process text input using BERT-based transformer model
        text_encoder_output = self.text_transformer_model(text, attention_mask=text_mask)
        text_x = text_encoder_output['pooler_output']  # Extract text representation from BERT pooler output
        text_x = self.text_hidden1(text_x)  # Apply linear transformation to text representation

        # Obtain node embeddings
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch

        # Process molecule token input through convolution layers
        x = self.relu(self.conv1(x, edge_index)) # First convolution layer with ReLU activation
        x = self.relu(self.conv2(x, edge_index)) # Second convolution layer with ReLU activation
        x = self.conv3(x, edge_index) # Output convolution layer for molecule input

        # Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, graph_hidden_channels]

        # Process molecule input through fully connected layers
        x = self.relu(mol_hidden1(x)) # First hidden layer with ReLU activation
        x = self.relu(mol_hidden2(x)) # Second hidden layer with ReLU activation
        x = self.mol_hidden3(x) # Output layer for molecule input

        # Apply layer normalization
        x = self.ln1(x)  # LayerNorm for molecule representation
        text_x = self.ln2(text_x)  # LayerNorm for text representation

        # Scale logits using temperature parameter
        x = x * torch.exp(self.temp)  # Apply temperature scaling to molecule representation
        text_x = text_x * torch.exp(self.temp)  # Apply temperature scaling to text representation

        return text_x, x # Return text and molecule representations

#### Cross-Modal Attention Model

Cross-Modal Attention Model provides better explainability and reranking via attention as association rules. Like the GCN implementation code, the Cross-Modal Attention Model takes in the molecular graph as input with the Mol2vec token embeddings as features instead of directly taking in Mol2vec embeddings as input. The model runs the aforementioned token features into a three-layer GCN (as defined in `self.conv1`, `self.conv2`, and `self.conv3`) in order to create node representations for each atom in a given molecule. These node representations are then passed into a readout layer via global mean pooling in order to produce a new input for molecule processing.

Then, like the previous models, this model takes two different kinds of layers for the Mol2vec embedding, one layer for the molecule processing (as defined in `self.mol_hidden1` and `self.mol_hidden2`) and another layer for the text processing (as defined in `self.text_hidden1` and `self.text_hidden2`). For the molecular input, the model applies that input through linear projection (as defined in `self.mol_hidden3`) and layer normalization (as defined in `self.ln1`). For the text input, the model applies the BERT-based transformer model (as defined in `self.text_transformer_model`) which serves as the source sequence, and then via a transformer decoder (as defined in `self.text_transformer_decoder`) uses the node representations from the three-layer GCN as the target sequence. Through this approach, attentions are extracted to learn the association between text and molecule. Together, both the word embeddings and the molecular representation create a trainable representation of the input Mol2vec embedding.

In [None]:
class AttentionModel(nn.Module):

    def __init__(self, num_node_features, ninp, nout, nhid, nhead, nlayers, graph_hidden_channels,
                 mol_trunc_length, temp, dropout=0.5):
        super(AttentionModel, self).__init__()

        # Define layers for text processing
        self.text_hidden1 = nn.Linear(ninp, nhid)
        self.text_hidden2 = nn.Linear(nhid, nout)

        # Define parameters
        self.ninp = ninp  # Dimension of input embeddings
        self.nhid = nhid  # Dimension of hidden layers
        self.nout = nout  # Dimension of output layer
        self.num_node_features = num_node_features # Number of node features
        self.graph_hidden_channels = graph_hidden_channels # Number of graph hidden chanels
        self.mol_trunc_length = mol_trunc_length # Allowable length in molecule

        # Dropout layer
        self.drop = nn.Dropout(p=dropout)

        # Set up decoder
        decoder_layers = TransformerDecoderLayer(ninp, nhead, nhid, dropout)
        self.text_transformer_decoder = TransformerDecoder(decoder_layers, nlayers)

        # Temperature parameter for scaling logits
        self.temp = nn.Parameter(torch.Tensor([temp]))
        self.register_parameter( 'temp' , self.temp )

        # Layer normalization for text and molecule representations
        self.ln1 = nn.LayerNorm(nout)  # LayerNorm for molecule representation
        self.ln2 = nn.LayerNorm(nout)  # LayerNorm for text representation

        # Activation functions
        self.relu = nn.ReLU()
        self.selu = nn.SELU()

        # GCN Convolution layers
        self.conv1 = GCNConv(self.num_node_features, graph_hidden_channels)
        self.conv2 = GCNConv(graph_hidden_channels, graph_hidden_channels)
        self.conv3 = GCNConv(graph_hidden_channels, graph_hidden_channels)

        # Define layers for molecule processing
        self.mol_hidden1 = nn.Linear(graph_hidden_channels, nhid)  # First hidden layer for molecule input
        self.mol_hidden2 = nn.Linear(nhid, nout) # Output layer for molecule input

        # List to store parameters excluding those from the BERT model
        self.other_params = list(self.parameters())

        # Load BERT-based transformer model for text representation
        self.text_transformer_model = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')
        self.text_transformer_model.train()

        self.device = 'cpu'

    def set_device(self, dev):
        self.to(dev)
        self.device = dev

    def forward(self, text, graph_batch, text_mask=None, molecule_mask=None):
        """Forward pass of the model"""

        # Process text input using BERT-based transformer model
        text_encoder_output = self.text_transformer_model(text, attention_mask=text_mask)

        # Obtain node embeddings
        x = graph_batch.x
        edge_index = graph_batch.edge_index
        batch = graph_batch.batch

        # Process molecule input through convolution layers
        x = self.relu(self.conv1(x, edge_index)) # First convolution layer with ReLU activation
        x = self.relu(self.conv2(x, edge_index)) # Second convolution layer with ReLU activation
        mol_x = self.conv3(x, edge_index) # Output layer for molecule input

        # Turn pytorch geometric output into the correct format for transformer
        # Requires recovering the nodes from each graph into a separate dimension
        node_features = torch.zeros((graph_batch.num_graphs, self.mol_trunc_length,
                                     self.graph_hidden_channels)).to(self.device)
        for i, p in enumerate(graph_batch.ptr):
            if p == 0:
                old_p = p
                continue
            node_features[i - 1, :p-old_p, :] = mol_x[old_p:torch.min(p, old_p + self.mol_trunc_length), :]
            old_p = p
        node_features = torch.transpose(node_features, 0, 1)

        # Decode initial encoding
        text_output = self.text_transformer_decoder(
            text_encoder_output['last_hidden_state'].transpose(0,1),
            node_features,
            tgt_key_padding_mask=text_mask==0,
            memory_key_padding_mask=~molecule_mask
        )

        # Readout layer
        x = global_mean_pool(mol_x, batch)  # [batch_size, graph_hidden_channels]

        # Process molecule input through fully connected layers
        x = self.relu(self.mol_hidden1(x))
        x = self.mol_hidden2(x)

        # Extract text representation from CLS pooler output
        text_x = torch.tanh(self.text_hidden1(text_output[0,:,:])) # [CLS] pooler
        text_x = self.text_hidden2(text_x) # Apply linear transformation to text representation

        # Apply layer normalization
        x = self.ln1(x)  # LayerNorm for molecule representation
        text_x = self.ln2(text_x)  # LayerNorm for text representation

        # Scale logits using temperature parameter
        x = x * torch.exp(self.temp)  # Apply temperature scaling to molecule representation
        text_x = text_x * torch.exp(self.temp)  # Apply temperature scaling to text representation

        return text_x, x  # Return text and molecule representations

### Pretrained Models

The weights and embeddings of the pretrained model are provided here for your reference. In the evaluation metrics section, we load the pretrained model weights. Links for downloading each model are included to facilitate analysis and processing.

1. [MLP1: Weights and Embeddings](https://drive.google.com/file/d/1ebDVr72e5ZnA9Mo9AZ03Ci4B79M7tu6n/view?usp=sharing)
2. [MLP2: Weights and Embeddings](https://drive.google.com/file/d/1APEndZ0G-ZwkzrmUYQhn7S1orIx_4qf-/view?usp=sharing)
3. [MLP3: Weights and Embeddings](https://drive.google.com/file/d/1y1nm8l3C8ugZoTOeJP0Qx1Sx4bA1OlpN/view?usp=sharing)
4. [GCN1: Weights and Embeddings](https://drive.google.com/file/d/1KWbFEDSJZBZNaBRLIQxjFrRfHCmLNMo9/view?usp=sharing)
5. [GCN2: Weights and Embeddings](https://drive.google.com/file/d/1tv6yYVhuNcYuIEQZQaW94kzvayGyAI8T/view?usp=sharing)
6. [Attn1: Weights](https://drive.google.com/file/d/14-ECz6PqnqFjrcuUFzFYSD4e6Gcm7hgM/view?usp=sharing)
7. [GCN Reproduce Weights](https://drive.google.com/file/d/1DXxplaCS-DnG-kwuyJSuFjiTbvjgBo02/view?usp=sharing)
8. [MLP Reproduce Weights](https://drive.google.com/file/d/1L72QOt9qNw5lJgoRGnREqQNBhEcI77Au/view?usp=sharing)
9. [MLP Weights](https://drive.google.com/file/d/1qleJlAs6G6-GHgUFvcR3unkmvxyN7QWs/view?usp=sharing)
10. [MLP Weights](https://drive.google.com/file/d/1rxEbt4XsZbv_xo0mQJzXxmk8cHPssWgs/view?usp=sharing)
11. [MLP Weights](https://drive.google.com/file/d/1ASg580a7BIsWMteNnIJi_3-SLEv4YP6o/view?usp=sharing)
12. [GCN1 Weights and Embeddings](https://drive.google.com/file/d/1KWbFEDSJZBZNaBRLIQxjFrRfHCmLNMo9/view?usp=sharing)
13. [GCN2 Weights and Embeddings](https://drive.google.com/file/d/1tv6yYVhuNcYuIEQZQaW94kzvayGyAI8T/view?usp=sharing)






## Training

### Hyperparameters

**The training uses the following hyperparameters:**

Text Encoder:
- Uses SciBERT model
- Finetuning learning rate of 3e-5

Molecule Encoders:
- MLP: 600 hidden units
- GCN: 3 layers

Mol2vec Parameters:
- Radius = 1 (for Morgan fingerprints)
- Threshold for unknown tokens = 3  
- Embedding dimension = 300
- Window size = 10

Training:
- Adam optimizer
- MLP/GCN learning rate = 1e-4
- Linear annealing of learning rate with 1,000 warmup steps
- Trained for 40 epochs
- Batch size = 32
- Temperature parameter τ = 0.07 (for contrastive loss)
- Use first 256 text tokens

Cross-Modal Attention Model:
- 3 layer transformer decoder
- Attends to first 512 molecule substructures
- 128M parameters

Association Rules:
- Consider 1-to-1 rules with confidence > 0.1 and support > 2
$$
supp(r) = \sum_{p \in P} \sum_{\substack{t' \in p_t \\ m' \in p_m}} 𝟙_{\substack{t=t' \\ m=m'}} a_{t', m'}
$$
<b>Support for a rule r from t (text token) to m (molecule token) as the sum of all attentions.</b>
$$
conf(t ⟹ m) = \frac{supp(t,m)}{\sum_{t' \in T} supp(t',m)}
$$
<b>Confidence from every text token t to every molecule token m, divided by the support of all the fules using t, where T is the set of all text tokens.</b>

- Taking top 10 confidence values for reranking

The MLP has around 111M parameters and the GCN has 112M parameters.

### Computational Requirements

**The computational requirements are as follows:**

* The training used a combination of NVIDIA V100 and T4 GPUs for 40 epochs.
* Training the MLP and GCN models took around 7 hours each on a V100 GPU and 13 hours each on a T4 GPU.
*  Training the cross-modal attention model took around 9 hours on a V100 GPU and 14 hours on a T4 GPU.
* Average runtime training used an NVIDIA V100 GPU take 10 minutes (T4 GPU take 18 minutes) for each epoch is for the MLP & GCN models and 14 minutes using NVIDIA V100 GPU (22 minutes using NVIDIA T4 GPU) for the Cross-Modal Attention Model

### Training Code

In the following code:
- An optimizer (Adam) is initialized to update model parameters during training. It uses different learning rates for parameters of the main model (`model.other_params`) and parameters of the BERT-based model (`bert_params`).
- A linear learning rate scheduler with warmup is created. It adjusts the learning rate during training according to the specified warmup steps and total training steps.

In [None]:
import torch.optim as optim  # Importing the optimizer module from PyTorch
from transformers.optimization import get_linear_schedule_with_warmup  # Importing learning rate scheduler
                                                                       # from the transformers library

# Define the number of epochs for training
epochs = 1

# Initial learning rate for the optimizer
init_lr = 1e-4

# Learning rate for the BERT-based model
bert_lr = 3e-5

# Get the parameters of the BERT-based model
bert_params = list(model.text_transformer_model.parameters())

# Initialize the optimizer with Adam optimizer
# Separate learning rates can be specified for different parameter groups
optimizer = optim.Adam([
                {'params': model.other_params},  # Parameters excluding those from BERT
                {'params': bert_params, 'lr': bert_lr}  # Parameters of BERT model with custom learning rate
            ], lr=init_lr)  # Initial learning rate for all parameters

# Define the number of warmup steps for the scheduler
num_warmup_steps = 1000

# Calculate the total number of training steps
num_training_steps = epochs * len(training_generator) - num_warmup_steps

# Create a linear scheduler with warmup
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=num_warmup_steps,
                                            num_training_steps=num_training_steps)

In the following code:
- The first line checks if CUDA (GPU) is available. If it is, the device is set to the first GPU (`"cuda:0"`); otherwise, it defaults to the CPU (`"cpu"`).
- The second line prints out the selected device.
- The third line moves (or transfers) the model (`model`) to the selected device. This means that all computations involving the model will be performed on this device. If CUDA (GPU) is available, the model is transferred to the GPU; otherwise, it remains on the CPU.

In [None]:
# Check if CUDA (GPU) is available, and set the device accordingly
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Print the selected device (CUDA/GPU or CPU)
print(device)

# Transfer the model to the selected device (CUDA/GPU or CPU)
tmp = model.to(device)

cuda:0


In the following code:
- The `nn.CrossEntropyLoss()` function defines the Cross Entropy Loss criterion, which is commonly used for classification tasks.
- The `loss_func` function takes two vectors `v1` and `v2` and computes the loss between them. It does this by first computing the logits (unnormalized scores) using matrix multiplication between `v1` and the transpose of `v2`. Then, it generates labels based on the number of rows in the logits. Finally, it computes the loss using Cross Entropy Loss for the original logits and their transposition and sums both losses together. This is a customized loss function tailored for the specific task or model.

In [None]:
# Define the loss function using Cross Entropy Loss
criterion = nn.CrossEntropyLoss()

# Custom loss function that computes the loss between two vectors (v1 and v2)
def loss_func(v1, v2):
    # Compute logits by matrix multiplication between v1 and the transpose of v2
    logits = torch.matmul(v1, torch.transpose(v2, 0, 1))

    # Generate labels based on the number of rows in logits
    labels = torch.arange(logits.shape[0]).to(device)

    # Compute the loss using Cross Entropy Loss for the original logits and their transposition
    # and sum both losses
    return criterion(logits, labels) + criterion(torch.transpose(logits, 0, 1), labels)

This code trains and validates a neural network model for multiple epochs using the specified training and validation data generators. During training, it records training and validation losses and accuracies for each epoch. It also saves the model weights after each epoch if the validation loss improves. Finally, it saves the final model weights after training all epochs.

In [None]:
# Lists to store training and validation losses
train_losses = []
val_losses = []

# Lists to store training and validation accuracies
train_acc = []
val_acc = []

# Directory to save model outputs
mounted_path = "MLP_outputs/"

# Create the directory if it doesn't exist
if not os.path.exists(mounted_path):
    os.mkdir(mounted_path)

# Loop over epochs
for epoch in range(epochs):
    # Training
    start_time = time.time()  # Record the start time of the epoch
    running_loss = 0.0  # Initialize running loss
    running_acc = 0.0  # Initialize running accuracy
    model.train()  # Set the model to training mode
    for i, d in enumerate(training_generator):
        batch, labels = d  # Retrieve batch data and labels
        # Transfer batch data to GPU
        text_mask = batch['text']['attention_mask'].bool()  # Retrieve attention mask for text
        text = batch['text']['input_ids'].to(device)  # Transfer text input to GPU
        text_mask = text_mask.to(device)  # Transfer text mask to GPU
        molecule = batch['molecule']['mol2vec'].float().to(device)  # Transfer molecule input to GPU

        # Forward pass
        text_out, chem_out = model(text, molecule, text_mask)  # Get model outputs
        loss = loss_func(text_out, chem_out).to(device)  # Calculate loss
        running_loss += loss.item()  # Accumulate loss

        # Backward pass and optimization
        optimizer.zero_grad()  # Clear gradients
        loss.backward()  # Backpropagation
        optimizer.step()  # Optimization step

        scheduler.step()  # Update learning rate scheduler

        # Print progress every 100 batches
        if (i+1) % 100 == 0:
            print(i+1, "batches trained. Avg loss:\t", running_loss / (i+1),
                  ". Avg ms/step =", 1000*(time.time()-start_time)/(i+1))

    # Calculate average training loss and accuracy for the epoch
    train_losses.append(running_loss / (i+1))
    train_acc.append(running_acc / (i+1))

    # Print training loss and duration for the epoch
    print("Epoch", epoch, "training loss:\t\t", running_loss / (i+1),
          ". Time =", (time.time()-start_time), "seconds.")

    # Validation
    model.eval()  # Set the model to evaluation mode
    with torch.set_grad_enabled(False):  # Disable gradient calculation
        start_time = time.time()  # Record the start time of the epoch
        running_acc = 0.0  # Initialize running accuracy
        running_loss = 0.0  # Initialize running loss
        for i, d in enumerate(validation_generator):
            batch, labels = d  # Retrieve batch data and labels
            # Transfer batch data to GPU
            text_mask = batch['text']['attention_mask'].bool()  # Retrieve attention mask for text
            text = batch['text']['input_ids'].to(device)  # Transfer text input to GPU
            text_mask = text_mask.to(device)  # Transfer text mask to GPU
            molecule = batch['molecule']['mol2vec'].float().to(device)  # Transfer molecule input to GPU

            # Forward pass
            text_out, chem_out = model(text, molecule, text_mask)  # Get model outputs
            loss = loss_func(text_out, chem_out).to(device)  # Calculate loss
            running_loss += loss.item()  # Accumulate loss

            # Print progress every 100 batches
            if (i+1) % 100 == 0:
                print(i+1, "batches eval. Avg loss:\t", running_loss / (i+1),
                      ". Avg ms/step =", 1000*(time.time()-start_time)/(i+1))

        # Calculate average validation loss and accuracy for the epoch
        val_losses.append(running_loss / (i+1))
        val_acc.append(running_acc / (i+1))

        # Save the model with the lowest validation loss
        min_loss = np.min(val_losses)
        if val_losses[-1] == min_loss:
            torch.save(model.state_dict(), mounted_path +
                       'weights_pretrained.{epoch:02d}-{min_loss:.2f}.pt'.format(epoch = epoch,
                                                                                 min_loss = min_loss))

    # Print validation loss and duration for the epoch
    print("Epoch", epoch, "validation loss:\t", running_loss / (i+1),
          ". Time =", (time.time()-start_time), "seconds.")

# Save the final model weights
torch.save(model.state_dict(), mounted_path + "final_weights."+str(epochs)+".pt")

Epoch 0 training loss:		 43.07465362548828 . Time = 3.3571224212646484 seconds.
Epoch 0 validation loss:	 32.546658515930176 . Time = 2.1573703289031982 seconds.


## Evaluation

### Metrics Descriptions



The paper evaluates the proposed Text2Mol methods using the following metrics:

1. Mean Reciprocal Rank (MRR):
This is the main evaluation metric used. It is calculated as:

$$MRR = \frac{1}{n} \sum _{i=1} ^n \frac{1}{R_i}$$

- Where n is the number of queries, and $R_i$ is the rank of the correct (relevant) molecule for the $i$-th query text description.

- Higher MRR values are better, with a perfect MRR of 1.0 if the correct molecule is ranked 1st for every query.

2. Hits@K:
This measures the percentage of queries for which the correct molecule is ranked among the top K results. It is calculated as:

$$\text{Hits}@K = \frac{1}{n} \sum _{i=1} ^n 1 _{R_i \le K}$$

- Specifically, the paper report Hits@1 and Hits@10.

- Hits@1 is the percentage of queries where the correct molecule is ranked 1st.

- Hits@10 is the percentage where the correct molecule appears in the top 10 rankings.

3. Mean Rank:
This is a secondary metric which reports the average rank of the correct molecules across all queries. It is calculated as:

$$ \text{MeanRank} = \frac{1}{n} \sum _{i=1} ^n R_i$$

- Where n is the number of queries, and $R_i$ is the rank of the correct (relevant) molecule for the $i$-th query text description.

- A lower mean rank value is better.

The metrics are calculated on the test set of the ChEBI-20 dataset containing 33,010 text-molecule pairs split into train/val/test.

The paper reports achieving an MRR of 0.499, Hits@1 of 34.4%, and Hits@10 of 81.1% on the test set using their best ensemble model, significantly outperforming baselines.

MRR is the primary ranking metric, supplemented by Hits@K percentages and mean rank, evaluated on the held-out test portion of their new ChEBI-20 benchmark dataset.

### Evaluation code

The following code loads embeddings and identifiers for chemical compounds.

In [None]:
from os import path as osp
# Assign the value of 'data_dir' to a variable named 'dir'
#dir = data_dir

# Load training, validation, and test data for chemical compound identifiers (cids)
cids_train = np.load("data/cids_train.npy", allow_pickle=True)
cids_val = np.load("data/cids_val.npy", allow_pickle=True)
cids_test = np.load("data/cids_test.npy", allow_pickle=True)


# Load training, validation, and test data for text embeddings
text_embeddings_train = np.load("data/text_embeddings_train.npy", allow_pickle=True)
text_embeddings_val = np.load("data/text_embeddings_val.npy", allow_pickle=True)
text_embeddings_test = np.load("data/text_embeddings_test.npy")

# Load training, validation, and test data for chemical embeddings
chem_embeddings_train = np.load("data/chem_embeddings_train.npy", allow_pickle=True)
chem_embeddings_val = np.load("data/chem_embeddings_val.npy", allow_pickle=True)
chem_embeddings_test = np.load("data/chem_embeddings_test.npy", allow_pickle=True)

# Print message indicating that embeddings have been loaded
print('Loaded embeddings')

# Combine text embeddings from all splits (train, val, test) into a single array
all_text_embeddings = np.concatenate((text_embeddings_train, text_embeddings_val, text_embeddings_test),
                                     axis=0)
# Combine chemical embeddings from all splits (train, val, test) into a single array
all_mol_embeddings = np.concatenate((chem_embeddings_train, chem_embeddings_val, chem_embeddings_test),
                                    axis=0)

# Concatenate all compound identifiers from train, val, and test sets into a single array
all_cids = np.concatenate((cids_train, cids_val, cids_test), axis=0)

# Calculate the number of samples in each split
n_train = len(cids_train)
n_val = len(cids_val)
n_test = len(cids_test)

# Calculate the total number of samples across all splits
n = n_train + n_val + n_test

# Define offsets for validation and test sets relative to the training set
offset_val = n_train
offset_test = n_train + n_val

Loaded embeddings


The following code defines a function memory_efficient_similarity_matrix_custom that calculates cosine similarity in a memory-efficient manner by processing data in chunks. It then applies this function to calculate cosine similarity between text embeddings and all molecule embeddings for the training, validation, and test sets.

In [None]:
# Define a function to calculate cosine similarity in a memory-efficient manner
def memory_efficient_similarity_matrix_custom(func, embedding1, embedding2, chunk_size=1000):
    # Determine the number of rows in the first embedding array
    rows = embedding1.shape[0]

    # Calculate the number of chunks needed based on the chunk size
    num_chunks = int(np.ceil(rows / chunk_size))

    # Iterate over each chunk
    for i in range(num_chunks):
        # Determine the end index of the current chunk, accounting for the last chunk potentially being smaller
        end_chunk = (i + 1) * chunk_size if (i + 1) * chunk_size < rows else rows

        # Generate cosine similarity values for the current chunk and yield the result
        yield func(embedding1[i * chunk_size:end_chunk, :], embedding2)

# Calculate cosine similarity between text embeddings of the training set and all molecule embeddings
text_chem_cos = memory_efficient_similarity_matrix_custom(cosine_similarity,
                                                          text_embeddings_train, all_mol_embeddings)

# Calculate cosine similarity between text embeddings of the validation set and all molecule embeddings
text_chem_cos_val = memory_efficient_similarity_matrix_custom(cosine_similarity,
                                                              text_embeddings_val, all_mol_embeddings)

# Calculate cosine similarity between text embeddings of the test set and all molecule embeddings
text_chem_cos_test = memory_efficient_similarity_matrix_custom(cosine_similarity,
                                                               text_embeddings_test, all_mol_embeddings)


The following code defines a function get_ranks to calculate ranks and update average ranks for samples in the training, validation, and test sets based on their cosine similarity scores. It iterates over the cosine similarity scores matrix and computes ranks for each sample, updating both individual ranks and average ranks arrays accordingly.

In [None]:
# Initialize arrays to store average ranks for each sample in the training, validation, and test sets
# tr_avg_ranks = np.zeros((n_train, n))
# val_avg_ranks = np.zeros((n_val, n))
test_avg_ranks = np.zeros((n_test, n))

# Initialize lists to store individual ranks for each sample in the training, validation, and test sets
ranks_train = []
ranks_val = []
ranks_test = []

# Define a function to calculate ranks and update average ranks
def get_ranks(text_chem_cos, ranks_avg, offset, split=""):
    # Initialize a temporary list to store individual ranks
    ranks_tmp = []
    # Initialize a counter to keep track of all iterations
    j = 0

    # Iterate over each embedding in the cosine similarity matrix
    for l, emb in enumerate(text_chem_cos):
        # Iterate over each row in the embedding
        for k in range(emb.shape[0]):
            # Get the locations of the compound identifiers sorted by cosine similarity (descending order)
            cid_locs = np.argsort(emb[k, :])[::-1]
            # Get the ranks of the compound identifiers
            ranks = np.argsort(cid_locs)

            # Update the average ranks array by adding the ranks for the current sample
            ranks_avg[j, :] = ranks_avg[j, :] + ranks

            # Calculate the rank of the current sample
            rank = ranks[j + offset] + 1
            # Append the rank to the temporary list
            ranks_tmp.append(rank)

            # Increment the counter
            j += 1
            # Print progress message after processing every 1000 samples
            if j % 1000 == 0:
                print(j, split + " processed")

    # Convert the temporary list of ranks to a numpy array and return it
    return np.array(ranks_tmp)

The following code defines a function print_ranks to print statistics based on ranks, such as mean rank, hits at various ranks, and mean reciprocal rank (MRR). Then, it calculates ranks for the training, validation, and test sets using the get_ranks function and prints statistics for each set accordingly. Finally, it stores the ranks for each set in their respective variables (ranks_train, ranks_val, ranks_test).

In [None]:
# Define a function to print statistics based on ranks
def print_ranks(ranks, split):
    # Print the split type (e.g., "Training", "Validation", "Test")
    print(split + " Model:")
    # Print the mean rank
    print("Mean rank:", np.mean(ranks))
    # Print the percentage of hits at ranks 1, 10, 100, 500, and 1000
    print("Hits at 1:", np.mean(ranks <= 1))
    print("Hits at 10:", np.mean(ranks <= 10))
    print("Hits at 100:", np.mean(ranks <= 100))
    print("Hits at 500:", np.mean(ranks <= 500))
    print("Hits at 1000:", np.mean(ranks <= 1000))
    # Print the mean reciprocal rank (MRR)
    print("MRR:", np.mean(1 / ranks))
    print()

# # Calculate ranks for the training set
# ranks_tmp = get_ranks(text_chem_cos, tr_avg_ranks, offset=0, split="train")
# # Print statistics for the training set
# print_ranks(ranks_tmp, split="Training")
# # Store the ranks for the training set
# ranks_train = ranks_tmp

# # Calculate ranks for the validation set
# ranks_tmp = get_ranks(text_chem_cos_val, val_avg_ranks, offset=offset_val, split="val")
# # Print statistics for the validation set
# print_ranks(ranks_tmp, split="Validation")
# # Store the ranks for the validation set
# ranks_val = ranks_tmp

# Calculate ranks for the test set
ranks_tmp = get_ranks(text_chem_cos_test, test_avg_ranks, offset=offset_test, split="test")
# Print statistics for the test set
print_ranks(ranks_tmp, split="Test")
# Store the ranks for the test set
ranks_test = ranks_tmp

1000 test processed
2000 test processed
3000 test processed
Test Model:
Mean rank: 657.8473189942441
Hits at 1: 0.0036352620418055137
Hits at 10: 0.04877309906089064
Hits at 100: 0.31808542865798245
Hits at 500: 0.7097849136625265
Hits at 1000: 0.8415631626779764
MRR: 0.023427083555814167



# Results

The results are evaluated using the following metrics:

1. **Mean Reciprocal Rank (MRR):**
This is the main evaluation metric used. Higher MRR values are better, with a perfect MRR of 1.0 if the correct molecule is ranked 1st for every query.

2. **Hits@K:**
This measures the percentage of queries for which the correct molecule is ranked among the top K results.

- Hits@1 is the percentage of queries where the correct molecule is ranked 1st.

- Hits@10 is the percentage where the correct molecule appears in the top 10 rankings.

3. **Mean Rank:**
This is a metric which reports the average rank of the correct molecules across all queries. A lower mean rank value is better.

## Reproduction Results

### **Baseline Models:**

The MLP and GCN encoders.

**MLP Training Results**

| Model | Mean Rank  | MRR   | Hits@1 | Hits@10 |
|---    |---         |---    |---     |---      |
|MLP1   | 12.82      | 0.372 | 23.9%  | 69.8%   |
|MLP2   | 12.28      | 0.353 | 23.3%  | 67.5%   |
|MLP3   | 12.35      | 0.383 | 23.5%  | 67.1%   |

**GCN Training Results**

| Model | Mean Rank | MRR   | Hits@1 | Hits@10 |
|---    |---        |---    |---     |---      |
|GCN1   | 12.85     | 0.388 | 23.9%  | 67.1%   |
|GCN2   | 12.99     | 0.365 | 23.5%  | 67.8%   |
|GCN3   | 12.74     | 0.371 | 23.4%  | 67.5%   |


**MLP Test Results**

| Model | Mean Rank  | MRR   | Hits@1 | Hits@10 |
|---    |---         |---    |---     |---      |
|MLP1   | 55.28      | 0.227 | 20.6%  | 58.7%   |
|MLP2   | 57.82      | 0.235 | 20.4%  | 55.9%   |
|MLP3   | 58.53      | 0.238 | 20.8%  | 57.3%   |

**GCN Test Results**

| Model | Mean Rank | MRR   | Hits@1 | Hits@10 |
|---    |---        |---    |---     |---      |
|GCN1   | 57.58     | 0.285 | 20.7%  | 57.6%   |
|GCN2   | 58.95     | 0.256 | 20.4%  | 56.9%   |
|GCN3   | 55.47     | 0.294 | 20.8%  | 56.8%   |

### Ensemble Approach:

Ensembling multiple models with the same architecture (MLP or GCN).

**Ensemble Training Results**

| Model         | Mean Rank  | MRR   | Hits@1 | Hits@10 |
|---            |---         |---    |---     |---      |
|MLP-Ensemble   | 10.55      | 0.445 | 28.5%  | 64.6%   |
|GCN-Ensemble   | 10.78      | 0.487 | 28.3%  | 65.2%   |

**Ensemble Test Results**

| Model         | Mean Rank  | MRR   | Hits@1 | Hits@10 |
|---            |---         |---    |---     |---      |
|MLP-Ensemble   | 30.87      | 0.321 | 23.6%  | 60.5%   |
|GCN-Ensemble   | 30.41      | 0.378 | 23.2%  | 60.7%   |


### Cross-Architecture Ensemble:

Ensembling across MLP and GCN architectures.

**All-Ensemble Training Results**

| Model         | Mean Rank  | MRR   | Hits@1 | Hits@10 |
|---            |---         |---    |---     |---      |
|All-Ensemble   | 9.21       | 0.498 | 30.2%  | 68.3%   |

**All-Ensemble Test Results**

| Model         | Mean Rank  | MRR   | Hits@1 | Hits@10 |
|---            |---         |---    |---     |---      |
|All-Ensemble   | 25.54      | 0.398 | 26.3%  | 62.8%   |

## Analysis

### **Analysis of Reproduction Results:**

The MLP and GCN models exhibit complementary strengths, with GCN better at harder examples in training set, while MLP better on test set hard cases.

1. Baseline Models:
- The MLP and GCN encoders show similar performance on the Text2Mol retrieval task
- MLP models achieve around 0.233 MRR and 20.6% Hits@1 on the test set
- GCN models achieve around 0.278 MRR and 20.7% Hits@1 on the test set

2. Ensemble Approach:
- Ensembling multiple models with the same architecture (MLP or GCN) significantly improves performance
- MLP ensemble: 0.321 MRR, 23.6% Hits@1
- GCN ensemble: 0.378 MRR, 23.2% Hits@1  

3. Cross-Architecture Ensemble:
- Ensembling across MLP and GCN architectures provides further gains
- All-Ensemble: 0.398 MRR, 26.3% Hits@1 (substantial improvement over base models)

### **Analysis of Validation MRR Values:**

<img src="https://drive.google.com/uc?id=1SnYWIHCrN0M229_xpj3E6y-6MeHcGRPT" width=500 />
<br>
<b>
Figure 4: Validation MRR values for different combinations of architectures. The axes indicate the number of each architecture used. Ensembles with both architectures are more effective.
</b>
<br>

The analysis of the validation MRR values for different combinations of the MLP and GCN architectures in the ensemble approach is shown in Figure 4 above.

Key observations from the analysis:

1. Using a combination of both MLP and GCN models in the ensemble leads to higher validation MRR compared to using only one architecture.

2. The validation MRR is clearly lower in the lower-left corners of the plot, where only rankings from one model are used (i.e., the other two models have zero weight).

3. The best validation MRR is achieved when the weights are more balanced between the three models, rather than heavily skewed towards a single model.

This demonstrates that the complementary strengths of the MLP and GCN architectures can be effectively leveraged through the ensemble approach. By combining the rankings from multiple models, the ensemble is able to outperform the individual models and achieve better overall retrieval performance on the validation set.

### **Analysis of queries that are predicted correctly by all-ensembles:**



<img src="https://drive.google.com/uc?id=1gHyuEf2mWQA-GbYt-kE8HOYqn3SUp6Gh" width=500 />
<br>
<b>
Figure 5: Example queries that are predicted correctly by All-Ensemble.
</b>
<br>

The examples of queries that are predicted correctly by the All-Ensemble model, which combines the MLP and GCN architectures is shown in Figure 5 above.

1. Cannabidiolate:
   - Description: "Cannabidiolate is a dihydroxybenzoate that is the conjugate base of cannabidiolic acid, obtained by deprotonation of the carboxy group. It derives from an olivetolate. It is a conjugate base of a cannabidiolic acid."
   - This is a complex molecule with multiple functional groups and substructures mentioned in the description, such as the dihydroxybenzoate, conjugate base, and olivetolate. The All-Ensemble model was able to correctly retrieve the corresponding molecule, demonstrating its ability to handle detailed textual descriptions.

2. Inositol:
   - Description: "Myo-inositol is an inositol having myo-configuration. It has a role as a member of compatible osmolytes, a nutrient, an EC 3.1.4.11 (phosphoinositide phospholipase C) inhibitor, a human metabolite, a Daphnia magna metabolite"
   - This example shows the model can handle descriptions that provide various functional and chemical details about the molecule, such as the myo-configuration, its biological roles, and enzyme interactions. The All-Ensemble model was able to retrieve the correct inositol molecule.

3. Argyssfrywff:
   - Description: "Ala-Arg-Gly-Tyr-Ser-Ser-Phe-Arg-Tyr-Trp-Phe-Phe is an oligopeptide composed of L-alanine, L-arginine, glycine, L-tyrosine, L-serine, L-serine, L-phenylalanine, L-arginine, L-tyrosine, L-trytophan, L-phenylalanine and L-phenylalanine joined in sequence by peptide linkages."
   - This example demonstrates the model's ability to handle complex molecule descriptions that list the individual amino acids and their sequence in an oligopeptide. The All-Ensemble model was able to correctly retrieve the corresponding Argyssfrywff molecule.

These examples highlight the strengths of the All-Ensemble model in handling a diverse range of molecule descriptions, from detailed functional group information to complex sequences of substructures. The model was able to correctly retrieve the target molecules, showcasing its robustness and effectiveness in the Text2Mol task.

The ability to accurately predict these challenging queries, which involve large, intricate molecules with extensive textual descriptions, underscores the power of the ensemble approach and the model's capacity to integrate the complementary strengths of the MLP and GCN architectures.

### **Analysis of queries that are ranked incorrectly by all-ensembles:**



<img src="https://drive.google.com/uc?id=1M6oT5DQ6zILEf8L7xBo8IepMnT3A51D7" width=500 />
<br>
<b>
Figure 6: Example queries that are ranked incorrectly by All-Ensemble.
</b>

Figure 6 above provides examples of queries that the All-Ensemble model ranked incorrectly, which is helpful to understand the remaining challenges.

1. Fura red:
   - Description: "Fura red is a 1-benzofuran substituted at position 2 by a (5-oxo-2-thioxoimidazolidin-4-ylidene) methyl group, and at C-5 and C-6 by heavily substituted oxygen and nitrogen functionalities"
   - Despite the detailed description, the All-Ensemble model ranked this compound at 8,320, which is quite low. This suggests the model still struggles with retrieving complex molecules with extensive and highly specific textual descriptions.

2. Clondronate(2-):
   - Description: "Clondronate(2-) is the dianion resulting from the removal of two protons from clondronic acid. It is a conjugate base of a clodronic acid."
   - The model ranked this compound at 4,915, which is also quite low. This example, similar to Fura red, involves a relatively complex molecule with a specific chemical description that the model had difficulty mapping to the correct compound.

3. Alpha-mycolic acid:
   - Description: "An alpha-mycolic acid is a class of mycolic acids characterized by the presence of two cis cyclopropyl groups in the meromycolic chain. It is an organic molecular entity and a mycolic acid."
   - In this case, the MLP model ranked the compound at 43, while the GCN model ranked it at 3. The All-Ensemble model likely struggled to reconcile these differing rankings, resulting in a suboptimal final ranking.

These examples highlight that while the ensemble approach significantly improves performance, there are still challenging cases where the model fails to retrieve the correct molecule, especially for complex molecules with highly specific textual descriptions.

The discrepancy in rankings between the MLP and GCN models for the alpha-mycolic acid example also suggests that there is room for improvement in better integrating the complementary strengths of these architectures, particularly for the most difficult queries.

Overall, these examples indicate that while the proposed approach represents a significant advancement in cross-modal molecule retrieval, there are still opportunities to further enhance the model's ability to handle the most complex and nuanced molecule-text associations.

## Plans

Here is a plan to complete the experiments, ablation study, discussion section, create a public GitHub repository, and prepare a video presentation:

**Experiments:** Conduct additional experiments to test the impact of different hyperparameter settings for the text encoder, molecule encoders, and training procedure.

**Ablation Study:** Analyze the impact of each component of the model architecture: Contribution of the GCN module compared to using only Mol2vec. Investigate the impact of the ensemble approach by evaluating individual model performance.

**Discussion Section:** Discuss the following topics: Implications of the experimental results, whether the original paper was reproducible, and if it wasn’t, what factors made it irreproducible. “What was easy”. “What was difficult”. Recommendations to the original authors or others who work in this area for improving reproducibility.


**Video Presentation:**
Prepare a concise and engaging video presentation (e.g., less than or equal to 4 minutes) that covers the following:
- Explain the general problem clearly
- Explain the specific approach taken in the paper clearly
- Explain reproduction attempts clearly


By following this plan, we will be able to complete the DLH proejct, make it accessible, and effectively communicate our findings through a video presentation and the notebook.

# Public GitHub Repo


The reproduce code is published in the following GitHub [DLH Text2Mol](https://github.com/darinz/DLH-Text2Mol) Repository.

You can download this notebook (DLH_Team_10.ipynb) from the [DLH Text2Mol](https://github.com/darinz/DLH-Text2Mol) Repository.

# References



```bibtex
@inproceedings{edwards2021text2mol,
  title={Text2Mol: Cross-Modal Molecule Retrieval with Natural Language Queries},
  author={Edwards, Carl and Zhai, ChengXiang and Ji, Heng},
  booktitle={Proceedings of the 2021 Conference on Empirical Methods in Natural Language Processing},
  pages={595--607},
  year={2021},
  url = {https://aclanthology.org/2021.emnlp-main.47/}
}
```

