In [14]:
import pickle
import numpy as np
import pandas as pd
from pathlib import Path

In [None]:
import os
import warnings
warnings.filterwarnings('ignore')  # Suppress all other warnings
os.environ['TRANSFORMERS_VERBOSITY'] = 'error'  # Suppress transformer warnings

In [None]:
def make_prompt(task, role_description, prediction_instruction, output_instruction, use_descriptors=False):
    descriptor_line = (
        "The molecule is described by a SMILES string, optionally followed by molecular descriptors (e.g., QED, SPS, MolWt).\n"
        "Use all available information.\n"
        if use_descriptors else ""
    )
    return (
        f"You are a {role_description}.\n"
        f"{prediction_instruction}\n"
        f"{descriptor_line}"
        f"{output_instruction}\n"
        "Do not explain. Do not include any other text."
    )


def get_roles(use_descriptors=False):
    return {
        'bace': make_prompt(
            task='bace',
            role_description="medicinal chemist predicting BACE-1 inhibitory activity",
            prediction_instruction="Your task is to predict whether the given small molecule is an active inhibitor.",
            output_instruction="Return only a single integer: 1 for active, 0 for inactive.",
            use_descriptors=use_descriptors
        ),
        'bbbp': make_prompt(
            task='bbbp',
            role_description="pharmacologist predicting blood-brain barrier penetration ability",
            prediction_instruction="Your task is to predict whether the given small molecule can penetrate the blood-brain barrier.",
            output_instruction="Return only a single integer: 1 for penetration, 0 for no penetration.",
            use_descriptors=use_descriptors
        ),
        'esol': make_prompt(
            task='esol',
            role_description="physical chemist modeling solvation thermodynamics",
            prediction_instruction="Your task is to predict the hydration free energy (kcal/mol) of the given small molecule in water.",
            output_instruction="Return only a single float value.",
            use_descriptors=use_descriptors
        ),
        'lipo': make_prompt(
            task='lipo',
            role_description="physical chemist modeling lipophilicity",
            prediction_instruction="Your task is to predict the octanol/water distribution coefficient (logD at pH 7.4) of the given small molecule.",
            output_instruction="Return only a single float value.",
            use_descriptors=use_descriptors
        ),
    }

# 사용 예시
roles = {'desc_w': get_roles(use_descriptors=True), 'desc_wo': get_roles(use_descriptors=False)}
roles

{'bace': 'You are a medicinal chemist predicting BACE-1 inhibitory activity.\nYour task is to predict whether the given small molecule is an active inhibitor.\nReturn only a single integer: 1 for active, 0 for inactive.\nDo not explain. Do not include any other text.',
 'bbbp': 'You are a pharmacologist predicting blood-brain barrier penetration ability.\nYour task is to predict whether the given small molecule can penetrate the blood-brain barrier.\nReturn only a single integer: 1 for penetration, 0 for no penetration.\nDo not explain. Do not include any other text.',
 'esol': 'You are a physical chemist modeling solvation thermodynamics.\nYour task is to predict the hydration free energy (kcal/mol) of the given small molecule in water.\nReturn only a single float value.\nDo not explain. Do not include any other text.',
 'lipo': 'You are a physical chemist modeling lipophilicity.\nYour task is to predict the octanol/water distribution coefficient (logD at pH 7.4) of the given small mo

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    output_hidden_states=True,
)

messages = [
    {"role": "system", "content": role_ip},
    {"role": "user", "content": user_ip},
]

input_ids = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt"
).to(model.device)

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

with torch.inference_mode():
    generation = model.generate(
        input_ids,
        max_new_tokens=256,
        eos_token_id=terminators,
        do_sample=False,
    )
    response = generation[0][input_ids.shape[-1]:]
    
decoded = tokenizer.decode(response, skip_special_tokens=True)
print(decoded)

with torch.no_grad():
    outputs = model(input_ids)
    last_hidden_state = outputs['hidden_states'][-1]
print(last_hidden_state)

In [None]:
# pip install accelerate
from transformers import AutoProcessor, AutoModelForImageTextToText
import torch

model_id = "google/medgemma-4b-it"

model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    output_hidden_states=True,
)
processor = AutoProcessor.from_pretrained(model_id)

messages = [
    {"role": "system", "content": [{"type": "text", "text": role_ip}]},
    {"role": "user", "content": [{"type": "text", "text": user_ip}]}
]

inputs = processor.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=True,
    return_dict=True, return_tensors="pt"
).to(model.device, dtype=torch.bfloat16)

input_len = inputs["input_ids"].shape[-1]

with torch.inference_mode():
    generation = model.generate(**inputs, max_new_tokens=200, do_sample=False)
    response = generation[0][input_len:]

decoded = processor.decode(response, skip_special_tokens=True)
print(decoded)

with torch.no_grad():
    outputs = model(**inputs)
    last_hidden_state = outputs['hidden_states'][-1]
print(last_hidden_state)

In [None]:
dataset = ['bace', 'bbbp', 'esol', 'lipo']
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
descriptor = False

for dt in dataset:
    # if dt != 'esol':
    #     continue

    data = pd.read_csv(f'data/{dt}/desc_pre.csv')
    
    desc_cols = [c for c in data.columns if c not in ['smiles', 'set', 'y']] # 115
    test = data[data['set'] == 'test']
    print(len(data), len(test))

    result = []
    for _, row in test.iterrows():
        smi = row['smiles']
        y = row['y']

        user_ip = [f'SMILES: {smi}']
        if descriptor:
            for desc in desc_cols:
                user_ip.append(str(f'{desc}: {row[desc]}'))
            user_ip = ' | '.join(user_ip)
            role_ip = roles['desc_w'][dt]
        else:
            user_ip = user_ip[0]
            role_ip = roles['desc_wo'][dt]

        if model_id == "meta-llama/Meta-Llama-3-8B-Instruct":
            messages = [
                {"role": "system", "content": role_ip},
                {"role": "user", "content": user_ip},
            ]
        elif model_id == "google/medgemma-4b-it":
            messages = [
                {"role": "system", "content": [{"type": "text", "text": role_ip}]},
                {"role": "user", "content": [{"type": "text", "text": user_ip}]}
            ]

        sample_result = {'smiles': smi, 'y': y, 'input': user_ip, 'system': role_ip, 
                         'llm_name': 0, 'chat': 0, 'last_hidden_state': 0}
        result.append(sample_result)

    # save
    with open(f'output/{dt}_output.pkl', 'wb') as f:
        pickle.dump(result, f)

    # load
    with open(f'output/{dt}_output.pkl', 'rb') as f:
        tmp = pickle.load(f)
        print(tmp[0])
    
    break

1513 152
1972 198
1121 113
4200 420


In [8]:
smi

'O1CC[C@@H](NC(=O)[C@@H](Cc2cc3cc(ccc3nc2N)-c2ccccc2C)C)CC1(C)C'

In [10]:
bace

'You are a medicinal chemist predicting BACE-1 inhibitory activity.\nYour task is to predict whether the given small molecule is an active inhibitor.\nReturn only a single integer: 1 for active, 0 for inactive.\nDo not explain. Do not include any other text.'