# Running KimiaNet

Moving the `KimiaNet_Keras_Feature...` script to a notebook to make my life easier!

## Setup

First, we want to set up some variables and define functions (most of this is from the original code)

In [6]:
prefix = "combo"

In [7]:
# config variables 
patch_dir = "./patches/" + prefix + "/"
extracted_features_save_adr = "./extracted_features_" + prefix + ".pickle"
network_weights_address = "./weights/KimiaNetKerasWeights.h5"
network_input_patch_width = 250
batch_size = 30
img_format = 'jpg'
use_gpu = False


In [8]:
# importing libraries
import os

if use_gpu:
    os.environ['NVIDIA_VISIBLE_DEVICES'] = '0'
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
else:
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


import tensorflow as tf
from tensorflow.keras.applications import DenseNet121
from tensorflow.keras import Model, Sequential
from tensorflow.keras.layers import GlobalAveragePooling2D, Lambda
# from tensorflow.keras.applications.densenet import preprocess_input
from tensorflow.keras.backend import bias_add, constant    

import glob, pickle, skimage.io, pathlib
import numpy as np
import pandas as pd
from tqdm import tqdm

In [4]:
# defining functions
# feature extractor preprocessing function
def preprocessing_fn(input_batch, network_input_patch_width):

    org_input_size = tf.shape(input_batch)[1]
    
    # standardization
    scaled_input_batch = tf.cast(input_batch, 'float') / 255.
    
    
    # resizing the patches if necessary
    resized_input_batch = tf.cond(tf.equal(org_input_size, network_input_patch_width),
                                lambda: scaled_input_batch, 
                                lambda: tf.image.resize(scaled_input_batch, 
                                                        (network_input_patch_width, network_input_patch_width)))
    
    
    # normalization, this is equal to tf.keras.applications.densenet.preprocess_input()---------------
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    data_format = "channels_last"
    mean_tensor = constant(-np.array(mean))
    standardized_input_batch = bias_add(resized_input_batch, mean_tensor, data_format)
    standardized_input_batch /= std
    #=================================================================================================
    
    return standardized_input_batch

# feature extractor initialization function
def kimianet_feature_extractor(network_input_patch_width, weights_address):
    
    dnx = DenseNet121(include_top=False, weights=weights_address, 
                      input_shape=(network_input_patch_width, network_input_patch_width, 3), pooling='avg')

    kn_feature_extractor = Model(inputs=dnx.input, outputs=GlobalAveragePooling2D()(dnx.layers[-3].output))
    
    kn_feature_extractor_seq = Sequential([Lambda(preprocessing_fn, 
                                                  arguments={'network_input_patch_width': network_input_patch_width}, 
                                   input_shape=(None, None, 3), dtype=tf.uint8)])
    
    kn_feature_extractor_seq.add(kn_feature_extractor)
    
    return kn_feature_extractor_seq

# feature extraction function
def extract_features(patch_dir, extracted_features_save_adr, network_weights_address, 
                     network_input_patch_width, batch_size, img_format):
        
    feature_extractor = kimianet_feature_extractor(network_input_patch_width, network_weights_address)
    
    patch_adr_list = [pathlib.Path(x) for x in glob.glob(patch_dir+'*.'+img_format)]
    feature_dict = {}

    for batch_st_ind in tqdm(range(0, len(patch_adr_list), batch_size)):
        batch_end_ind = min(batch_st_ind+batch_size, len(patch_adr_list))
        batch_patch_adr_list = patch_adr_list[batch_st_ind:batch_end_ind]
        patch_batch = np.array([skimage.io.imread(x) for x in batch_patch_adr_list])
        batch_features = feature_extractor.predict(patch_batch)
        feature_dict.update(dict(zip([x.stem for x in batch_patch_adr_list], list(batch_features))))
        
        with open(extracted_features_save_adr, 'wb') as output_file:
            pickle.dump(feature_dict, output_file, pickle.HIGHEST_PROTOCOL)

## Exporting patches to use for KimiaNet input
First, we read in the image we want to segment. We also read in the count information output from the Seurat file so we can exclude patches that don't include any transcript info

In [5]:
from PIL import Image

In [39]:
# read in the png image associated with our dataset
fname = "./pngs/" + prefix + ".png"
full_image = skimage.io.imread(fname)

In [40]:
# read in the count info 
count_data = pd.read_csv('./CountsForSquares_' + prefix + '.txt', sep="\t")

count_data.shape

(529, 36601)

In [41]:
# subset only patches which have at least one transcript (excludes white space, areas not covered by ST)
nonzero = count_data.loc[(count_data.sum(axis=1) != 0),]

nonzero.shape

(296, 36601)

In [42]:
# acquire patch names
nonzero_patches = nonzero.index.values

In [43]:
output_patch_size = 250

# iterate over the image's x/y coordinates
for x in range(0,full_image.shape[1], output_patch_size):
    for y in range(0,full_image.shape[0], output_patch_size):
        
        # formulate the patch name based on x/y coords
        patch_name = '(' + str(x) + ',' + str(y) +')_(' + str(x+output_patch_size) + ',' + str(y+output_patch_size) + ')'
        
        # we only want to output the patch if it has count info!
        if (patch_name in nonzero_patches):
        
            output_name='patches/' + prefix + '/patch_' + patch_name + '.jpg'

            if os.path.exists(output_name):
                continue

            if x+output_patch_size >= full_image.shape[1] or y+output_patch_size >= full_image.shape[0]:
                continue

            patch = full_image[x:(x+output_patch_size), y:(y+output_patch_size)]

            pil_image=Image.fromarray(patch).convert('RGB')
            pil_image.save(output_name)



## Feature Extraction
Feature extraction is super simple - only one command!

In [9]:
extract_features(patch_dir, extracted_features_save_adr, network_weights_address, 
                 network_input_patch_width, batch_size, img_format)

100%|██████████| 20/20 [00:53<00:00,  2.69s/it]


## Convert feature info to a more readable file format
Pickle is specific to Python, and I'm doing downstream stuff in R sooo let's convert this

In [10]:
import os
import pickle

In [11]:
objects = []
with (open("./extracted_features_" + prefix + ".pickle", "rb")) as openfile:
    while True:
        try:
            objects.append(pickle.load(openfile))
        except EOFError:
            break

In [12]:
df = pd.DataFrame(objects[0])

In [13]:
df.transpose().to_csv("./extracted_features_" + prefix + ".csv")