# offline save prompts for zeroshot task

In [None]:
import os
import json

ZEROSHOT_BASE_DIR = '/cpfs01/projects-HDD/cfff-bb5d866c17c2_HDD/taoyuhui/RenalCLIP/RenalCLIP/zero_shot'

#  all templates and attribute candidates are saved in zeroshot_prompt.json
json_file_path = os.path.join(ZEROSHOT_BASE_DIR, 'zeroshot_prompt.json')

try:
    with open(json_file_path, 'r', encoding='utf-8') as json_file:
        data = json.load(json_file)
except FileNotFoundError:
    print(f"Error: JSON file not found at {json_file_path}. Please ensure it exists.")
    exit()
except json.JSONDecodeError:
    print(f"Error: Could not decode JSON from {json_file_path}. Check file format.")
    exit()

templates = data["templates"]
attributes = data["attributes"]

# Construct a dictionary for task-template-label matching.
# The structure of this dictionary will be:
# {
#   "TASK_NAME": {
#     "LABEL_VALUE": [
#       "template1_desc1", "template1_desc2", "template1_desc3",
#       "template2_desc1", "template2_desc2", "template2_desc3",
#       ...
#     ]
#   }
# }
result_dict = {}

for task, labels in attributes.items(): # Iterate through each task, such as "IC", and "BMC".
    task_dict = {}
    for label, descriptions_list in labels.items(): # Iterate through each label value, such as "0" and "1".
        # descriptions_list is now a list containing all the description candidates for this label
        
        filled_texts_for_label = []
        for template in templates: # Iterate through all the templates
            for description in descriptions_list: # Iterate through all the description candidates for the current label
                # Replace the placeholders in the template with the current description
                filled_text = template.replace("____", description)
                filled_texts_for_label.append(filled_text)
        
        task_dict[label] = filled_texts_for_label
    result_dict[task] = task_dict

# Save the generated result_dict to another JSON file
output_file = os.path.join(ZEROSHOT_BASE_DIR, 'expanded_zeroshot_prompts.json')
with open(output_file, 'w', encoding='utf-8') as outfile:
    json.dump(result_dict, outfile, indent=4, ensure_ascii=False)
print(f"Expanded prompts 已成功保存到 {output_file}")

# offline save embeddings of prompts for llm2vec

In [None]:
import os
import json
import pandas as pd
import random
import numpy as np
import torch
from llm2vec import LLM2Vec
from tqdm import tqdm
import sys

sys.path.append("..")

In [None]:
TEXT_PRETRAINED_DIR = fr"/cpfs01/projects-HDD/cfff-bb5d866c17c2_HDD/taoyuhui/RenalCLIP/pretrained_models/language_family"
llm2vec_base_name = "hub/Meta-Llama-3-8B-Instruct-radiology-ext-long"
llm2vec_peft_name = "hub/Meta-Llama-3-8B-Instruct-radiology-simcse/checkpoint-1000"

l2v = LLM2Vec.from_pretrained(
    os.path.join(TEXT_PRETRAINED_DIR, llm2vec_base_name),
    peft_model_name_or_path=os.path.join(TEXT_PRETRAINED_DIR, llm2vec_peft_name),
    device_map=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    torch_dtype=torch.bfloat16,
    local_files_only=True,
    pooling_mode="mean",
    max_length=224,
)

In [None]:
avg_before_align = False

prefix = "templates_with_avg" if avg_before_align else "templates_wo_avg"
ZEROSHOT_BASE_DIR = '/cpfs01/projects-HDD/cfff-bb5d866c17c2_HDD/taoyuhui/RenalCLIP/RenalCLIP/zero_shot'
TEXT_EMBEDDINGS_DIR = fr"/cpfs01/projects-SSD/cfff-bb5d866c17c2_SSD/public/RenalCLIP/zeroshot_text_embeddings/llm2vec-rad"
os.makedirs(TEXT_EMBEDDINGS_DIR, exist_ok=True)

# load Zero-shot prompts
zeroshot_file_path = os.path.join(ZEROSHOT_BASE_DIR, 'expanded_zeroshot_prompts.json')
with open(zeroshot_file_path, 'r') as json_file:
    prompt_data = json.load(json_file)
    
with torch.no_grad():
    for task_name, labels in prompt_data.items():
        print(f"Processing task: {task_name}")
        task_dir = os.path.join(TEXT_EMBEDDINGS_DIR, prefix, task_name)
        os.makedirs(task_dir, exist_ok=True)

        for label, texts in labels.items():
            if avg_before_align:
                # Before performing zero-shot inference, save the averaged embeddings from multiple templates
                template_embeddings = l2v.encode(texts, show_progress_bar=False).float()  # shape: (num_templates, embedding_dim)
                template_embeddings = template_embeddings / template_embeddings.norm(dim=-1, keepdim=True) # norm
                avg_embedding = template_embeddings.mean(dim=0, keepdim=True) # shape: (1, embedding_dim)
                avg_embedding = avg_embedding / avg_embedding.norm(dim=-1, keepdim=True) # norm
                avg_embedding = avg_embedding.numpy()

                embedding_file = os.path.join(task_dir, f"{label}.npy")
                np.save(embedding_file, avg_embedding)
            else:
                # save each template embedding individually
                template_embeddings = l2v.encode(texts, show_progress_bar=False).float()  # shape: (num_templates, embedding_dim)
                template_embeddings = template_embeddings / template_embeddings.norm(dim=-1, keepdim=True) # norm
                template_embeddings = template_embeddings.numpy()            

                embedding_file = os.path.join(task_dir, f"{label}.npy")
                np.save(embedding_file, template_embeddings)


    print("All embeddings have been processed and saved.")