In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image
import models_vit as models
from huggingface_hub import hf_hub_download
np.set_printoptions(threshold=np.inf)
np.random.seed(1)
torch.manual_seed(1)





In [None]:
def prepare_model(chkpt_dir, arch='RETFound_mae'):
    
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    
    # build model
    if arch=='RETFound_mae':
        model = models.__dict__[arch](
            img_size=224,
            num_classes=2,
            drop_path_rate=0,
            global_pool=True,
        )
        msg = model.load_state_dict(checkpoint['model'], strict=False)
    else:
        model = models.__dict__[arch](
            num_classes=2,
            drop_path_rate=0,
            args=None,
        )
        msg = model.load_state_dict(checkpoint['teacher'], strict=False)
    return model

def run_one_image(img, model, arch):
    
    x = torch.tensor(img)
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)
    
    x = x.to(device, non_blocking=True)
    latent = model.forward_features(x.float())
    
    if arch=='dinov2_large':
        latent = latent[:, 1:, :].mean(dim=1,keepdim=True)
        latent = nn.LayerNorm(latent.shape[-1], eps=1e-6).to(device)(latent)
    
    latent = torch.squeeze(latent)
    return latent

def run_one_image_for_prediction(img, model, arch): # not sure about this: 还没有写如果是dinov2_large的情况

    x = torch.tensor(img)
    x = x.unsqueeze(dim=0)
    x = torch.einsum('nhwc->nchw', x)
    
    x = x.to(device, non_blocking=True)
    with torch.no_grad():
        predictions = model(x.float())
        probs = torch.softmax(predictions, dim=1)
        pred_class = torch.argmax(probs, dim=1).item()

        probs_pred_class = probs[0][pred_class].item(),
        probs_1_class = probs[0][1].item()
    
    return {
        'predi predictions.cpu().numpy(), 
        'probs'      : probs.cpu().numpy(),
        # 'predicted_class': pred_class,
        # 'pred_class_probability': pred_class_probability,
        # 'positive_class_probability': positive_class_probability
    }
    
Can you explan the following code to me: i

In [None]:
%debug
# model_ = prepare_model(chkpt_dir, arch)
# model_.to(device)
name_list = []
feature_list = []
pred_class_list = []  # 存储预测类别
confidence_list = []  # 存储置信度
model_.eval()

img = Image.open("/mnt/d/3.dlProject/bdrv/data/retinal_vasculitis_crop/test/bd/1209_2013.3.19_od_1.png").convert('RGB')

img = img.resize((224, 224))

img = np.array(img) / 255.

img[...,0] = (img[...,0] - img[...,0].mean())/img[...,0].std()
img[...,1] = (img[...,1] - img[...,1].mean())/img[...,1].std()
img[...,2] = (img[...,2] - img[...,2].mean())/img[...,2].std()
assert img.shape == (224, 224, 3)

latent_feature = run_one_image(img, model_, arch) # latent_feature

pred_result = run_one_image_for_prediction(img, model_, arch)
pred_result['predictions']



In [20]:
def get_feature(data_path,
                chkpt_dir,
                device,
                arch='RETFound_mae'):
    # loading model
    model_ = prepare_model(chkpt_dir, arch)
    model_.to(device)

    img_list = os.listdir(data_path)
    
    name_list = []
    feature_list = []
    pred_class_list = []  # 存储预测类别
    confidence_list = []  # 存储置信度

    model_.eval()
    
    finished_num = 0
    for i in img_list:
        finished_num+=1
        if (finished_num%1000 == 0):
            print(str(finished_num)+"finished")
        
        # img = Image.open(os.path.join(data_path, i))
        img = Image.open(os.path.join(data_path, i)).convert('RGB')

        img = img.resize((224, 224))
        img = np.array(img) / 255.
        
        img[...,0] = (img[...,0] - img[...,0].mean())/img[...,0].std()
        img[...,1] = (img[...,1] - img[...,1].mean())/img[...,1].std()
        img[...,2] = (img[...,2] - img[...,2].mean())/img[...,2].std()
        assert img.shape == (224, 224, 3)
        
        latent_feature = run_one_image(img, model_, arch) # latent_feature

        pred_result = run_one_image_for_prediction(img, model_, arch) # 获取预测结果
        pred_class_list.append(pred_result['predicted_class'])
        confidence_list.append(pred_result['confidence'])

        name_list.append(i)
        feature_list.append(latent_feature.detach().cpu().numpy())
        
    return [name_list, feature_list, pred_class_list, confidence_list]



In [3]:
# chkpt_dir = hf_hub_download(repo_id="YukunZhou/RETFound_dinov2_meh", filename="RETFound_dinov2_meh.pth")
chkpt_dir = "/mnt/d/3.dlProject/bdrv/output/LEO/1.RETFound_mae_natureOCT0-bdrv/output_dir/1.RETFound_mae_natureOCT0-bdrv/checkpoint-best.pth"
data_path = "/mnt/d/3.dlProject/bdrv/data/retinal_vasculitis_crop/test/bd"
device = torch.device('cuda')
arch='RETFound_mae'

In [4]:
# Diagnostic cell to check image formats in the dataset
def check_image_formats(data_path, max_samples=10):
    img_list = os.listdir(data_path)
    if len(img_list) > max_samples:
        img_list = img_list[:max_samples]  # Check only a few images
    
    print(f"Checking up to {max_samples} images from {data_path}")
    for i, img_name in enumerate(img_list):
        try:
            img_path = os.path.join(data_path, img_name)
            img = Image.open(img_path)
            img_array = np.array(img)
            print(f"[{i+1}] {img_name}: mode={img.mode}, shape={img_array.shape}, size={img.size}")
        except Exception as e:
            print(f"[{i+1}] {img_name}: Error - {str(e)}")
    
# Run this before processing to check your images
check_image_formats(data_path)



In [None]:
[name_list,feature] = get_feature(data_path, chkpt_dir, device, arch=arch)

In [8]:
#save the feature
df_feature = pd.DataFrame(feature)
df_imgname = pd.DataFrame(name_list)
df_visualization = pd.concat([df_imgname,df_feature], axis=1)
column_name_list = []

for i in range(1024):
    column_name_list.append("feature_{}".format(i))
df_visualization.columns = ["name"] + column_name_list
# df_visualization.to_csv("Feature.csv",index=False)