In [None]:
# ! pip install transformers==4.38.1
# ! pip install rdkit==2023.9.4
# ! pip install accelerate==0.27.2
# ! pip install flash-attn
# ! pip install -q -U bitsandbytes
# ! pip install datasets
# ! pip install loralib
# ! pip install git+https://github.com/huggingface/peft.git

In [None]:
# ! pip install tensorflow==2.10.0

In [1]:
import random, pickle, json, os
from datasets import Dataset
# import torch
# import torch.nn as nn
# import bitsandbytes as bnb
# from peft import PeftModelForCausalLM

import sys
sys.path.append('../credentials/')
from HF_credentials import *

In [2]:
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig

# Tokenizer

In [3]:
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS, model_max_length=1024)
tokenizer.pad_token = tokenizer.eos_token

In [4]:
chat = [
  {"role": "user", "content": ""},
  {"role": "assistant", "content": ""}
]

tokenizer.apply_chat_template(chat, tokenize=False)

'<s>[INST]  [/INST]</s>'

# Data

In [21]:
def create_datasets(split='train'):

    conversations = []
    input_smiles = []

    with open(f'../data/LlaSMol/{split}/property_prediction-bbbp.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is blood-brain barrier permeability (BBBP) a property of <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-clintox.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is <SMILES> {txt['input']} </SMILES> toxic?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-esol.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"How soluble is <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"Its log solubility is <NUMBER> {txt['output']} </NUMBER> mol/L"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-hiv.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Can <SMILES> {txt['input']} </SMILES> serve as an inhibitor of HIV replication?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-lipo.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> {txt['input']} </SMILES>"
            chat[1]['content'] = f"<NUMBER> {txt['output']} </NUMBER>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-sider.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Are there any known side effects of <SMILES> {txt['input']} </SMILES> affecting the heart?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']['Vascular disorders']} </BOOLEAN>"
            conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])
    print(len(conversations))

    # with open(f'../data/LlaSMol/{split}/molecule_captioning.jsonl', 'r', encoding='utf-8') as f:
    #     for line in f:
    #         try:
    #             txt = json.loads(line)
    #         except:
    #             continue
    #         chat[0]['content'] = f"Describe the molecule: <SMILES> {txt['input']} </SMILES>"
    #         chat[1]['content'] = f"{txt['output']}"
    #         conversations.append(tokenizer.apply_chat_template(chat, tokenize=False))
    #         input_smiles.append(txt['input'])

    # print(conversations[-1])
    # print(len(conversations))

    return conversations, input_smiles

In [22]:
print('Train:')
train_conversations, train_input_smiles = create_datasets('train')
print('Test:')
test_conversations, test_input_smiles = create_datasets('test')

Train:
<s>[INST] Is blood-brain barrier permeability (BBBP) a property of <SMILES> CC1(C)NC(=O)C(/C=C/C2=CC=CC=C2)O1 </SMILES>? [/INST]<BOOLEAN> Yes </BOOLEAN></s>
<s>[INST] Is <SMILES> CCCCOCCOCCOCC1=CC2=C(C=C1CCC)OCO2 </SMILES> toxic? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] How soluble is <SMILES> OCC1=CC=CC=C1OC1OC(CO)C(O)C(O)C1O </SMILES>? [/INST]Its log solubility is <NUMBER> -0.85 </NUMBER> mol/L</s>
<s>[INST] Can <SMILES> COC(=O)C1C2=C(CC3C4=CC=CC=C4NC(=O)C31)C1=CC=CC=C1N2 </SMILES> serve as an inhibitor of HIV replication? [/INST]<BOOLEAN> No </BOOLEAN></s>
<s>[INST] Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> C[C@H]1O[C@@H](N2C=NC3=C(N)N=C(OCC45CC6CC(CC(C6)C4)C5)N=C32)[C@H](O)[C@@H]1O </SMILES> [/INST]<NUMBER> 3.58 </NUMBER></s>
<s>[INST] Are there any known side effects of <SMILES> CCCCCOC(=O)NC1=NC(=O)N([C@@H]2O[C@H](C)[C@@H](O)[C@H]2O)C=C1F </SMILES> affecting the heart? [/INST]<BOOLEAN> Yes </BOOLEAN></s>
4096

# Dataloader

In [24]:
class CombinedDataset(Dataset):
    def __init__(self, smiles, conversations, encoder_tokenizer, llm_tokenizer, max_length=256):
        self.smiles = smiles
        self.conversations = conversations
        self.chemberta_tencoder_tokenizerokenizer = encoder_tokenizer
        self.llm_tokenizer = llm_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.smiles)

    def __getitem__(self, idx):
        smiles_tokenized = self.encoder_tokenizer(self.smiles[idx], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        conversation_tokenized = self.llm_tokenizer(self.conversations[idx], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt')
        return smiles_tokenized, conversation_tokenized

# Chem encoder

# LLM

In [None]:
# model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-v0.1',
#     torch_dtype=torch.bfloat16,
#     device_map="auto",
#     token=HF_CREDENTIALS
# )

# model.config.use_cache = False
# model.config.pretraining_tp = 1

In [None]:
class LoRA(nn.Module):
    def __init__(self, embed_dim, rank, alpha, dropout_rate=0.05):
        super(LoRA, self).__init__()
        self.rank = rank
        self.alpha = alpha # Scaling factor for LoRA

        # Low-rank matrices A and B
        self.A = nn.Parameter(torch.randn(embed_dim, rank))
        self.B = nn.Parameter(torch.randn(rank, embed_dim))

        # Dropout layer
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, original_weight):
        delta_weight = self.alpha * torch.matmul(self.A, self.B)
        delta_weight = self.dropout(delta_weight)
        return original_weight + delta_weight

In [None]:
class MolEncoderLLMPipeline(nn.Module):
    def __init__(self, mol_encoder, llm_model, llm_embedding_dim, lora_rank=32, lora_alpha=64):
        super().__init__()
        # Load molecule encoder
        self.mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR")

        llm_config = AutoConfig.from_pretrained('mistralai/Mistral-7B-v0.1')
        self.llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.1',
            torch_dtype=torch.bfloat16,
            device_map="auto",
            token=HF_CREDENTIALS
        )
        self.llm_model.config.use_cache = False
        self.llm_model.config.pretraining_tp = 1

        # Initialize LoRA layers for Mistral
        self.lora_layers = nn.ModuleList([
            LoRA(llm_config.hidden_size, lora_rank, lora_alpha) for _ in range(len(self.llm_model.encoder.layer))
        ])

        # Freeze encoder and LLM weights
        for param in self.mol_encoder.parameters():
            param.requires_grad = False
        for param in self.llm_model.parameters():
            param.requires_grad = False

    def forward(self, smiles_tokens, input_ids):
        # Encoder forward pass

        # Get embeddings from LLM for the question
        embedding_layer = self.llm_model.embed_tokens
        llm_embeddings = embedding_layer(input_ids)

        # Concatenate encoder and LLM embeddings
        combined_embeddings = #concat([llm_embeddings])

        # Pass through Mistral's transformer layers with LoRA adjustments
        extended_attention_mask = torch.ones(combined_embeddings.shape[0], combined_embeddings.shape[1], device=combined_embeddings.device)
        hidden_states = combined_embeddings
        for i, layer_module in enumerate(self.llm_model.encoder.layer):
            layer_output = layer_module(hidden_states, attention_mask=extended_attention_mask)[0]
            # Apply LoRA modification
            qkv_weights = [self.lora_layers[i](w) for w in layer_module.attention.self.query.weight, layer_module.attention.self.key.weight, layer_module.attention.self.value.weight]
            layer_module.attention.self.query.weight, layer_module.attention.self.key.weight, layer_module.attention.self.value.weight = qkv_weights
            hidden_states = layer_output

        return hidden_states

# Train

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

# Eval

In [None]:
model.config.use_cache = True