In [None]:
# ! pip install transformers==4.38.1
# ! pip install rdkit==2023.9.4
# ! pip install accelerate==0.27.2
# ! pip install flash-attn
# ! pip install -q -U bitsandbytes
# ! pip install datasets
# ! pip install loralib
# ! pip install git+https://github.com/huggingface/peft.git
# ! pip install sentencepiece

In [None]:
# ! pip install tensorflow==2.10.0

In [None]:
import random, pickle, json, os
from datasets import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
# import bitsandbytes as bnb
# from peft import PeftModelForCausalLM

import sys
sys.path.append('../credentials/')
from HF_credentials import *

In [None]:
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig

# Tokenizer

In [None]:
llm_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS, model_max_length=256, add_prefix_space=False)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_tokenizer.padding_side = "right"

In [None]:
chat = [
  {"role": "user", "content": ""},
  {"role": "assistant", "content": ""}
]

llm_tokenizer.apply_chat_template(chat, tokenize=False)

# Data

In [51]:
def create_datasets(split='train'):

    conversations = []
    input_smiles = []

    with open(f'../data/LlaSMol/{split}/property_prediction-bbbp.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is blood-brain barrier permeability (BBBP) a property of <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-clintox.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is <SMILES> {txt['input']} </SMILES> toxic?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-esol.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"How soluble is <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"Its log solubility is <NUMBER> {txt['output']} </NUMBER> mol/L"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-hiv.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Can <SMILES> {txt['input']} </SMILES> serve as an inhibitor of HIV replication?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-lipo.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> {txt['input']} </SMILES>"
            chat[1]['content'] = f"<NUMBER> {txt['output']} </NUMBER>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-sider.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Are there any known side effects of <SMILES> {txt['input']} </SMILES> affecting the heart?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']['Vascular disorders']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])
    print(len(conversations))

    return conversations, input_smiles

In [52]:
print('Train:')
train_conversations, train_input_smiles = create_datasets('train')
print('Test:')
test_conversations, test_input_smiles = create_datasets('test')

Train:
<s>[INST] Is blood-brain barrier permeability (BBBP) a property of <SMILES> CC1(C)NC(=O)C(/C=C/C2=CC=CC=C2)O1 </SMILES>? [/INST]<BOOLEAN> Yes </BOOLEAN></s>
<s>[INST] Is <SMILES> CCCCOCCOCCOCC1=CC2=C(C=C1CCC)OCO2 </SMILES> toxic? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] How soluble is <SMILES> OCC1=CC=CC=C1OC1OC(CO)C(O)C(O)C1O </SMILES>? [/INST]Its log solubility is <NUMBER> -0.85 </NUMBER> mol/L</s>
<s>[INST] Can <SMILES> COC(=O)C1C2=C(CC3C4=CC=CC=C4NC(=O)C31)C1=CC=CC=C1N2 </SMILES> serve as an inhibitor of HIV replication? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> C[C@H]1O[C@@H](N2C=NC3=C(N)N=C(OCC45CC6CC(CC(C6)C4)C5)N=C32)[C@H](O)[C@@H]1O </SMILES> [/INST]<NUMBER> 3.58 </NUMBER></s>
[1, 733, 16289, 28793, 4867, 736, 707, 2651, 2081, 6092, 302, 523, 28735, 5877, 20335, 28767, 334, 4020, 28743, 1998, 28743, 28732, 28746, 28762, 28731, 9419, 28740, 28746, 9419, 28732, 28

In [41]:
class CombinedDataset(Dataset):
    def __init__(self, smiles_list, conversations, encoder_tokenizer, llm_tokenizer, max_length=256):
        self.smiles_list = smiles_list
        self.conversations = conversations
        self.encoder_tokenizer = encoder_tokenizer
        self.llm_tokenizer = llm_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        smiles_encoding = self.encoder_tokenizer(smiles, return_tensors='pt', truncation=True, padding='max_length', max_length=self.max_length)
        conversation_tokenized = self.llm_tokenizer(self.conversations[idx], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        return {key: tensor[0].to('cuda') for key, tensor in smiles_encoding.items()}, conversation_tokenized.to('cuda')

In [42]:
# Load tokenizers
chemberta_tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
mistral_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', add_prefix_space=False)
mistral_tokenizer.pad_token = mistral_tokenizer.eos_token
mistral_tokenizer.padding_side = "right"

# Create combined dataset
combined_dataset = CombinedDataset(test_input_smiles, test_conversations, chemberta_tokenizer, mistral_tokenizer)

# Define DataLoader
batch_size = 2
combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)

## Test

In [None]:
# x, y = next(iter(combined_loader))

# mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
# llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
#             torch_dtype=torch.bfloat16,
#             # quantization_config=bnb_config,
#             device_map="auto",
#             token=HF_CREDENTIALS
# )

# mol_encoder(**x)['last_hidden_state'];
# llm_model.model.embed_tokens(y['input_ids'].to('cuda'));

# Model

In [None]:
class MolEncoderLLMPipeline(nn.Module):
    def __init__(self, lora_rank=32, lora_alpha=64):
        super().__init__()
        # Load molecule encoder
        self.mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR").to('cuda')

        # UNCOMMENT TO BRING DOWN FROM 15GB TO 7GB
        # bnb_config = BitsAndBytesConfig(
        #     load_in_4bit= True,
        #     bnb_4bit_quant_type= "nf4",
        #     bnb_4bit_compute_dtype= torch.bfloat16,
        #     bnb_4bit_use_double_quant= False,
        # )
        self.llm_config = AutoConfig.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS)
        self.llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
            torch_dtype=torch.bfloat16,
            # quantization_config=bnb_config,
            device_map="auto",
            token=HF_CREDENTIALS
        )

        # Freeze encoder and LLM weights
        for param in self.mol_encoder.parameters():
            param.requires_grad = False
        for param in self.llm_model.parameters():
            param.requires_grad = False

        self.linear_project = nn.Linear(self.mol_encoder.config.hidden_size, self.llm_config.hidden_size)

    def forward(self, smiles_tokens, text_tokens):
        # Encoder forward pass / Get SMILES embeddings
        mol_encoder_output = self.mol_encoder(**smiles_tokens)
        smiles_embedding = mol_encoder_output['last_hidden_state'][:,0,:] # torch.Size([batch, max_length, 384])
        smiles_projection = self.linear_project(smiles_embedding)

        # Get embeddings from LLM for the question
        embedding_layer = self.llm_model.model.embed_tokens
        llm_embeddings = embedding_layer(text_tokens['input_ids'].to('cuda')) # torch.Size([batch, 1, max_length, 4096])
        llm_embeddings = llm_embeddings.squeeze(1)



        

        return smiles_projection, llm_embeddings

In [None]:
model = MolEncoderLLMPipeline()

In [None]:
# x, y = next(iter(combined_loader))
# model(x,y)

# Eval

In [17]:
def parse_answer(text, is_boolean):
    if is_boolean:
        start = text.find('<BOOLEAN>') + len('<BOOLEAN>')
        end = text.find('</BOOLEAN>')
        return text[start:end].strip()
    else:
        start = text.find('<NUMBER>') + len('<NUMBER>')
        end = text.find('</NUMBER>')
        return text[start:end].strip()

def get_answer(true_sentence, pred_sentence):
    true_answer = true_sentence.split('[/INST]')[1]
    pred_answer = pred_sentence.split('[/INST]')[1]

    if 'BOOLEAN' in true_answer:
        y_true = parse_answer(true_answer, True)
        y_pred = parse_answer(pred_answer, True)
        return y_true, y_pred, True
    
    elif 'NUMBER' in true_answer:
        y_true = parse_answer(true_answer, False)
        y_pred = parse_answer(pred_answer, False)
        return y_true, y_pred, False

In [54]:
true = train_conversations[0]
pred = train_conversations[0]

In [56]:
train_conversations[0]

'<s>[INST] Is blood-brain barrier permeability (BBBP) a property of <SMILES> CNC(C)C(=O)C1=CC=C(OC)C=C1 </SMILES>? [/INST]<BOOLEAN> No </BOOLEAN></s>'

In [55]:
get_answer(true, pred)

('No', 'No', True)

In [None]:
true_BBBP = []
pred_BBBP = []

for batch in combined_dataset:
    x, y = batch
    if 'BBBP' in x:
    



In [None]:
# Break after this
import sys
sys.exit()

# Train

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Assume model and criterion are defined elsewhere
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Training loop
epochs = 5
model.train()
for epoch in range(epochs):
    for batch in combined_loader:
        smiles_data, conversation_data = batch
        smiles_input_ids, smiles_attention_mask = smiles_data['input_ids'].squeeze(1), smiles_data['attention_mask'].squeeze(1)
        convo_input_ids, convo_attention_mask = conversation_data['input_ids'].squeeze(1), conversation_data['attention_mask'].squeeze(1)

        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(smiles_input_ids, convo_input_ids) # Adjust if your model's `forward` method expects more parameters
        
        # Define labels appropriately
        labels = ... # Define how to obtain these
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Eval

In [None]:
model.config.use_cache = True