# Calculate chemberta features from pretrained huggingface model and add feature info to local AMPL install
- `pip install transformers`


In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
model = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-77M-MTR", output_hidden_states=True)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of RobertaForMaskedLM were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['lm_head.bias', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [2]:
import pandas as pd

sl=pd.read_csv("../datasets/sl_test.csv")
sl

Unnamed: 0,bucket_name,dataset_key,dataset_name,response_cols,collection,scaffold_split_uuid
0,public,/Users/echun/repos/DILI/datasets/training_data...,ROS,active,SUP,c5e78a93-b3e6-41e0-9d15-3208c7809db5
1,public,/Users/echun/repos/DILI/datasets/training_data...,ROS,active,SUP,0a722b6f-f9af-4321-a7b6-7002830d6282
2,public,/Users/echun/repos/DILI/datasets/training_data...,ROS,active,SUP,10bcfc3d-7072-48dc-a3dd-20e9ebf6e34d


In [8]:
import os
import torch.nn.functional as F
for i, row in sl.iterrows():
    output_path=f"/Users/echun/repos/DILI/datasets/training_data/scaled_descriptors/{row.dataset_name}_with_chemberta_descriptors2.csv"
    
    if os.path.exists(output_path):
        continue
    
    data=pd.read_csv(row.dataset_key)
    smiles=data.base_rdkit_smiles.tolist()

    tokens = tokenizer(smiles, padding=True, truncation=True, return_tensors="pt")

    outputs = model(**tokens)

    print(outputs.hidden_states[-1].size())

    embeddings = outputs.hidden_states[-1][:, 0, :].detach() # only take the [CLS] token
    embeddings = F.normalize(embeddings, p=2, dim=1)
    print(embeddings.size())

    feat_dict = {}
    for idx, row in data.iterrows():
        feat_dict[row.compound_id] = embeddings[idx]

    cbfeat=pd.DataFrame(feat_dict).T
    cbfeat.columns=[f"cbert_{i}" for i in range(384)]
    cbfeat=cbfeat.reset_index(names="compound_id")

    data=data.merge(cbfeat, on="compound_id")
    data.to_csv(output_path, index=False)

torch.Size([802, 207, 384])
torch.Size([802, 384])


In [5]:
## Add chemberta descriptors to your AMPL repo here

In [9]:
cbfeat=pd.DataFrame(feat_dict).T
cbfeat.columns=[f"cbert_{i}" for i in range(384)]
cbfeat=cbfeat.reset_index(names="compound_id")
cbfeat

Unnamed: 0,compound_id,cbert_0,cbert_1,cbert_2,cbert_3,cbert_4,cbert_5,cbert_6,cbert_7,cbert_8,...,cbert_374,cbert_375,cbert_376,cbert_377,cbert_378,cbert_379,cbert_380,cbert_381,cbert_382,cbert_383
0,SPID236,0.002251,0.028562,-0.016843,-0.009047,-0.053455,-0.029690,-0.054598,0.035769,0.019233,...,-0.064403,0.007601,0.039732,0.030598,0.024841,0.095832,-0.014733,0.018610,-0.047235,-0.002703
1,SPID231,0.003803,0.026089,-0.017280,-0.012377,-0.055483,-0.033093,-0.055724,0.034364,0.017688,...,-0.059642,0.008624,0.031380,0.033059,0.021762,0.094865,-0.022049,0.019294,-0.053656,0.001162
2,SPID339,0.064925,0.034863,0.005250,0.017845,-0.034566,0.090750,-0.034707,0.070891,-0.005289,...,0.021493,-0.032002,0.159562,0.092612,-0.005058,0.040746,-0.027926,0.032228,-0.000982,0.019543
3,FOUWCSDKDDHKQP,-0.023970,0.110951,-0.073772,0.012170,-0.012353,-0.078822,0.050759,0.111099,0.013458,...,0.016436,-0.029264,0.016473,0.023791,0.004634,0.031802,-0.019214,-0.024088,-0.041581,0.003462
4,SPID475,-0.003037,0.060185,0.007600,0.011861,-0.004303,0.080625,-0.033400,0.028681,0.019726,...,-0.005417,0.031222,0.161364,0.041321,0.003640,0.017013,0.000566,0.058733,0.054995,0.035561
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
797,SPID539,-0.006399,0.113224,0.003144,0.064388,0.023331,0.022815,0.062002,-0.069140,0.044119,...,0.016613,-0.064291,0.093962,0.093828,-0.054510,-0.021189,-0.108586,-0.109512,0.003312,-0.055201
798,SPID593,-0.015379,0.018089,-0.005725,0.137144,-0.027008,0.029869,-0.055053,0.047046,-0.020512,...,0.002080,-0.007837,0.000157,0.000200,-0.016261,-0.077247,0.039074,-0.008788,0.032375,0.008873
799,SPID471,0.063563,0.062372,0.005832,-0.058277,0.093874,-0.074001,-0.004721,-0.021318,0.052022,...,-0.051454,0.001915,0.142196,-0.057371,-0.033600,-0.089699,0.033473,0.007166,0.041151,-0.070947
800,XRXDAJYKGWNHTQ,0.027656,0.004316,0.006526,0.073648,0.048484,0.011988,-0.003525,-0.057530,0.119560,...,0.018549,0.042012,0.146661,-0.054191,-0.010068,-0.065169,-0.049448,-0.003017,0.043927,0.037914


In [14]:
data=data.merge(cbfeat, on="compound_id")
# rdkit
data.to_csv("/Users/echun/repos/DILI/datasets/ROS_hits_up_class_cur3.csv", index=False)

In [15]:
data = pd.read_csv("/Users/echun/repos/DILI/datasets/ROS_hits_up_class_cur3.csv")
data = data.rename(columns=lambda c: c.replace("_x", "") if c.startswith("cbert_") else c)

# Save back out
data.to_csv("/Users/echun/repos/DILI/datasets/ROS_hits_up_class_cur3.csv", index=False)

In [13]:
# open the descriptors file in your main ampl repo
desc=pd.read_csv("/Users/echun/repos/AMPL/atomsci/ddm/data/descriptor_sets_sources_by_descr_type.csv")
desc

Unnamed: 0,descr_type,descriptors,scaled,source
0,moe_raw,ASA;ASA+;ASA-;ASA_H;ASA_P;ast_fraglike;ast_fra...,0,moe
1,moe_norm,ASA+_per_atom;ASA-;ASA_H_per_atom;ASA_P;ASA_pe...,1,moe
2,moe,ASA+_per_atom;ASA-;ASA_H_per_atom;ASA_P;ASA_pe...,1,moe
3,moe_filtered,ASA;ASA+;ASA-;ASA_H;ASA_P;BCUT_PEOE_0;BCUT_PEO...,0,moe
4,moe_scaled,ASA+_per_atom;ASA-;ASA_H_per_atom;ASA_P;ASA_pe...,1,moe
5,moe_scaled_filtered,ASA+_per_atom;ASA-;ASA_H_per_atom;ASA_P;ASA_pe...,1,moe
6,moe_informative,ASA+;ASA-;ASA_H;ASA_P;ASA;BCUT_PEOE_0;BCUT_PEO...,0,moe
7,mordred_raw,ABC;ABCGG;nAcid;nBase;SpAbs_A;SpMax_A;SpDiam_A...,0,mordred
8,mordred_filtered,AATS0Z;AATS0are;AATS0d;AATS0dv;AATS0i;AATS0m;A...,0,mordred
9,rdkit_raw,MaxEStateIndex;MinEStateIndex;MaxAbsEStateInde...,0,rdkit


In [36]:
# make a new row with the chemberta descriptors
d2=pd.DataFrame({
    "descr_type":"chemberta",
    "descriptors":";".join(cbfeat.columns[1:].tolist()),
    "scaled":0,
    "source":"chemberta",
}, index=[0])
d2


Unnamed: 0,descr_type,descriptors,scaled,source
0,chemberta,cbert_0;cbert_1;cbert_2;cbert_3;cbert_4;cbert_...,0,chemberta


In [37]:
# add to bottom of the descriptors file
desc=pd.concat([desc,d2])
desc

Unnamed: 0,descr_type,descriptors,scaled,source
0,moe_raw,ASA;ASA+;ASA-;ASA_H;ASA_P;ast_fraglike;ast_fra...,0,moe
1,moe_norm,ASA+_per_atom;ASA-;ASA_H_per_atom;ASA_P;ASA_pe...,1,moe
2,moe,ASA+_per_atom;ASA-;ASA_H_per_atom;ASA_P;ASA_pe...,1,moe
3,moe_filtered,ASA;ASA+;ASA-;ASA_H;ASA_P;BCUT_PEOE_0;BCUT_PEO...,0,moe
4,moe_scaled,ASA+_per_atom;ASA-;ASA_H_per_atom;ASA_P;ASA_pe...,1,moe
5,moe_scaled_filtered,ASA+_per_atom;ASA-;ASA_H_per_atom;ASA_P;ASA_pe...,1,moe
6,moe_informative,ASA+;ASA-;ASA_H;ASA_P;ASA;BCUT_PEOE_0;BCUT_PEO...,0,moe
7,mordred_raw,ABC;ABCGG;nAcid;nBase;SpAbs_A;SpMax_A;SpDiam_A...,0,mordred
8,mordred_filtered,AATS0Z;AATS0are;AATS0d;AATS0dv;AATS0i;AATS0m;A...,0,mordred
9,rdkit_raw,MaxEStateIndex;MinEStateIndex;MaxAbsEStateInde...,0,rdkit


In [38]:
desc=desc.reset_index(drop=True)
desc.to_csv("/Users/echun/repos/AMPL/atomsci/ddm/data/descriptor_sets_sources_by_descr_type.csv", index=False)