In [1]:
import deepchem
import deepchem.molnet
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm
from typing import Literal, Callable, Any, Tuple, List
import random
import numpy as np
import datetime
from rdkit import Chem, rdBase
from rdkit.Chem.Fingerprints import FingerprintMols
from rdkit.Chem import AllChem, MACCSkeys
from rdkit import DataStructs
from sklearn.metrics import roc_auc_score
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, LlamaTokenizer, LlamaForCausalLM
import argparse

rdBase.DisableLog('rdApp.warning')

def set_seed(seed):
    torch.manual_seed(seed)  # 设置CPU的随机种子
    torch.cuda.manual_seed(seed)  # 设置当前GPU的随机种子
    torch.cuda.manual_seed_all(seed)  # 设置所有GPU的随机种子
    np.random.seed(seed)  # 设置numpy的随机种子
    random.seed(seed)  # 设置Python的随机种子
    torch.backends.cudnn.deterministic = True  # 确保cudnn的确定性
    torch.backends.cudnn.benchmark = False  # 确保cudnn的确定性

set_seed(1111)

tasks, datasets, transformers = deepchem.molnet.load_clintox(splitter='scaffold', reload=True,
                                                             data_dir='./data/clintox_data',
                                                             save_dir='./data/clintox_datasets')

train_dataset, valid_dataset, test_dataset = datasets

batch_size = 1
total_batches = len(test_dataset) // batch_size

class Model:
    def __init__(self, model_name_or_id: Literal["AI4Chem/ChemLLM-20B-Chat-SFT", "AI4Chem/ChemLLM-20B-Chat-DPO", "X-LANCE/ChemDFM-13B-v1.0"], **kwargs):
        assert model_name_or_id in ["AI4Chem/ChemLLM-20B-Chat-SFT", "AI4Chem/ChemLLM-20B-Chat-DPO", "X-LANCE/ChemDFM-13B-v1.0"], \
            "model must be one of 'AI4Chem/ChemLLM-20B-Chat-SFT', 'AI4Chem/ChemLLM-20B-Chat-DPO', 'X-LANCE/ChemDFM-13B-v1.0'"
        self.model_name_or_id = model_name_or_id

        self.yes_token_ids = [
            [7560,], # 7560为ChemLLM词表中Yes对应的id
            [3869,], # 3869为ChemDFM词表中Yes对应的id
            ]
        self.no_token_ids = [
            [2458, 2783],  # 2458为ChemLLM词表中No对应的id, 2783-Not
            [1939,],  # 1939为ChemDFM词表中No对应的id
            ]
        
        if "AI4Chem" in model_name_or_id:
            self.model = AutoModelForCausalLM.from_pretrained(model_name_or_id, torch_dtype=torch.float16, trust_remote_code=True, device_map='auto')
            self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_id, trust_remote_code=True)
            self.yes_token_id = self.yes_token_ids[0]
            self.no_token_id = self.no_token_ids[0]

        else: # ChemDFM        
            self.model = LlamaForCausalLM.from_pretrained(model_name_or_id, torch_dtype=torch.float16, device_map='auto')
            self.tokenizer = LlamaTokenizer.from_pretrained(model_name_or_id)
            self.yes_token_id = self.yes_token_ids[1]
            self.no_token_id = self.no_token_ids[1]

        self.generation_config = GenerationConfig(
            do_sample=True,
            top_k=1,
            **kwargs,
            repetition_penalty=1.5,
            pad_token_id=self.tokenizer.eos_token_id,
            output_scores=True,
            return_dict_in_generate=True
        )

    def __call__(self, prompt: str, debug_mode: bool = False):
        torch.cuda.empty_cache()
        inputs = self.tokenizer(prompt, return_tensors="pt").to("cuda")

        input_ids = inputs['input_ids']
        outputs = self.model.generate(**inputs, generation_config=self.generation_config)
        logits = outputs.scores
        generated_ids = outputs.sequences
        probs = [torch.softmax(log, dim=-1) for log in logits]

        output_token_ids = generated_ids[0][ len(input_ids[0]): ]
        response = self.tokenizer.decode(output_token_ids, skip_special_tokens=True)
        
        if debug_mode:
            for i, token_id in enumerate(output_token_ids):
                token_prob = probs[i][0, token_id].item()
                print(f"Token ID: {token_id}, Probability: {token_prob}")

        total_yesp, total_nop = 0., 0.
        for x in self.yes_token_id:
            total_yesp += probs[0][0, x].item()
        for y in self.no_token_id:
            total_nop += probs[0][0, y].item()
        
        sump = (total_yesp + total_nop) + 1e-14
        y_score = total_yesp / sump


        inputs, input_ids, outputs, logits, generated_ids, probs, output_token_ids = None, None, None, None, None, None, None
        del inputs
        del input_ids
        del outputs
        del logits
        del generated_ids
        del probs
        del output_token_ids
        torch.cuda.empty_cache()

        return [response], [y_score]

def smiles2maccs_fp(smiles: str):
    return MACCSkeys.GenMACCSKeys(Chem.MolFromSmiles(smiles))

def smiles2rdk_fp(smiles: str):
    return Chem.RDKFingerprint(Chem.MolFromSmiles(smiles))

def smiles2morgan_fp(smiles: str):
    return AllChem.GetMorganFingerprint(Chem.MolFromSmiles(smiles), 2)

def calc_tanimoto_similarity(fp1, fp2) -> float:
    return DataStructs.TanimotoSimilarity(fp1, fp2)

def calc_cosine_similarity(fp1, fp2) -> float:
    return DataStructs.CosineSimilarity(fp1, fp2)

def calc_dice_similarity(fp1, fp2) -> float:
    return DataStructs.DiceSimilarity(fp1, fp2)

class BasePrompter(object):
    def __init__(self, system_instruction: str = "", template: str = "", verbose: bool = False):
        self.system_instruction = system_instruction
        self.template = template
        self.verbose = verbose

    def generate_prompt(self, query_smiles):
        raise NotImplementedError
    
    def canonicalize_smiles(self, smiles):
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
        return Chem.MolToSmiles(mol, isomericSmiles=True, kekuleSmiles=False)
    
class ZeroShotPrompter(BasePrompter):
    def __init__(self, system_instruction: str = "", template: str = "", verbose: bool = False):
        super().__init__(system_instruction, template, verbose)

    @staticmethod
    def generate_prompt(self, query_smiles):
        query_smiles = self.canonicalize_smiles(query_smiles)

        prompt = f"{self.template}\nSMILES:{query_smiles}\nAnswer:"

        if self.verbose:
            print(prompt)

        return prompt

class FewShotPrompter(BasePrompter):
    def __init__(self, system_instruction: str = "", template: str = "", verbose: bool = False, *, 
                 sample_dataset, sample_molecule_format: Literal['smiles', 'maccs_fp', 'rdk_fp', 'morgan_fp'], 
                 sample_mode: Literal['random', 'cosine_similarity', 'tanimoto_similarity', 'dice_similarity'], sample_num: int):
        super(FewShotPrompter, self).__init__(system_instruction, template, verbose)
        
        assert sample_mode in ['random', 'cosine_similarity', 'tanimoto_similarity', 'dice_similarity'], "mode must be either 'random' or 'cosine_similarity' or 'tanimoto_similarity', 'dice_similarity'\n"
        self.sample_mode = sample_mode
        assert sample_molecule_format in ['smiles', 'maccs_fp', 'rdk_fp', 'morgan_fp'], "mode must be either 'smiles' or 'maccs_fp' or 'rdk_fp' or 'morgan_fp'\n"
        self.sample_molecule_format = sample_molecule_format
        self.sample_num = sample_num
        self.sample_dataset = self.convert_molecule_format(sample_dataset)

    def convert_molecule_format(self, sample_dataset):
        sample_dataset = sample_dataset.to_dataframe()
        if self.sample_molecule_format == 'smiles':
            return sample_dataset
        else:
            sample_dataset[self.sample_molecule_format] = sample_dataset['ids'].apply(lambda x: eval(f'smiles2{self.sample_molecule_format}')(x))
            return sample_dataset
            
    def get_demonstrations(self, query_smiles: str) -> List:
        if self.sample_mode == 'random':
            return self.random_sample_examples(query_smiles)
        else:
            return self.similar_sample_examples(query_smiles, self.sample_num)

    def random_sample_examples(self, query_smiles: str) -> List[Tuple[str, str]]:
        # y2 = CT_TOX
        positive_examples = self.sample_dataset[self.sample_dataset["y2"] == 1].sample(int(self.sample_num/2))
        negative_examples = self.sample_dataset[self.sample_dataset["y2"] == 0].sample(int(self.sample_num/2))

        smiles = positive_examples["ids"].tolist() + negative_examples["ids"].tolist()
        smiles = [self.canonicalize_smiles(i) for i in smiles]
        class_label = positive_examples["y2"].tolist() + negative_examples["y2"].tolist()
        #convert 1 to "Yes" and 0 to "No"" in class_label
        class_label = ["Yes" if i == 1 else "No" for i in class_label]
        sample_examples = list(zip(smiles, class_label))
        return sample_examples

    def similar_sample_examples(self, query_smiles: str, top_k: int) -> List[Tuple[str, str]]:
        query_smiles = eval(f'smiles2{self.sample_molecule_format}')(query_smiles)
        similarities = []
        for k in self.sample_dataset[self.sample_molecule_format].tolist():
            similarities.append(eval(f'calc_{self.sample_mode}')(query_smiles, k))
        sample_idx = np.argsort(-np.array(similarities))[:top_k]

        smiles = []
        class_label = []
        for i in sample_idx:
            smiles.append(self.canonicalize_smiles(self.sample_dataset.iloc[i]['ids']))
            class_label.append(self.sample_dataset.iloc[i]['y2'])
        class_label = ["Yes" if i == 1 else "No" for i in class_label]
        sample_examples = list(zip(smiles, class_label))
        return sample_examples


    def generate_prompt(self, query_smiles: str) -> str:
        query_smiles = self.canonicalize_smiles(query_smiles)

        few_shot = ""
        demonstrations = self.get_demonstrations(query_smiles)
        for example in demonstrations:
            few_shot += f"SMILES:{example[0]}\nAnswer:{example[-1]}\n"

        prompt = f"{self.template}\n{few_shot}SMILES:{query_smiles}\nAnswer:"

        if self.verbose:
            print(prompt)
            
        return prompt
    
class FewShotPrompter1(FewShotPrompter):
    def generate_prompt(self, query_smiles):
        query_smiles = self.canonicalize_smiles(query_smiles)

        few_shot = ""
        demonstrations = self.get_demonstrations(query_smiles)
        for example in demonstrations:
            few_shot += f"SMILES:{example[0]}\nAnswer:{example[-1]}\n"

        prompt = f"{self.template}\n{few_shot}Is this molecule Clinically-trail-Toxic (Yes) or Not Clinically-trail-toxic (No)?\nSMILES:{query_smiles}\nAnswer:"

        if self.verbose:
            print(prompt)
            
        return prompt

def main(dataset: Any, 
         batch_size: int, 
         total_batches: int, 
         model: Callable[[str, bool], Tuple],
         prompt_generator: Callable[[str], str],
         ):
    y_trues = []
    y_scores = []
    responses = []
    cnt = 0
    
    for X, Y, W, ids in tqdm(dataset.iterbatches(batch_size=batch_size), total=total_batches):
        input_X = [prompt_generator(id) for id in ids]
        
        y_trues.extend(Y[:, -1])

        bs_responses, bs_y_scores = model(input_X)
        
        print(bs_responses, bs_y_scores)

        responses.extend(bs_responses)
        y_scores.extend(bs_y_scores)

        cnt += 1

        # if cnt > 3:
        #     break

    print(responses)
    print(y_trues)
    print(y_scores)
    print(cnt)

    roc = roc_auc_score(y_trues, y_scores)
    print(roc)

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x76254075be20>>
Traceback (most recent call last):
  File "/home/fangmiaoNLP/.conda/envs/LZZ/lib/python3.8/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 
No normalization for SPS. Feature removed!
No normalization for AvgIpc. Feature removed!
2024-08-04 11:03:51.956984: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Skipped loading modules with pytorch-geometric dependency, missing a dependency. No module named 'torch_geometric'
Skipped loading modules with pytorch-geometric dependency, missing a dependency. cannot import name 'DMPNN' from 'd

In [None]:
zero_shot_prompt = \
    "You are an expert chemist, your task is to predict the property of molecule using your experienced chemical property prediction knowledge.\nPlease strictly follow the format, no other information can be provided. Given the SMILES string of a molecule, the task focuses on predicting molecular properties, specifically wether a molecule is Clinically-trail-Toxic(Yes) or Not Clinically-trail-toxic (No) based on the SMILES string representation of each molecule. The task is to predict the binary label for a given molecule, please answer with only 'Yes' or 'No'."

few_shot_prompt = \
    "Please strictly follow the format, no other information can be provided. Given the SMILES string of a molecule, the task focuses on predicting molecular properties, specifically wether a molecule is Clinically-trail-Toxic (Yes) or Not Clinically-trail-toxic (No). The task is to predict the binary label for a given molecule, please answer with only 'Yes' or 'No'."


# ChemLLM

# temperature 0.9

In [None]:
model = 'AI4Chem/ChemLLM-20B-Chat-SFT'
temperature = 0.9
m = Model(model_name_or_id=model, temperature=temperature, max_new_tokens=10)

Loading checkpoint shards:   0%|          | 0/41 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### sample_mode cosine_similarity

#### sample_molecule_format maccs_fp

##### sample_num 2

In [None]:
# few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
#                                      sample_molecule_format='maccs_fp', 
#                                      sample_mode='cosine_similarity', 
#                                      sample_num=2)

# main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 4

In [None]:
# few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
#                                      sample_molecule_format='maccs_fp', 
#                                      sample_mode='cosine_similarity', 
#                                      sample_num=4)

# main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 8

In [None]:
# few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
#                                      sample_molecule_format='maccs_fp', 
#                                      sample_mode='cosine_similarity', 
#                                      sample_num=8)

# main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

"""
OOM
"""

#### sample_molecule_format rdk_fp

##### sample_num 2

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='rdk_fp', 
                                     sample_mode='cosine_similarity', 
                                     sample_num=2)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

  0%|          | 0/148 [00:00<?, ?it/s]

  1%|          | 1/148 [00:18<45:41, 18.65s/it]

[' No'] [0.0]


  1%|▏         | 2/148 [00:27<31:16, 12.85s/it]

['Not clin'] [0.0]


  2%|▏         | 3/148 [00:35<25:42, 10.64s/it]

[' No'] [0.0]


  3%|▎         | 4/148 [00:53<32:35, 13.58s/it]

[' No'] [0.0]


  3%|▎         | 5/148 [01:03<29:31, 12.39s/it]

[' No'] [0.0]


  4%|▍         | 6/148 [01:12<25:59, 10.98s/it]

[' No'] [0.0]


  5%|▍         | 7/148 [01:22<25:12, 10.73s/it]

['Not'] [0.0]


  5%|▌         | 8/148 [01:38<29:09, 12.49s/it]

[' Yes'] [0.499999999999995]


  6%|▌         | 9/148 [01:46<25:41, 11.09s/it]

['Not'] [0.0]


  7%|▋         | 10/148 [02:03<29:24, 12.79s/it]

[' No'] [0.0]


  7%|▋         | 11/148 [02:13<27:28, 12.03s/it]

[' No'] [0.0]


  8%|▊         | 12/148 [02:29<30:14, 13.34s/it]

[' No'] [0.0]


  9%|▉         | 13/148 [02:41<28:50, 12.82s/it]

['Not Clinical trial toxic'] [0.0]


  9%|▉         | 14/148 [02:48<24:31, 10.98s/it]

['Not'] [0.0]


 10%|█         | 15/148 [02:57<23:09, 10.45s/it]

['Not Clinical trial toxic'] [0.0]


 11%|█         | 16/148 [03:05<21:24,  9.73s/it]

['Not'] [0.0]


 11%|█▏        | 17/148 [03:21<25:37, 11.74s/it]

[' Yes'] [0.99999999999999]


 12%|█▏        | 18/148 [03:38<28:34, 13.19s/it]

[' No'] [0.0]


 13%|█▎        | 19/148 [03:54<30:20, 14.11s/it]

[' No'] [0.0]


 14%|█▎        | 20/148 [04:04<27:39, 12.96s/it]

[' No'] [0.0]


 14%|█▍        | 21/148 [04:13<24:29, 11.57s/it]

[' No'] [0.0]


 15%|█▍        | 22/148 [04:31<28:38, 13.64s/it]

[' No'] [0.0]


 16%|█▌        | 23/148 [04:39<24:54, 11.96s/it]

[' No'] [0.0]


 16%|█▌        | 24/148 [04:47<22:23, 10.84s/it]

['Not'] [0.0]


 17%|█▋        | 25/148 [04:56<20:52, 10.18s/it]

['Not clin'] [0.0]


 18%|█▊        | 26/148 [05:13<24:36, 12.10s/it]

[' No'] [0.0]


 18%|█▊        | 27/148 [05:21<22:06, 10.96s/it]

[' No'] [0.0]


 19%|█▉        | 28/148 [05:37<25:09, 12.58s/it]

[' No'] [0.0]


 20%|█▉        | 29/148 [05:44<21:25, 10.81s/it]

[' Yes'] [0.99999999999999]


 20%|██        | 30/148 [05:57<22:16, 11.33s/it]

['Not'] [0.0]


 21%|██        | 31/148 [06:13<25:01, 12.83s/it]

[' No'] [0.0]


 22%|██▏       | 32/148 [06:21<22:06, 11.43s/it]

[' No'] [0.0]


 22%|██▏       | 33/148 [06:38<24:48, 12.94s/it]

[' No'] [0.0]


 23%|██▎       | 34/148 [06:54<26:37, 14.01s/it]

[' No'] [0.0]


 24%|██▎       | 35/148 [07:03<23:44, 12.61s/it]

['Not Clinical trial toxic'] [0.0]


 24%|██▍       | 36/148 [07:22<26:39, 14.28s/it]

[' No'] [0.0]


 25%|██▌       | 37/148 [07:38<27:42, 14.98s/it]

[' Yes'] [0.499999999999995]


 26%|██▌       | 38/148 [07:55<28:11, 15.38s/it]

[' No'] [0.0]


 26%|██▋       | 39/148 [08:09<27:24, 15.08s/it]

[' Yes'] [0.99999999999999]


 27%|██▋       | 40/148 [08:17<23:31, 13.07s/it]

[' No'] [0.0]


 28%|██▊       | 41/148 [08:25<20:41, 11.60s/it]

[' No'] [0.0]


 28%|██▊       | 42/148 [08:42<22:51, 12.94s/it]

[' No'] [0.0]


 29%|██▉       | 43/148 [08:54<22:24, 12.80s/it]

[' No'] [0.0]


 30%|██▉       | 44/148 [09:03<19:59, 11.53s/it]

[' No'] [0.0]


 30%|███       | 45/148 [09:19<22:14, 12.96s/it]

[' Yes'] [0.99999999999999]


 31%|███       | 46/148 [09:35<23:44, 13.97s/it]

[' No'] [0.0]


 32%|███▏      | 47/148 [09:44<20:48, 12.36s/it]

['Not clin'] [0.0]


 32%|███▏      | 48/148 [09:52<18:43, 11.24s/it]

['Not Yes'] [0.0]


 33%|███▎      | 49/148 [10:13<23:05, 13.99s/it]

[' Yes'] [0.99999999999999]


 34%|███▍      | 50/148 [10:33<25:53, 15.85s/it]

[' No'] [0.0]


 34%|███▍      | 51/148 [10:45<23:36, 14.61s/it]

['Not Clinical trial toxic'] [0.0]


 35%|███▌      | 52/148 [11:01<24:10, 15.11s/it]

[' No'] [0.0]


 36%|███▌      | 53/148 [11:13<22:37, 14.29s/it]

['Not'] [0.0]


 36%|███▋      | 54/148 [11:24<20:32, 13.11s/it]

[' No'] [0.0]


 37%|███▋      | 55/148 [11:33<18:39, 12.04s/it]

['Not Clinical trial toxic'] [0.0]


 38%|███▊      | 56/148 [11:42<16:43, 10.91s/it]

['Not'] [0.0]


 39%|███▊      | 57/148 [11:58<19:00, 12.53s/it]

[' No'] [0.0]


 39%|███▉      | 58/148 [12:14<20:35, 13.73s/it]

[' Yes'] [0.99999999999999]


 40%|███▉      | 59/148 [12:23<17:57, 12.11s/it]

['Not'] [0.0]


 41%|████      | 60/148 [12:30<15:27, 10.54s/it]

['Not clin'] [0.0]


 41%|████      | 61/148 [12:38<14:24,  9.93s/it]

['Not clin'] [0.0]


 42%|████▏     | 62/148 [12:45<12:47,  8.93s/it]

['Not'] [0.0]


 43%|████▎     | 63/148 [12:51<11:40,  8.24s/it]

[' No'] [0.0]


 43%|████▎     | 64/148 [12:58<10:49,  7.73s/it]

['Not'] [0.0]


 44%|████▍     | 65/148 [13:08<11:44,  8.49s/it]

['Not'] [0.0]


 45%|████▍     | 66/148 [13:26<15:33, 11.38s/it]

[' Yes'] [0.99999999999999]


 45%|████▌     | 67/148 [13:36<14:34, 10.80s/it]

['Not Clinical trial toxic'] [0.0]


 46%|████▌     | 68/148 [13:52<16:39, 12.49s/it]

[' No'] [0.0]


 47%|████▋     | 69/148 [14:00<14:44, 11.20s/it]

['Not'] [0.0]


 47%|████▋     | 70/148 [14:13<14:59, 11.53s/it]

[' No'] [0.0]


 48%|████▊     | 71/148 [14:20<13:02, 10.16s/it]

['Not clin'] [0.0]


 49%|████▊     | 72/148 [14:30<12:52, 10.16s/it]

[' No'] [0.0]


 49%|████▉     | 73/148 [14:40<12:44, 10.20s/it]

[' No'] [0.0]


 50%|█████     | 74/148 [14:48<11:38,  9.44s/it]

['Not Clinical trial toxic'] [0.0]


 51%|█████     | 75/148 [14:54<10:23,  8.54s/it]

[' No'] [0.0]


 51%|█████▏    | 76/148 [15:11<13:04, 10.90s/it]

[' No'] [0.0]


 52%|█████▏    | 77/148 [15:27<14:54, 12.60s/it]

[' No'] [0.0]


 53%|█████▎    | 78/148 [15:36<13:14, 11.35s/it]

[' No'] [0.0]


 53%|█████▎    | 79/148 [15:52<14:46, 12.84s/it]

[' No'] [0.0]


 54%|█████▍    | 80/148 [16:04<14:25, 12.72s/it]

['Not'] [0.0]


 55%|█████▍    | 81/148 [16:15<13:23, 12.00s/it]

[' Yes'] [0.99999999999999]


 55%|█████▌    | 82/148 [16:21<11:27, 10.42s/it]

[' No'] [0.0]


 56%|█████▌    | 83/148 [16:36<12:33, 11.60s/it]

[' No'] [0.0]


 57%|█████▋    | 84/148 [16:46<11:55, 11.19s/it]

[' No'] [0.0]


 57%|█████▋    | 85/148 [16:54<10:47, 10.29s/it]

[' No'] [0.0]


 58%|█████▊    | 86/148 [17:01<09:35,  9.29s/it]

['Not clin'] [0.0]


 59%|█████▉    | 87/148 [17:08<08:37,  8.48s/it]

['Not'] [0.0]


 59%|█████▉    | 88/148 [17:17<08:46,  8.77s/it]

['Not Clinical trial toxic'] [0.0]


 60%|██████    | 89/148 [17:24<08:03,  8.20s/it]

['Not clin'] [0.0]


 61%|██████    | 90/148 [17:35<08:52,  9.18s/it]

['Not Clinical trial toxic'] [0.0]


 61%|██████▏   | 91/148 [17:52<10:45, 11.33s/it]

[' No'] [0.0]


 62%|██████▏   | 92/148 [18:01<10:00, 10.72s/it]

['Not Clinical trial toxic'] [0.0]


 63%|██████▎   | 93/148 [18:09<09:06,  9.94s/it]

[' No'] [0.0]


 64%|██████▎   | 94/148 [18:26<10:41, 11.87s/it]

[' No'] [0.0]


 64%|██████▍   | 95/148 [18:44<12:10, 13.79s/it]

[' No'] [0.0]


 65%|██████▍   | 96/148 [19:00<12:39, 14.61s/it]

[' No'] [0.0]


 66%|██████▌   | 97/148 [19:09<10:53, 12.82s/it]

['Not clin'] [0.0]


 66%|██████▌   | 98/148 [19:17<09:31, 11.43s/it]

[' No'] [0.0]


 67%|██████▋   | 99/148 [19:30<09:32, 11.68s/it]

['Not'] [0.0]


 68%|██████▊   | 100/148 [19:42<09:30, 11.88s/it]

[' No'] [0.0]


 68%|██████▊   | 101/148 [19:58<10:21, 13.23s/it]

[' No'] [0.0]


 69%|██████▉   | 102/148 [20:06<08:59, 11.73s/it]

[' No'] [0.0]


 70%|██████▉   | 103/148 [20:15<08:05, 10.78s/it]

['Not clin'] [0.0]


 70%|███████   | 104/148 [20:23<07:18,  9.97s/it]

[' No'] [0.0]


 71%|███████   | 105/148 [20:31<06:44,  9.40s/it]

[' No'] [0.0]


 72%|███████▏  | 106/148 [20:39<06:19,  9.03s/it]

['Not'] [0.0]


 72%|███████▏  | 107/148 [20:47<05:59,  8.76s/it]

['Not'] [0.0]


 73%|███████▎  | 108/148 [21:00<06:32,  9.82s/it]

['Not'] [0.0]


 74%|███████▎  | 109/148 [21:08<06:08,  9.46s/it]

['Not clin'] [0.0]


 74%|███████▍  | 110/148 [21:18<05:58,  9.43s/it]

['Not Clinical trial toxic'] [0.0]


 75%|███████▌  | 111/148 [21:28<05:57,  9.67s/it]

[' No'] [0.0]


 76%|███████▌  | 112/148 [21:36<05:32,  9.22s/it]

[' No'] [0.0]


 76%|███████▋  | 113/148 [21:43<04:57,  8.51s/it]

['Not clin'] [0.0]


 77%|███████▋  | 114/148 [21:51<04:39,  8.23s/it]

['Not Clinical trial toxic'] [0.0]


 78%|███████▊  | 115/148 [22:07<05:53, 10.71s/it]

[' No'] [0.0]


 78%|███████▊  | 116/148 [22:25<06:54, 12.94s/it]

[' No'] [0.0]


 79%|███████▉  | 117/148 [22:42<07:13, 13.99s/it]

[' No'] [0.0]


 80%|███████▉  | 118/148 [22:50<06:07, 12.24s/it]

['Not'] [0.0]


 80%|████████  | 119/148 [23:06<06:31, 13.49s/it]

[' No'] [0.0]


 81%|████████  | 120/148 [23:23<06:43, 14.40s/it]

[' No'] [0.0]


 82%|████████▏ | 121/148 [23:29<05:25, 12.05s/it]

['Not'] [0.0]


 82%|████████▏ | 122/148 [23:36<04:34, 10.57s/it]

['Not clin'] [0.0]


 83%|████████▎ | 123/148 [23:46<04:14, 10.17s/it]

['Not Clinical trial toxic'] [0.0]


 84%|████████▍ | 124/148 [23:52<03:38,  9.12s/it]

['Not'] [0.0]


 84%|████████▍ | 125/148 [24:09<04:20, 11.31s/it]

[' No'] [0.0]


 85%|████████▌ | 126/148 [24:18<03:56, 10.75s/it]

['Not Clinical trial toxic'] [0.0]


 86%|████████▌ | 127/148 [24:27<03:33, 10.14s/it]

['Not clin'] [0.0]


 86%|████████▋ | 128/148 [24:37<03:24, 10.21s/it]

['Not'] [0.0]


 87%|████████▋ | 129/148 [24:52<03:39, 11.55s/it]

[' No'] [0.0]


 88%|████████▊ | 130/148 [24:58<03:00, 10.04s/it]

[' No'] [0.0]


 89%|████████▊ | 131/148 [25:07<02:41,  9.48s/it]

[' No'] [0.0]


 89%|████████▉ | 132/148 [25:23<03:05, 11.60s/it]

[' No'] [0.0]


 90%|████████▉ | 133/148 [25:40<03:15, 13.05s/it]

[' No'] [0.0]


 91%|█████████ | 134/148 [25:49<02:46, 11.92s/it]

['Not Clinical trial toxic'] [0.0]


 91%|█████████ | 135/148 [25:57<02:19, 10.76s/it]

[' No'] [0.0]


 92%|█████████▏| 136/148 [26:15<02:36, 13.02s/it]

[' No'] [0.0]


 93%|█████████▎| 137/148 [26:23<02:07, 11.57s/it]

[' No'] [0.0]


 93%|█████████▎| 138/148 [26:33<01:48, 10.85s/it]

['Not Clinical trial toxic'] [0.0]


 94%|█████████▍| 139/148 [26:49<01:52, 12.51s/it]

[' No'] [0.0]


 95%|█████████▍| 140/148 [27:02<01:40, 12.62s/it]

['Not clin'] [0.0]


 95%|█████████▌| 141/148 [27:12<01:23, 11.88s/it]

[' No'] [0.0]


 96%|█████████▌| 142/148 [27:28<01:19, 13.19s/it]

[' No'] [0.0]


 97%|█████████▋| 143/148 [27:45<01:10, 14.20s/it]

[' No'] [0.0]


 97%|█████████▋| 144/148 [27:53<00:50, 12.54s/it]

['Not clin'] [0.0]


 98%|█████████▊| 145/148 [28:02<00:33, 11.26s/it]

[' No'] [0.0]


 99%|█████████▊| 146/148 [28:18<00:25, 12.77s/it]

[' No'] [0.0]


 99%|█████████▉| 147/148 [28:36<00:14, 14.42s/it]

[' No'] [0.0]


100%|██████████| 148/148 [28:44<00:00, 11.65s/it]

['Not Clinical trial toxic'] [0.0]
[' No', 'Not clin', ' No', ' No', ' No', ' No', 'Not', ' Yes', 'Not', ' No', ' No', ' No', 'Not Clinical trial toxic', 'Not', 'Not Clinical trial toxic', 'Not', ' Yes', ' No', ' No', ' No', ' No', ' No', ' No', 'Not', 'Not clin', ' No', ' No', ' No', ' Yes', 'Not', ' No', ' No', ' No', ' No', 'Not Clinical trial toxic', ' No', ' Yes', ' No', ' Yes', ' No', ' No', ' No', ' No', ' No', ' Yes', ' No', 'Not clin', 'Not Yes', ' Yes', ' No', 'Not Clinical trial toxic', ' No', 'Not', ' No', 'Not Clinical trial toxic', 'Not', ' No', ' Yes', 'Not', 'Not clin', 'Not clin', 'Not', ' No', 'Not', 'Not', ' Yes', 'Not Clinical trial toxic', ' No', 'Not', ' No', 'Not clin', ' No', ' No', 'Not Clinical trial toxic', ' No', ' No', ' No', ' No', ' No', 'Not', ' Yes', ' No', ' No', ' No', ' No', 'Not clin', 'Not', 'Not Clinical trial toxic', 'Not clin', 'Not Clinical trial toxic', ' No', 'Not Clinical trial toxic', ' No', ' No', ' No', ' No', 'Not clin', ' No', 'Not', ' 




##### sample_num 4

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='rdk_fp', 
                                     sample_mode='cosine_similarity', 
                                     sample_num=4)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

  0%|          | 0/148 [00:00<?, ?it/s]

##### sample_num 8

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='rdk_fp', 
                                     sample_mode='cosine_similarity', 
                                     sample_num=8)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

#### sample_molecule_format morgan_fp

##### sample_num 2

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='morgan_fp', 
                                     sample_mode='cosine_similarity', 
                                     sample_num=2)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 4

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='morgan_fp', 
                                     sample_mode='cosine_similarity', 
                                     sample_num=4)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 8

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='morgan_fp', 
                                     sample_mode='cosine_similarity', 
                                     sample_num=8)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

### sample_mode tanimoto_similarity

#### sample_molecule_format maccs_fp

##### sample_num 2

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='maccs_fp', 
                                     sample_mode='tanimoto_similarity', 
                                     sample_num=2)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 4

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='maccs_fp', 
                                     sample_mode='tanimoto_similarity', 
                                     sample_num=4)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 8

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='maccs_fp', 
                                     sample_mode='tanimoto_similarity', 
                                     sample_num=8)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

#### sample_molecule_format rdk_fp

##### sample_num 2

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='rdk_fp', 
                                     sample_mode='tanimoto_similarity', 
                                     sample_num=2)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 4

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='rdk_fp', 
                                     sample_mode='tanimoto_similarity', 
                                     sample_num=4)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 8

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='rdk_fp', 
                                     sample_mode='tanimoto_similarity', 
                                     sample_num=8)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

#### sample_molecule_format morgan_fp

##### sample_num 2

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='morgan_fp', 
                                     sample_mode='tanimoto_similarity', 
                                     sample_num=2)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 4

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='morgan_fp', 
                                     sample_mode='tanimoto_similarity', 
                                     sample_num=4)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)

##### sample_num 8

In [None]:
few_shot_prompter = FewShotPrompter1(template=few_shot_prompt, sample_dataset=train_dataset, 
                                     sample_molecule_format='morgan_fp', 
                                     sample_mode='tanimoto_similarity', 
                                     sample_num=8)

main(dataset=test_dataset, batch_size=batch_size, total_batches=total_batches, model=m, prompt_generator=few_shot_prompter.generate_prompt)