## Calculate features for nuclei and generate .pt files for each graph
## 15 features in total

In [None]:
###TCGA-27-2528-01Z-00-DX1.8160EE21-D3C3-4FFF-8BB6-979235C1F963_2
# val:ZT76_39_B_1_5

In [1]:
import re, os
import cv2
import math
import random
import torch
import resnet
import skimage.feature
import pdb
from PIL import Image
from pyflann import *
from torch_geometric.data import Data
from collections import OrderedDict

import networkx as nx
import numpy as np
import pandas as pd
import torchvision.transforms.functional as F
import torch_geometric.data as data
import torch_geometric.utils as utils
import pdb
import torch_geometric

In [2]:
from model import CPC_model
device = torch.device('cuda:{}'.format('0'))
model = CPC_model(1024, 256)
encoder = model.encoder.to(device)
ckpt_dir = './pretrained_models/cpc_best.pt'
ckpt = torch.load(ckpt_dir)
encoder.load_state_dict(ckpt['encoder_state_dict'])

<All keys matched successfully>

In [3]:
encoder

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): Bottleneck(
      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (relu): ReLU(inplace=True)
    )
    (2): Bottleneck(
      (co

In [5]:
x = torch.randn((1, 3, 64, 64)).cuda()

In [4]:
def from_networkx(G):
    r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a
    :class:`torch_geometric.data.Data` instance.
    Args:
        G (networkx.Graph or networkx.DiGraph): A networkx graph.
    """

    G = G.to_directed() if not nx.is_directed(G) else G
    edge_index = torch.tensor(list(G.edges)).t().contiguous()

    keys = []
    keys += list(list(G.nodes(data=True))[0][1].keys())
    keys += list(list(G.edges(data=True))[0][2].keys())
    data = {key: [] for key in keys}

    for _, feat_dict in G.nodes(data=True):
        for key, value in feat_dict.items():
            data[key].append(value)

    for _, _, feat_dict in G.edges(data=True):
        for key, value in feat_dict.items():
            data[key].append(value)

    for key, item in data.items():
        data[key] = torch.tensor(item)

    data['edge_index'] = edge_index
    data = torch_geometric.data.Data.from_dict(data)
    data.num_nodes = G.number_of_nodes()

    return data

In [5]:
def get_cell_image_og(img, cx, cy):
    if cx < 32 and cy < 32:
        return img[0: cy+32, 0:cx+32, :]
    elif cx < 32:
        return img[cy-32: cy+32, 0:cx+32, :] 
    elif cy < 32:
        return img[0: cy+32, cx-32:cx+32, :]
    else:
        return img[cy-32: cy+32, cx-32:cx+32, :]
    
def my_transform(img):
    img = F.to_tensor(img)
    img = F.normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    return img

def get_cpc_features_og(cell):
    
    cell_R = np.squeeze(cell[:, :, 2:3])
    cell_G = np.squeeze(cell[:, :, 1:2])
    cell_B = np.squeeze(cell[:, :, 0:1])

    cell_R = np.pad(cell_R, [(0, 64-cell_R.shape[0]), (0, 64-cell_R.shape[1])], mode = 'constant')
    cell_G = np.pad(cell_G, [(0, 64-cell_G.shape[0]), (0, 64-cell_G.shape[1])], mode = 'constant')
    cell_B = np.pad(cell_B, [(0, 64-cell_B.shape[0]), (0, 64-cell_B.shape[1])], mode = 'constant')
    cell = np.stack((cell_R, cell_B, cell_G))
    
    cell = np.transpose(cell, (1, 2, 0))
    cell = my_transform(cell)
    cell = cell.unsqueeze(0)
    
    device = torch.device('cuda:{}'.format('0'))
    
    feats = encoder(cell.to(device)).cpu().detach().numpy()
    return feats
    #feats_cpu = [f.cpu().detach().numpy() for f in feats]
    #return feats_cpu

In [6]:
from torchvision import transforms
import itertools

def get_cell_image(img, cx, cy, size=512):
    cx = 32 if cx < 32 else size-32 if cx > size-32 else cx
    cy = 32 if cy < 32 else size-32 if cy > size-32 else cy
    return img[cy-32:cy+32, cx-32:cx+32, :]

def get_cpc_features(cell):
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    cell = transform(cell)
    cell = cell.unsqueeze(0)
    device = torch.device('cuda:{}'.format('0'))
    feats = encoder(cell.to(device)).cpu().detach().numpy()[0]
    return feats

def get_cell_features(img, contour):
    
    # Get contour coordinates from contour
    (cx, cy), (short_axis, long_axis), angle = cv2.fitEllipse(contour)
    cx, cy = int(cx), int(cy)
    
    # Get a 64 x 64 center crop over each cell    
    img_cell = get_cell_image(img, cx, cy)

    grey_region = cv2.cvtColor(img_cell, cv2.COLOR_RGB2GRAY)
    img_cell_grey = np.pad(grey_region, [(0, 64-grey_region.shape[0]), (0, 64-grey_region.shape[1])], mode = 'reflect') 


    # 1. Generating contour features
    eccentricity = math.sqrt(1-(short_axis/long_axis)**2)
    convex_hull = cv2.convexHull(contour)
    area, hull_area = cv2.contourArea(contour), cv2.contourArea(convex_hull)
    solidity = float(area)/hull_area
    arc_length = cv2.arcLength(contour, True)
    roundness = (arc_length/(2*math.pi))/(math.sqrt(area/math.pi))
    
    # 2. Generating GLCM features
    out_matrix = skimage.feature.greycomatrix(img_cell_grey, [1], [0])
    dissimilarity = skimage.feature.greycoprops(out_matrix, 'dissimilarity')[0][0]
    homogeneity = skimage.feature.greycoprops(out_matrix, 'homogeneity')[0][0]
    energy = skimage.feature.greycoprops(out_matrix, 'energy')[0][0]
    ASM = skimage.feature.greycoprops(out_matrix, 'ASM')[0][0]
    
    # 3. Generating CPC features
    cpc_feats = get_cpc_features(img_cell)
    

    # Concatenate + Return all features
    x = [[short_axis, long_axis, angle, area, arc_length, eccentricity, roundness, solidity],
         [dissimilarity, homogeneity, energy, ASM], 
         cpc_feats]
    
    return np.array(list(itertools.chain(*x)), dtype=np.float64), cx, cy


def seg2graph(img, contours):
    G = nx.Graph()
    
    contours = [c for c in contours if c.shape[0] > 5]

    for v, contour in enumerate(contours):

        features, cx, cy = get_cell_features(img, contour)
        G.add_node(v, centroid = [cx, cy], x = features)

    if v < 5: return None
    return G

In [20]:
data_dir = "/media/hdd1/Jingwen/codes/staintools/"
save_dir = os.path.join(data_dir, 'KIRC_st_cpc_blue')
img_dir = os.path.join(data_dir, 'KIRC_st')
seg_dir =  os.path.join(data_dir,'KIRC_st_seg')

roi1 = 'TCGA-B0-4839-01Z-00-DX1.0c13b082-d7e6-4327-b5cc-ab6bd99b78aa_roi_2_x_74016_y_24064_99.388.png'
roi2 = 'TCGA-B0-4694-01Z-00-DX1.beeee877-b7f3-4110-84e9-2db8be5667f5_roi_1_x_86528_y_47936_95.186.png'
assert roi1 in os.listdir(seg_dir)
assert roi2 in os.listdir(seg_dir)

In [25]:
data_dir = "/media/hdd1/Jingwen/codes/staintools/"
save_dir = os.path.join(data_dir, 'all_st_cpc_img_blue')
img_dir = os.path.join(data_dir, 'all_st')
seg_dir =  os.path.join(data_dir,'all_st_seg')

roi1 = 'TCGA-06-0174-01Z-00-DX3.23b6e12e-dfc1-4c6f-903e-170038a0e055_1.png'
roi2 = 'TCGA-HT-7470-01Z-00-DX4.204D0CF2-A22E-4428-8E8C-572432B86280_1.png'
roi3 = 'TCGA-26-1442-01Z-00-DX1.FD8D4EB7-AD5E-49E8-BD0B-6CDDEA8DDB35_1.png'

assert roi1 in os.listdir(seg_dir)
assert roi2 in os.listdir(seg_dir)
assert roi3 in os.listdir(seg_dir)

In [14]:
from tqdm import tqdm

In [43]:
G = seg2graph(img, contours)

In [41]:
G

array([], shape=(64, 0, 3), dtype=uint8)

In [19]:
img_dir = os.path.join(data_dir, 'KIRC_st')
for img_fname in tqdm(os.listdir(seg_dir)):
    if int(img_fname.split('_')[2]) > 2: continue
    os.system('cp %s %s' % (os.path.join(data_dir, img_dir, img_fname), os.path.join(data_dir, 'KIRC_st_small', img_fname)))

100%|██████████| 11961/11961 [01:25<00:00, 140.49it/s]


In [26]:

pt_dir = os.path.join(save_dir, 'pt_bi')
graph_dir = os.path.join(save_dir, 'graphs')
fail_list = []

from tqdm import tqdm

for img_fname in tqdm([roi1, roi2, roi3]):
    
    #if int(img_fname.split('_')[2]) > 2: continue
    #print("Processing...(%d/%d):\t%s" % (idx+1, len(os.listdir(seg_dir)), img_fname))
    
    img = np.array(Image.open(os.path.join(img_dir, img_fname)))
    seg = np.array(Image.open(os.path.join(seg_dir, img_fname)))
    ret, binary = cv2.threshold(seg, 127, 255, cv2.THRESH_BINARY) 
    contours, hierarchy = cv2.findContours(binary, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours) < 1: continue
    
    G = seg2graph(img, contours)

    if G is None: 
        fail_list.append(img_fname)
        continue


    centroids = []
    for u, attrib in G.nodes(data=True):
        centroids.append(attrib['centroid'])
    
    cell_centroids = np.array(centroids).astype(np.float64)
    dataset = cell_centroids
    
    start = None
            
    for idx, attrib in list(G.nodes(data=True)):
        start = idx
        flann = FLANN()
        testset = np.array([attrib['centroid']]).astype(np.float64)
        results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 100, checks = 16)
        results, dists = results[0], dists[0]
        nns_fin = []
       # assert (results.shape[0] < 6)
        for i in range(1, len(results)):
            G.add_edge(idx, results[i], weight = dists[i])
            nns_fin.append(results[i])
        #attrib['nn'] = list(nns_fin)

    G = G.subgraph(max(nx.connected_components(G), key=len))

    #for idx, attrib in list(G.nodes(data=True)):
    #    cv2.circle(img, tuple(attrib['centroid']), 8, (0, 255, 0), -1)
    
    cv2.drawContours(img, contours, -1, (0,255,0), 2)
    
    for n, nbrs in G.adjacency():
        for nbr, eattr in nbrs.items():
            cv2.line(img, tuple(G.nodes[n]['centroid']),  tuple(G.nodes[nbr]['centroid']), (0, 0, 255), 2)

    Image.fromarray(img).save(os.path.join(graph_dir, img_fname))
    
    G = from_networkx(G)
    
    edge_attr_long = (G.weight.unsqueeze(1)).type(torch.LongTensor)
    G.edge_attr = edge_attr_long 
    
    edge_index_long = G['edge_index'].type(torch.LongTensor)
    G.edge_index = edge_index_long
    
    x_float = G['x'].type(torch.FloatTensor)
    G.x = x_float
    
    G['weight'] = None
    G['nn'] = None
    torch.save(G, os.path.join(pt_dir, img_fname[:-4]+'.pt'))

100%|██████████| 3/3 [00:13<00:00,  4.41s/it]


In [None]:

for img_fname in tqdm(os.listdir(seg_dir)):
    
    if int(img_fname.split('_')[2]) > 2: continue
 

In [16]:
x = torch.load('../staintools/KIRC_st_cpc/pt_bi/TCGA-3Z-A93Z-01Z-00-DX1.79F4D1A6-ACDB-4AB1-B8A8-C1CEE617C734_roi_0_x_70176_y_35424_99.659.pt')

In [17]:
x

Data(centroid=[102, 2], edge_attr=[502, 1], edge_index=[2, 502], x=[102, 1036])

In [262]:
list(nx.connected_components(G))

[{0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  28,
  29,
  30,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  38,
  39,
  40,
  41,
  42,
  43,
  44,
  45,
  46,
  47,
  48,
  49,
  50,
  51,
  52,
  53,
  54,
  55,
  56,
  57,
  58,
  59,
  60,
  61,
  62,
  63,
  64,
  65,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  73,
  74,
  75,
  76,
  77,
  78,
  79,
  80,
  81,
  82,
  83,
  84,
  85,
  86,
  87,
  88,
  89,
  90,
  91,
  92,
  93,
  94,
  95,
  96,
  97,
  98,
  99,
  100,
  101,
  102,
  103,
  104,
  105,
  106,
  107,
  108,
  109,
  110,
  111,
  112,
  113,
  114,
  115,
  116,
  117,
  118,
  119,
  120,
  121,
  122,
  123,
  124,
  125,
  126,
  127,
  128,
  129,
  130,
  131,
  132,
  133}]

In [264]:
G

<networkx.classes.graph.Graph at 0x7fb3167cdc90>

<networkx.classes.graph.Graph at 0x7fb2c0301950>

In [255]:
G

{0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133}

In [224]:
results

array([ 0, 17, 27, 13, 28], dtype=int32)

In [231]:
dataset.shape

(134, 2)

In [200]:
np.linalg.norm(results[3]-results[0])

13.0

In [205]:
dists

array([   0., 2548., 5777., 6698., 7001.])

In [190]:
G.node(1)

AttributeError: 'Data' object has no attribute 'node'

In [None]:
results[]

In [94]:
idx, attrib = list(G.nodes(data=True))[0]

In [98]:
testset = np.array([attrib['centroid']]).astype(np.float64)
testset

array([[ 17., 487.]])

In [99]:
results, dists = flann.nn(dataset, testset, num_neighbors=5, algorithm = 'kmeans', branching = 32, iterations = 7, checks = 16)

In [107]:
dists

array([[   0., 2548., 5777., 6698., 7001.]])

In [105]:
results

array([[ 0, 17, 27, 13, 28]], dtype=int32)

In [86]:
print(len(G.nodes[n]), len(G.nodes[nbr]))

0 2


In [87]:
(len(G.nodes[n]) > 0)

False

In [76]:
len(G.nodes[n]) > 0

False

In [66]:
G.nodes[nbr]

{'centroid': [353, 482],
 'x': array([ 11.95111752,  14.6584177 , 150.54612732, ...,   0.        ,
          0.        ,   0.        ])}