In [None]:
# ! pip install transformers==4.38.1
# ! pip install rdkit==2023.9.4
# ! pip install accelerate==0.27.2
# ! pip install flash-attn
# ! pip install -q -U bitsandbytes
# ! pip install datasets
# ! pip install loralib
# ! pip install git+https://github.com/huggingface/peft.git

In [None]:
# ! pip install tensorflow==2.10.0

In [1]:
import random, pickle, json, os
from datasets import Dataset
# import torch
# import torch.nn as nn
# import bitsandbytes as bnb
# from peft import PeftModelForCausalLM

import sys
sys.path.append('../credentials/')
from HF_credentials import *

In [2]:
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM

# Tokenizer

In [3]:
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS, model_max_length=1024)
tokenizer.pad_token = tokenizer.eos_token

In [4]:
chat = [
  {"role": "user", "content": ""},
  {"role": "assistant", "content": ""}
]

tokenizer.apply_chat_template(chat, tokenize=False)

'<s>[INST]  [/INST]</s>'

# Data

In [21]:
def create_datasets(split='train'):

    conversations = []
    input_smiles = []

    with open(f'../data/LlaSMol/{split}/property_prediction-bbbp.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is blood-brain barrier permeability (BBBP) a property of <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-clintox.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is <SMILES> {txt['input']} </SMILES> toxic?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-esol.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"How soluble is <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"Its log solubility is <NUMBER> {txt['output']} </NUMBER> mol/L"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-hiv.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Can <SMILES> {txt['input']} </SMILES> serve as an inhibitor of HIV replication?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-lipo.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> {txt['input']} </SMILES>"
            chat[1]['content'] = f"<NUMBER> {txt['output']} </NUMBER>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-sider.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Are there any known side effects of <SMILES> {txt['input']} </SMILES> affecting the heart?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']['Vascular disorders']} </BOOLEAN>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])
    print(len(conversations))

    # with open(f'../data/LlaSMol/{split}/molecule_captioning.jsonl', 'r', encoding='utf-8') as f:
    #     for line in f:
    #         try:
    #             txt = json.loads(line)
    #         except:
    #             continue
    #         chat[0]['content'] = f"Describe the molecule: <SMILES> {txt['input']} </SMILES>"
    #         chat[1]['content'] = f"{txt['output']}"
    #         conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
    #         input_smiles.append(txt['input'])

    # print(conversations[-1])
    # print(len(conversations))

    return conversations, input_smiles

In [22]:
print('Train:')
train_conversations, train_input_smiles = create_datasets('train')
print('Test:')
test_conversations, test_input_smiles = create_datasets('test')

Train:
<s>[INST] Is blood-brain barrier permeability (BBBP) a property of <SMILES> CC1(C)NC(=O)C(/C=C/C2=CC=CC=C2)O1 </SMILES>? [/INST]<BOOLEAN> Yes </BOOLEAN></s>
<s>[INST] Is <SMILES> CCCCOCCOCCOCC1=CC2=C(C=C1CCC)OCO2 </SMILES> toxic? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] How soluble is <SMILES> OCC1=CC=CC=C1OC1OC(CO)C(O)C(O)C1O </SMILES>? [/INST]Its log solubility is <NUMBER> -0.85 </NUMBER> mol/L</s>
<s>[INST] Can <SMILES> COC(=O)C1C2=C(CC3C4=CC=CC=C4NC(=O)C31)C1=CC=CC=C1N2 </SMILES> serve as an inhibitor of HIV replication? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> C[C@H]1O[C@@H](N2C=NC3=C(N)N=C(OCC45CC6CC(CC(C6)C4)C5)N=C32)[C@H](O)[C@@H]1O </SMILES> [/INST]<NUMBER> 3.58 </NUMBER></s>
<s>[INST] Are there any known side effects of <SMILES> CCCCCOC(=O)NC1=NC(=O)N([C@@H]2O[C@H](C)[C@@H](O)[C@H]2O)C=C1F </SMILES> affecting the heart? [/INST]<BOOLEAN> Yes </BOOLEAN></s>
4096

# Chem encoder

# LLM

In [None]:
model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-v0.1',
    torch_dtype=torch.bfloat16,
    device_map="auto",
    token=HF_CREDENTIALS
)

In [None]:
model.config.use_cache = False
model.config.pretraining_tp = 1

for param in model.parameters():
    param.requires_grad = False  # freeze the model - train adapters later
    if param.ndim == 1:
        # cast the small parameters (e.g. layernorm) to fp32 for stability
        param.data = param.data.to(torch.float32)

model.gradient_checkpointing_enable()  # reduce number of stored activations
model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
    def forward(self, x): return super().forward(x).to(torch.float32)

model.lm_head = CastOutputToFloat(model.lm_head)

In [None]:
# from peft import get_peft_model, LoraConfig, TaskType

# lora_config = LoraConfig(
#     r=32,
#     # target_modules=["q_proj", "v_proj"],
#     task_type=TaskType.CAUSAL_LM,
#     lora_alpha=64,
#     lora_dropout=0.1
# )

# base_model_with_new_adapter = get_peft_model(model, lora_config)
# base_model_with_new_adapter.print_trainable_parameters()

# Train

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

# Eval

In [None]:
model.config.use_cache = True