In [1]:
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import torch.nn.functional as F
import torch
import torch.optim as optim
from dataset import DrugDataset
from torch.utils.data import Dataset, DataLoader
import os
import pandas as pd
from MolCLR.dataset.dataset import MoleculeDatasetWrapper
from MolCLR.models.ginet_finetune import GINet

class ContrastiveLearningWithBioBERT(nn.Module):
    def __init__(self, molecule_model):
        super(ContrastiveLearningWithBioBERT, self).__init__()
        
        # Load the tokenizer and model
        self.tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
        self.text_model = BertModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1")

        # Molecule model
        self.molecule_model = molecule_model
        # Freeze the molecule model
        for param in self.molecule_model.parameters():
            param.requires_grad = False

    def forward(self, text, molecule_representation):
        # Tokenize the input text
        encoded_input = self.tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=512)
        
        # Get the representation from the text model (bioBERT)
        text_features = self.text_model(**encoded_input).last_hidden_state.mean(dim=1)

        # Calculate the cosine similarity between text and molecule representations
        similarities = F.cosine_similarity(text_features.unsqueeze(0), molecule_representation.unsqueeze(1), dim=2)
        
        # Compute the loss
        loss = F.cross_entropy(similarities, torch.zeros(similarities.shape[0], dtype=torch.long).to(similarities.device))
        
        # normalized features
        # molecule_representation = molecule_representation / molecule_representation.norm(dim=1, keepdim=True)
        # text_features = text_features / text_features.norm(dim=1, keepdim=True)

        # cosine similarity as logits
        # logit_scale = self.logit_scale.exp()
        # logits_per_image = logit_scale * molecule_representation @ text_features.t()
        # logits_per_text = logits_per_image.t()

        # shape = [global_batch_size, global_batch_size]
        # return logits_per_image, logits_per_text
        return loss

# Placeholder for the molecule model (for demonstration purposes)
class MoleculeModelMock(nn.Module):
    def __init__(self, embedding_dim):
        super(MoleculeModelMock, self).__init__()
        self.encoder = nn.Linear(100, embedding_dim) # Just a mock encoder
        
    def forward(self, molecule):
        return self.encoder(molecule)

# Initialize models
molecule_model = MoleculeModelMock(embedding_dim=768) # Using 768 to match bioBERT's output dimension
model_with_biobert = ContrastiveLearningWithBioBERT(molecule_model)

# Test with random data
texts = ["Metformin is a first-line oral hypoglycemic agent.", "Aspirin is an analgesic."] * 16
molecule_representations = torch.randn(32, 768) # Corresponding molecules for the batch

loss = model_with_biobert(texts, molecule_representations)
print(loss)

# Training setup
print("here")
# Hyperparameters
num_epochs = 10
learning_rate = 1e-4

# Initialize the model and optimizer

smiles_data = pd.read_csv("/home/luli/MolCLR/output_smiles.csv")

torch.backends.cudnn.deterministic = True
dataset = MoleculeDatasetWrapper(6,4,0.1, "/home/luli/MolCLR/output_smiles.csv")
molecule_model = GINet( "classification" ).to("cuda:0")
checkpoints_folder = os.path.join('MolCLR/ckpt', "pretrained_gin", 'checkpoints')
state_dict = torch.load(os.path.join(checkpoints_folder, 'model.pth'), map_location= "cuda:0")
molecule_model.load_my_state_dict(state_dict)
print("here3")
# Assuming you have already defined the molecule model and loaded the pretrained weights
model = ContrastiveLearningWithBioBERT(molecule_model)
optimizer = optim.Adam(model.text_model.parameters(), lr=learning_rate)

# Test the DrugDataset
# For demonstration purposes, I'm using placeholder paths. Replace with your actual paths.
description_csv_path = "/home/luli/drugBank/drugbank.csv"
molecule_csv_path = "/home/luli/drugBank/output_smiles.csv"
dataset = DrugDataset(description_csv_path, molecule_csv_path)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

print("here2")


Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tensor(3.4654, grad_fn=<NllLossBackward>)
here
here3


Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


here2


In [2]:
batch_texts, batch_molecules = dataloader.dataset.__getitem__(0)

In [3]:
batch_texts = [batch_texts, batch_texts]

In [4]:
tokenizer = BertTokenizer.from_pretrained("dmis-lab/biobert-base-cased-v1.1")
text_model = BertModel.from_pretrained("dmis-lab/biobert-base-cased-v1.1")

encoded_input = tokenizer(batch_texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
       
        # Get the representation from the text model (bioBERT)


Some weights of the model checkpoint at dmis-lab/biobert-base-cased-v1.1 were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [5]:
text_features = text_model(**encoded_input).last_hidden_state.mean(dim=1)


In [8]:
text_features

tensor([[ 0.1142,  0.1500, -0.0950,  ...,  0.2557,  0.1438,  0.0029],
        [ 0.1142,  0.1500, -0.0950,  ...,  0.2557,  0.1438,  0.0029]],
       grad_fn=<MeanBackward1>)

In [7]:
batch_molecules.unsqueeze(0)

AttributeError: 'list' object has no attribute 'unsqueeze'

In [6]:
similarities = F.cosine_similarity(text_features.unsqueeze(0), batch_molecules.unsqueeze(1), dim=2)
        


AttributeError: 'list' object has no attribute 'unsqueeze'

In [None]:
        # Compute the loss
loss = F.cross_entropy(similarities, torch.zeros(similarities.shape[0], dtype=torch.long).to(similarities.device))