In [1]:
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from unsloth import FastLanguageModel, is_bfloat16_supported


Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel, is_bfloat16_supported


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [2]:
sft_adapter_path = "./outputs/sft_lora_adapter_generation"

g_model, g_tokenizer = FastLanguageModel.from_pretrained(
            model_name="unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit",
            max_seq_length=1024,
            load_in_4bit=True,
            dtype=None,
        )
peft_g_model = PeftModel.from_pretrained(g_model, sft_adapter_path)

==((====))==  Unsloth 2025.6.4: Fast Llama patching. Transformers: 4.52.4.
   \\   /|    NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.684 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/2.35G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.7k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/3.83k [00:00<?, ?B/s]

In [3]:
from transformers import pipeline
pipe = pipeline(
          "text-generation",
          model=peft_g_model,
          tokenizer=g_tokenizer,
          max_new_tokens=500,
          return_full_text=False
        )

Device set to use cuda:0


In [4]:
sft_adapter_path2 = "./outputs/sft_lora_adapter_summarization"

s_model, s_tokenizer = FastLanguageModel.from_pretrained(
            model_name="unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit",
            max_seq_length=1024,
            load_in_4bit=True,
            dtype=None,
        )
peft_s_model = PeftModel.from_pretrained(s_model, sft_adapter_path)


pipe2 = pipeline(
          "text-generation",
          model=peft_s_model,
          tokenizer=s_tokenizer,
          max_new_tokens=500,
          return_full_text=False
        )

==((====))==  Unsloth 2025.6.4: Fast Llama patching. Transformers: 4.52.4.
   \\   /|    NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.684 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Device set to use cuda:0


In [5]:
sft_adapter_path3 = "./outputs/sft_lora_adapter_evaluation"

e_model, e_tokenizer = FastLanguageModel.from_pretrained(
            model_name="unsloth/Llama-3.2-3B-Instruct-unsloth-bnb-4bit",
            max_seq_length=1024,
            load_in_4bit=True,
            dtype=None,
        )
peft_e_model = PeftModel.from_pretrained(e_model, sft_adapter_path3)


pipe3 = pipeline(
          "text-generation",
          model=peft_e_model,
          tokenizer=e_tokenizer,
          max_new_tokens=500,
          return_full_text=False
        )

==((====))==  Unsloth 2025.6.4: Fast Llama patching. Transformers: 4.52.4.
   \\   /|    NVIDIA GeForce RTX 3090. Num GPUs = 1. Max memory: 23.684 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.6. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Device set to use cuda:0


In [6]:
import requests
from bs4 import BeautifulSoup

In [7]:
import torch
import re

In [8]:
def get_uniprot_id(gene):
    url = "https://rest.uniprot.org/uniprotkb/search"
    params = {
        "query": f"gene_exact:{gene} AND organism_id:9606",
        "fields": "accession,gene_primary,organism_name",
        "format": "json",
        "size": 1
    }
    r = requests.get(url, params=params)
    data = r.json()

    if data.get("results"):
        result = data["results"][0]
        return result["primaryAccession"]
    else:
        return None


def search_gene(gene_name):
    uniprot_id = get_uniprot_id(gene_name)
    r = requests.get(f"https://rest.uniprot.org/uniprotkb/{uniprot_id}", headers={"Accept": "application/json"})
    data = r.json()

    return data

def construct_context(markers):
    context_rag = "This is the context that you should consider to identify the cell type\n\n"
    for i in range(len(markers)):
        curGene = markers[i]
        context_rag += f"Marker gene {i+1}: {curGene}\n"
        metaData = search_gene(curGene)
        if metaData == None:
            context_rag += "\n"
            continue

        try:
            protein_description = metaData['proteinDescription']
            protein_comments = metaData['comments']
        except Exception as e:
            context_rag += "\n"
            continue

        if "recommendedName" in protein_description:
            context_rag += f"This gene is {protein_description['recommendedName']['fullName']['value']}\n"
        else:
            context_rag += f"This gene is {protein_description['submissionNames'][0]['fullName']['value']}\n"

        for j in range(len(protein_comments)):
            if j > 10:
                break
            curPos = protein_comments[j]
            if "texts" in curPos:
                context_rag += f"{curPos['texts'][0]['value']}\n"

        context_rag += "\n"

    return context_rag

def SuRe_Generate(llm, tokenizer, markers, context_rag):
    input_prompt = f"""You are an intelligent expert to identify the most appropriate cell type from the given marker genes.
    You will be given marker genes and context that you should consider.

    **Do NOT explain your answer. You must only return a single cell type and then finish your answer.**

    Given the expression of genes {markers}, identify the most appropriate cell type.

    ### Context:
    {context_rag}

    Answer: 
    """


    result1 = llm(input_prompt, temperature=0.9, no_repeat_ngram_size=2, return_full_text=False)[0]['generated_text']
    

    result2 = llm(input_prompt, temperature=0.9, no_repeat_ngram_size=2, return_full_text=False)[0]['generated_text']
    return result1, result2


def SuRe_Summarize(llm, tokenizer, markers, context_rag, ans1, ans2):
    answers = [ans1, ans2]
    summaries = ["", ""]
    for i in range(2):
        summary_prompt = f"""You are an intelligent summarizer to check whether the given answer is proper for the given question.
    
    You should produce a summary that shows how well the given answer aligns with the given supporting text.


    ### Question:
    Given the expression of genes {markers}, identify the most likely cell type.

    ### Answer:

    {answers[i]}

    ### Supporting Text:

    {context_rag}

    ### Your output should:

    1. Explain which parts of the given supporting text the answer draws on.
    2. Never be too long
    """

        summary = llm(summary_prompt, temperature=0.9, return_full_text=False)[0]['generated_text']
        summaries[i] = summary

    return summaries[0], summaries[1]

def get_score(response):
    matches = re.search(r'score(?: is|:)\s+[0-9]', response, re.IGNORECASE)
    if matches == None:
        return None

    return matches.group(0)[6:].strip()


def SuRe_Validation(eval_llm, tokenizer, summary1, summary2, markers, ans1, ans2):
    answers = [ans1, ans2]
    summaries = [summary1, summary2]
    scores = [0, 0]

    for i in range(2):
        eval_prompt = f"""You are an impartial judge evaluating the answer for the given question.

        Your task:

        1. Read the given question, answer, and its justification.
        2. Assign the consistency score from 0 to 10 (0 = not align at all, 10 = fully supported) based on how accurately the answer reflects the evidence.

        ** Note that you must not explain your answer **
        ** RETURN a only single number **

        ### Question:
        
        Given the expression of genes {markers}, identify the most likely cell type.

        ### Answer:

        {answers[i]}

        ### Justification:

        {summaries[i]}

        score: 
        """

        count = 0
        while (True):
            score = eval_llm(eval_prompt, temperature=0.9, return_full_text=False)[0]['generated_text']
            score = score[len(eval_prompt):].strip()
            score = get_score(score)
      
            try:
                score = int(score)
                break
            except Exception as e:
                if count < 10:
                    count += 1
                    continue
                else:
                    score = 0
                    break
        scores[i] = score
    index = max(range(len(scores)), key=lambda i: scores[i])

    return answers[index]

def extract_answer(response):
    matches = re.search(r'answer(?: is|:)\s+(.+?)(?=[.,;!?]|\n|[.]$)', response, re.IGNORECASE)


    if matches == None:
        return None

    if matches.group(0)[6] == ":":
        return matches.group(0)[7:].strip()
    else:
        return matches.group(0)[9:].strip()

def generate_response(llm, tokenizer, s_llm, tokenizer2, eval_llm, tokenizer3, markers):
    markers_gene = markers.split(',')
    context_rag = construct_context(markers_gene)
    fin_cellType = ""

    count = 0

    while True:
        result1, result2 = SuRe_Generate(llm, tokenizer, markers, context_rag)

        result1 = extract_answer(result1)
        result2 = extract_answer(result2)

        if result1 == None and result2 == None:
            if count > 5:
                return "Sorry, I don't know"
            else:
                count += 1
                continue
        elif result1 == None:
            return result2
        elif result2 == None:
            return result1
        else:
            summary1, summary2 = SuRe_Summarize(s_llm, tokenizer2, markers, context_rag, result1, result2)
            fin_cellType = SuRe_Validation(eval_llm, tokenizer3, summary1, summary2, markers, result1, result2)
            break


    return fin_cellType

In [17]:
markers = "TAGLN, ACTA2, MYL9, IGFBP5, MCAM, CALD1, TPM2, NOTCH3, IGFBP7, MAP1B"

print("Fin: "+ generate_response(pipe, g_tokenizer, pipe2, s_tokenizer, pipe3, e_tokenizer, markers))


Fin: Smooth muscle


In [124]:
import json

In [None]:
with open("test_data.json") as f:
    data = json.load(f)
    count = 0
    test_list = []
    for item in data:
        print(count)
        if count == 50:
            break
        res = generate_response(pipe, g_tokenizer, pipe2, s_tokenizer, pipe3, e_tokenizer, item['Instruction'])
        test_list.append({item['Output'] : res})
        count += 1

    with open("output.json", "w", encoding="utf-8") as f:
        json.dump(test_list, f, ensure_ascii=False, indent=2)


        
        
        

In [128]:
pipe_base = pipeline(
          "text-generation",
          model=g_model,
          tokenizer=g_tokenizer,
          max_new_tokens=500,
          return_full_text=False
        )

Device set to use cuda:0


In [None]:

with open("test_data.json") as f:
    data = json.load(f)
    count = 0
    test_list = []
    for item in data:
        print(count)
        if count == 50:
            break

        prompt_test = f""" Given the expression of genes {item['Instruction']}, identify the most likely cell type.

        Answer: 
        """
        res = pipe_base(prompt_test, temperature=0.9, return_full_text=False)[0]['generated_text']
        res = res[len(prompt_test):].strip()
        test_list.append({item['Output'] : res})
        count += 1

    with open("output2.json", "w", encoding="utf-8") as f:
        json.dump(test_list, f, ensure_ascii=False, indent=2)