In [None]:
# ! pip install transformers==4.38.1
# ! pip install rdkit==2023.9.4
# ! pip install accelerate==0.27.2
# ! pip install flash-attn
# ! pip install -q -U bitsandbytes
# ! pip install datasets
# ! pip install loralib
# ! pip install git+https://github.com/huggingface/peft.git
# ! pip install sentencepiece

In [None]:
# ! pip install tensorflow==2.10.0

In [None]:
import random, pickle, json, os
from tqdm import tqdm
from datasets import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import bitsandbytes as bnb
from peft import PeftModelForCausalLM
from peft import get_peft_model, LoraConfig

import sys
sys.path.append('../credentials/')
from HF_credentials import *

In [None]:
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel, AutoConfig, BitsAndBytesConfig

# Tokenizer

In [None]:
llm_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS, model_max_length=256, add_prefix_space=False)
llm_tokenizer.pad_token = llm_tokenizer.eos_token
llm_tokenizer.padding_side = "right"

In [None]:
chat = [
  {"role": "user", "content": ""},
  {"role": "assistant", "content": ""}
]

llm_tokenizer.apply_chat_template(chat, tokenize=False)

# Data

In [None]:
def create_datasets(split='train'):

    conversations = []
    input_smiles = []

    with open(f'../data/LlaSMol/{split}/property_prediction-bbbp.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is blood-brain barrier permeability (BBBP) a property of <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-clintox.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Is <SMILES> {txt['input']} </SMILES> toxic?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-esol.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"How soluble is <SMILES> {txt['input']} </SMILES>?"
            chat[1]['content'] = f"Its log solubility is <NUMBER> {txt['output']} </NUMBER> mol/L"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-hiv.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Can <SMILES> {txt['input']} </SMILES> serve as an inhibitor of HIV replication?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-lipo.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Predict the octanol/water distribution coefficient logD under the circumstances of pH 7.4 for <SMILES> {txt['input']} </SMILES>"
            chat[1]['content'] = f"<NUMBER> {txt['output']} </NUMBER>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])

    with open(f'../data/LlaSMol/{split}/property_prediction-sider.jsonl', 'r') as f:
        for line in f:
            txt = json.loads(line)
            chat[0]['content'] = f"Are there any known side effects of <SMILES> {txt['input']} </SMILES> affecting the heart?"
            chat[1]['content'] = f"<BOOLEAN> {txt['output']['Vascular disorders']} </BOOLEAN>"
            # conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=True, truncation=True, padding='max_length', max_length=256))
            conversations.append(llm_tokenizer.apply_chat_template(chat, tokenize=False))
            input_smiles.append(txt['input'])
    print(conversations[-1])
    print(len(conversations))

    return conversations, input_smiles

In [None]:
print('Train:')
train_conversations, train_input_smiles = create_datasets('train')
print('Test:')
test_conversations, test_input_smiles = create_datasets('test')

In [None]:
class CombinedDataset(Dataset):
    def __init__(self, smiles_list, conversations, encoder_tokenizer, llm_tokenizer, max_length=256):
        self.smiles_list = smiles_list
        self.conversations = conversations
        self.encoder_tokenizer = encoder_tokenizer
        self.llm_tokenizer = llm_tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        smiles_encoding = self.encoder_tokenizer(smiles, return_tensors='pt', truncation=True, padding='max_length', max_length=self.max_length)
        conversation_tokenized = self.llm_tokenizer(self.conversations[idx], truncation=True, padding='max_length', max_length=self.max_length, return_tensors='pt', add_special_tokens=False)
        return {key: tensor[0].to('cuda') for key, tensor in smiles_encoding.items()}, conversation_tokenized.to('cuda')

In [None]:
# Load tokenizers
chemberta_tokenizer = AutoTokenizer.from_pretrained('DeepChem/ChemBERTa-77M-MTR')
mistral_tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', add_prefix_space=False)
mistral_tokenizer.pad_token = mistral_tokenizer.eos_token
mistral_tokenizer.padding_side = "right"

# Create combined dataset
train_dataset = CombinedDataset(train_input_smiles, train_conversations, chemberta_tokenizer, mistral_tokenizer)
test_dataset = CombinedDataset(test_input_smiles, test_conversations, chemberta_tokenizer, mistral_tokenizer)

# Define DataLoader
batch_size = 1
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

## Test

In [None]:
# x, y = next(iter(combined_loader))

# mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
# llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
#             torch_dtype=torch.bfloat16,
#             # quantization_config=bnb_config,
#             device_map="auto",
#             token=HF_CREDENTIALS
# )

# mol_encoder(**x)['last_hidden_state'];
# llm_model.model.embed_tokens(y['input_ids'].to('cuda'));

# Model

In [None]:
class MolEncoderLLMPipeline(nn.Module):
    def __init__(self, lora_rank=32, lora_alpha=64):
        super().__init__()
        # Load molecule encoder
        self.mol_encoder = AutoModel.from_pretrained("DeepChem/ChemBERTa-77M-MTR", torch_dtype=torch.bfloat16).to('cuda')

        # UNCOMMENT TO BRING DOWN FROM 15GB TO 7GB
        bnb_config = BitsAndBytesConfig(
            load_in_4bit= True,
            bnb_4bit_quant_type= "nf4",
            bnb_4bit_compute_dtype= torch.bfloat16,
            bnb_4bit_use_double_quant= True,
        )
        self.llm_config = AutoConfig.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2', token=HF_CREDENTIALS)
        self.llm_model = AutoModelForCausalLM.from_pretrained('mistralai/Mistral-7B-Instruct-v0.2',
            torch_dtype=torch.bfloat16,
            quantization_config=bnb_config,
            device_map="auto",
            token=HF_CREDENTIALS
        )

        self.llm_model.config.use_cache = False
        self.llm_model.config.pretraining_tp = 1

        # Initialize LoRA layers for Mistral
        self.lora_config = LoraConfig(
            r=lora_rank,
            lora_alpha=lora_alpha,
            lora_dropout=0.05,
            target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
            bias="none",
            task_type="CAUSAL_LM",
        )

        # Freeze encoder and LLM weights
        for param in self.mol_encoder.parameters():
            param.requires_grad = False
        for param in self.llm_model.parameters():
            param.requires_grad = False

        self.linear_project = nn.Linear(self.mol_encoder.config.hidden_size, self.llm_config.hidden_size, dtype=torch.bfloat16)

        # Apply LoRA modification
        self.llm_model = get_peft_model(self.llm_model, self.lora_config)

    def forward(self, smiles_tokens, text_tokens):
        # Encoder forward pass / Get SMILES embeddings
        mol_encoder_output = self.mol_encoder(**smiles_tokens)
        smiles_embedding = mol_encoder_output['last_hidden_state'][:,0,:] # torch.Size([batch, max_length, 384])
        smiles_projection = self.linear_project(smiles_embedding).unsqueeze(1)
        # print(smiles_projection.shape)

        # Get embeddings from LLM for the question
        embedding_layer = self.llm_model.model.model.embed_tokens
        llm_embeddings = embedding_layer(text_tokens['input_ids'].to('cuda')).squeeze(1) # torch.Size([batch, max_length, 4096])
        # print(llm_embeddings.shape)

        # Concatenate encoder and LLM embeddings
        combined_embeddings = torch.cat((smiles_projection, llm_embeddings), dim=1)
        # print(combined_embeddings.shape)

        # Custom attention mask
        attention_mask = torch.zeros(batch_size, combined_embeddings.shape[1], combined_embeddings.shape[1], device='cuda')
        
        # SMILES mask for itself
        attention_mask[:, 0, 0] = 1
        for i in range(1, combined_embeddings.shape[1]):
            attention_mask[:, i, 0:i+1] = 1 # From SMILES to current token (inclusive)
        attention_mask = attention_mask.unsqueeze(1)
        # print(attention_mask.shape)

        # Pass through Mistral's transformer layers with LoRA adjustments
        output = self.llm_model(inputs_embeds=combined_embeddings, attention_mask=attention_mask)

        return output

In [None]:
model = MolEncoderLLMPipeline(lora_rank=16, lora_alpha=16).to('cuda')

In [None]:
model.llm_model.print_trainable_parameters()

In [None]:
# x, y = next(iter(combined_loader))
# model(x,y)

# Train

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
# from datasets import load_dataset
  
# dataset = load_dataset('osunlp/SMolInstruct')
# train_set = dataset['train']
# validation_set = dataset['validation']
# test_set = dataset['test']

In [None]:
# for split, split_dataset in dataset.items():
#     split_dataset.to_json(f"squad-{split}.jsonl")

In [None]:
from torch.cuda.amp import autocast


optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=1e-5)
loss_fn = nn.CrossEntropyLoss(ignore_index=mistral_tokenizer.pad_token_id)

# Define the total number of training steps and the number of warmup steps
epochs = 5
total_steps = len(test_loader) * epochs
warmup_steps = 100

accumulation_steps = 32

# Create the learning rate scheduler
scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=total_steps)

model.train()
for epoch in range(epochs):
    total_loss = 0
    tprog = tqdm(enumerate(test_loader), total=len(test_loader))
    for i, batch in tprog:
        smiles_data, conversation_data = batch

        # Forward pass
        with autocast():
            output = model(smiles_data, conversation_data)
            logits = output.logits[:, 1:, :]

            # Prepare labels
            labels = conversation_data['input_ids'].squeeze(1)
            labels = torch.cat([labels[:, 1:], labels.new_full((labels.size(0), 1), mistral_tokenizer.pad_token_id)], dim=1)

            # Compute loss
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), labels.view(-1))

        # Backward and accumulate gradients
        loss.backward()
        total_loss += loss.item()
        tprog.set_description(f'train step loss: {loss.item():.4f}')

        if (i+1) % accumulation_steps == 0:  # Step the optimizer every accumulation_steps
            optimizer.step()
            optimizer.zero_grad()

            # Step the scheduler
            scheduler.step()

            # Clean
            gc.collect()
            torch.cuda.empty_cache()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader)}")

# Eval

In [None]:
# def parse_answer(text, is_boolean):
#     if is_boolean:
#         start = text.find('<BOOLEAN>') + len('<BOOLEAN>')
#         end = text.find('</BOOLEAN>')
#         return text[start:end].strip()
#     else:
#         start = text.find('<NUMBER>') + len('<NUMBER>')
#         end = text.find('</NUMBER>')
#         return text[start:end].strip()

# def get_answer(true_sentence, pred_sentence):
#     true_answer = true_sentence.split('[/INST]')[1]
#     pred_answer = pred_sentence.split('[/INST]')[1]

#     if 'BOOLEAN' in true_answer:
#         y_true = parse_answer(true_answer, True)
#         y_pred = parse_answer(pred_answer, True)
#         return y_true, y_pred
    
#     elif 'NUMBER' in true_answer:
#         y_true = parse_answer(true_answer, False)
#         y_pred = parse_answer(pred_answer, False)
#         return y_true, y_pred

In [None]:
true = train_conversations[0]
pred = train_conversations[0]

In [None]:
train_conversations[0]

In [None]:
get_answer(true, pred)

In [None]:
# # Initialize lists to store mean accuracies
# mean_accs = {
#     "BBBP": [],
#     "side effects": [],
#     "logD": [],
#     "soluble": [],
#     "toxic": [],
#     "HIV": []
# }

# # Iterate through the dataset
# for _, y in tqdm(combined_dataset):
#     # Initialize lists to store results for each category
#     category_results = {category: [] for category in mean_accs}

#     # Iterate through ground truth and prediction
#     for gt, pred in zip(ground_truth, prediction):
#         # Get answer for the current category
#         _, _, result = get_answer(gt, pred)
#         category = None

#         # Determine the category based on the conversation
#         if 'BBBP' in pred:
#             category = "BBBP"
#         elif 'side effects' in pred:
#             category = "side effects"
#         elif 'logD' in pred:
#             category = "logD"
#         elif 'soluble' in pred:
#             category = "soluble"
#         elif 'toxic' in pred:
#             category = "toxic"
#         else:
#             category = "HIV"

#         # Append result to the respective category list
#         category_results[category].append(result)

#     # Calculate mean accuracy for each category and append to mean_accs
#     for category, results in category_results.items():
#         mean_acc = sum(results) / len(results) if results else 0
#         mean_accs[category].append(mean_acc)

# # Compute the final mean accuracies for each category
# final_mean_accs = {category: sum(accs) / len(accs) for category, accs in mean_accs.items()}

# # Print or use final mean accuracies as needed
# for category, acc in final_mean_accs.items():
#     print(f"Final {category} Mean Accuracy: {acc}")

In [None]:
# # Initialize lists to store mean accuracies
# mean_accs = {
#     "BBBP": [],
#     "side effects": [],
#     "logD": [],
#     "soluble": [],
#     "toxic": [],
#     "HIV": []
# }

# for batch, true_sentence in zip(test_loader, test_conversations):
#     print(text)

#     pred_sentence = decode(...)

#     category_results = {category: [] for category in mean_accs}

#     get_answer(true_sentence, pred_sentence)

#     # Determine the category based on the conversation
#     if 'BBBP' in true_sentence:
#         category = "BBBP"
#     elif 'side effects' in true_sentence:
#         category = "side effects"
#     elif 'logD' in true_sentence:
#         category = "logD"
#     elif 'soluble' in true_sentence:
#         category = "soluble"
#     elif 'toxic' in true_sentence:
#         category = "toxic"
#     else:
#         category = "HIV"

In [None]:
# Break after this
import sys
sys.exit()

# Train

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

# Eval

In [None]:
model.config.use_cache = True