In [1]:
import argparse
import os
import math
import slideio
import cv2
import json
import numpy as np
import torch
import torch.nn as nn
import numpy as np
import time
import glob
from torchvision import transforms
from torchvision.models import resnet50, ResNet50_Weights, resnet34, ResNet34_Weights, inception_v3, Inception_V3_Weights
from torchsummary import summary
from random import random
from tqdm import tqdm
from matplotlib import pyplot as plt

from torchvision import models
import torch.optim as optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import zipfile
from pathlib import Path
from s3fs import S3FileSystem
from concurrent.futures import ThreadPoolExecutor
import boto3
from botocore.client import ClientError
from tqdm import tqdm

import sys
sys.path.append('utils')
from stainlib.augmentation.augmenter import StainAugmentor
from stainlib.augmentation.augmenter import HedLighterColorAugmenter
from stainlib.augmentation.augmenter import GrayscaleAugmentor
from stainlib.utils.plot_utils import _plot_imagegrid
from S3FileManager import S3FileManager, S3UploadSync

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [2]:
def create_folder(path):
    if not os.path.isdir(path):
        os.makedirs(path)

def fetch_processed_slides(path):
    return {x.replace('.pt', '').split('--')[0] for x in os.listdir(path) if '.pt' in x}

def fetch_avg_saturation(image):
    blurred_image = cv2.GaussianBlur(image, (5,5), 0)
    hsv_image = cv2.cvtColor(blurred_image, cv2.COLOR_RGB2HSV)
    return np.mean(hsv_image[:,:,1])

def get_file_size(file_path):
    file_stats = os.stat(file_path)
    return file_stats.st_size/(10**9)

def num_estimated_patches(scene, stride):
    width = scene.size[0]
    height = scene.size[1]
    num_expected_patches = (math.ceil((height - stride) / stride)) * (math.ceil((width - stride) / stride))
    return num_expected_patches

class Flatten(torch.nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)

def load_feature_extractor_inception_v3():
    inception = inception_v3(weights=Inception_V3_Weights.IMAGENET1K_V1)
    inception.eval()
    # Remove the last two layers and replace the AdaptiveAvgPool2d layer with a Flatten layer
    feature_extractor = torch.nn.Sequential(
        *list(inception.children())[:-2],
        Flatten()
    )
    feature_extractor = feature_extractor.to(device)
    return feature_extractor

def load_feature_extractor_resnet50():
    resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
    resnet.eval()
    feature_extractor = torch.nn.Sequential(*list(resnet.children())[:-1])
    feature_extractor = feature_extractor.to(device)
    return feature_extractor

def correct_img(patches):  # Expects the patches to come in the cropped size, 299
    preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    t = torch.stack([preprocess(torch.from_numpy(img).permute(2, 0, 1).float() / 255.0) for img in patches])
    t = t.to(device)
    return t

def get_patch_for_size(slide_handler, size):
    scene_0 = slide_handler.get_scene(0)
    resolutions = scene_0.resolution
    if resolutions == (0, 0):
        # Resolução esperada nas imagens com problema em microns per pixel
        resolution_microns = 0.467
        # Convert to meters per pixel
        resolution_meters = resolution_microns * 1e-6
        resolutions = (resolution_meters, resolution_meters)

    patch_size_x = int(size[0]/resolutions[0])
    patch_size_y = int(size[1]/resolutions[1])
    return (patch_size_x, patch_size_y)

def tensor_to_numpy_image(tensor):
    # Move tensor to CPU
    tensor = tensor.to('cpu')

    # Convert tensor to NumPy array
    numpy_img = tensor.numpy()

    # Transpose channels back to original order
    numpy_img = np.moveaxis(numpy_img, [0, 1, 2], [2, 0, 1])

    # Denormalize pixel values
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    numpy_img = (numpy_img * std + mean) * 255

    cv2.imwrite("tensor.png", numpy_img)

In [3]:
def correct_img2(img):
    preprocess = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    
    img = np.array(img)
    
    t = preprocess(torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0)
    t = t.to(device)
    
    return t

class CustomDataset(Dataset):
        def __init__(self, all_images, transform=None):
            self.all_images = all_images
            self.transform = transform

        def __len__(self):
            return len(self.all_images)

        def __getitem__(self, idx):
            img = self.all_images[idx]
            if self.transform:
                img = self.transform(img)

            return img

In [4]:
def patch_generator_with_precision_scale(image_path, batch_size, size_to_fetch=(0.0003861256, 0.0003861256),
                                         final_size=224, min_saturation=10, aug_type=0):
    slide_handler = slideio.open_slide(image_path, 'SVS')
    scene = slide_handler.get_scene(0)
    actual_size = scene.size
    actual_width = actual_size[0]
    actual_height = actual_size[1]
    patch_size = get_patch_for_size(slide_handler, size_to_fetch)
    patch_x = patch_size[0]
    patch_y = patch_size[1]
    width_to_use = 10*actual_width//patch_x
    smaller_image = scene.read_block(size=(width_to_use,0))
    copied_image = scene.read_block(size=(width_to_use,0))
    smaller_width = np.shape(smaller_image)[1]
    smaller_height = np.shape(smaller_image)[0]
    smaller_patch_size_x = patch_x*smaller_width//actual_width
    smaller_patch_size_y = patch_y*smaller_height//actual_height

    positions = []

    hed_lighter_aug = HedLighterColorAugmenter()
    hed_lighter_aug.randomize()
    grayscale_aug = GrayscaleAugmentor()

    avg_saturation_is_ok = False
    while not avg_saturation_is_ok:
        positions = []
        all_saturations = []
        current_x = 0
        current_y = 0
        while current_y <= actual_height:
            while current_x <= actual_width:
                small_current_y = (current_y * smaller_height)//actual_height
                small_current_x = (current_x * smaller_width)//actual_width

                bottom_border = small_current_y + smaller_patch_size_y
                right_border = small_current_x + smaller_patch_size_x

                if bottom_border > smaller_height:
                    bottom_border = smaller_height
                if right_border > smaller_width:
                    right_border = smaller_width

                sub_image = smaller_image[small_current_y:bottom_border,
                                          small_current_x:right_border]
                saturation = fetch_avg_saturation(sub_image)

                if saturation > min_saturation:
                    all_saturations.append(saturation)
                    positions.append((current_x, current_y))
                current_x += patch_x
            current_y += patch_y
            current_x = 0
        avg_saturation = np.mean(all_saturations)
        if avg_saturation < min_saturation * 3:
            min_saturation = min_saturation / 2
        else:
            avg_saturation_is_ok = True

    final_img_batch = []
    normalized_pos_batch = []
    for i, position in enumerate(positions):
        try:
            if len(final_img_batch) == batch_size:
                final_img_batch = []
                normalized_pos_batch = []
            x1 = position[0]
            y1 = position[1]
            width = min(patch_x, actual_width - x1)
            height = min(patch_y, actual_height - y1)
            normalized_position = (x1//patch_x, y1//patch_y)
            normalized_pos_batch.append(normalized_position)

            sub_image = scene.read_block((x1, y1, width, height), (final_size, final_size))
            # aug_type == 0: Sem augmentation
            if aug_type == 1:  # Augmentation bgr
                sub_image = cv2.cvtColor(sub_image, cv2.COLOR_RGB2BGR)
            elif aug_type == 2:  # Augmentation grayscale
                grayscale_aug.fit(sub_image)
                sub_image = grayscale_aug.pop()
            elif aug_type == 3:  # Tudo embaralhado usando Hematoxylin and Eosin intensity
                hed_lighter_aug.randomize()
                sub_image = hed_lighter_aug.transform(sub_image)
            final_img_batch.append(sub_image)
            if len(final_img_batch) == batch_size or i == len(positions) - 1:
                yield normalized_pos_batch, np.array(final_img_batch)
        except BaseException as e:
            print('erro no yield: ' + str(e), image_path)
            pass

In [5]:
def fetch_feature_array(model, patches):
    if device != torch.device('cpu'):
        patches = patches.to(device)
    raw_output = model(patches)
    main_size = raw_output.shape[1]
    out_np = raw_output.cpu().detach()
    return torch.reshape(out_np, [-1, main_size])

In [11]:
def generate_tensor_from_slide(slide_path, aug_type):
    global feat_gen, BATCH_SIZE
    patch_gen = patch_generator_with_precision_scale(slide_path, BATCH_SIZE, aug_type=aug_type)
#     patch_gen = patch_analyzer_with_precision_scale(slide_path, 0.5, BATCH_SIZE, 1, aug_type=aug_type)
    feat_arrays = []
    for _, patch_batch in patch_gen:
        batch_patches_t = correct_img(patch_batch)
        new_feats = fetch_feature_array(feat_gen, batch_patches_t)
        feat_arrays.append(new_feats)
#     patch_gen = patch_analyzer_with_precision_scale(slide_path, 0.1, BATCH_SIZE, 0.5, aug_type=aug_type)
#     for _, patch_batch in patch_gen:
#         batch_patches_t = correct_img(patch_batch)
#         new_feats = fetch_feature_array(feat_gen, batch_patches_t)
#         feat_arrays.append(new_feats)
    final_tensor = torch.vstack(feat_arrays)
    return final_tensor

In [7]:
feat_gen = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
feat_gen.eval()
feat_gen = feat_gen.to(device)
num_ftrs = feat_gen.fc.in_features
feat_gen.fc = nn.Linear(num_ftrs, 3)
feat_gen = torch.nn.Sequential(*list(feat_gen.children())[:-1])
# feat_gen.load_state_dict(torch.load(r'patches\resnet50_98_all_layers2.pth'))


In [8]:
%%time
SIZE_MAX = 0.40 * 1024 * 1024 * 1024 # 400Mb
BATCH_SIZE = 8

def find_svs_files(directory):
    return glob.glob(directory + '/**/*.svs', recursive=True)

folder = 'lusc'
bucket_name = 'oncodata-datasources/tcga/lung/' + folder
directory = '/tmp/train-data'

file_manager = S3FileManager(bucket_name=bucket_name, local_dir=directory)

svs_files = find_svs_files(directory)

input_slides = [{'dataset': folder,
                 'image_path': path} for path in svs_files]

output_folder = '/tmp/preprocessing/resnetmultizoomlusc'
create_folder(output_folder)
processed_images = fetch_processed_slides(output_folder)
#print('PROCESSED IMAGES', processed_images)

#feat_gen = load_feature_extractor_resnet50() #load_feature_extractor_densenet() #load_feature_extractor_resnet34

CPU times: user 1.04 s, sys: 375 ms, total: 1.41 s
Wall time: 2.23 s


In [12]:
!ls /tmp/preprocessing/resnetmultizoomlusc/

TCGA-33-4586-11A-01-TS1.b34fae7a-c25e-494b-a9a9-f97157a133c9--hechaos__lusc.pt
TCGA-33-4586-11A-01-TS1.b34fae7a-c25e-494b-a9a9-f97157a133c9--noaug__lusc.pt


In [10]:
%%time

aug_array = ['noaug', 'bgr', 'gray', 'hechaos']
found = False

for image_data in tqdm(input_slides):
    dataset = image_data['dataset']
    image_path = image_data['image_path']
    image_name = os.path.split(image_path)[-1].replace('.svs', '')

    #print('WILL PROCESS', image_name)
    if image_name in processed_images or file_manager.get_file_size(image_path) > SIZE_MAX:
        # print('Skipping', image_name, file_manager.get_file_size(image_path))
        continue
    else:
        for aug_type in [0, 3]:
            try:
                output_name = '%s--%s__%s.pt' % (image_name, aug_array[aug_type], dataset)
                output_name = os.path.join(output_folder, output_name)
                featureset = generate_tensor_from_slide(image_path, aug_type)
                torch.save(featureset, output_name)
            except BaseException as e:
                print(e, image_name)
                pass
    break

  0%|          | 0/1608 [00:00<?, ?it/s]

checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
checkpoint2
chec

  0%|          | 0/1608 [02:08<?, ?it/s]

checkpoint2
CPU times: user 1min 57s, sys: 1.29 s, total: 1min 59s
Wall time: 2min 8s





In [68]:
# list the number of files inside /tmp/train-data
print(len(os.listdir(directory)))

2


In [20]:
s3syncher = S3UploadSync('oncodata-sagemaker-shared', 'code', 'testing')

In [21]:
s3syncher.sync()

file_path:  model_def.py code/model_def.py testing/model_def.py
file_path:  train_pytorch_smdataparallel_mnist.py code/train_pytorch_smdataparallel_mnist.py testing/train_pytorch_smdataparallel_mnist.py
file_path:  train_dsmil.py code/train_dsmil.py testing/train_dsmil.py
file_path:  create_heatmaps.ipynb code/create_heatmaps.ipynb testing/create_heatmaps.ipynb
file_path:  README.md code/README.md testing/README.md
file_path:  requirements.txt code/requirements.txt testing/requirements.txt
file_path:  roc_stomach.png code/roc_stomach.png testing/roc_stomach.png
file_path:  seed_5708.pt code/seed_5708.pt testing/seed_5708.pt
file_path:  separate_jsons_by_patient.ipynb code/separate_jsons_by_patient.ipynb testing/separate_jsons_by_patient.ipynb
file_path:  stomach_json_test.json code/stomach_json_test.json testing/stomach_json_test.json
file_path:  stomach_json_train.json code/stomach_json_train.json testing/stomach_json_train.json
file_path:  stomach_json.json code/stomach_json.json tes