In [None]:
import argparse
import sys
import os
from tqdm import tqdm

current_dir = os.path.dirname(os.path.abspath('__file__'))
project_root = os.path.dirname(current_dir)
sys.path.insert(0, project_root)

from utils.util import *
sys.argv = ['run.py']

CUSTOM_PRETRAINED_DIR = fr"/cpfs01/projects-HDD/cfff-bb5d866c17c2_HDD/taoyuhui/RenalCLIP/clip_output"
from models.RenalModel import RenalModel
from utils.parser import get_downstream_args_img
from utils.data_util import custom_collate_fn_downstream_img
from datasets.data_loader_RenalCLIP_downstream_img import DatasetRenalCLIPDownstreamImg

In [None]:
def get_image_encoder(args):
    image_encoder = RenalModel(  
                            mode='pretrain',
                            pretrained_exp_name=args.pretrained_exp_name,
                            pretrained_metric=args.pretrained_metric,
                            model_type=args.model_type,
                            logger=None,
                            clip_output_dim=args.clip_output_dim,
                            clip_hidden_dim=args.clip_hidden_dim,
                            )

    return image_encoder

# offline save image embeddings of ours

In [None]:
args = get_downstream_args_img()
args.img_encoder_type = 'cnn'
args.pretrained_exp_name = 'RenalCLIP-image-encoder'

args.pretrained_metric = 'acc'
args.pre_processed_type = 'one_kidney_tc_v3'
args.split_file_name = '/cpfs01/projects-HDD/cfff-bb5d866c17c2_HDD/taoyuhui/RenalCLIP/RenalCLIP/data/data_split_latest.json'
args.cropsize_3d = 128
args.crop_slices = 32
args.in_features = 512
args.clip_output_dim = 4096

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_encoder = get_image_encoder(args)
image_encoder = image_encoder.to(device)
image_encoder.eval()

output_base_dir = "/cpfs01/projects-SSD/cfff-bb5d866c17c2_SSD/public/RenalCLIP/image_embeddings_demo"
model_name = 'ours'

print(f"Saving image embeddings for model '{model_name}' to: {output_base_dir}")

with torch.no_grad():
    hospitals = ["internal", "external", "TCIA"]
    for hospital in hospitals:
        dataset = DatasetRenalCLIPDownstreamImg(args, split='test', hospital=hospital)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            pin_memory=True,
            shuffle=False,
            drop_last=False,
            batch_size=args.batch_size,
            num_workers=args.num_workers,
            collate_fn=custom_collate_fn_downstream_img(args),
        )

        for i, batch in enumerate(tqdm(dataloader, desc="Processing and Saving Image Features")):

            images = batch['imgs'].to(device)
            patient_ids = batch['patients']

            img_feats = image_encoder(images).float()
            img_feats = image_encoder.image_encoder.global_embedding(img_feats)

            # normalize
            img_feats = img_feats / img_feats.norm(dim=-1, keepdim=True)
            img_feats_np = img_feats.cpu().numpy()

            for j, patient_id in enumerate(patient_ids):

                patient_output_dir = os.path.join(output_base_dir, model_name, patient_id)

                os.makedirs(patient_output_dir, exist_ok=True)

                save_path = os.path.join(patient_output_dir, 'image_embedding.npy')

                np.save(save_path, img_feats_np[j])

print("所有图片特征已成功保存到磁盘！")
print(f"特征保存根目录: {output_base_dir}/{model_name}/")