In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image
import models_vit as models
from huggingface_hub import hf_hub_download
np.set_printoptions(threshold=np.inf)
np.random.seed(1)
torch.manual_seed(1)

  from .autonotebook import tqdm as notebook_tqdm


<torch._C.Generator at 0x7f1404212250>

In [2]:
def prepare_model(chkpt_dir, arch='RETFound_mae'):
    
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    
    # build model
    if arch=='RETFound_mae':
        model = models.__dict__[arch](
            img_size=224,
            num_classes=2,
            drop_path_rate=0,
            global_pool=True,
        )
        msg = model.load_state_dict(checkpoint['model'], strict=False)
    else:
        model = models.__dict__[arch](
            num_classes=2,
            drop_path_rate=0,
            args=None,
        )
        msg = model.load_state_dict(checkpoint['teacher'], strict=False)
    return model

def run_one_image(img, model, arch):
    
    x = torch.tensor(img)
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)
    
    x = x.to(device, non_blocking=True)
    latent = model.forward_features(x.float())
    
    if arch=='dinov2_large':
        latent = latent[:, 1:, :].mean(dim=1,keepdim=True)
        latent = nn.LayerNorm(latent.shape[-1], eps=1e-6).to(device)(latent)
    
    latent = torch.squeeze(latent)
    return latent

def run_one_image_for_prediction(img, model, arch): # not sure about this: 还没有写如果是dinov2_large的情况

    x = torch.tensor(img)
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)
    
    x = x.to(device, non_blocking=True)
    with torch.no_grad():
        predictions = model(x.float())
        probs = torch.softmax(predictions, dim=1)
        pred_class = torch.argmax(probs, dim=1).item()
        probs_pred_class = probs[0][pred_class].item()
        probs_1_class = probs[0][1].item()
    
    return {
        'score_raw' : predictions.cpu().numpy(), # raw score before softmax: shape通常是(1, num_classes)
        'prob_all' : probs.cpu().numpy(),        # probabilities: shape通常是(1, num_classes)
        'pred_class': pred_class,                # predicted class: int
        'pred_class_prob': probs_pred_class,     # probability of predicted class: float
        'pred_1class_prob': probs_1_class        # probability of class 1: float
    }


In [3]:
def get_feature(data_path,
                chkpt_dir,
                device,
                arch='RETFound_mae'):
    # loading model
    model_ = prepare_model(chkpt_dir, arch)
    model_.to(device)

    img_list = os.listdir(data_path)
    
    name_list, feature_list = [], []
    score_raw_list, prob_all_list, pred_class_list, pred_class_prob_list, pred_1class_prob_list = [], [], [], [], [] # we added this

    model_.eval()
    
    finished_num = 0
    for i in img_list:
        finished_num+=1
        if (finished_num%1000 == 0):
            print(str(finished_num)+"finished")
        
        # img = Image.open(os.path.join(data_path, i))
        img = Image.open(os.path.join(data_path, i)).convert('RGB')

        img = img.resize((224, 224))
        img = np.array(img) / 255.
        
        img[...,0] = (img[...,0] - img[...,0].mean())/img[...,0].std()
        img[...,1] = (img[...,1] - img[...,1].mean())/img[...,1].std()
        img[...,2] = (img[...,2] - img[...,2].mean())/img[...,2].std()
        assert img.shape == (224, 224, 3)
        
        latent_feature = run_one_image(img, model_, arch) # latent_feature
        pred_dict = run_one_image_for_prediction(img, model_, arch)

        name_list.append(i)
        feature_list.append(latent_feature.detach().cpu().numpy())

        score_raw_list.append(pred_dict['score_raw'])               # numpy array (1, num_classes)
        prob_all_list.append(pred_dict['prob_all'])                 # numpy array (1, num_classes)
        pred_class_list.append(pred_dict['pred_class'])             # int
        pred_class_prob_list.append(pred_dict['pred_class_prob'])   # float
        pred_1class_prob_list.append(pred_dict['pred_1class_prob']) # float
        
    return (
        name_list, 
        feature_list, 
        score_raw_list, 
        prob_all_list, 
        pred_class_list, 
        pred_class_prob_list, 
        pred_1class_prob_list
    )


In [None]:
# chkpt_dir = hf_hub_download(repo_id="YukunZhou/RETFound_dinov2_meh", filename="RETFound_dinov2_meh.pth")
chkpt_dir = "/mnt/d/3.dlProject/bdrv/output/LEO/1.RETFound_mae_natureOCT0.cwflmix-bdrv/output_dir/1.RETFound_mae_natureOCT0.cwflmix-bdrv/checkpoint-best.pth"
output_path = "/mnt/d/3.dlProject/bdrv/output/LEO/1.RETFound_mae_natureOCT0.cwflmix-bdrv/output_dir/1.RETFound_mae_natureOCT0.cwflmix-bdrv/feature.csv"
data_path = "/mnt/d/3.dlProject/bdrv/data/retinal_vasculitis_crop/test/bd"
device = torch.device('cuda')
arch='RETFound_mae'

In [None]:
# # Diagnostic cell: check image formats in the dataset
# def check_image_formats(data_path, max_samples=10):
#     img_list = os.listdir(data_path)
#     if len(img_list) > max_samples:
#         img_list = img_list[:max_samples]  # Check only a few images
    
#     print(f"Checking up to {max_samples} images from {data_path}")
#     for i, img_name in enumerate(img_list):
#         try:
#             img_path = os.path.join(data_path, img_name)
#             img = Image.open(img_path)
#             img_array = np.array(img)
#             print(f"[{i+1}] {img_name}: mode={img.mode}, shape={img_array.shape}, size={img.size}")
#         except Exception as e:
#             print(f"[{i+1}] {img_name}: Error - {str(e)}")
    
# # Run this before processing to check your images
# check_image_formats(data_path)

Checking up to 10 images from /mnt/d/3.dlProject/bdrv/data/retinal_vasculitis_crop/test/bd
[1] 1209_2013.3.19_od_1.png: mode=L, shape=(496, 768), size=(768, 496)
[2] 1209_2013.3.19_od_2.png: mode=L, shape=(496, 768), size=(768, 496)
[3] 1517_2019.1.2_os_1.jpg: mode=L, shape=(496, 768), size=(768, 496)
[4] 1517_2019.1.2_os_2.jpg: mode=L, shape=(496, 768), size=(768, 496)
[5] 174_2013.8.8_od_1.png: mode=L, shape=(496, 768), size=(768, 496)
[6] 174_2013.8.8_od_2.png: mode=L, shape=(496, 768), size=(768, 496)
[7] 174_2013.8.8_os_1.png: mode=L, shape=(496, 768), size=(768, 496)
[8] 174_2013.8.8_os_2.png: mode=L, shape=(496, 768), size=(768, 496)
[9] 603_2016.4.12_od_1.jpg: mode=L, shape=(496, 768), size=(768, 496)
[10] 603_2016.4.12_od_2.jpg: mode=L, shape=(496, 768), size=(768, 496)


In [19]:
(
    name_list, 
    feature_list, 
    score_raw_list, 
    prob_all_list, 
    pred_class_list, 
    pred_class_prob_list, 
    pred_1class_prob_list
) = get_feature(data_path, chkpt_dir, device, arch=arch)

In [28]:
df_info = pd.DataFrame({
    "name": name_list,
    "score_raw": score_raw_list,
    "prob_all": prob_all_list,
    "pred_class": pred_class_list,
    "pred_class_prob": pred_class_prob_list,
    "pred_1class_prob": pred_1class_prob_list
})

df_feature = pd.DataFrame(
    feature_list, 
    columns = ["feature_{}".format(i) for i in range(feature_list[0].shape[0])]
)

df_final = pd.concat([df_info, df_feature], axis=1)
print(f"Saving features to {output_path}")
df_final.to_csv(output_path, index=False)

Saving features to /mnt/d/3.dlProject/bdrv/output/LEO/1.RETFound_mae_natureOCT0-bdrv/output_dir/1.RETFound_mae_natureOCT0-bdrv/feature.csv


## 对根目录进行批量循环提取特征

In [None]:
# todo: 给一个父文件夹/mnt/d/3.dlProject/bdrv/data/retinal_vasculitis_crop对下面的所有子文件夹（train/val/test）及其子文件夹（bd[0]/rv[1]）来统一提取；第一列就编程第一个子文件夹；第二列就编程子子文件夹；再后面的不变
def collect_features_from_all_subfolders(root_path, chkpt_dir, device, arch='RETFound_mae'):
    """
    遍历 root_path 下所有子文件夹(例如 train/val/test)，
    对于每个子文件夹下的子子文件夹(例如 bd/rv 等)，
    调用原先的 get_feature(child_path, ...) 来提取特征和预测信息。
    
    在结果中，会额外添加 2 列：
      - folder_1: 表示第一层子文件夹 (train / val / test 等)
      - folder_2: 表示第二层子文件夹 (bd / rv 等)
    最终返回一个汇总的 DataFrame。
    """
    # 用来存放多个子文件夹返回的 DataFrame
    all_dfs = []

    # 1) 遍历第一级子文件夹 (train / val / test等)
    for parent_dir_name in os.listdir(root_path):
        parent_dir_path = os.path.join(root_path, parent_dir_name)
        # 如果不是文件夹，就跳过
        if not os.path.isdir(parent_dir_path):
            continue
        
        # 2) 遍历第二级子文件夹 (bd / rv 等)
        for child_dir_name in os.listdir(parent_dir_path):
            child_dir_path = os.path.join(parent_dir_path, child_dir_name)
            if not os.path.isdir(child_dir_path):
                continue
            
            # 3) 调用之前的 get_feature (不改动原函数) 来处理这个子子文件夹
            (
                name_list,
                feature_list,
                score_raw_list,
                prob_all_list,
                pred_class_list,
                pred_class_prob_list,
                pred_1class_prob_list
            ) = get_feature(child_dir_path, chkpt_dir, device, arch=arch)
            
            # 4) 构造一个用于存放 [folder_1, folder_2, name, ...] 的 DataFrame
            df_info = pd.DataFrame({
                "folder_1": [parent_dir_name] * len(name_list),  # 第一列
                "folder_2": [child_dir_name] * len(name_list),   # 第二列
                "name": name_list,
                "score_raw": score_raw_list,
                "prob_all": prob_all_list,
                "pred_class": pred_class_list,
                "pred_class_prob": pred_class_prob_list,
                "pred_1class_prob": pred_1class_prob_list
            })
            
            # 5) 把 feature_list (形如 [ [x1,...,x1024], [y1,...,y1024], ... ]) 转成 DataFrame
            #    列名为 feature_0, feature_1, ... , feature_1023
            df_feature = pd.DataFrame(
                feature_list,
                columns=["feature_{}".format(i) for i in range(feature_list[0].shape[0])]
            )
            
            # 6) 横向拼起来，得到包含所有列的 df_combined
            df_combined = pd.concat([df_info, df_feature], axis=1)
            
            # 7) 收集到 all_dfs 里
            all_dfs.append(df_combined)
    
    # 8) 把所有子文件夹的数据合并成一个总的 DataFrame
    final_df = pd.concat(all_dfs, ignore_index=True)
    return final_df

root_path = "/mnt/d/3.dlProject/bdrv/data/retinal_vasculitis_crop"
chkpt_dir = "/mnt/d/3.dlProject/bdrv/output/LEO/1.RETFound_mae_natureOCT0.cwflmix-bdrv/output_dir/1.RETFound_mae_natureOCT0.cwflmix-bdrv/checkpoint-best.pth"
output_path = "/mnt/d/3.dlProject/bdrv/output/LEO/1.RETFound_mae_natureOCT0.cwflmix-bdrv/output_dir/1.RETFound_mae_natureOCT0.cwflmix-bdrv/feature.csv"
device = torch.device('cuda')
arch = 'RETFound_mae'
df_final = collect_features_from_all_subfolders(
    root_path=root_path, 
    chkpt_dir=chkpt_dir, 
    device=device, 
    arch=arch
)

1000finished


In [30]:
print(df_final.head())

  folder_1 folder_2                     name                   score_raw  \
0     test       bd  1209_2013.3.19_od_1.png  [[-0.6993152, 0.69788224]]   
1     test       bd  1209_2013.3.19_od_2.png  [[-0.39495006, 0.3931978]]   
2     test       bd   1517_2019.1.2_os_1.jpg   [[1.6192925, -1.6223154]]   
3     test       bd   1517_2019.1.2_os_2.jpg   [[1.5682701, -1.5715133]]   
4     test       bd    174_2013.8.8_od_1.png    [[1.236756, -1.2395748]]   

                      prob_all  pred_class  pred_class_prob  pred_1class_prob  \
0     [[0.1982612, 0.8017388]]           1         0.801739          0.801739   
1     [[0.3125665, 0.6874335]]           1         0.687433          0.687433   
2   [[0.96237034, 0.03762962]]           0         0.962370          0.037630   
3  [[0.95850426, 0.041495733]]           0         0.958504          0.041496   
4   [[0.9224658, 0.077534236]]           0         0.922466          0.077534   

   feature_0  feature_1  ...  feature_1014  feature_1015

In [31]:
df_final.to_csv(output_path, index=False)
print(f"Saved to {output_path}")

Saved to /mnt/d/3.dlProject/bdrv/output/LEO/1.RETFound_mae_natureOCT0-bdrv/output_dir/features_allfolders.csv
