In [13]:
import os
import numpy as np

from skimage.transform import resize
from scipy.ndimage import rotate
from skimage.color import gray2rgb
import torch
import tensorflow_datasets as tfds
from segment_anything import sam_model_registry
from tqdm import tqdm
import pandas as pd
import h5py
import argparse

from visualization_utils import (crop_image,
                                 extract_coords,
                                 extract_roi,
                                 visualize_features,
                                 hu_to_rgb_vectorized)

from tfds_dense_descriptor import (prepare_image, 
                                   load_model)

In [8]:
ls '../../../Data/PET-CT'


[0m[01;34mdata[0m/  [01;34mlung_radiomics[0m/


In [10]:
model_name = "medsam"
model_path = "../../PET-CT/medsam_vit_b.pth"
dataset_path = os.path.join('../../../Data/PET-CT', 'lung_radiomics')
ds_path = os.path.join('../../../Data/PET-CT', 'lung_radiomics', 'lung_radiomics_datasets.hdf5')
df_metdata_path = os.path.join('../../../Data/PET-CT', 'lung_radiomics', 'lung_radiomics_datasets.csv')
second_modality = 'ct'
use_tfds = ds_path is None
pet_liver=False
gpu_device = 0
torch.cuda.set_device(int(gpu_device))

In [17]:
model = load_model(model_name, model_path)
dataset = ['stanford_dataset']#'santa_maria_dataset', 
modalities = ['pet', second_modality]

In [18]:
df_metadata = pd.read_csv(df_metdata_path)
df_metadata['label'] = (df_metadata['egfr'] == 'Mutant').astype(int)
patient2label = dict(zip(df_metadata['patient_id'], df_metadata['label']))
if second_modality == 'pet':
    df_metadata = df_metadata[np.logical_or(df_metadata['has_petct'], df_metadata['has_petchest'])]
else:
    df_metadata = df_metadata[df_metadata[f'has_{"".join(modalities)}']]
df_metadata.reset_index(inplace=True, drop=True)

In [None]:
if use_tfds:
    if dataset_name == 'stanford_dataset':
        ds_pet, info_pet = tfds.load(f'{dataset_name}/pet', data_dir=dataset_path, with_info=True)
        ds_ct, info_ct = tfds.load(f'{dataset_name}/ct', data_dir=dataset_path, with_info=True)
    else:
        ds_pet, info_pet = tfds.load(f'{dataset_name}/pet', data_dir=dataset_path, with_info=True)
        ds_ct, info_ct = tfds.load(f'{dataset_name}/torax3d', data_dir=dataset_path, with_info=True)

    patient_pet = set(list(ds_pet.keys()))
    patient_ct = set(list(ds_ct.keys()))

    patient_ids = list(patient_ct.intersection(patient_pet))

In [None]:
        if use_tfds:
            if dataset_name == 'stanford_dataset':
                ds_pet, info_pet = tfds.load(f'{dataset_name}/pet', data_dir=dataset_path, with_info=True)
                ds_ct, info_ct = tfds.load(f'{dataset_name}/ct', data_dir=dataset_path, with_info=True)
            else:
                ds_pet, info_pet = tfds.load(f'{dataset_name}/pet', data_dir=dataset_path, with_info=True)
                ds_ct, info_ct = tfds.load(f'{dataset_name}/torax3d', data_dir=dataset_path, with_info=True)

            patient_pet = set(list(ds_pet.keys()))
            patient_ct = set(list(ds_ct.keys()))

            patient_ids = list(patient_ct.intersection(patient_pet))
        else:
            dataset_name_sort = dataset_name.replace('_dataset', '')
            patient_ids = list(df_metadata[df_metadata['dataset'] == dataset_name_sort]['patient_id'].unique())

        for patient_id in tqdm(patient_ids, desc=dataset_name):
            for modality in [second_modality]:
                df_path = os.path.join(features_dir, f'{patient_id}_{modality}.parquet')
                features_file = os.path.join(feature_folder, f'features_masks_{modality}.hdf5')
                if not os.path.exists(df_path):
                    if use_tfds:
                        if modality == 'pet':
                            img_raw, mask_raw, label, spatial_res = tfds2voxels(ds_pet, patient_id, pet=True)
                        else:
                            img_raw, mask_raw, label, spatial_res = tfds2voxels(ds_ct, patient_id)

                        label = label[0]
                        if label not in [0, 1]:  # ignore unknown (2) and not collected (3) labels
                            print(f'\nWarning: skip {patient_id} with label {label}')
                        else:
                            nodule_pixels = mask_raw.sum(axis=(0, 1)).round(2)
                            if not nodule_pixels.max():
                                print(f'\nWarning: {patient_id} has empty mask')

                            # normalize pixel values
                            if modality == 'ct':
                                img_raw = apply_window_ct(img_raw, width=800, level=40)
                            else:
                                img_raw = (img_raw - img_mean)/ img_std
                    else:
                        label = patient2label[patient_id]
                        img_raw, mask_raw, _, spatial_res = get_voxels(ds_path, patient_id, modality,pet_liver)
                        
                        # normalize pixel values

                        ##COMENTAR
                        if modality == 'pet':
                            if patient_id[:2]=="sm":
                                info_name="sm"
                            else:
                                info_name=patient_id[:3]
                            info_dataset=preprocess_info[preprocess_info["dataset"]==info_name]
                            img_mean=float(info_dataset["mean"])
                            img_std=float(info_dataset["std"])
                            img_raw = (img_raw - img_mean)/ img_std
                        else:
                            img_raw = apply_window_ct(img_raw, width=1800, level=40)
                        #img_raw, mask_raw, spatial_res = get_voxels(ds_path, patient_id, modality)

                        # extract patch features of each slice
                        df = {'slice': [],
                              'angle': [],
                              'flip': []}

                        all_features = []
                        all_masks = []
                        angles = []
                        flips = []
                        slices = []
                        # apply flip and rotation to use them as offline data augmentation
                        for flip_type in [None, 'horizontal', 'vertical']:
                            image_flip, mask_flip = flip_image(img_raw, mask_raw, flip_type)
                            for angle in [0, 90]:#range(0, 180, 45):
                                image, mask = rotate_image(image_flip, mask_flip, angle)
                                features, features_mask = generate_features(model=model,
                                                                            img_3d=image,
                                                                            mask_3d=mask,
                                                                            tqdm_text=f'{modality} {patient_id}',
                                                                            display=False)

                                all_masks += features_mask
                                all_features += features

                                df['angle'] += [angle] * len(features)
                                df['flip'] += [flip_type] * len(features)
                                df['slice'] += list(range(0, len(features)))

                        # store metadata of each featuremap in a dataframe
                        df = pd.DataFrame(df)
                        df.reset_index(drop=False, inplace=True)
                        df = df.rename(columns={'index': 'feature_id'})
                        df['patient_id'] = patient_id
                        df['label'] = label
                        df['dataset'] = dataset_name.replace('_dataset', '')
                        df['modality'] = modality
                        df['augmentation'] = np.logical_not(np.logical_and(df['flip'] is None,  df['angle'] == 0))
                        df['spatial_res'] = [spatial_res] * df.shape[0]
                        df.to_parquet(df_path)
                        save_features(features_file, all_features, all_masks, patient_id)