In [None]:
# ! pip install transformers==4.38.1
# ! pip install rdkit==2023.9.4
# ! pip install accelerate==0.27.2
# ! pip install flash-attn
# ! pip install -q -U bitsandbytes
# ! pip install datasets
# ! pip install loralib
# ! pip install git+https://github.com/huggingface/peft.git
# ! pip install sentencepiece

In [None]:
# ! pip install tensorflow==2.10.0

In [2]:
import random, pickle, json, os
from tqdm import tqdm
from datasets import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import bitsandbytes as bnb
from peft import PeftModelForCausalLM
from peft import get_peft_model, LoraConfig
from sklearn.metrics import accuracy_score, mean_squared_error
from math import sqrt
from torch.cuda.amp import autocast

import sys
sys.path.append('../credentials/')
from HF_credentials import *

In [3]:
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig, BitsAndBytesConfig

# Tokenizer

In [4]:
llm_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS, model_max_length=256, add_prefix_space=False)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_tokenizer.padding_side = "right"

In [5]:
chat = [
  {"role": "user", "content": ""},
  {"role": "assistant", "content": ""}
]

llm_tokenizer.apply_chat_template(chat, tokenize=False)

'<s>[INST]  [/INST]</s>'

# Data

In [6]:
def create_datasets(split='train'):

    conversations = []
    input_smiles = []

    with open(f'../data/LlaSMol/{split}/property_prediction-bbbp.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is blood-brain barrier permeability (BBBP) a property of <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-clintox.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is <SMILES> {txt['input']} </SMILES> toxic?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-esol.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"How soluble is <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"Its log solubility is <NUMBER> {txt['output']} </NUMBER> mol/L"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-hiv.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Can <SMILES> {txt['input']} </SMILES> serve as an inhibitor of HIV replication?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-lipo.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> {txt['input']} </SMILES>"
            chat[1]['content'] = f"<NUMBER> {txt['output']} </NUMBER>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-sider.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Are there any known side effects of <SMILES> {txt['input']} </SMILES> affecting the heart?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']['Vascular disorders']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])
    print(len(conversations))

    return conversations, input_smiles

In [7]:
print('Train:')
train_conversations, train_input_smiles = create_datasets('train')
print('Val:')
val_conversations, val_input_smiles = create_datasets('val')
print('Test:')
test_conversations, test_input_smiles = create_datasets('test')

Train:
<s>[INST] Is blood-brain barrier permeability (BBBP) a property of <SMILES> CC1(C)NC(=O)C(/C=C/C2=CC=CC=C2)O1 </SMILES>? [/INST]<BOOLEAN> Yes </BOOLEAN></s>
<s>[INST] Is <SMILES> CCCCOCCOCCOCC1=CC2=C(C=C1CCC)OCO2 </SMILES> toxic? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] How soluble is <SMILES> OCC1=CC=CC=C1OC1OC(CO)C(O)C(O)C1O </SMILES>? [/INST]Its log solubility is <NUMBER> -0.85 </NUMBER> mol/L</s>
<s>[INST] Can <SMILES> COC(=O)C1C2=C(CC3C4=CC=CC=C4NC(=O)C31)C1=CC=CC=C1N2 </SMILES> serve as an inhibitor of HIV replication? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> C[C@H]1O[C@@H](N2C=NC3=C(N)N=C(OCC45CC6CC(CC(C6)C4)C5)N=C32)[C@H](O)[C@@H]1O </SMILES> [/INST]<NUMBER> 3.58 </NUMBER></s>
<s>[INST] Are there any known side effects of <SMILES> CCCCCOC(=O)NC1=NC(=O)N([C@@H]2O[C@H](C)[C@@H](O)[C@H]2O)C=C1F </SMILES> affecting the heart? [/INST]<BOOLEAN> Yes </BOOLEAN></s>
4096

In [8]:
class CombinedDataset(Dataset):
    def __init__(self, smiles_list, conversations, encoder_tokenizer, llm_tokenizer, max_length=256):
        self.smiles_list = smiles_list
        self.conversations = conversations
        self.encoder_tokenizer = encoder_tokenizer
        self.llm_tokenizer = llm_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        smiles_encoding = self.encoder_tokenizer(smiles, return_tensors='pt', truncation=True, padding='max_length', max_length=self.max_length)
        conversation_tokenized = self.llm_tokenizer(self.conversations[idx], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt', add_special_tokens=False)
        return {key: tensor[0].to('cuda') for key, tensor in smiles_encoding.items()}, conversation_tokenized.to('cuda')

In [9]:
# Load tokenizers
chemberta_tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
mistral_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', add_prefix_space=False)
mistral_tokenizer.pad_token = mistral_tokenizer.eos_token
mistral_tokenizer.padding_side = "right"

# Create combined dataset
train_dataset = CombinedDataset(train_input_smiles, train_conversations, chemberta_tokenizer, mistral_tokenizer)
val_dataset = CombinedDataset(val_input_smiles, val_conversations, chemberta_tokenizer, mistral_tokenizer)
test_dataset = CombinedDataset(test_input_smiles, test_conversations, chemberta_tokenizer, mistral_tokenizer)

# Define DataLoader
batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

## Test

In [None]:
# x, y = next(iter(combined_loader))

# mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
# llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
#             torch_dtype=torch.bfloat16,
#             # quantization_config=bnb_config,
#             device_map="auto",
#             token=HF_CREDENTIALS
# )

# mol_encoder(**x)['last_hidden_state'];
# llm_model.model.embed_tokens(y['input_ids'].to('cuda'));

# Model

In [None]:
class MolEncoderLLMPipeline(nn.Module):
    def __init__(self, lora_rank=32, lora_alpha=64):
        super().__init__()
        # Load molecule encoder
        self.mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR", torch_dtype=torch.bfloat16).to('cuda')

        # UNCOMMENT TO BRING DOWN FROM 15GB TO 7GB
        bnb_config = BitsAndBytesConfig(
            load_in_4bit= True,
            bnb_4bit_quant_type= "nf4",
            bnb_4bit_compute_dtype= torch.bfloat16,
            bnb_4bit_use_double_quant= True,
        )
        self.llm_config = AutoConfig.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS)
        self.llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
            torch_dtype=torch.bfloat16,
            quantization_config=bnb_config,
            device_map="auto",
            token=HF_CREDENTIALS
        )

        self.llm_model.config.use_cache = False
        self.llm_model.config.pretraining_tp = 1

        # Initialize LoRA layers for Mistral
        self.lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            lora_dropout=0.05,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )

        # Freeze encoder and LLM weights
        for param in self.mol_encoder.parameters():
            param.requires_grad = False
        for param in self.llm_model.parameters():
            param.requires_grad = False

        self.linear_project = nn.Linear(self.mol_encoder.config.hidden_size, self.llm_config.hidden_size, dtype=torch.bfloat16)

        # Apply LoRA modification
        self.llm_model = get_peft_model(self.llm_model, self.lora_config)

    def forward(self, smiles_tokens, text_tokens):
        # Encoder forward pass / Get SMILES embeddings
        mol_encoder_output = self.mol_encoder(**smiles_tokens)
        smiles_embedding = mol_encoder_output['last_hidden_state'][:,0,:] # torch.Size([batch, max_length, 384])
        smiles_projection = self.linear_project(smiles_embedding).unsqueeze(1)
        # print(smiles_projection.shape)

        # Get embeddings from LLM for the question
        embedding_layer = self.llm_model.model.model.embed_tokens
        llm_embeddings = embedding_layer(text_tokens['input_ids'].to('cuda')).squeeze(1) # torch.Size([batch, max_length, 4096])
        # print(llm_embeddings.shape)

        # Concatenate encoder and LLM embeddings
        combined_embeddings = torch.cat((smiles_projection, llm_embeddings), dim=1)
        # print(combined_embeddings.shape)

        # Custom attention mask
        attention_mask = torch.zeros(batch_size, combined_embeddings.shape[1], combined_embeddings.shape[1], device='cuda')
        
        # SMILES mask for itself
        attention_mask[:, 0, 0] = 1
        for i in range(1, combined_embeddings.shape[1]):
            attention_mask[:, i, 0:i+1] = 1 # From SMILES to current token (inclusive)
        attention_mask = attention_mask.unsqueeze(1)
        # print(attention_mask.shape)

        # Pass through Mistral's transformer layers with LoRA adjustments
        output = self.llm_model(inputs_embeds=combined_embeddings, attention_mask=attention_mask)

        return output

In [None]:
model = MolEncoderLLMPipeline(lora_rank=16, lora_alpha=16).to('cuda')

In [None]:
model.llm_model.print_trainable_parameters()

In [None]:
# x, y = next(iter(combined_loader))
# model(x,y)

# Train

In [10]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
def parse_answer(text, is_boolean):
    if is_boolean:
        start = text.find('<BOOLEAN>') + len('<BOOLEAN>')
        end = text.find('</BOOLEAN>')
        return text[start:end].strip()
    else:
        start = text.find('<NUMBER>') + len('<NUMBER>')
        end = text.find('</NUMBER>')
        return text[start:end].strip()

def get_answer(true_sentence, pred_sentence):
    true_answer = true_sentence.split('[/INST]')[1]
    pred_answer = pred_sentence.split('[/INST]')[1]

    if 'BOOLEAN' in true_answer:
        y_true = parse_answer(true_answer, True)
        y_pred = parse_answer(pred_answer, True)
        return y_true, y_pred
    
    elif 'NUMBER' in true_answer:
        y_true = parse_answer(true_answer, False)
        y_pred = parse_answer(pred_answer, False)
        return y_true, y_pred

In [None]:
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss(ignore_index=mistral_tokenizer.pad_token_id)

# Define the total number of training steps and the number of warmup steps
epochs = 5
total_steps = len(test_loader) * epochs
warmup_steps = 100

accumulation_steps = 32

# Create the learning rate scheduler
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

model.train()
for epoch in range(epochs):
    total_loss = 0
    tprog = tqdm(enumerate(test_loader), total=len(test_loader))
    for i, batch in tprog:
        smiles_data, conversation_data = batch

        # Forward pass
        with autocast():
            output = model(smiles_data, conversation_data)
            logits = output.logits[:, 1:, :]

            # Prepare labels
            labels = conversation_data['input_ids'].squeeze(1)
            labels = torch.cat([labels[:, 1:], labels.new_full((labels.size(0), 1), mistral_tokenizer.pad_token_id)], dim=1)

            # Compute loss
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), labels.view(-1))

        # Backward and accumulate gradients
        loss.backward()
        total_loss += loss.item()
        tprog.set_description(f'train step loss: {loss.item():.4f}')

        if (i+1) % accumulation_steps == 0:  # Step the optimizer every accumulation_steps
            optimizer.step()
            optimizer.zero_grad()

            # Step the scheduler
            scheduler.step()

            # Clean
            gc.collect()
            torch.cuda.empty_cache()

        # Validation step
        with torch.no_grad():
            if i % 500 == 0:

                categories = ["BBBP", "side effects", "logD", "soluble", "toxic", "HIV"]
                preds, invalid_count, trues = {cat: [] for cat in categories}, {cat: 0 for cat in categories}, {cat: [] for cat in categories}

                def convert_to_boolean(input_string):
                    return True if input_string == 'Yes' else False if input_string == 'No' else None

                for batch, true_sentence in zip(test_loader, test_conversations):
                    # Predict
                    smiles_data, conversation_data = batch
                    output = model(smiles_data, conversation_data)
                    output_ids = output.logits.argmax(dim=-1)
                    pred_sentence = mistral_tokenizer.decode(output_ids)
                    y_true, y_pred = get_answer(true_sentence, pred_sentence)
                    for category in categories:
                        if category in true_sentence:
                            if category in ["BBBP", "side effects", "toxic", "HIV"]:  # binary categories
                                if y_pred in ['Yes', 'No']:
                                    preds[category].append(convert_to_boolean(y_pred))
                                    trues[category].append(convert_to_boolean(y_true))
                                else:
                                    invalid_count[category] += 1
                            else:  # continuous categories
                                try:
                                    preds[category].append(float(y_pred))
                                    trues[category].append(float(y_true))
                                except:
                                    invalid_count[category] += 1

                for key in preds:
                    if len(preds[key]) > 0:  # to avoid division by zero
                        if key in ["BBBP", "side effects", "toxic", "HIV"]:  # binary categories
                            accuracy = accuracy_score(trues[key], preds[key])
                            print(f'{key} accuracy: {accuracy:.4f}')
                        else:  # continuous categories
                            rmse = sqrt(mean_squared_error(trues[key], preds[key]))
                            print(f'{key} RMSE: {rmse:.4f}')
                print('Invalid count:')
                print(invalid_count)

                # Clean
                gc.collect()
                torch.cuda.empty_cache()

# Eval

In [24]:
categories = ["BBBP", "side effects", "logD", "soluble", "toxic", "HIV"]
preds, invalid_count, trues = {cat: [] for cat in categories}, {cat: 0 for cat in categories}, {cat: [] for cat in categories}

def convert_to_boolean(input_string):
    return True if input_string == 'Yes' else False if input_string == 'No' else None

for batch, true_sentence in zip(test_loader, test_conversations):
    # Predict
    pred_sentence = decode(...)
    y_true, y_pred = get_answer(true_sentence, pred_sentence)
    for category in categories:
        if category in true_sentence:
            if category in ["BBBP", "side effects", "toxic", "HIV"]:  # binary categories
                if y_pred in ['Yes', 'No']:
                    preds[category].append(convert_to_boolean(y_pred))
                    trues[category].append(convert_to_boolean(y_true))
                else:
                    invalid_count[category] += 1
            else:  # continuous categories
                try:
                    preds[category].append(float(y_pred))
                    trues[category].append(float(y_true))
                except:
                    invalid_count[category] += 1

for key in preds:
    if len(preds[key]) > 0:  # to avoid division by zero
        if key in ["BBBP", "side effects", "toxic", "HIV"]:  # binary categories
            accuracy = accuracy_score(trues[key], preds[key])
            print(f'{key} accuracy: {accuracy:.4f}')
        else:  # continuous categories
            rmse = sqrt(mean_squared_error(trues[key], preds[key]))
            print(f'{key} RMSE: {rmse:.4f}')

print(invalid_count)

BBBP accuracy: 1.0000
side effects accuracy: 1.0000
logD RMSE: 0.0000
soluble RMSE: 0.0000
toxic accuracy: 1.0000
HIV accuracy: 1.0000


In [None]:
# Break after this
import sys
sys.exit()

# Test set

In [None]:
model.config.use_cache = True

In [19]:
# # Initialize lists to store mean accuracies
# preds = {
#     "BBBP": [],
#     "side effects": [],
#     "logD": [],
#     "soluble": [],
#     "toxic": [],
#     "HIV": []
# }

# invalid_count = {
#     "BBBP": 0,
#     "side effects": 0,
#     "logD": 0,
#     "soluble": 0,
#     "toxic": 0,
#     "HIV": 0
# }

# trues = {
#     "BBBP": [],
#     "side effects": [],
#     "logD": [],
#     "soluble": [],
#     "toxic": [],
#     "HIV": []
# }

# def convert_to_boolean(input_string):
#     if input_string == 'Yes':
#         return True
#     elif input_string == 'No':
#         return False
#     else:
#         return None

# # for batch, true_sentence in zip(test_loader, test_conversations):
# for pred_sentence, true_sentence in zip(test_conversations, test_conversations):

#     # pred_sentence = decode(...)
#     y_true, y_pred = get_answer(true_sentence, pred_sentence)

#     # Determine the category based on the conversation
#     if 'BBBP' in true_sentence:
#         if (y_pred=='Yes') or (y_pred=='No'):
#             preds["BBBP"].append(convert_to_boolean(y_pred))
#             trues["BBBP"].append(convert_to_boolean(y_true))
#         else:
#             invalid_count["BBBP"] += 1

#     elif 'side effects' in true_sentence:
#         if (y_pred=='Yes') or (y_pred=='No'):
#             preds["side effects"].append(convert_to_boolean(y_pred))
#             trues["side effects"].append(convert_to_boolean(y_true))
#         else:
#             invalid_count["side effects"] += 1

#     elif 'logD' in true_sentence:
#         try:
#             preds["logD"].append(float(y_pred))
#             trues["logD"].append(float(y_true))
#         except:
#             invalid_count["logD"] += 1
        
#     elif 'soluble' in true_sentence:
#         try:
#             preds["soluble"].append(float(y_pred))
#             trues["soluble"].append(float(y_true))
#         except:
#             invalid_count["soluble"] += 1

#     elif 'toxic' in true_sentence:
#         if (y_pred=='Yes') or (y_pred=='No'):
#             preds["toxic"].append(convert_to_boolean(y_pred))
#             trues["toxic"].append(convert_to_boolean(y_true))
#         else:
#             invalid_count["toxic"] += 1

#     elif 'HIV' in true_sentence:
#         if (y_pred=='Yes') or (y_pred=='No'):
#             preds["HIV"].append(convert_to_boolean(y_pred))
#             trues["HIV"].append(convert_to_boolean(y_true))
#         else:
#             invalid_count["HIV"] += 1

# for key in preds:
#     if key in ["BBBP", "side effects", "toxic", "HIV"]:  # binary categories
#         if len(preds[key]) > 0:  # to avoid division by zero
#             accuracy = accuracy_score(trues[key], preds[key])
#             print(f'{key} accuracy: {accuracy:.4f}')
#     else:  # continuous categories
#         if len(preds[key]) > 0:  # to avoid division by zero
#             rmse = sqrt(mean_squared_error(trues[key], preds[key]))
#             print(f'{key} RMSE: {rmse:.4f}')

BBBP accuracy: 1.0000
side effects accuracy: 1.0000
logD RMSE: 0.0000
soluble RMSE: 0.0000
toxic accuracy: 1.0000
HIV accuracy: 1.0000
