In [4]:
import torch
import pandas as pd
from transformers import RobertaTokenizer, RobertaModel

# Use the correct tokenizer and model for ChemBERTa
tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")


vocab.json:   0%|          | 0.00/9.43k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/3.21k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/179M [00:00<?, ?B/s]

In [5]:

# Example SMILES strings
smiles_data = [
    "O=C1C=C(C(NCC2=CC3=C(C=C(CNCC4CCC4)N3)C=C2)=O)N=C5C=CC=CN15",
    "CC(=O)OC1=CC=CC=C1C(=O)O",
    "CC1=CC(=O)C=CC1=O"
]

# Function to get embeddings for a batch of SMILES
def get_chemberta_embeddings(smiles_list):
    embeddings = []
    for smiles in smiles_list:
        # Tokenize the SMILES string
        inputs = tokenizer(smiles, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
        
        # Get embeddings from the model
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Extract the [CLS] token embedding (typically used as a summary representation)
        cls_embedding = outputs.last_hidden_state[:, 0, :]
        embeddings.append(cls_embedding.squeeze().numpy())

    return embeddings

# Get embeddings for the SMILES data
embeddings = get_chemberta_embeddings(smiles_data)

# Convert to a DataFrame for visualization or further processing
embeddings_df = pd.DataFrame(embeddings)
print(embeddings_df.head())

        0         1         2         3         4         5         6    \
0  0.812642  0.407679 -0.034935  0.192313 -0.297040 -0.310102 -0.575852   
1  0.537932  0.146034 -0.158660 -0.712971 -0.038075 -1.054943 -0.403816   
2 -0.724743 -0.385784 -0.228628 -0.961522  0.649793 -1.763603  0.086294   

        7         8         9    ...       758       759       760       761  \
0 -0.571581  0.265512 -0.649386  ...  0.773662 -0.504897  0.132977 -0.518708   
1  0.120810 -1.168498  0.707663  ...  1.150743 -0.552200  0.131314 -0.764380   
2 -0.333877 -0.561120  0.118316  ...  0.226266 -0.501020 -0.537967 -1.169865   

        762       763       764       765       766       767  
0 -0.693730  2.000358  0.187271 -0.262813 -0.602395  1.409950  
1  0.983986  0.256525 -0.830974 -1.576819 -0.002441  2.146913  
2  0.673772  1.377905 -0.614138 -0.868043  0.607155  2.236074  

[3 rows x 768 columns]
