In [1]:
import itertools
import os
import sys
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.io

from collections import defaultdict
from tqdm.auto import tqdm
from joblib import Parallel, delayed
import re
import h5py
import napari
import glob
import natsort
import tifffile as tiff


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
data_dir = (Path().cwd().parents[0] / 'data').absolute()
data_raw = r'Y:\coskun-lab\Shuangyi\ERK, YAP project_2022\PLA\HCC827 cell culture'


# Extract count 

In [4]:
import pickle 
from sklearn.neighbors import NearestNeighbors
import networkx as nx
import scipy 

def read_PPI(path):
    with open(path, 'rb') as file:
        PPI_dict = pickle.load(file)

    return PPI_dict

def get_NN_radius(data, r):
    fit = NearestNeighbors(radius=r).fit(data)
    m = fit.radius_neighbors(data, return_distance=True, sort_results=True)

    # Put in dataframe format
    neighbours = pd.DataFrame(m[1].tolist(), index = data.index)
    
    A = fit.radius_neighbors_graph(data)
    return neighbours, A

def plot_spot_on_image(spots, spacing, radius):
    spot_img = np.zeros(spots.max(0))
    coords = (spots[:, :3] / spacing).astype(int)
    r = radius  # shorthand
    for coord in coords:
        slc = tuple(slice(x-r, x+r) if i != 0 else slice(x, x+1) for i, x in enumerate(coord) )
        spot_img[slc] = 1
    return spot_img

def plot_label_on_image(spots, spacing, radius, labels):
    spot_img = np.zeros(spots.max(0))
    coords = (spots[:, :3] / spacing).astype(int)
    r = radius  # shorthand
    for i, coord in enumerate(coords):
        slc = tuple(slice(x-r, x+r) if i != 0 else slice(x, x+1) for i, x in enumerate(coord) )
        spot_img[slc] = labels[i]
    return spot_img

def create_PPI_df(df_loc_filtered, mask_cyto, mask_nuclei, name):
    _, y_max, x_max = mask_cyto.shape
    z, y, x = df_loc_filtered['z'].to_numpy(), df_loc_filtered['y'].to_numpy(), df_loc_filtered['x'].to_numpy()
    y = np.clip(y, a_min=0, a_max=y_max-1)
    x = np.clip(x, a_min=0, a_max=x_max-1)

    labels = mask_cyto[z, y, x]
    labels_nuclei = mask_nuclei[z, y, x].astype(np.uint8)
    
    df_per_cell = pd.DataFrame({'Cell': labels, 'Nuclei': labels_nuclei, 'z': z, 'y':y, 'x':x})
    df_per_cell['PPI'] = name
    return df_per_cell

In [5]:
data_dir = (Path().cwd().parents[0] / 'data').absolute()

df_meta_path = data_dir / 'OCT Cell Culture' / '3D_Whole' / 'metadata' / 'imgs_sti.csv'
df_imgs = pd.read_csv(df_meta_path)

PPI_save_path =  data_dir / 'OCT Cell Culture' / '3D_Whole'  / 'PPI'

In [6]:
mask_filt_dir = data_dir / 'OCT Cell Culture' / '3D_Whole' / 'imgs' / 'masks_3D_filtered'

masks_path = defaultdict(dict) 
for path in os.listdir(mask_filt_dir):
    name = path.split('.')[0]
    if 'Nuclei' in name:
        masks_path[name[7:]]['nuclei'] = mask_filt_dir / path
    elif 'Cyto' in name:
        masks_path[name[5:]]['cyto'] =mask_filt_dir / path
    elif 'Cell' in name:
        masks_path[name[5:]]['cell'] =mask_filt_dir / path    
    elif 'df' in name:
        masks_path[name[3:]]['df'] =mask_filt_dir / path
    else:
        pass

In [10]:
PPI_save_path =  data_dir / 'OCT Cell Culture' / '3D_Whole' / 'PPI'
group = df_imgs.groupby(['Timepoint', 'FOV'])

for names, df_group in group:
    name = '_'.join(names)
    # Read PPi
    PPI_dict = read_PPI(PPI_save_path / f'{name}.pkl')
    
    # Read mask
    mask_cyto_path = masks_path[name]['cell']
    mask_nuclei_path = masks_path[name]['nuclei']
    df_path =  masks_path[name]['df']
    
    mask_cyto = skimage.io.imread(mask_cyto_path)
    mask_nuclei = skimage.io.imread(mask_nuclei_path)
    _, y_max, x_max = mask_cyto.shape

    df_PPIs = []
    for k in PPI_dict.keys():
        PPI_loc = PPI_dict[k][:, :3].astype(np.uint32)
        
        # Get PPI loc in panda dataframe format
        df_loc = pd.DataFrame(PPI_loc, columns=['z', 'y', 'x'])

        # Get dot neighboring graph with user defined radius
        nei, A = get_NN_radius(df_loc[['y', 'x']], r=2.5)

        # Assign new labels based on connected components label
        labels = scipy.sparse.csgraph.connected_components(A, directed=False)[1]
        df_loc['CC_label'] = labels

        # Group by CC and extract mean position
        df_loc_filtered = df_loc.groupby(['CC_label']).mean().astype(np.uint32)
        
        # Extract Cell info per PPI
        df_PPI = create_PPI_df(df_loc_filtered, mask_cyto, mask_nuclei, k)
        df_PPIs.append(df_PPI)
    df_PPI = pd.concat(df_PPIs)
    df_PPI['Condition'] = names[0]
    df_PPI['FOV'] = names[1]
    
    # Save dataframe
    path = PPI_save_path / f'{name}.csv'
    df_PPI.to_csv(path, index=False)

In [8]:
# img_spot = plot_label_on_image(df_loc_filtered.values, 1, 1, labels)
# img_spot = img_spot.astype(np.uint16)

In [9]:
# import napari

# viewer = napari.view_labels(img_spot)
# viewer.add_labels(mask_cyto)

# Generate PPI network

In [27]:
from sklearn.preprocessing import OneHotEncoder
import networkx as nx
from sklearn import preprocessing
import scipy 
from scipy.spatial import Delaunay
import itertools
import pickle

def plot_tri_simple(ax, points, tri):
    for tr in tri.simplices:
        pts = points[tr, :]
        ax.plot3D(pts[[0,1],0], pts[[0,1],1], pts[[0,1],2], color='g', lw='0.1')
        ax.plot3D(pts[[0,2],0], pts[[0,2],1], pts[[0,2],2], color='g', lw='0.1')
        ax.plot3D(pts[[0,3],0], pts[[0,3],1], pts[[0,3],2], color='g', lw='0.1')
        ax.plot3D(pts[[1,2],0], pts[[1,2],1], pts[[1,2],2], color='g', lw='0.1')
        ax.plot3D(pts[[1,3],0], pts[[1,3],1], pts[[1,3],2], color='g', lw='0.1')
        ax.plot3D(pts[[2,3],0], pts[[2,3],1], pts[[2,3],2], color='g', lw='0.1')

    ax.scatter(points[:,0], points[:,1], points[:,2], color='b')
    
def create_network(df, t=50, scale=0.2):   
    z,y,x = df['z'].to_numpy()*scale, df['y'].to_numpy(), df['x'].to_numpy()

    # Get coordinates
    coordinates = np.vstack([x, y, z]).T
    points = coordinates - np.mean(coordinates, axis=0)
    
    # Delaunay
    tri = Delaunay(points)
    G = nx.Graph()
    for path in tri.simplices:
        G.add_nodes_from(path)
        edges = list(itertools.combinations(path, 2))
        G.add_edges_from(edges)
    
    # Euclidian distance
    distance = scipy.spatial.distance_matrix(points, points)
    l = t**2/np.log(t)
    distance_norm = np.exp(-distance**2/t)

    # distance_norm = np.where(distance>t, 0, np.exp(-distance*2/l))
    
    g_dist = nx.from_numpy_array(distance_norm)
    g = g_dist.edge_subgraph(G.edges()).copy() # Only keep delaunay edges
    nx.set_node_attributes(g, dict(zip(g.nodes(), df.Labels)), "labels")
    nuclei = (df.Nuclei > 0).astype(int).tolist()
    nx.set_node_attributes(g, dict(zip(g.nodes(), nuclei)), "nuclei")
    nx.set_node_attributes(g, dict(zip(g.nodes(), coordinates)), "pos")
    return g

In [28]:
PPI_save_path =  data_dir / 'OCT Cell Culture' / '3D_Whole' / 'PPI'

# Read PPi info 
dfs = []
for path in os.listdir(PPI_save_path):
    if 'csv' in path:
        df = pd.read_csv(PPI_save_path / path)
        dfs.append(df)
df = pd.concat(dfs)

# Create label 
enc = OneHotEncoder(handle_unknown='ignore')
labels = enc.fit_transform(df['PPI'].to_numpy().reshape(-1, 1)).toarray().astype(np.uint8)
df['Labels'] = labels.tolist()

In [29]:
df

Unnamed: 0,Cell,Nuclei,z,y,x,PPI,Condition,FOV,Labels
0,254,216,2,701,6431,TEAD1 & YAP1,HCC827Ctrl,FW1,"[0, 0, 0, 1, 0]"
1,697,0,3,1813,479,TEAD1 & YAP1,HCC827Ctrl,FW1,"[0, 0, 0, 1, 0]"
2,0,0,3,3387,7570,TEAD1 & YAP1,HCC827Ctrl,FW1,"[0, 0, 0, 1, 0]"
3,242,182,0,617,3815,TEAD1 & YAP1,HCC827Ctrl,FW1,"[0, 0, 0, 1, 0]"
4,117,72,1,264,2572,TEAD1 & YAP1,HCC827Ctrl,FW1,"[0, 0, 0, 1, 0]"
...,...,...,...,...,...,...,...,...,...
95739,0,0,15,1675,7804,Mcl-1 & BAK,HCC827Osim,FW2,"[0, 1, 0, 0, 0]"
95740,813,0,15,3427,6970,Mcl-1 & BAK,HCC827Osim,FW2,"[0, 1, 0, 0, 0]"
95741,0,0,15,2078,527,Mcl-1 & BAK,HCC827Osim,FW2,"[0, 1, 0, 0, 0]"
95742,684,0,15,2856,6507,Mcl-1 & BAK,HCC827Osim,FW2,"[0, 1, 0, 0, 0]"


In [30]:
graph_save_path =  data_dir / 'OCT Cell Culture' / '3D_Whole' / 'graphs' / 'raw'
graph_save_path .mkdir(parents=True, exist_ok=True)

scale = 0.4/0.18872

group = df.groupby(['Condition', 'FOV', 'Cell'])
for name, df_group in group:
    if name[2] == 0:
        continue
    if len(df_group) < 5:
        continue
    
    g = create_network(df_group, t=70, scale=scale)
    if nx.number_of_isolates(g)>0:
        print('Isolated nodes')
    
    # Save pgrahs
    save_path = graph_save_path / ('_'.join([str(n) for n in name])+'.pkl')
    with open(save_path, 'wb') as f:
        pickle.dump(g, f, protocol=pickle.HIGHEST_PROTOCOL)


In [12]:
# scale = 0.4/0.18872
# z,y,x = df_group['z'].to_numpy()*scale, df_group['y'].to_numpy(), df_group['x'].to_numpy()

# points = np.vstack([x, y, z]).T
# tri = Delaunay(points)

# G = nx.Graph()
# for path in tri.simplices:
#     G.add_nodes_from(path)
#     edges = list(itertools.combinations(path, 2))
#     G.add_edges_from(edges)
    
# pos = points
# node_xyz = np.array([pos[v] for v in sorted(G)])
# edge_xyz = np.array([(pos[u], pos[v]) for u, v in G.edges()])

# # Create the 3D figure
# fig = plt.figure()
# ax = fig.add_subplot(111, projection="3d")

# # Plot the nodes - alpha is scaled by "depth" automatically
# ax.scatter(*node_xyz.T, color='b')
# for vizedge in edge_xyz:
#     ax.plot(*vizedge.T, color='g', lw=0.5)

In [13]:
# fig = plt.figure()
# ax = plt.axes(projection='3d')
# plot_tri_simple(ax, points, tri)