In [7]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import pandas as pd

tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForMaskedLM.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

sample_SMILES = "<s>C"

t = tokenizer(sample_SMILES, return_tensors="pt")

output = model(**t)


In [15]:
import torch
import torch.nn as nn

class BERT_GCxGC(nn.Module):
    def __init__(self, base_model, hidden_dim, output_dim):
        super(BERT_GCxGC, self).__init__()
        self.base_model = base_model

        # Predicts Molecular Weight or M/Z
        self.m_z = nn.Sequential(
            nn.Linear(base_model.config.vocab_size, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        # Predicting retention time 1
        self.rt1 = nn.Sequential(
            nn.Linear(base_model.config.vocab_size, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        # Predicting retention time 2
        self.rt2 = nn.Sequential(
            nn.Linear(base_model.config.vocab_size, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, input_ids, attention_mask=None):
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        cls_token = outputs.logits[:, 0, :]  # Get the CLS token(<s> token)
        m_z = self.m_z(cls_token)
        rt1 = self.rt1(cls_token)
        rt2 = self.rt2(cls_token)
        return m_z, rt1, rt2

# Load the base model
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForMaskedLM.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

# Create the model with MLPs
hidden_dim = 128  # Hidden dimension size for each MLP
output_dim = 1  # Output dimension for each MLP
custom_model = BERT_GCxGC(model, hidden_dim, output_dim)

# Example usage
sample_SMILES = "<s>CC1=CC(=CC=C1)S(=O)(=O)NC2=CC=C(C=C2)S(=O)(=O)NC(C)C"
inputs = tokenizer(sample_SMILES, return_tensors="pt")
outputs = custom_model(**inputs)
print(outputs)  # Outputs from the three MLPs


(tensor([[0.2239]], grad_fn=<AddmmBackward0>), tensor([[0.2485]], grad_fn=<AddmmBackward0>), tensor([[-0.1168]], grad_fn=<AddmmBackward0>))


In [12]:
all_dataset = pd.read_csv('results/pubchem_spectrabase_combined_extracted_info.csv')
all_dataset  = all_dataset[all_dataset['Canonical_SMILES'].notna()]


In [13]:
all_dataset

Unnamed: 0,Compound,PubChem_Link,InChI,InChIKey,Canonical_SMILES,Spectrabase_Link
0,Acenaphthene,https://pubchem.ncbi.nlm.nih.gov/compound/6734,InChI=1S/C12H10/c1-3-9-4-2-6-11-8-7-10(5-1)12(...,CWRYPZZKDGJXCA-UHFFFAOYSA-N,C1CC2=CC=CC3=C2C1=CC=C3,
4,Verapamil,https://pubchem.ncbi.nlm.nih.gov/compound/2520,"InChI=1S/C27H38N2O4/c1-20(2)27(19-28,22-10-12-...",SGTNSNPWRIOYBX-UHFFFAOYSA-N,CC(C)C(CCCN(C)CCC1=CC(=C(C=C1)OC)OC)(C#N)C2=CC...,
6,1-(4-Trimethylsilylmethyl-3-cyclohexenyl)-5-me...,https://pubchem.ncbi.nlm.nih.gov/compound/1016...,InChI=1S/C17H32OSi/c1-14(2)7-6-8-17(18)16-11-9...,GYSUXFPJPGFPTA-UHFFFAOYSA-N,CC(=CCCC(C1CCC(=CC1)C[Si](C)(C)C)O)C,
9,"Phosphorodifluoridothioic hydrazide, 2,2-dimet...",https://pubchem.ncbi.nlm.nih.gov/compound/548096,"InChI=1S/C2H7F2N2PS/c1-6(2)5-7(3,4)8/h1-2H3,(H...",UPHGGRDGJYIORA-UHFFFAOYSA-N,CN(C)NP(=S)(F)F,
10,"Fluoro(methyl)(2,4,6-tri-tert-butylphenyl)silanol",https://pubchem.ncbi.nlm.nih.gov/compound/1361...,"InChI=1S/C19H33FOSi/c1-17(2,3)13-11-14(18(4,5)...",UGWOVUFMECJNIU-UHFFFAOYSA-N,CC(C)(C)C1=CC(=C(C(=C1)C(C)(C)C)[Si](C)(O)F)C(...,
...,...,...,...,...,...,...
11693,"2,11-Dodecanedione",https://pubchem.ncbi.nlm.nih.gov/compound/522378,InChI=1S/C12H22O2/c1-11(13)9-7-5-3-4-6-8-10-12...,IALFUWZSWAKBDF-UHFFFAOYSA-N,CC(=O)CCCCCCCCC(=O)C,
11695,"benzene, 1,1',1'',1'''-(1,2-ethenediylidene)te...",https://pubchem.ncbi.nlm.nih.gov/compound/9174...,InChI=1S/C34H36O4/c1-5-35-29-17-9-25(10-18-29)...,ULKJGRGKIGOYGV-UHFFFAOYSA-N,CCOC1=CC=C(C=C1)C(=C(C2=CC=C(C=C2)OCC)C3=CC=C(...,
11696,"2',4'-Dihydroxy-2,3-dimethoxychalcone",https://pubchem.ncbi.nlm.nih.gov/compound/5377844,InChI=1S/C17H16O5/c1-21-16-5-3-4-11(17(16)22-2...,JUCNRAJYHMZLOT-RMKNXTFCSA-N,COC1=CC=CC(=C1OC)C=CC(=O)C2=C(C=C(C=C2)O)O,
11700,"Propanamide, 2,3,3,3-tetrafluoro-2-heptafluoro...",https://pubchem.ncbi.nlm.nih.gov/compound/561513,"InChI=1S/C13H8F11NO2/c14-9(11(17,18)19,8(26)25...",MBYGYAYBAZPQOT-UHFFFAOYSA-N,C1=CC=C(C=C1)CNC(=O)C(C(F)(F)F)(OC(C(C(F)(F)F)...,


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

# Assuming you have a dataset in the form of a list of tuples [(smiles, m_z, rt1, rt2), ...]
class SMILESDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        smiles, m_z, rt1, rt2 = self.data[idx]
        inputs = self.tokenizer(smiles, return_tensors="pt")
        return inputs.input_ids.squeeze(0), torch.tensor([m_z, rt1, rt2], dtype=torch.float32)

# Prepare the dataset and dataloader
dataset = SMILESDataset(your_dataset, tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Loss Function
criterion = nn.MSELoss()

# Optimizer
optimizer = Adam(custom_model.parameters(), lr=1e-4)

# Training Loop
for epoch in range(num_epochs):
    for input_ids, targets in dataloader:
        optimizer.zero_grad()
        m_z_pred, rt1_pred, rt2_pred = custom_model(input_ids)
        loss = criterion(m_z_pred, targets[:, 0]) + criterion(rt1_pred, targets[:, 1]) + criterion(rt2_pred, targets[:, 2])
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} loss: {loss.item()}")

# Save the fine-tuned model
torch.save(custom_model.state_dict(), "chemberta_with_mlps_finetuned.pth")
