In [None]:
# ! pip install transformers==4.38.1
# ! pip install rdkit==2023.9.4
# ! pip install accelerate==0.27.2
# ! pip install flash-attn
# ! pip install -q -U bitsandbytes
# ! pip install datasets
# ! pip install loralib
# ! pip install git+https://github.com/huggingface/peft.git
# ! pip install sentencepiece

In [None]:
# ! pip install tensorflow==2.10.0

In [1]:
import random, pickle, json, os
from tqdm import tqdm
from datasets import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import bitsandbytes as bnb
from peft import PeftModelForCausalLM
from peft import get_peft_model, LoraConfig
from sklearn.metrics import accuracy_score, mean_squared_error
from math import sqrt
from torch.cuda.amp import autocast

import sys
sys.path.append('../credentials/')
from HF_credentials import *

  from .autonotebook import tqdm as notebook_tqdm
The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.


In [2]:
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig, BitsAndBytesConfig

# Tokenizer

In [3]:
llm_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS, model_max_length=256, add_prefix_space=False)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_tokenizer.padding_side = "right"

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


In [4]:
chat = [
  {"role": "user", "content": ""},
  {"role": "assistant", "content": ""}
]

llm_tokenizer.apply_chat_template(chat, tokenize=False)

'<s>[INST]  [/INST]</s>'

# Data

In [6]:
bbbp_prompts = [
    "Does the compound <SMILES> {smiles} </SMILES> have good blood-brain barrier permeability (BBBP)?",
    "Is <SMILES> {smiles} </SMILES> BBBP positive?",
    "Is blood-brain barrier permeability a property of <SMILES> {smiles} </SMILES>?",
    "Is BBBP a property of <SMILES> {smiles} </SMILES>?",
    "Does <SMILES> {smiles} </SMILES> have high BBBP?",
]
bbbp_responses = [
    "<BOOLEAN> {res} </BOOLEAN>"
]

clintox_prompts = [
    "Is <SMILES> {smiles} </SMILES> toxic?",
    "Does <SMILES> {smiles} </SMILES> have high toxicity?",
    "Is this compound toxic: <SMILES> {smiles} </SMILES>?",
    "Is this molecule toxic to human: <SMILES> {smiles} </SMILES>?",
    "Is <SMILES> {smiles} </SMILES> toxic to human?",
]
clintox_responses = [
    "<BOOLEAN> {res} </BOOLEAN>"
]

esol_prompts = [
    "Estimate the solubility of this molecule: <SMILES> {smiles} </SMILES>.",
    # "How soluble is <SMILES> {smiles} </SMILES>?",
    "Please predict the log solubility of <SMILES> {smiles} </SMILES>.",
    "<SMILES> {smiles} </SMILES> - what is its solubility?",
    "Can you predict the solubility of this molecule: <SMILES> {smiles} </SMILES>?",
    "Calculate the log solubility of <SMILES> {smiles} </SMILES>.",
    "Estimate the solubility of <SMILES> {smiles} </SMILES>.",
    "What is the solubility of this compound: <SMILES> {smiles} </SMILES>?",
    "What is the solubility of <SMILES> {smiles} </SMILES>?",
    "How much is the log solubility of <SMILES> {smiles} </SMILES>?"
]
esol_responses = [
    "Its log solubility is <NUMBER> {res} </NUMBER> mol/L"
]

hiv_prompts = [
    "Can <SMILES> {smiles} </SMILES> serve as an inhibitor of HIV replication?",
    "Can <SMILES> {smiles} </SMILES> inhibit HIV replication?",
    "Is <SMILES> {smiles} </SMILES> an inhibitor of HIV replication?",
    "Can the following molecule serve as an inhibitor of HIV replication: <SMILES> {smiles} </SMILES>?",
    "<SMILES> {smiles} </SMILES> - Is this compound an inhibitor of HIV replication?"
]
hiv_responses = [
    "<BOOLEAN> {res} </BOOLEAN>"
]

lipo_prompts = [
    "Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> {smiles} </SMILES>.",
    "Estimate the logD coefficient at pH=7.4 for the compound <SMILES> {smiles} </SMILES>.",
    "For <SMILES> {smiles} </SMILES>, calculate its logD coefficient at pH=7.4.",
    "Calculate the octanol/water distribution coefficient logD at pH=7.4 for this compound: <SMILES> {smiles} </SMILES>.",
    "At pH of 7.4, approximate the logD coefficient for <SMILES> {smiles} </SMILES>."
]
lipo_responses = [
    "<NUMBER> {res} </NUMBER>"
]

sider_prompts = [
    "Are there any known side effects of <SMILES> {smiles} </SMILES> affecting the heart?",
    "Are there any known heart-related side effects of <SMILES> {smiles} </SMILES>?",
    "Can <SMILES> {smiles} </SMILES> cause vascular disorders side effects?",
    "Does <SMILES> {smiles} </SMILES> have any side effects that affect the heart?",
    "Does <SMILES> {smiles} </SMILES> exhibit any heart-related side effects?",
    "Does <SMILES> {smiles} </SMILES> associate with any heart-related side effects?"
]
sider_reponses = [
    "<BOOLEAN> {res} </BOOLEAN>"
]

def prompt_sample(task: str):    
    if task == "bbbp":
        return random.choice(bbbp_prompts), random.choice(bbbp_responses)
    elif task == "clintox":
        return random.choice(clintox_prompts), random.choice(clintox_responses)
    elif task == "esol":
        return random.choice(esol_prompts), random.choice(esol_responses)
    elif task == "hiv":
        return random.choice(hiv_prompts), random.choice(hiv_responses)
    elif task == "lipo":
        return random.choice(lipo_prompts), random.choice(lipo_responses)
    elif task == "sider":
        return random.choice(sider_prompts), random.choice(sider_reponses)
    else:
        raise ValueError(f"Unrecognized task: {task}")

In [None]:
def create_flexible_datasets(split='train'):
    conversations = [] # [(task, SMILES, output),...]
    input_smiles = []
    
    with open(f'./data/LlaSMol/{split}/property_prediction-bbbp.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("bbbp", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'./data/LlaSMol/{split}/property_prediction-clintox.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("clintox", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'./data/LlaSMol/{split}/property_prediction-esol.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("esol", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'./data/LlaSMol/{split}/property_prediction-hiv.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("hiv", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'./data/LlaSMol/{split}/property_prediction-lipo.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("lipo", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'./data/LlaSMol/{split}/property_prediction-sider.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("sider", txt['input'], txt['output']['Vascular disorders']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    print(len(conversations))
    
    return conversations, input_smiles

In [7]:
print('Train:')
train_conversations, train_input_smiles = create_flexible_datasets('train')
print('Val:')
val_conversations, val_input_smiles = create_flexible_datasets('val')
print('Test:')
test_conversations, test_input_smiles = create_flexible_datasets('test')

Train:
<s>[INST] Is blood-brain barrier permeability (BBBP) a property of <SMILES> CC1(C)NC(=O)C(/C=C/C2=CC=CC=C2)O1 </SMILES>? [/INST]<BOOLEAN> Yes </BOOLEAN></s>
<s>[INST] Is <SMILES> CCCCOCCOCCOCC1=CC2=C(C=C1CCC)OCO2 </SMILES> toxic? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] How soluble is <SMILES> OCC1=CC=CC=C1OC1OC(CO)C(O)C(O)C1O </SMILES>? [/INST]Its log solubility is <NUMBER> -0.85 </NUMBER> mol/L</s>
<s>[INST] Can <SMILES> COC(=O)C1C2=C(CC3C4=CC=CC=C4NC(=O)C31)C1=CC=CC=C1N2 </SMILES> serve as an inhibitor of HIV replication? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> C[C@H]1O[C@@H](N2C=NC3=C(N)N=C(OCC45CC6CC(CC(C6)C4)C5)N=C32)[C@H](O)[C@@H]1O </SMILES> [/INST]<NUMBER> 3.58 </NUMBER></s>
<s>[INST] Are there any known side effects of <SMILES> CCCCCOC(=O)NC1=NC(=O)N([C@@H]2O[C@H](C)[C@@H](O)[C@H]2O)C=C1F </SMILES> affecting the heart? [/INST]<BOOLEAN> Yes </BOOLEAN></s>
4096

In [8]:
class VariableCombinedDataset(Dataset):
    def __init__(
        self, 
        smiles_list,
        conversations,
        encoder_tokenizer,
        llm_tokenizer,
        max_length=256,
        ):
        self.smiles_list = smiles_list
        self.conversations = conversations
        self.encoder_tokenizer = encoder_tokenizer
        self.llm_tokenizer = llm_tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.smiles_list)
    
    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        smiles_encoding = self.encoder_tokenizer(smiles, return_tensors='pt', truncation=True, padding='max_length', max_length=self.max_length)
        
        chat = [
            {"role": "user", "content": ""},
            {"role": "assistant", "content": ""}
        ]
        
        task, input_, output_ = self.conversations[idx]
        prompt_format, response_format = prompt_sample(task)
        prompt = prompt_format.format(smiles=input_)
        response = response_format.format(res=output_)
        chat[0]['content'] = prompt
        chat[1]['content'] = response
        conversation = self.llm_tokenizer.apply_chat_template(chat, tokenize=False)
        conversation_tokenized = self.llm_tokenizer(conversation, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt', add_special_tokens=False)
        
        if torch.cuda.is_available():
            return {key: tensor[0].to('cuda') for key, tensor in smiles_encoding.items()}, conversation_tokenized.to('cuda')
        else:
            return {key: tensor[0] for key, tensor in smiles_encoding.items()}, conversation_tokenized

In [9]:
# Load tokenizers
chemberta_tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
mistral_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', add_prefix_space=False)
mistral_tokenizer.pad_token = mistral_tokenizer.eos_token
mistral_tokenizer.padding_side = "right"

# Create combined dataset
train_dataset = VariableCombinedDataset(train_input_smiles, train_conversations, chemberta_tokenizer, mistral_tokenizer)
val_dataset = VariableCombinedDataset(val_input_smiles, val_conversations, chemberta_tokenizer, mistral_tokenizer)
test_dataset = VariableCombinedDataset(test_input_smiles, test_conversations, chemberta_tokenizer, mistral_tokenizer)

# Define DataLoader
batch_size = 2
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Alternative Prompt dataset

In [8]:
bbbp_prompts = [
    "Does the compound <SMILES> {smiles} </SMILES> have good blood-brain barrier permeability (BBBP)?",
    "Is <SMILES> {smiles} </SMILES> BBBP positive?",
    "Is blood-brain barrier permeability a property of <SMILES> {smiles} </SMILES>?",
    "Is BBBP a property of <SMILES> {smiles} </SMILES>?",
    "Does <SMILES> {smiles} </SMILES> have high BBBP?",
]
bbbp_responses = [
    "<BOOLEAN> {res} </BOOLEAN>"
]

clintox_prompts = [
    "Is <SMILES> {smiles} </SMILES> toxic?",
    "Does <SMILES> {smiles} </SMILES> have high toxicity?",
    "Is this compound toxic: <SMILES> {smiles} </SMILES>?",
    "Is this molecule toxic to human: <SMILES> {smiles} </SMILES>?",
    "Is <SMILES> {smiles} </SMILES> toxic to human?",
]
clintox_responses = [
    "<BOOLEAN> {res} </BOOLEAN>"
]

esol_prompts = [
    "Estimate the solubility of this molecule: <SMILES> {smiles} </SMILES>.",
    "How soluble is <SMILES> {smiles} </SMILES>?",
    "Please predict the log solubility of <SMILES> {smiles} </SMILES>.",
    "<SMILES> {smiles} </SMILES> - what is its solubility?",
    "Can you predict the solubility of this molecule: <SMILES> {smiles} </SMILES>?",
    "Calculate the log solubility of <SMILES> {smiles} </SMILES>.",
    "Estimate the solubility of <SMILES> {smiles} </SMILES>.",
    "What is the solubility of this compound: <SMILES> {smiles} </SMILES>?",
    "What is the solubility of <SMILES> {smiles} </SMILES>?",
    "How much is the log solubility of <SMILES> {smiles} </SMILES>?"
]
esol_responses = [
    "Its log solubility is <NUMBER> {res} </NUMBER> mol/L"
]

hiv_prompts = [
    "Can <SMILES> {smiles} </SMILES> serve as an inhibitor of HIV replication?",
    "Can <SMILES> {smiles} </SMILES> inhibit HIV replication?",
    "Is <SMILES> {smiles} </SMILES> an inhibitor of HIV replication?",
    "Can the following molecule serve as an inhibitor of HIV replication: <SMILES> {smiles} </SMILES>?",
    "<SMILES> {smiles} </SMILES> - Is this compound an inhibitor of HIV replication?"
]
hiv_responses = [
    "<BOOLEAN> {res} </BOOLEAN>"
]

lipo_prompts = [
    "Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> {smiles} </SMILES>.",
    "Estimate the logD coefficient at pH=7.4 for the compound <SMILES> {smiles} </SMILES>.",
    "For <SMILES> {smiles} </SMILES>, calculate its logD coefficient at pH=7.4.",
    "Calculate the octanol/water distribution coefficient logD at pH=7.4 for this compound: <SMILES> {smiles} </SMILES>.",
    "At pH of 7.4, approximate the logD coefficient for <SMILES> {smiles} </SMILES>."
]
lipo_responses = [
    "<NUMBER> {res} </NUMBER>"
]

sider_prompts = [
    "Are there any known side effects of <SMILES> {smiles} </SMILES> affecting the heart?",
    "Are there any known heart-related side effects of <SMILES> {smiles} </SMILES>?",
    "Can <SMILES> {smiles} </SMILES> cause vascular disorders?",
    "Does <SMILES> {smiles} </SMILES> have any side effects that affect the heart?",
    "Does <SMILES> {smiles} </SMILES> exhibit any heart-related side effects?",
    "Does <SMILES> {smiles} </SMILES> associate with any heart-related side effects?"
]
sider_reponses = [
    "<BOOLEAN> {res} </BOOLEAN>"
]

def prompt_sample(task: str):    
    if task == "bbbp":
        return random.choice(bbbp_prompts), random.choice(bbbp_responses)
    elif task == "clintox":
        return random.choice(clintox_prompts), random.choice(clintox_responses)
    elif task == "esol":
        return random.choice(esol_prompts), random.choice(esol_responses)
    elif task == "hiv":
        return random.choice(hiv_prompts), random.choice(hiv_responses)
    elif task == "lipo":
        return random.choice(lipo_prompts), random.choice(lipo_responses)
    elif task == "sider":
        return random.choice(sider_prompts), random.choice(sider_reponses)
    else:
        raise ValueError(f"Unrecognized task: {task}")

In [6]:
def create_flexible_datasets(split='train'):
    conversations = [] # [(task, SMILES, output),...]
    input_smiles = []
    
    with open(f'../data/LlaSMol/{split}/property_prediction-bbbp.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("bbbp", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'../data/LlaSMol/{split}/property_prediction-clintox.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("clintox", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'../data/LlaSMol/{split}/property_prediction-esol.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("esol", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'../data/LlaSMol/{split}/property_prediction-hiv.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("hiv", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'../data/LlaSMol/{split}/property_prediction-lipo.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("lipo", txt['input'], txt['output']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    
    with open(f'../data/LlaSMol/{split}/property_prediction-sider.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            conversations.append(("sider", txt['input'], txt['output']['Vascular disorders']))
            input_smiles.append(txt['input'])
    
    print(conversations[-1])
    print(len(conversations))
    
    return conversations, input_smiles

In [7]:
print('Train:')
train_conversations, train_input_smiles = create_flexible_datasets('train')
print('Val:')
val_conversations, val_input_smiles = create_flexible_datasets('val')
print('Test:')
test_conversations, test_input_smiles = create_flexible_datasets('test')

Train:
('bbbp', 'CC1(C)NC(=O)C(/C=C/C2=CC=CC=C2)O1', 'Yes')
('clintox', 'CCCCOCCOCCOCC1=CC2=C(C=C1CCC)OCO2', 'No')
('esol', 'OCC1=CC=CC=C1OC1OC(CO)C(O)C(O)C1O', '-0.85')
('hiv', 'COC(=O)C1C2=C(CC3C4=CC=CC=C4NC(=O)C31)C1=CC=CC=C1N2', 'No')
('lipo', 'C[C@H]1O[C@@H](N2C=NC3=C(N)N=C(OCC45CC6CC(CC(C6)C4)C5)N=C32)[C@H](O)[C@@H]1O', '3.58')
('sider', 'CCCCCOC(=O)NC1=NC(=O)N([C@@H]2O[C@H](C)[C@@H](O)[C@H]2O)C=C1F', 'Yes')
40966
Val:
('bbbp', 'COC1=CC=C(CC2=NC=CC3=CC(OC)=C(OC)C=C23)C=C1OC', 'No')
('clintox', 'O=C1N(CCC[NH+]2CCN(C3=CC=CC(Cl)=C3)CC2)N=C2C=CC=CN12', 'No')
('esol', 'C1=CC=C2C(=C1)NC1=CC=CC=C12', '-5.27')
('hiv', 'CCOC(=O)C(=CN1C(=O)C(=CC2=CC=CC([N+](=O)[O-])=C2)SC1=S)C(=O)C1=CC=CC=C1', 'No')
('lipo', 'ClC1=CC(NC2=NC=NC3=C2C(OCCN2CCCC2)=NN3)=CC=C1OCC1=CC=CC=N1', '2.87')
('sider', 'O=C(NC1=C(Cl)C=NC=C1Cl)C1=CC=C(OC(F)F)C(OCC2CC2)=C1', 'No')
5117
Test:
('bbbp', 'COC1=CC=C2OC3=CC=CC=C3C=C(N3CCN(C)CC3)C2=C1', 'Yes')
('clintox', 'CC(C)[C@@]1(C(=O)N[C@H]2CC(=O)O[C@]2(O)CF)CC(C2=NC=CC3=CC=

In [9]:
class VariableCombinedDataset(Dataset):
    def __init__(
        self, 
        smiles_list,
        conversations,
        encoder_tokenizer,
        llm_tokenizer,
        max_length=256,
        ):
        self.smiles_list = smiles_list
        self.conversations = conversations
        self.encoder_tokenizer = encoder_tokenizer
        self.llm_tokenizer = llm_tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.smiles_list)
    
    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        smiles_encoding = self.encoder_tokenizer(smiles, return_tensors='pt', truncation=True, padding='max_length', max_length=self.max_length)
        
        chat = [
            {"role": "user", "content": ""},
            {"role": "assistant", "content": ""}
        ]
        
        task, input_, output_ = self.conversations[idx]
        prompt_format, response_format = prompt_sample(task)
        prompt = prompt_format.format(smiles=input_)
        response = response_format.format(res=output_)
        chat[0]['content'] = prompt
        chat[1]['content'] = response
        conversation = self.llm_tokenizer.apply_chat_template(chat, tokenize=False)
        conversation_tokenized = self.llm_tokenizer(conversation, truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt', add_special_tokens=False)
        
        if torch.cuda.is_available():
            return {key: tensor[0].to('cuda') for key, tensor in smiles_encoding.items()}, conversation_tokenized.to('cuda')
        else:
            return {key: tensor[0] for key, tensor in smiles_encoding.items()}, conversation_tokenized
        

In [11]:
# Load tokenizers
chemberta_tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR', cache_dir="hf_cache/")
mistral_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', cache_dir="hf_cache/", add_prefix_space=False)
mistral_tokenizer.pad_token = mistral_tokenizer.eos_token
mistral_tokenizer.padding_side = "right"

# Create combined dataset
train_dataset = VariableCombinedDataset(train_input_smiles, train_conversations, chemberta_tokenizer, mistral_tokenizer)
val_dataset = VariableCombinedDataset(val_input_smiles, val_conversations, chemberta_tokenizer, mistral_tokenizer)
test_dataset = VariableCombinedDataset(test_input_smiles, test_conversations, chemberta_tokenizer, mistral_tokenizer)

# Define DataLoader
batch_size = 3
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)



In [20]:
llm_tokenizer.batch_decode(train_dataset[10][1]["input_ids"])

['<s> [INST] Is BBBP a property of <SMILES> CC1=CC=CC(C)=C1 </SMILES>? [/INST]<BOOLEAN> Yes </BOOLEAN></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

In [18]:
llm_tokenizer.batch_decode(train_dataset[10][1]["input_ids"])

['<s> [INST] Does <SMILES> CC1=CC=CC(C)=C1 </SMILES> have high BBBP? [/INST]<BOOLEAN> Yes </BOOLEAN></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>']

## Test

In [None]:
# x, y = next(iter(combined_loader))

# mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
# llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
#             torch_dtype=torch.bfloat16,
#             # quantization_config=bnb_config,
#             device_map="auto",
#             token=HF_CREDENTIALS
# )

# mol_encoder(**x)['last_hidden_state'];
# llm_model.model.embed_tokens(y['input_ids'].to('cuda'));

# Model

In [None]:
class MolEncoderLLMPipeline(nn.Module):
    def __init__(self, lora_rank=32, lora_alpha=64):
        super().__init__()
        # Load molecule encoder
        self.mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR", torch_dtype=torch.bfloat16).to('cuda')

        # UNCOMMENT TO BRING DOWN FROM 15GB TO 7GB
        bnb_config = BitsAndBytesConfig(
            load_in_4bit= True,
            bnb_4bit_quant_type= "nf4",
            bnb_4bit_compute_dtype= torch.bfloat16,
            bnb_4bit_use_double_quant= False,
        )
        self.llm_config = AutoConfig.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS)
        self.llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
            torch_dtype=torch.bfloat16,
            quantization_config=bnb_config,
            token=HF_CREDENTIALS
        )#.to('cuda')

        self.llm_model.config.use_cache = False
        self.llm_model.config.pretraining_tp = 1

        # Initialize LoRA layers for Mistral
        self.lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            lora_dropout=0.05,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )

        # Freeze encoder and LLM weights
        for param in self.mol_encoder.parameters():
            param.requires_grad = False
        for param in self.llm_model.parameters():
            param.requires_grad = False

        self.linear_project = nn.Linear(self.mol_encoder.config.hidden_size, self.llm_config.hidden_size, dtype=torch.bfloat16).to('cuda')

        # Apply LoRA modification
        self.llm_model = get_peft_model(self.llm_model, self.lora_config)
        self.llm_model.to('cuda')

    def forward(self, smiles_tokens, text_tokens):
        # Encoder forward pass / Get SMILES embeddings
        smiles_tokens = {k: v.to('cuda') for k, v in smiles_tokens.items()}
        text_tokens = {k: v.to('cuda') for k, v in text_tokens.items()}

        mol_encoder_output = self.mol_encoder(**smiles_tokens)
        smiles_embedding = mol_encoder_output['last_hidden_state'][:,0,:] # torch.Size([batch, max_length, 384])
        smiles_projection = self.linear_project(smiles_embedding).unsqueeze(1)#.to('cuda')
        # print(smiles_projection.shape)

        # Get embeddings from LLM for the question
        embedding_layer = self.llm_model.model.model.embed_tokens
        llm_embeddings = embedding_layer(text_tokens['input_ids']).squeeze(1)#.to('cuda') # torch.Size([batch, max_length, 4096])
        # print(llm_embeddings.shape)

        # Concatenate encoder and LLM embeddings
        combined_embeddings = torch.cat((smiles_projection, llm_embeddings), dim=1)#.to('cuda')
        # print(combined_embeddings.shape)

        # Custom attention mask
        attention_mask = torch.zeros(smiles_projection.shape[0], combined_embeddings.shape[1], combined_embeddings.shape[1], device='cuda')
        attention_mask[:, 0, 0] = 1 # SMILES mask for itself
        for i in range(1, combined_embeddings.shape[1]):
            attention_mask[:, i, 0:i+1] = 1 # From SMILES to current token (inclusive)

        attention_mask = attention_mask.unsqueeze(1)
        # print(attention_mask.shape)

        # Pass through Mistral's transformer layers with LoRA adjustments
        output = self.llm_model(inputs_embeds=combined_embeddings, attention_mask=attention_mask)

        return output

In [None]:
model = MolEncoderLLMPipeline(lora_rank=16, lora_alpha=16).to('cuda')

In [None]:
model.llm_model.print_trainable_parameters()

In [None]:
# x, y = next(iter(combined_loader))
# model(x,y)

# Train

In [10]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
def parse_answer(text, is_boolean):
    if is_boolean:
        start = text.find('<BOOLEAN>') + len('<BOOLEAN>')
        end = text.find('</BOOLEAN>')
        return text[start:end].strip()
    else:
        start = text.find('<NUMBER>') + len('<NUMBER>')
        end = text.find('</NUMBER>')
        return text[start:end].strip()

def get_answer(true_sentence, pred_sentence):
    true_answer = true_sentence.split('[/INST]')[1]
    pred_answer = pred_sentence.split('[/INST]')[1]

    if 'BOOLEAN' in true_answer:
        y_true = parse_answer(true_answer, True)
        y_pred = parse_answer(pred_answer, True)
        return y_true, y_pred
    
    elif 'NUMBER' in true_answer:
        y_true = parse_answer(true_answer, False)
        y_pred = parse_answer(pred_answer, False)
        return y_true, y_pred

In [None]:
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss(ignore_index=mistral_tokenizer.pad_token_id)

# Define the total number of training steps and the number of warmup steps
epochs = 10
total_steps = len(test_loader) * epochs
warmup_steps = 1000

accumulation_steps = 32

# Create the learning rate scheduler
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

for epoch in range(epochs):
    total_loss = 0
    tprog = tqdm(enumerate(train_loader), total=len(train_loader))
    for i, batch in tprog:
        model.train();
        smiles_data, conversation_data = batch

        # Forward pass
        with autocast():
            output = model(smiles_data, conversation_data)
            logits = output.logits[:, 1:, :]

            # Prepare labels
            labels = conversation_data['input_ids'].squeeze(1)
            labels = torch.cat([labels[:, 1:], labels.new_full((labels.size(0), 1), mistral_tokenizer.pad_token_id)], dim=1)

            # Compute loss
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), labels.view(-1))

        # Backward and accumulate gradients
        loss.backward()
        total_loss += loss.item()
        tprog.set_description(f'train step loss: {loss.item():.4f}')

        if (i+1) % accumulation_steps == 0:  # Step the optimizer every accumulation_steps
            optimizer.step()
            optimizer.zero_grad()

            # Step the scheduler
            scheduler.step()

            # Clean
            gc.collect()
            torch.cuda.empty_cache()

        # Validation step
        if (i % 5000 == 0) & (i != 0):
            with torch.no_grad():
                model.eval();

                categories = ["BBBP", "side effects", "logD", "solubility", "toxic", "HIV"]
                preds, invalid_count, trues = {cat: [] for cat in categories}, {cat: 0 for cat in categories}, {cat: [] for cat in categories}

                def convert_to_boolean(input_string):
                    return True if input_string == 'Yes' else False if input_string == 'No' else None

                val_dataset = VariableCombinedDataset(val_input_smiles[:100], val_conversations[:100], chemberta_tokenizer, mistral_tokenizer)
                val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

                for batch, true_sentence in zip(val_loader, val_conversations):
                    # Predict
                    smiles_data, conversation_data = batch
                    output = model(smiles_data, conversation_data)
                    output_ids = output.logits.argmax(dim=-1)
                    pred_sentence = mistral_tokenizer.decode(output_ids.tolist()[0])
                    y_true, y_pred = get_answer(true_sentence, pred_sentence)
                    for category in categories:
                        if category in true_sentence:
                            if category in ["BBBP", "side effects", "toxic", "HIV"]:  # binary categories
                                if y_pred in ['Yes', 'No']:
                                    preds[category].append(convert_to_boolean(y_pred))
                                    trues[category].append(convert_to_boolean(y_true))
                                else:
                                    invalid_count[category] += 1
                            else:  # continuous categories
                                try:
                                    preds[category].append(float(y_pred))
                                    trues[category].append(float(y_true))
                                except:
                                    invalid_count[category] += 1
                                    continue

                for key in preds:
                    if len(preds[key]) > 0:  # to avoid division by zero
                        if key in ["BBBP", "side effects", "toxic", "HIV"]:  # binary categories
                            accuracy = accuracy_score(trues[key], preds[key])
                            print(f'{key} accuracy: {accuracy:.4f}')
                        else:  # continuous categories
                            rmse = sqrt(mean_squared_error(trues[key], preds[key]))
                            print(f'{key} RMSE: {rmse:.4f}')
                print('Invalid count:')
                print(invalid_count)

                # Clean
                gc.collect()
                torch.cuda.empty_cache()

    # Save the model
    torch.save(model.state_dict(), f"output/model_{epoch}.pth")

# Test set

In [None]:
model.llm_model.config.use_cache = True
model.eval();

In [19]:
with torch.no_grad():
    categories = ["BBBP", "side effects", "logD", "soluble", "toxic", "HIV"]
    preds, invalid_count, trues = {cat: [] for cat in categories}, {cat: 0 for cat in categories}, {cat: [] for cat in categories}

    for batch, true_sentence in zip(test_loader, test_conversations):
        # Predict
        smiles_data, conversation_data = batch
        output = model(smiles_data, conversation_data)
        output_ids = output.logits.argmax(dim=-1)
        pred_sentence = mistral_tokenizer.decode(output_ids.tolist()[0])
        y_true, y_pred = get_answer(true_sentence, pred_sentence)
        for category in categories:
            if category in true_sentence:
                if category in ["BBBP", "side effects", "toxic", "HIV"]:  # binary categories
                    if y_pred in ['Yes', 'No']:
                        preds[category].append(convert_to_boolean(y_pred))
                        trues[category].append(convert_to_boolean(y_true))
                    else:
                        invalid_count[category] += 1
                else:  # continuous categories
                    try:
                        preds[category].append(float(y_pred))
                        trues[category].append(float(y_true))
                    except:
                        invalid_count[category] += 1

    for key in preds:
        if len(preds[key]) > 0:  # to avoid division by zero
            if key in ["BBBP", "side effects", "toxic", "HIV"]:  # binary categories
                accuracy = accuracy_score(trues[key], preds[key])
                print(f'{key} accuracy: {accuracy:.4f}')
            else:  # continuous categories
                rmse = sqrt(mean_squared_error(trues[key], preds[key]))
                print(f'{key} RMSE: {rmse:.4f}')
    print('Invalid count:')
    print(invalid_count)

BBBP accuracy: 1.0000
side effects accuracy: 1.0000
logD RMSE: 0.0000
soluble RMSE: 0.0000
toxic accuracy: 1.0000
HIV accuracy: 1.0000
