In [1]:
import torch
import pandas as pd
from transformers import RobertaTokenizer, RobertaModel

# Use the correct tokenizer and model for ChemBERTa
tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")




In [2]:
smiles = pd.read_csv("../ZafrensData/smiles.csv")

In [3]:
smiles

Unnamed: 0,control_rx_id,bb1_id,bb2_id,bb3_id,bb4_id,SMILES
0,1,-1,-1,-1,-1,O=C1C=C(C(NCC2=CC3=C(C=C(CNCC4CCC4)N3)C=C2)=O)...
1,2,-1,-1,-1,-1,O=C(NCC1=C[N]2C=C(CNCC3CCCCC3)C=CC2=N1)C4=CC(=...
2,3,-1,-1,-1,-1,CC1CCN(Cc2nc(C(NC(CC3)CCN3c3c(cc[n]4C)c4nc(Cl)...
3,4,-1,-1,-1,-1,CC1CCN(Cc2nc(C(NC(CC3)CCN3c3c(cc[n]4S(c5ccc(C)...
4,5,-1,-1,-1,-1,CC1CCN(Cc2nc(C(NC(CC3)CCN3c3c(cc[n]4S(c5ccccc5...
...,...,...,...,...,...,...
14116,-1,215,195,2,1462,CC(=O)NC[C@H]1CCCN(c2ccnc(-c3cnn(CC(=O)N(C)C)c...
14117,-1,222,195,2,1462,CC(=O)NCC1CCN(c2ccnc(-c3cnn(CC(=O)N(C)C)c3)n2)CC1
14118,-1,238,195,2,1462,CC(=O)NC[C@@H]1CCN(c2ccnc(-c3cnn(CC(=O)N(C)C)c...
14119,-1,269,195,2,1462,CC(=O)NC[C@H]1CN(c2ccnc(-c3cnn(CC(=O)N(C)C)c3)...


In [6]:

columns_to_concat = ['control_rx_id', 'bb1_id', 'bb2_id', 'bb3_id', 'bb4_id']


In [7]:

smiles['sample'] = smiles[columns_to_concat].astype(str).agg('_'.join, axis=1)


In [None]:

# Example SMILES strings
smiles_data = smiles['SMILES'].tolist()

# Function to get embeddings for a batch of SMILES
def get_chemberta_embeddings(smiles_list):
    embeddings = []
    for smiles in smiles_list:
        # Tokenize the SMILES string
        inputs = tokenizer(smiles, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
        
        # Get embeddings from the model
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Extract the [CLS] token embedding (typically used as a summary representation)
        cls_embedding = outputs.last_hidden_state[:, 0, :]
        embeddings.append(cls_embedding.squeeze().numpy())

    return embeddings

# Get embeddings for the SMILES data
embeddings = get_chemberta_embeddings(smiles_data)

# Convert to a DataFrame for visualization or further processing
embeddings_df = pd.DataFrame(embeddings)
print(embeddings_df.head())