## Bootstrapping to determine the necessary set

1. Find a slide that is predicted correctly under our model.

2. Process the slide with random subsets of tiles. Keep N << K if K is the total possible tiles.

3. Track the attention for each subset in-place, i.e. indexed w.r.t. the global set of tiles.

    - We get back a sparse cube of attentions
    
    - Also track the predictions made for each subset
    
4. Take the set of attention images that yield a correct prediction

5. Intersect the tiles in the correctly predicted set. The intersection is the set of tiles consistently necessary to give a correct (or incorrect) result.

In [1]:
import sys
print(sys.executable)

import tensorflow as tf

from svs_reader import Slide, reinhard
from attimg import draw_attention
import numpy as np
import shutil
import cv2
import os

from milk.eager import MilkEager
from milk.encoder_config import get_encoder_args

tf.enable_eager_execution()

/home/ing/miniconda3/envs/milk/bin/python


In [2]:
%matplotlib inline
import matplotlib.pyplot as plt

In [3]:
def subset_indices(indices, n=100):
    subset = np.random.choice(indices, n, replace=False)
    return subset

In [4]:
def process_slide(svs, model, n=100):
    idx_subset = subset_indices(np.arange(len(svs.tile_list)), n=n)

    batches = 0
    zs = []
    indices = []
    for k,idx_ in enumerate(idx_subset):
        coords = svs.tile_list[idx_]
        img = svs._read_tile(coords)
        img = tf.constant(np.expand_dims(img, 0))
        z = model.encode_bag(img, return_z=True)
        zs.append(z)
        indices.append(idx_)

    zs = tf.concat(zs, axis=0)
    indices = np.array(indices)
    z_att, att = model.mil_attention(zs, verbose=False, return_att=True)
    att = np.squeeze(att)

    yhat = model.apply_classifier(z_att, verbose=False)
    print('yhat:', yhat)

    return yhat, att, indices

In [5]:
snapshot = '../../experiment/wide_model_pretrained/save/2019_03_28_19_01_12.h5'
encoder_args = get_encoder_args('wide')
model = MilkEager(encoder_args = encoder_args,
                  mil_type = 'attention',
                  batch_size = 32,
                  temperature = 0.5,
                  deep_classifier = True)
xpl = np.zeros((1, 1, 96, 96, 3), dtype=np.float32)
yhat = model(tf.constant(xpl), verbose=True)

model.load_weights(snapshot)

Instantiating a DenseNet with settings:
	depth_of_model           : 48
	growth_rate              : 64
	num_of_blocks            : 4
	num_layers_in_each_block : 12
	dropout_rate             : 0.3
	mcdropout                : False
	pool_initial             : True
Setting up classifier normalizing layers
(1, 1, 96, 96, 3)
Encoder Call:
n_x:  1
x_bag: (1, 1, 96, 96, 3)
	 z:  (1, 1456)
	z bag: (1, 1456)
attention: (1, 256)
attention: (1, 1)
attention: (1, 1)
features - attention: (1, 1456)
z: (1, 1456)
z_batch:  (1, 1456)
Classifier layer 0
Classifier layer 1
Classifier layer 2
Classifier layer 3
Classifier layer 4
returning (1, 2)


In [6]:
# slide_src = '/mnt/linux-data/va-pnbx/0a992a117147f8103c4be76c7b2b5155.svs'
slide_src = '/mnt/linux-data/va-pnbx/0a4441e55db6987df0844b3df8c08551.svs'
os.path.exists(slide_src)
ramdisk_path = '/dev/shm/tmp_svs.svs'
shutil.copyfile(slide_src, ramdisk_path)

# fgimg_path = '../../usable_area/inference/0a992a117147f8103c4be76c7b2b5155_fg.png'
fgimg_path = '../../usable_area/inference/0a4441e55db6987df0844b3df8c08551_fg.png'
fgimg = cv2.imread(fgimg_path, 0)

svs = Slide(slide_path = ramdisk_path,
           background_speed = 'image',
           background_image = fgimg,
           preprocess_fn = lambda x: (reinhard(x)/255.).astype(np.float32),
           process_mag = 5,
           process_size = 96,
           oversample_factor = 2,
           verbose = False)

Checking tile read function
Passed read check


In [87]:
yhats, atts, indices = [], [], [] 
for k in range(250):
    yhat, att, idx = process_slide(svs, model, n=50)
    yhats.append(yhat)
    atts.append(att)
    indices.append(idx)

yhat: tf.Tensor([[0.17028718 0.8297128 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.23250484 0.7674951 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.28131187 0.7186881 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.24541694 0.75458306]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.19210753 0.8078925 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.21157071 0.78842926]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.1578838  0.84211624]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.1316773  0.86832273]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.26447752 0.73552245]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.1295769  0.87042314]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.29458883 0.7054112 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.23809469 0.7619053 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.10467838 0.89532155]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.28197914 0.71802086]], shape=(1, 2), dtype=f

yhat: tf.Tensor([[0.16939336 0.83060664]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.15621272 0.8437873 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.46549246 0.5345076 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.20434041 0.7956596 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.26022878 0.73977125]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.0697713 0.9302287]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.11939174 0.88060826]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.21955636 0.78044367]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.24522904 0.75477093]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.16591577 0.8340843 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.16589236 0.83410764]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.20401679 0.79598325]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.22418794 0.77581203]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.18792327 0.81207675]], shape=(1, 2), dtype=flo

yhat: tf.Tensor([[0.088067 0.911933]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.15784788 0.8421521 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.23546879 0.76453125]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.14299963 0.85700035]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.18699539 0.81300455]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.20724873 0.7927512 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.4138466 0.5861534]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.14546458 0.85453546]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.23039101 0.769609  ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.2470204  0.75297964]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.17349468 0.8265053 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.36237586 0.6376242 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.16137329 0.8386267 ]], shape=(1, 2), dtype=float32)
yhat: tf.Tensor([[0.18257123 0.8174288 ]], shape=(1, 2), dtype=float32

In [94]:
# attvects = np.zeros(len(svs.tile_list))
attvects = []
timespicked = []
for k in range(250):
    if yhats[k].numpy()[0,1] > 0.85:
        attvect = np.zeros(len(svs.tile_list))
        attvect[indices[k]] = atts[k]
        attvect[attvect == 0] = np.nan
        attvects.append(attvect)
        timespicked.append(np.logical_not(np.isnan(attvect)))

In [95]:
print(len(timespicked))
total_timespicked = np.sum(timespicked, axis=0)
print(total_timespicked.min(), total_timespicked.max())

55
0 5


In [97]:
import cv2
attmean = np.nanmean(attvects, axis=0)
attmean[np.isnan(attmean)] = 0
cutoff = np.quantile(attmean[attmean > 0], 0.95)
for idx in np.argwhere(attmean > cutoff):
    img = svs._read_tile(svs.tile_list[idx[0]])
    cv2.imwrite('high-att/{:3.3f}_{}_{}.jpg'.format(attmean[idx[0]],idx[0],total_timespicked[idx][0]), 
                img[:,:,::-1] * 255)

  


In [91]:
attimgs = []
for k in range(50):
    if yhats[k].numpy()[0,1] > 0.9:
        output_name = 'aggr{}'.format(k)
        svs.initialize_output(name=output_name, dim=1, mode='tile')
        svs.place_batch(atts[k], indices[k], output_name, mode='tile')
        attimg = np.squeeze(svs.output_imgs[output_name])
        attimgs.append(attimg)

In [92]:
meanatt = np.sum(attimgs, axis=0)
print(meanatt.shape)

plt.figure(dpi=300)
plt.matshow(meanatt)
plt.colorbar()

()


ValueError: not enough values to unpack (expected 2, got 0)

<Figure size 1800x1200 with 0 Axes>