In [3]:
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch

### 1. Load your file

In [4]:
df = pd.read_csv("metabolites_to_SMILES.csv")

# Keep only rows that have valid SMILES
df = df[df["SMILES"].notna() & df["SMILES"].str.len() > 0]
df.head()

Unnamed: 0.1,Unnamed: 0,Exact Match to Standard (* = isomer family),SMILES
0,HILIC-neg_Cluster_0622,"1,2,3,4-tetrahydro-1-methyl-beta-carboline-3-c...",CC1NC(Cc2c1[nH]c3ccccc23)C(O)=O
1,C18-neg_Cluster_0183,"1,2,3,4-tetrahydro-b-carboline-1,3-dicarboxyli...",OC(=O)C1Cc2c([nH]c3ccccc23)C(N1)C(O)=O
2,C18-neg_Cluster_0393,12.13-diHOME,CCCCCC(O)C(O)C\C=C/CCCCCCCC(O)=O
3,HILIC-neg_Cluster_0480,1-3-7-trimethylurate,CN1C(=O)N(C)C2=C(N(C)C(=O)N2)C1=O
4,C18-neg_Cluster_0530,13-docosenoate,CCCCCCCCC=CCCCCCCCCCCCC([O-])=O


### 2. Load ChemBERTa model

In [5]:
model_name = "seyonec/ChemBERTa-zinc-base-v1"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
model.eval()

RobertaModel(
  (embeddings): RobertaEmbeddings(
    (word_embeddings): Embedding(767, 768, padding_idx=1)
    (position_embeddings): Embedding(514, 768, padding_idx=1)
    (token_type_embeddings): Embedding(1, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): RobertaEncoder(
    (layer): ModuleList(
      (0-5): 6 x RobertaLayer(
        (attention): RobertaAttention(
          (self): RobertaSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): RobertaSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (dropou

### 3. Function to get embedding

In [6]:
def chemberta_embed(smiles: str):
    try:
        tokens = tokenizer(smiles, return_tensors='pt', truncation=True)
        with torch.no_grad():
            output = model(**tokens)
        # CLS token representation
        return output.last_hidden_state[0, 0, :].numpy()
    except:
        return np.zeros(model.config.hidden_size)

### 4. Run embedding for each SMILES

In [7]:
embeddings = []
for smi in df["SMILES"]:
    emb = chemberta_embed(smi)
    embeddings.append(emb)

embeddings = np.vstack(embeddings)

### 5. Save to CSV

In [8]:
out = pd.DataFrame(embeddings, columns=[f"emb_{i}" for i in range(embeddings.shape[1])])
out.insert(0, "Metabolite", df["Exact Match to Standard (* = isomer family)"].values)
out.insert(1, "SMILES", df["SMILES"].values)

out.to_csv("ChemBERTa_embeddings-2.csv", index=False)

In [9]:
out.head()

Unnamed: 0,Metabolite,SMILES,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,...,emb_758,emb_759,emb_760,emb_761,emb_762,emb_763,emb_764,emb_765,emb_766,emb_767
0,"1,2,3,4-tetrahydro-1-methyl-beta-carboline-3-c...",CC1NC(Cc2c1[nH]c3ccccc23)C(O)=O,1.42942,-0.540536,-0.331599,-0.229404,0.028993,-0.455408,-0.457653,-0.201152,...,1.083394,-0.191145,1.621734,-0.103927,0.167003,-0.213973,-0.367873,-1.237645,-0.410275,1.608706
1,"1,2,3,4-tetrahydro-b-carboline-1,3-dicarboxyli...",OC(=O)C1Cc2c([nH]c3ccccc23)C(N1)C(O)=O,1.789039,0.397368,-0.280884,-0.42107,0.193951,0.313314,-0.849916,-0.765328,...,0.898589,-0.729976,1.310638,-0.284856,0.530808,-0.026889,-0.620363,-1.258979,-1.336268,1.819879
2,12.13-diHOME,CCCCCC(O)C(O)C\C=C/CCCCCCCC(O)=O,0.664841,0.127098,0.489494,-0.912496,0.85627,-0.149675,-0.299197,0.153961,...,0.73586,-0.030524,0.441598,0.280812,0.653006,1.726362,-0.213789,-2.027552,0.036159,2.003242
3,1-3-7-trimethylurate,CN1C(=O)N(C)C2=C(N(C)C(=O)N2)C1=O,1.164944,-0.644483,-1.374942,-1.656458,0.395046,-1.281229,1.034291,0.097224,...,0.067419,-0.655262,0.774642,-1.138976,1.161253,0.718392,-0.841113,-1.142369,0.072901,0.391751
4,13-docosenoate,CCCCCCCCC=CCCCCCCCCCCCC([O-])=O,0.337044,-0.780372,-0.388524,-1.431474,0.977176,-1.671496,0.320499,0.471153,...,0.493602,0.031439,0.291707,-0.383697,1.354304,0.987944,-1.429884,-1.313846,0.257583,2.293484
