In [2]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import pandas as pd

tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForMaskedLM.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")

sample_SMILES = "<s>C"

t = tokenizer(sample_SMILES, return_tensors="pt")

output = model(**t)


In [36]:
import torch
import torch.nn as nn
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6"

class BERT_GCxGC(nn.Module):
    def __init__(self, base_model, hidden_dim, output_dim):
        super(BERT_GCxGC, self).__init__()
        self.base_model = base_model

        # Predicts Molecular Weight or M/Z
        self.m_z = nn.Sequential(
            nn.Linear(base_model.config.vocab_size, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        ).to(device)
        # Predicting retention time 1
        self.rt1 = nn.Sequential(
            nn.Linear(base_model.config.vocab_size, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        ).to(device)
        # Predicting retention time 2
        self.rt2 = nn.Sequential(
            nn.Linear(base_model.config.vocab_size, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        ).to(device)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.base_model(input_ids, attention_mask=attention_mask)
        cls_token = outputs.logits[:, 0, :]  # Get the CLS token(<s> token)
        m_z = self.m_z(cls_token)
        rt1 = self.rt1(cls_token)
        rt2 = self.rt2(cls_token)
        return m_z, rt1, rt2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the base model
tokenizer = AutoTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = AutoModelForMaskedLM.from_pretrained("seyonec/ChemBERTa-zinc-base-v1").to(device)

# Create the model with MLPs
hidden_dim = 64  # Hidden dimension size for each MLP
output_dim = 1  # Output dimension for each MLP
custom_model = BERT_GCxGC(model, hidden_dim, output_dim).to(device)


# Example usage
# sample_SMILES = "<s>CC1=CC(=CC=C1)S(=O)(=O)NC2=CC=C(C=C2)S(=O)(=O)NC(C)C"
# inputs = tokenizer(sample_SMILES, return_tensors="pt")
# outputs = custom_model(**inputs)
# print(outputs)  # Outputs from the three MLPs


In [37]:
all_dataset = pd.read_csv('results/training_set_march20.csv')


In [40]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

# Assuming you have a dataset in the form of a list of tuples [(smiles, m_z, rt1, rt2), ...]
class SMILESDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        smiles = self.data.iloc[idx]['Canonical_SMILES']
        m_z = self.data.iloc[idx]['m_z']
        rt1 = self.data.iloc[idx]['1st Dimension Time (s)']
        rt2 = self.data.iloc[idx]['2nd Dimension Time (s)']

        # Tokenize the SMILES and pad to the max length
        inputs = self.tokenizer(smiles, padding='max_length', truncation=True, return_tensors="pt")

        # Remove the batch dimension that the tokenizer adds by default
        input_ids = inputs.input_ids.squeeze(0)

        # Your targets as a tensor
        targets = torch.tensor([m_z, rt1, rt2], dtype=torch.float32, device=device)
        
        return input_ids.to(device), targets.to(device)


# Prepare the dataset and dataloader
dataset = SMILESDataset(all_dataset, tokenizer)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True)

# Loss Function
criterion = nn.MSELoss()

# Optimizer
optimizer = Adam(custom_model.parameters(), lr=1e-5)

# Training Loop
for epoch in range(2):
    for input_ids, targets in dataloader:
        input_ids = input_ids.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        m_z_pred, rt1_pred, rt2_pred = custom_model(input_ids)
        loss = criterion(m_z_pred, targets[:, 0]) + criterion(rt1_pred, targets[:, 1]) + criterion(rt2_pred, targets[:, 2])
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch} loss: {loss.item()}")

# Save the fine-tuned model
# torch.save(custom_model.state_dict(), "chemberta_with_mlps_finetuned.pth")


Epoch 0 loss: 5384391.0
Epoch 1 loss: 7054440.5
