In [None]:
# Figure 7 and 8: UMAP visualizations of the embeddings
import matplotlib.pyplot as plt
from rdkit import Chem
import torch
import umap
import os

def get_color_for_mol(mol, functional_group_smarts):
    for fg, smarts in functional_group_smarts.items():
        if mol.HasSubstructMatch(Chem.MolFromSmarts(smarts)):
            return fg
    return "Other"

def plot_functional_umap(smiles, model_path, embedding_type):
    # Automatically generate path_to_embeddings from model_path
    path_to_embeddings = os.path.join(model_path, f"{embedding_type}_embeddings.pt")
    
    # Load embeddings and transform
    embeddings = torch.load(path_to_embeddings)
    embeddings = embeddings[valid_indices]
    embeddings_numpy = embeddings.detach().numpy()

    reducer = umap.UMAP(n_components=2, random_state=7, n_jobs=1)
    red = reducer.fit_transform(embeddings_numpy)

    # Show plain umap first
    plt.figure(figsize=(5.5, 5))
    plt.scatter(red[:, 0], red[:, 1], marker='.', c='gray')
    plt.axis('off')  
    plt.grid(False)
    plt.tight_layout()

    # Save figure
    output_filename = os.path.join(model_path, f"{embedding_type}_umap_plain.png")
    plt.savefig(output_filename, dpi=300)
    plt.show()
    plt.close()
    
    # Color umaps by functional group
    fig, axs = plt.subplots(nrows=1, ncols=8, figsize=(16, 2))
    for i, (fg_name, fg_smarts) in enumerate(functional_groups.items()):
        colors = [colors_dict[get_color_for_mol(Chem.MolFromSmiles(s), {fg_name: fg_smarts})] for s in smiles]
        axs[i].scatter(red[:, 0], red[:, 1], marker='.', c=colors, s=5)
        axs[i].axis('off')  
        axs[i].grid(False)  

    plt.tight_layout()

    # Save figure
    output_filename = os.path.join(model_path, f"{embedding_type}_umap_colorful.png")
    plt.savefig(output_filename, dpi=300)
    plt.show()

    plt.close()

# Define functional groups as SMARTS
functional_groups = {
    "Alcohol": "[OX2H][CX4;!$(C([OX2H])[O,S,#7,#15])]",
    "Carboxylic Acid": "[CX3](=O)[OX2H1]",
    "Ether": "[OD2]([#6])[#6]",
    "Alkene": "[CX3]=[CX3]",
    "Benzene": "c1ccccc1",
    "Primary Amine": "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]",
    "Amide": "[NX3][CX3](=[OX1])[#6]",
    "Sulfide": "[#16X2H0]"
}

# Colors palette
colors_dict = {
    "Alcohol": "red",
    "Carboxylic Acid": "orange",
    "Ether": "green",
    "Alkene": "blue",
    "Benzene": "purple",
    "Primary Amine": "brown",
    "Amide": "pink",
    "Sulfide": "cyan",
    "Other": "lightgray"
}

# Filter the SMILES strings
path_to_smiles = '../models/contrastive0.0/embeddings/smiles.txt'
with open(path_to_smiles, 'r') as f:
    smiles_list = f.readlines()

valid_indices = []
smiles = []
for idx, s in enumerate(smiles_list):
    mol = Chem.MolFromSmiles(s.strip())
    if mol:
        valid_indices.append(idx)
        smiles.append(s.strip())

In [None]:
base_model_path = "../models/contrastive0.0/embeddings"

# Call the function for "smiles"
plot_functional_umap(
    smiles=smiles,
    model_path=base_model_path,
    embedding_type="smiles"
)

# Call the function for "spectrum"
plot_functional_umap(
    smiles=smiles,
    model_path=base_model_path,
    embedding_type="spectrum"
)

In [None]:
base_model_path = "../models/contrastive0.1/embeddings"

# Call the function for "smiles"
plot_functional_umap(
    smiles=smiles,
    model_path=base_model_path,
    embedding_type="smiles"
)

# Call the function for "smiles"
plot_functional_umap(
    smiles=smiles,
    model_path=base_model_path,
    embedding_type="spectrum"
)

In [None]:
base_model_path = "../models/contrastive0.5/embeddings"

# Call the function for "smiles"
plot_functional_umap(
    smiles=smiles,
    model_path=base_model_path,
    embedding_type="smiles"
)

# Call the function for "smiles"
plot_functional_umap(
    smiles=smiles,
    model_path=base_model_path,
    embedding_type="spectrum"
)

In [None]:
base_model_path = "../models/contrastive1.0/embeddings"

# Call the function for "smiles"
plot_functional_umap(
    smiles=smiles,
    model_path=base_model_path,
    embedding_type="smiles"
)

# Call the function for "smiles"
plot_functional_umap(
    smiles=smiles,
    model_path=base_model_path,
    embedding_type="spectrum"
)