In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os,sys,time

import torch
from model_MLP import *
from utils import *

## check available device
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print("device:", device)

device: cpu


In [2]:
project = "BRCA"
ik = 4
il = 1

n_genes = 22884
n_genes_each = 512

n_models = int(n_genes/n_genes_each) + 1
print("n_models:", n_models)

i_steps = np.array([i*n_genes_each for i in range(n_models)])
print("i_steps.shape:", i_steps.shape)
print(i_steps)

n_outputs = np.full(n_models, n_genes_each)
n_outputs[n_models-1] = n_genes - (n_models-1)*n_genes_each
print("n_outputs.shape:", n_outputs.shape)
print(n_outputs)

n_models: 45
i_steps.shape: (45,)
[    0   512  1024  1536  2048  2560  3072  3584  4096  4608  5120  5632
  6144  6656  7168  7680  8192  8704  9216  9728 10240 10752 11264 11776
 12288 12800 13312 13824 14336 14848 15360 15872 16384 16896 17408 17920
 18432 18944 19456 19968 20480 20992 21504 22016 22528]
n_outputs.shape: (45,)
[512 512 512 512 512 512 512 512 512 512 512 512 512 512 512 512 512 512
 512 512 512 512 512 512 512 512 512 512 512 512 512 512 512 512 512 512
 512 512 512 512 512 512 512 512 356]


In [3]:
## load features_AE:
slide_features = np.load("../14collect_features_nonrot/%s_features_all_AE_%s.npy"%(project,ik), allow_pickle=True)
n_slides = len(slide_features)
#print("n_slides:", n_slides)

i_slides_train_valid_test = np.load("../16patients_split/i_slides_nested_%s.npz"%project, allow_pickle=True)
i_slides_test = i_slides_train_valid_test["i_slides_test"][ik]
#print(i_slides_test)

i_slides = np.array([slide_features[i][0] for i in range(n_slides)])
#print(i_slides)

test_idx = np.array([np.argwhere(i_slides == i)[0][0] for i in i_slides_test])
print(test_idx.shape)
print(test_idx)

(206,)
[   0    2   11   28   31   35   43   45   50   54   55   64   70   79
   83   84   96  102  103  105  107  118  119  134  136  137  141  145
  148  151  154  159  163  165  170  181  183  184  185  186  187  189
  191  198  202  204  205  210  211  215  218  223  225  226  232  242
  244  251  255  256  261  282  283  284  285  296  307  321  326  330
  331  332  333  343  353  361  363  380  395  399  418  428  430  434
  435  438  440  441  445  447  450  453  454  456  464  469  482  488
  489  496  499  501  514  520  521  528  535  548  556  569  571  572
  575  596  599  605  612  618  626  628  633  636  638  652  668  674
  681  683  685  686  687  700  704  707  710  714  719  721  727  729
  733  734  736  738  743  752  755  757  763  765  767  769  772  785
  795  801  821  826  828  830  831  832  833  841  854  858  865  866
  870  872  873  874  878  881  885  890  897  906  909  916  921  927
  932  934  938  941  943  945  954  956  957  958  963  964  968  970

In [4]:
## predict tile score:
n_inputs = 512
n_hiddens = 512
#n_outputs = 512
dropout=0.2

tile_scores = []

## each slide
for i0 in test_idx:
    print(" ")
    #print("i0:", i0)
    
    i_slide = slide_features[i0][0]
    print("i_slide:", i_slide)
    
    slide_name = slide_features[i0][1]
    #print("slide_name:", slide_name)

    tile_features = slide_features[i0][2]

    x0 = tile_features[np.newaxis].transpose(1,2,0)
    #print("x0.shape:", x0.shape)  ## n_tiles, n_features, 1

    ## each model
    for iig, ig in enumerate(i_steps):
        model = MLP_regression(n_inputs, n_hiddens, n_outputs[iig], dropout, device, bias_init=None)
        model.to(device)
        #print(model)

        ## load trained_model
        model.load_state_dict(torch.load("220Nov21BRCA_nonrot_ik%s_il%s/result_%s_%s_%s/model_trained.pth"%(ik,il,ik,il,ig), \
                                         map_location=device))

        x = model.layer0(torch.Tensor(x0))
        #print("x.shape:", x.shape)

        ## spatial transcriptome
        y = model.layer1(x).detach().numpy().squeeze()
        #print("y.shape:", y.shape)

        if iig == 0:
            y_all = y
        else:
            y_all = np.hstack((y_all, y))

        #print("y_all.shape:", y_all.shape)
        
    ## combine every slide:
    tile_scores.append((i_slide, slide_name, y_all))

 
i_slide: 0
 
i_slide: 2
 
i_slide: 11
 
i_slide: 28
 
i_slide: 31
 
i_slide: 35
 
i_slide: 43
 
i_slide: 45
 
i_slide: 50
 
i_slide: 54
 
i_slide: 55
 
i_slide: 65
 
i_slide: 72
 
i_slide: 82
 
i_slide: 87
 
i_slide: 88
 
i_slide: 101
 
i_slide: 108
 
i_slide: 109
 
i_slide: 111
 
i_slide: 113
 
i_slide: 128
 
i_slide: 129
 
i_slide: 144
 
i_slide: 146
 
i_slide: 147
 
i_slide: 151
 
i_slide: 155
 
i_slide: 158
 
i_slide: 162
 
i_slide: 165
 
i_slide: 170
 
i_slide: 174
 
i_slide: 176
 
i_slide: 181
 
i_slide: 194
 
i_slide: 196
 
i_slide: 197
 
i_slide: 198
 
i_slide: 199
 
i_slide: 200
 
i_slide: 202
 
i_slide: 204
 
i_slide: 211
 
i_slide: 215
 
i_slide: 217
 
i_slide: 218
 
i_slide: 223
 
i_slide: 224
 
i_slide: 228
 
i_slide: 231
 
i_slide: 236
 
i_slide: 238
 
i_slide: 239
 
i_slide: 245
 
i_slide: 255
 
i_slide: 257
 
i_slide: 264
 
i_slide: 268
 
i_slide: 269
 
i_slide: 274
 
i_slide: 295
 
i_slide: 296
 
i_slide: 297
 
i_slide: 298
 
i_slide: 309
 
i_slide: 320
 
i_slide: 33

In [5]:
np.save("tile_scores_%s_ik%s_il%s.npy"%(project,ik,il), tile_scores)

In [6]:
tile_scores[0]

(0,
 'BRCA_00000_0001a1fb-f388-41c6-bfe9-ecbb10429e37',
 array([[8.07952   , 8.024243  , 8.191448  , ..., 1.9282755 , 1.4302287 ,
         1.414066  ],
        [8.088963  , 8.199239  , 8.054601  , ..., 1.7036307 , 1.5774993 ,
         1.6212946 ],
        [8.187705  , 8.088943  , 8.20764   , ..., 2.039182  , 1.620887  ,
         1.7089056 ],
        ...,
        [7.961115  , 7.731134  , 7.909792  , ..., 2.4098053 , 1.0067791 ,
         1.2671643 ],
        [8.153152  , 7.9447    , 8.16727   , ..., 1.7984827 , 1.9147347 ,
         1.2591765 ],
        [8.547988  , 8.197728  , 8.122331  , ..., 2.0762725 , 1.8689097 ,
         0.97550946]], dtype=float32))

In [7]:
tile_scores[0][0]

0

In [8]:
tile_scores[0][1]

'BRCA_00000_0001a1fb-f388-41c6-bfe9-ecbb10429e37'

In [9]:
tile_scores[0][2].shape

(4435, 22884)

In [10]:
len(tile_scores)

206