In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os,sys,time,platform
#import openslide
from PIL import Image
#from tqdm import tqdm
#import tensorflow as tf
#from tensorflow.keras.applications.resnet50 import preprocess_input
import torch
from torch import nn
import torchvision.transforms as transforms
from torchvision.models import resnet50, ResNet50_Weights

#from torchvision import models as torch_models
#from utils_preprocessing import *
#import utils_color_norm
#color_norm = utils_color_norm.macenko_normalizer()

## check available device
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print("device:", device)

device: cpu


In [6]:
batch_size = 64

In [2]:
class ResNet_extractor(nn.Module):
    def __init__(self, layers=101):
        super().__init__()
        
        self.resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        
    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = self.resnet.avgpool(x)
        x = torch.flatten(x, 1)
        return x

In [3]:
#path2tiles = "/Volume/test_10tiles/"
path2tiles = "/Volumes/TaiHoang5T6/LUSC_tiles/LUSC_00000_TCGA-18-3406-01Z-00-DX1/"

## collect selecting tile names within a slide folder
tile_names = []
for f in os.listdir(path2tiles):
    if f.startswith("tile_"):
        tile_names.append(f)

## alphabet sort
tile_names = np.array(sorted(tile_names))
print(tile_names)

n_tiles = len(tile_names)
print("n_tiles:", n_tiles)

['tile_00002_00083_00333_016.png' 'tile_00003_00069_00444_016.png'
 'tile_00003_00070_00445_016.png' ... 'tile_00082_00074_10324_016.png'
 'tile_00082_00075_10325_016.png' 'tile_00082_00076_10326_016.png']
n_tiles: 5256


In [4]:
model = ResNet_extractor().to(device)
model = model.eval()

In [5]:
data_transform = transforms.Compose([transforms.Resize(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                                          std=[0.229, 0.224, 0.225])])


start_time = time.time()
##----------
images = []
for i_tile in range(n_tiles):
    tile_name = tile_names[i_tile]
    image = Image.open(f"{path2tiles}{tile_name}").convert('RGB')

    image = data_transform(image).unsqueeze(0)
    images.append(image)
    
images = torch.cat(images, dim=0)
print("images.shape:", images.shape)

print("time:", (time.time() - start_time))

images.shape: torch.Size([5256, 3, 224, 224])
time: 241.22160601615906


In [7]:
n_tiles_selected = images.shape[0]
print("n_tiles_selected:", n_tiles_selected)

n_tiles_selected: 5256


In [10]:
##----------
start_time = time.time()

features_list = []
for idx_start in range(0, n_tiles_selected, batch_size):
    idx_end = idx_start + min(batch_size, n_tiles_selected - idx_start)

    features = model(images[idx_start:idx_end].to(device))
    
    features_list.append(features.detach().cpu().numpy())

print("len(features_list):", len(features_list))

print("time:", (time.time() - start_time))

len(features_list): 83
time: 514.7612488269806
