In [1]:
#!/usr/bin/env python
# coding: utf-8

#NOTE: use paimg9 env
import sys
import os
import numpy as np
import openslide
from fastai.vision.all import *
import PIL
matplotlib.use('Agg')
import pandas as pd
import warnings
sys.path.insert(0, '../Utils/')
from Preprocessing import preprocess_mutation_data, preprocess_site_data
from Utils import generate_deepzoom_tiles
from Utils import create_dir_if_not_exists
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torchvision import transforms
import ResNet as ResNet
from torch.utils.data import Dataset
import time

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
trnsfrms_val = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean = mean, std = std)
    ]
)


class get_tile_representation(Dataset):
    def __init__(self, tile_info, deepzoom_tiles, tile_levels, pretrain_model):
        super().__init__()
        self.transform = trnsfrms_val
        self.tile_info = tile_info
        self.deepzoom_tiles = deepzoom_tiles
        self.tile_levels = tile_levels
        self.mag_extract = list(set(tile_info['MAG_EXTRACT']))[0]
        self.save_image_size = list(set(tile_info['SAVE_IMAGE_SIZE']))[0]
        self.pretrain_model = pretrain_model

    def __getitem__(self, idx):
        #Get x, y index
        tile_ind = self.tile_info['TILE_XY_INDEXES'].iloc[idx].strip("()").split(", ")
        x ,y = int(tile_ind[0]) , int(tile_ind[1])

        #Pull tiles
        tile_pull = self.deepzoom_tiles.get_tile(self.tile_levels.index(self.mag_extract), (x, y))
        tile_pull = tile_pull.resize(size=(self.save_image_size, self.save_image_size),resample=PIL.Image.LANCZOS) #resize

        #Get features
        tile_pull_trns = self.transform(tile_pull)
        tile_pull_trns = tile_pull_trns.unsqueeze(0)  # Adds a dimension at the 0th index

        #use model to get feature
        self.pretrain_model.eval()
        with torch.no_grad():
            features = self.pretrain_model(tile_pull_trns)
            features = features.cpu().numpy()

        return tile_pull,features

##################
###### DIR  ######
##################
proj_dir = '/fh/scratch/delete90/etzioni_r/lucas_l/michael_project/mutation_pred/'
wsi_path = proj_dir + '/data/OPX/'
label_path = proj_dir + 'data/MutationCalls/'
model_path = proj_dir + 'models/feature_extraction_models/'
tile_path = proj_dir + 'intermediate_data/cancer_prediction_results110224/IMSIZE250_OL0/'
ft_ids_path =  proj_dir + 'intermediate_data/cd_finetune/cancer_detection_training/' #the ID used for fine-tuning cancer detection model, needs to be excluded from mutation study
pretrain_model_name = 'retccl'


##################
#Select GPU
##################
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')





############################################################################################################
#Select IDS
############################################################################################################

#All available IDs
opx_ids = [x.replace('.tif','') for x in os.listdir(wsi_path)] #207
opx_ids.sort()

#Get IDs that are in FT train or already processed to exclude 
ft_ids_df = pd.read_csv(ft_ids_path + 'all_tumor_fraction_info.csv')
ft_train_ids = list(ft_ids_df.loc[ft_ids_df['Train_OR_Test'] == 'Train','sample_id'])

#OPX_182 â€“Exclude Possible Colon AdenoCa 

toexclude_ids = ft_train_ids + ['OPX_182']  #25


#Exclude ids in ft_train or processed
selected_ids = [x for x in opx_ids if x not in toexclude_ids] #199

################################################
#Load mutation label data
################################################
label_df = pd.read_excel(label_path + "OPX_FH_original.xlsx")
label_df = preprocess_mutation_data(label_df)
label_df = label_df.loc[label_df['SAMPLE_ID'].isin(selected_ids)] #filter IDs


################################################
#Load Site data
################################################
site_df = pd.read_excel(label_path + "OPX_anatomic sites.xlsx")
site_df = preprocess_site_data(site_df)
site_df = site_df.loc[site_df['SAMPLE_ID'].isin(selected_ids)] #filter IDs


############################################################################################################
#Load tile info for selected_ids
############################################################################################################
tile_info_list = []
for cur_id in selected_ids:
    cur_tile_info_df = pd.read_csv(os.path.join(tile_path,cur_id,cur_id + "_tiles.csv"))
    tile_info_list.append(cur_tile_info_df)
all_tile_info_df = pd.concat(tile_info_list)
print(len(set(all_tile_info_df['SAMPLE_ID']))) #199

print(all_tile_info_df.shape) #3375102 tiles in total

#Print stats
tile_counts = all_tile_info_df['SAMPLE_ID'].value_counts()
print("Max # tile/per pt:", tile_counts.max())
print("Min # tile/per pt:", tile_counts.min())
print("Median # tile/per pt:", tile_counts.median())


############################################################################################################
#Combine all info
############################################################################################################
all_comb_df = all_tile_info_df.merge(label_df, on = ['SAMPLE_ID'])
all_comb_df = all_comb_df.merge(site_df, on = ['SAMPLE_ID'])


mag_extract = list(set(all_comb_df['MAG_EXTRACT']))[0]
save_image_size = list(set(all_comb_df['SAVE_IMAGE_SIZE']))[0]
pixel_overlap = list(set(all_comb_df['PIXEL_OVERLAP']))[0]
limit_bounds =   list(set(all_comb_df['LIMIT_BOUNDS']))[0]


############################################################################################################
# Load Pretrained representation model
############################################################################################################
model = ResNet.resnet50(num_classes=128,mlp=False, two_branch=False, normlinear=True)
pretext_model = torch.load(model_path + 'best_ckpt.pth',map_location=torch.device(device))
model.fc = nn.Identity()
model.load_state_dict(pretext_model, strict=True)


selected_ids= ['OPX_005']
############################################################################################################
#For each patient tile, get representation
############################################################################################################
ct = 0 
for cur_id in selected_ids:
    print(cur_id)

    if ct % 10 == 0: print(ct)

    #Load slide
    _file = wsi_path + cur_id + ".tif"
    oslide = openslide.OpenSlide(_file)
    save_name = str(Path(os.path.basename(_file)).with_suffix(''))
    
    #Generate tiles
    tiles, tile_lvls, physSize, base_mag = generate_deepzoom_tiles(oslide,save_image_size, pixel_overlap, limit_bounds)
    
    #Get tile info
    comb_df = all_comb_df.loc[all_comb_df['SAMPLE_ID'] == cur_id]
    
    
    #Grab tile 
    tile_img = get_tile_representation(comb_df, tiles, tile_lvls, model)
    
    #Get feature
    start_time = time.time()
    feature_list = [tile_img[i][1] for i in range(comb_df.shape[0])]
    print("--- %s seconds ---" % (time.time() - start_time))
    
    feature_df = np.concatenate(feature_list)
    feature_df = pd.DataFrame(feature_df)
    
    
    save_location = tile_path + cur_id + '/' + 'features/'
    create_dir_if_not_exists(save_location)
    save_name = save_location + 'features_alltiles_nonoverlap' + pretrain_model_name + '.h5'
    feature_df.to_hdf(save_name, key='feature', mode='w')
    comb_df.to_hdf(save_name, key='tile_info', mode='a')

    ct += 1



199
(1215198, 9)
Max # tile/per pt: 34689
Min # tile/per pt: 43
Median # tile/per pt: 1569.0
OPX_005
0
--- 150.22415018081665 seconds ---
Directory '/fh/scratch/delete90/etzioni_r/lucas_l/michael_project/mutation_pred/intermediate_data/cancer_prediction_results110224/IMSIZE250_OL0/OPX_005/features/' created.
