In [1]:
import os
import pandas as pd
import numpy as np
import glob
import torch
import torchvision.models as models
from PIL import Image
import torchvision.transforms as transforms
import math


In [2]:
LABEL_TO_INDEX = {'0': 0, '3': 1, '4': 2, '5': 3}
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64

# some regions are duplicates and should not need be included
DUPLICATES = [
    '16B0001851_Block_Region_3',
    '16B0003388_Block_Region_5',
    '16B0003394_Block_Region_1',
    '16B0022608_Block_Region_2',
    '16B0022786_Block_Region_0',
    '16B0023614_Block_Region_3',
    '16B0026792_Block_Region_3',
    '16B0027040_Block_Region_10',
    '18B0005478J_Block_Region_13',
    '18B0005478J_Block_Region_10'
]

# 定义图像预处理步骤
PREPROCESS = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [3]:
def extract_feature(model, paths, type='mobileNet'):
    feature_arr = []
    
    length = math.ceil(len(paths)/BATCH_SIZE)
    for i in range(length):
        batchArr = []
        end = min(len(paths), i*BATCH_SIZE + BATCH_SIZE)
        for path in paths[i*BATCH_SIZE:end]:
            region = '_'.join(path.split('/')[-1].split('_')[0:4])
            if (region) in DUPLICATES:
                print('.................DUPLICATES...................', path)
                # continue

            # 加载图像并进行预处理
            img = Image.open(path)
            img_tensor = PREPROCESS(img)

            # 添加批次维度，并移动到设备
            batchArr.append(img_tensor.unsqueeze(0).to(DEVICE))
            
        input_batch = torch.cat(batchArr, dim=0)
        
        feature = model(input_batch)
        feature_arr.append(feature.detach())
    
    features = torch.cat(feature_arr, dim=0)
    # print(f"{type} features shape: {features.shape}")
    
    return features

In [4]:
def extract_mobileNet_feature(paths):
    # 创建一个新的模型，只到全局平均池化层
    # features shape: torch.Size([2, 3, 224, 224])
    layer_model = models.mobilenet_v2()
    # layer_model = models.mobilenet_v3_large() # features shape: torch.Size([2, 960, 7, 7])
    # layer_model = models.mobilenet_v3_small() # features shape: torch.Size([2, 576, 7, 7])
    layer_model = torch.nn.Sequential(*list(layer_model.children())[:-1])
    layer_model.eval()
    if torch.cuda.is_available():
        layer_model.to(DEVICE)
    
    return extract_feature(layer_model, paths)

In [5]:
def extract_ViT_feature(paths):
    # 创建一个新的模型，只到全局平均池化层
    layer_model = models.vit_l_32() 
    # layer_model = models.maxvit_t() # torch.Size([2, 64, 112, 112])
    # layer_model = models.vit_b_16() # torch.Size([2, 768, 14, 14])
    # layer_model = models.vit_b_32() # torch.Size([2, 768, 7, 7])
    # layer_model = models.vit_h_14() # torch.Size([2, 1280, 16, 16])
    # layer_model = models.vit_l_16() # torch.Size([2, 1024, 14, 14])
    # layer_model = models.vit_l_32() # torch.Size([2, 1024, 7, 7])
    layer_model = torch.nn.Sequential(*list(layer_model.children())[:-2])
    layer_model.eval()
    if torch.cuda.is_available():
        layer_model.to(DEVICE)
        
    return extract_feature(layer_model, paths, type="ViT")

In [6]:
def main():
    os.makedirs('dataset/features', exist_ok=True)

    ALL_PATH = glob.glob(os.path.join('dataset', 'SICAPv2', 'images', '*.jpg'))
    IMAGE_LABELS = pd.read_csv('dataset/image_labels.csv')
    for i, row in IMAGE_LABELS.iterrows():
        name = row.iloc[0]
        gleason_grade = row.iloc[3].split('+')
        label = np.zeros(4, dtype=int)
        label[LABEL_TO_INDEX[gleason_grade[0]]] = 1
        label[LABEL_TO_INDEX[gleason_grade[1]]] = 1

        paths = [f for f in ALL_PATH if name in f and '_'.join(f.split('/')[-1].split('_')[0:4]) not in DUPLICATES]
        m_features = extract_mobileNet_feature(paths)
        v_features = extract_ViT_feature(paths)

        print(i, name, label, m_features.shape, v_features.shape)
        data = {
            'name': name,
            'label': label,
            'm_features': m_features,
            'v_features': v_features
        }
        torch.save(data, f'dataset/features/{name}.pth')

In [7]:
main()

0 16B0001851 [0 0 1 1] torch.Size([44, 1280, 7, 7]) torch.Size([44, 1024, 7, 7])
1 16B0003388 [0 0 1 0] torch.Size([54, 1280, 7, 7]) torch.Size([54, 1024, 7, 7])
2 16B0003394 [0 1 0 0] torch.Size([55, 1280, 7, 7]) torch.Size([55, 1024, 7, 7])
3 16B0006668 [0 0 0 1] torch.Size([122, 1280, 7, 7]) torch.Size([122, 1024, 7, 7])
4 16B0006669 [0 0 0 1] torch.Size([60, 1280, 7, 7]) torch.Size([60, 1024, 7, 7])
5 16B0006694 [0 0 1 0] torch.Size([141, 1280, 7, 7]) torch.Size([141, 1024, 7, 7])
6 16B0006695 [0 0 1 1] torch.Size([165, 1280, 7, 7]) torch.Size([165, 1024, 7, 7])
7 16B0008045 [0 1 0 0] torch.Size([55, 1280, 7, 7]) torch.Size([55, 1024, 7, 7])
8 16B0008067 [0 0 1 1] torch.Size([143, 1280, 7, 7]) torch.Size([143, 1024, 7, 7])
9 16B0022608 [0 1 1 0] torch.Size([64, 1280, 7, 7]) torch.Size([64, 1024, 7, 7])
10 16B0022610 [0 1 0 0] torch.Size([86, 1280, 7, 7]) torch.Size([86, 1024, 7, 7])
11 16B0022612 [0 1 1 0] torch.Size([182, 1280, 7, 7]) torch.Size([182, 1024, 7, 7])
12 16B0022613 [0

In [8]:
myFeature = torch.load('dataset/features/16B0003388.pth', weights_only=False)
print(myFeature['name'], myFeature['label'],
      myFeature['m_features'].shape, 
      myFeature['v_features'].shape
      )

16B0003388 [0 0 1 0] torch.Size([54, 1280, 7, 7]) torch.Size([54, 1024, 7, 7])
