# Tissue Patterning

In [None]:
# pip install ipywidgets

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patheffects as patheffects

# conda install ipywidgets for movie 
from ipywidgets import interact, FloatSlider, IntSlider

In [None]:
%matplotlib inline

## 1. Emergent properties of collective gene-expression patterns in multicellular systems
Smart, M., & Zilman, A. (2023). Emergent properties of collective gene-expression patterns in multicellular systems. *Cell Reports Physical Science*, 101247.
[Paper Link](https://doi.org/10.1016/j.xcrp.2023.101247)
### 1.1 Single-cell model
Assume cell state is described by the activity of $N$ genes in the cell. Let $\xi_1, \ldots, \xi_N$ denotes the desired cell types, with $\xi_i\in\{1,-1\}^N$. Let $\hat{J}=\xi(\xi^{T}\xi)^{-1}\xi^T$ represent the gene-gene interactions. 
* First, we want to show that $\xi_i$ is the eigenvector of matrix $\hat{J}$. That is to show that $$\hat{J}\xi_i=\lambda_i\xi_i.$$
Let $A=(\xi^{T}\xi)^{-1}\xi^T$. Note that 
$$A\xi=(\xi^{T}\xi)^{-1}\xi^T\xi=I.$$
This implies the following,
\begin{align*}
\hat{J}\xi&=\xi A \xi\\
&=\xi I\\
&=I\xi.
\end{align*}
Therefore, $\xi_i$ are eigenvectors of matrix $\hat{J}$ for the repeated eigenvaule $\lambda=1$.


* Next, to get the final gene-gene interaction matrix J, the authors set the diagonal value to zero. That is,
$$J=\hat J - \text{diag}(\hat J).$$
Assuming all entries of matrix $\hat J$ are the same, the above transformation won't change the eigenspace of the matrix. But, is it guaranteed all diagonal entries are the same? 

In [None]:
import numpy as np
import random
from numpy.linalg import inv
from scipy.stats import bernoulli
from numpy import linalg as LA

N = 9  # number of genes
assert N == 9  # parts of code below assume 4 genes;

xi = -np.ones((N, 3))
xi[:3,0] = 1
xi[3:6,1]  = 1
xi[6:,2]  = 1
print('Cell Patter xi= \n', xi)

Jhat = xi @ inv(xi.T @ xi) @ xi.T
J = Jhat - np.diag(np.diag(Jhat))
print('Matrix hat_J= \n',np.round(Jhat,2))  # rounded for display purpose

### 1.2 Single Cell Hamiltonian

$$\mathcal{H}(s_0)= -\frac12(s^TJs) - h^Ts$$

In [None]:
# Generate all possible states of the cell and verify the lowest energy state in a single cell
# Parallel the calculation if N is large

comb = np.tile([-1,1], (N, 1)).tolist()
S = np.stack(np.meshgrid(*comb), -1).reshape(-1, N).T       # S is a N by 2^N matrix that stores all possible patterns
H = np.diag(-S.T @ J @ S/2)

H_unique, counts = np.unique(np.round(H,2), return_counts=True)

np.set_printoptions(suppress=True)
print("Different energy states in the system:", H_unique)
print("Counts of cell states with above value:", counts)

print("Different States and its energy value:\n",np.vstack((S, H)))
# # Find the global/local max/min
# ind_xi = np.where(abs(H - H.min())<1e-6)[0]
# print("The lowest energy states in a single cell are the column vectors of matrix:\n", S[:, ind_xi])
# print("The global minimum are cell states with ID:", ind_xi)

# ind_glob_max = np.where(abs(H - H.max())<1e-6)[0]
# print("The global maximum are cell states with ID:", ind_glob_max)

# ind_neighb = [sum(abs(np.tile(S[:,i],(2**N,1)).T - S)) == 2 for i in range(2**N)]
# local_min = np.array([sum((H[i]-1e-6)<=H[ind_neighb[i]]) for i in range(2**N)])
# print("The local minimum are cell states with ID:", np.where(local_min==N)[0])

# local_max = np.array([sum((H[i]+1e-6)>=H[ind_neighb[i]]) for i in range(2**N)])
# print("The local maximum are cell states with ID:", np.where(local_max==N)[0])

In [None]:
np.random.seed(100)
W0 = np.random.uniform(-1,1, (N,N))
W = W0/2 + W0.transpose()/2

ind_rand = np.random.randint(len(S.T))  
# s = np.reshape(S[:,ind_rand],(N,1))

# H = -s.transpose()@W@S
# ind_min = np.argmin(H)
# s_hat= np.sign(s.transpose()@W).transpose()
# # s_hat.flatten() == S[:,ind_min]

# H_hat = -s_hat.transpose()@W@S
# ind_hat_min = np.argmin(H)
# s.flatten() == S[:,ind_hat_min]

s = np.reshape(S[:,ind_rand],(N,1))
states_covered = np.array([])
edges = []
group_ids = np.array([],dtype='int')
group_id = 0
node_ids = []

for j in range(2**N):
    if ~np.isin(j,node_ids):
        s_path = np.zeros((2**N,), dtype=int)-1
        s_path[0] = j
        group_id=group_id+1
        for i,ind_i in enumerate(s_path[:-1]):
            si = np.reshape(S[:,ind_i],(N,1))
            node_ids.append(ind_i)
            group_ids = np.append(group_ids,group_id)
            
            ind_hat_si = np.argmin(-si.transpose()@W@S)
            s_path[i+1] = ind_hat_si
            edges.append([ind_i,ind_hat_si]) # each row [i, j] represents a directed edge from i to j
            
            if max(ind_hat_si == node_ids):
                id_end_state = group_ids[node_ids.index(ind_hat_si)]
                if id_end_state!=group_id:
                    group_ids[group_ids == group_id] = id_end_state
                    group_id=group_id-1
                break

In [None]:
print('group stats for plot below:')
print('\tS.shape', S.shape)
print('\tind_rand', ind_rand)
print('\ts (one cell, anchored to ind=%d)' % ind_rand, s.T)


In [None]:
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(10, 5))
gs = gridspec.GridSpec(1, 2, width_ratios=[2, 1])

# Create a directed graph
G = nx.DiGraph()

# cmap = plt.cm.get_cmap('tab20')
# colors = colors.ListedColormap([cmap(i%20) for i in range(group_id)])

colors = plt.cm.get_cmap('Spectral',group_id)

# Add edges from matrix
for edge in edges:
    if group_ids[node_ids.index(edge[0])] > 0:
        G.add_edge(edge[0], edge[1])

color_map = []
for node in G:
    node_group_id = group_ids[node_ids.index(node)]
    color_map.append(colors(node_group_id))

# Draw the graph
ax0 = fig.add_subplot(gs[0])
nx.draw(G,ax=ax0, node_color=color_map, pos=nx.spring_layout(G),with_labels=False, node_size=25, arrowstyle='-|>', arrowsize=5,linewidths=0,width= 0.1,)
ax0.set_title("s -> s' if s' minimizes f(s,s_i)")

ax1 = fig.add_subplot(gs[1],adjustable='box', aspect=0.3)
n, bins, patches = ax1.hist(group_ids, bins=np.arange(0,group_id, 1))
for i,patch in enumerate(patches):
    patch.set_facecolor(colors(i))
ax1.set_xlabel("Index for Each Connected Subgraph")
ax1.set_ylabel("Number of Cell States")
plt.savefig('two_cell.png', bbox_inches='tight',dpi=600)
plt.show()

### 1.3 Multicellular Model

Assume a given tissue (multicellular stystem) has $M$ different cells. Let $s_i\in\{1,-1\}^N$ represent the $i$-th cell. The Hamiltonian of the system is calculated as the following,
$$\mathcal{H}(s_1,s_2,\ldots, s_M)=\sum_i^M -\frac12(s_i^TJs_i)+\gamma \sum_i\sum_j A_{ij}f(s_i,s_j).$$
* The first term describes the summation of the Hamiltonian of each cell
* The second term describes the Hamiltonian from cell-cell interaction,
$$f(s_i,s_j)=-\frac12 s_i^TWs_j,$$
with strength $\gamma$ and randomly sampled matrix $W$.


In [None]:
# Cell arrangement: m1 rows of m2 cells
m1 = 4   # no. of rows
m2 = 6    # no. of columns 
M = m1 * m2  # total number of cells

# Generating neighborhood matrix A
A = np.zeros((M, M))
k = 0
for i in np.arange(1, m1+1):
    for j in np.arange(1, m2+1):
        B = np.zeros((m1+2, m2+2))
        B[i-1:i+2, j-1:j+2] = 1
        B[i, j] = 0
        B_temp = B[1:m1+1, 1:m2+1]
        A[k, :] = B_temp.flatten()
        k += 1

# Ramdon inter-cellular communication matrix W
np.random.seed(100)
W = np.random.uniform(-1,1, (N,N))

# # Plot the adjacency matrix and the neighborhoood for a random cell
# fig, axs = plt.subplots(1,2,layout='constrained')

# axs[0].title.set_text('Adjacency matrix for %d = %d x %d cells' % (M, m1, m2))
# axs[0].imshow(A)

# np.random.seed(seed=None)
# i_rand = np.random.randint(M)  # pick a random cell
# i_neigb = A[i_rand,]
# i_neigb[i_rand] = 2

# extent = (0, m2, m1, 0)
# axs[1].imshow(i_neigb.reshape(m1,m2), extent=extent)
# axs[1].title.set_text('Visual Inspection of\n Cell %d (yellow)\'s Neighbors (cyan)'  % i_rand)
# axs[1].set(xticks=range(0, m2, 1), yticks=range(0, m1, 1))
# axs[1].grid(color='w', linewidth=2)   
# axs[1].set_frame_on(False)
# plt.show()

In [None]:
T = 100                              # Simulation Time Steps
tissue_state_traj = np.zeros((N, M, T+1))        # Storing Transitions: N - # of genes; M - # of cells
gamma = 10                           # Cell-cell communication strength
beta = 2000                          # Noise strength

# Initial tissue state by randomly choose M different cell types from S: N by 2^N matrix that stores all possible cell states
ind_rand = np.random.randint(len(S.T),size = M)  
tissue_state_traj[:,:,0] = S[:,ind_rand]

# Stochastic update one gene at a time
for k in range(1, T+1):
    for i in range(M):
        s = np.copy(tissue_state_traj[:,i,k-1])
        cell_trans = s
        for j in range(N):
            ind = A[i,:] != 0
            h = J @ s.T + gamma * np.sum(s @ W @ tissue_state_traj[:,ind,k-1])
            transition_prob = 1 / ( 1 + np.exp(-2 * beta * h[j]) )
            cell_trans[j] = 2 * bernoulli.rvs(transition_prob) - 1
            tissue_state_traj[j,i,k] = cell_trans[j]


In [None]:
# Definition of the plot_tissue function, our "callback function".
def plot_tissue_integer(t):
    cell_color = 2**np.arange(N) @ np.heaviside(tissue_state_traj[:,:,t], 0)
    cell_color.resize((m1, m2))

    text_kwargs = dict(ha='center', va='center', fontsize=10, path_effects=[patheffects.withStroke(linewidth=1, foreground='black')])
    plt.imshow(cell_color, vmin=0, vmax=2**N-1)
#    plt.colorbar(ticks=np.arange(0, 2**4, 4))
    color_index_lowH = 2**np.arange(N) @ np.heaviside(S[:,ind_xi], 0) #list the color index for cells with lowest Hamiltonian
    col_lowH = 'green'
    color_index_highH = 2**np.arange(N) @ np.heaviside(S[:,ind_glob_max], 0) #list the color index for cells with lowest Hamiltonian
    col_highH = 'red'

    # add state text to each cell
    for i in range(m1):
        for j in range(m2):
            cell_idx = j + i * m2
            cell_color_ij = cell_color[i, j]
            if cell_color_ij in color_index_lowH:
                col= col_lowH
            elif cell_color_ij in color_index_highH:
                col= col_highH
            else:
                col = 'white'
            # cell_state = list(np.heaviside(tissue_state_traj[:, cell_idx, t], 0))
            # plt.text(j, i, 'c%d\n%d' % (cell_idx, cell_color[i, j]), **text_kwargs, c=col)
            plt.text(j, i, 'c%d' % cell_idx, **text_kwargs, c=col)
            
    plt.text(m2+1, 0.5, 'Global Max.', **text_kwargs, c=col_highH)
    plt.text(m2+1, 1.5, 'Global Min.', **text_kwargs, c=col_lowH)
    plt.show()


## Generate our user interface.
interact(plot_tissue_integer, t=IntSlider(min=0, max=T, step=1, value=0))

### (Work in progress) Additional plotting tool -- scatterplot visualization of cell state (circle = gene on; absence = off)
- step 1: push state size onto nearest square; pad with zeros and mask the dummy values
- step 2: plot backdrop 'square' on gene indices of square (nothing for masked dumnmy entries)
- step 3: plot circle marker for genes that are ON (val = 1)

In [None]:
def pad_to_len_square(x, square_size):
    x_pad = np.zeros(square_size ** 2)    
    x_pad[:len(x)] = x
    return x_pad

In [None]:
def plot_rna_square(vals, ns, top_left, marker_size=100, extent=1.0, ax=None):
    """
    vals:   np.array values in -1, 1 or 0 (0 means undefined -- mask it out)
    ns:     side length of closest square s.t. ns ** 2 >= vals.size()
    top_left: x,y coordinates to place scatter plot
    marker_size: 100 (square size)
    extent: scale of marker (x, y) coordinates, which should be proportional to marker size
    """
    square_size = extent * 200 * np.ones(ns ** 2)
    circle_size = 0.25 * square_size
    
    # construct array to maps gene index to (x, y) position
    # - xv, yv construction: first gene in top left, last bottom right
    xv, yv = np.meshgrid(
        np.linspace(top_left[0], top_left[0] + extent, ns), 
        np.linspace(top_left[1], top_left[1] - extent, ns), 
        indexing='xy')

    if ax is None:
        ax = plt.figure().gca()
    
    # where vals == non-zero; squares 's'
    square_masked = np.ma.masked_where(vals == 0, square_size)
    ax.scatter(xv, yv, square_masked, marker='s', c='blue') 
    
    # where vals == 1.0;      circles 'o'
    circle_masked = np.ma.masked_where(vals != 1.0, circle_size)
    ax.scatter(xv, yv, circle_masked, marker='o', c='yellow')
    
    return

"""
s = np.array(
    [1,1,1,-1,-1,-1,1,1]
)
n = len(s)

square_size = np.ceil(np.sqrt(n)).astype(int)

#s_pad = np.zeros(square_size ** 2)
#s_pad[:n] = s
s_pad = pad_to_len_square(s, square_size)

s_sqr = s_pad.reshape(square_size, square_size)

print(s)
print(s_pad)
print(s_sqr)

plt.imshow(s_sqr)

plot_rna_square(s_pad, 3, (5, 5))
plt.show()
"""

In [None]:
# Definition of the plot_tissue function, our "callback function".
def plot_tissue_symbolic(t):
    cell_color = [1, 2, 4, 8] @ np.heaviside(tissue_state_traj[:,:,t], 0)
    cell_color.resize((m1, m2))
    plt.imshow(cell_color, vmin=0, vmax=15)
    plt.colorbar(ticks=np.arange(0, 2**4, 4))
    
    square_size = np.ceil(np.sqrt(N)).astype(int)
    
    for i in range(m1):
        for j in range(m2):
            cell_idx = j + i * m2
            cell_state = tissue_state_traj[:, cell_idx, t]
            cell_state_padded = pad_to_len_square(cell_state, square_size)
            
            symbol_extent = 0.2
            top_left = (j - 0.5 * symbol_extent, 
                        i + 0.5 * symbol_extent)
            plot_rna_square(cell_state_padded, square_size, top_left, ax=plt.gca(), extent=symbol_extent)
    plt.show()

## Generate our user interface.
interact(plot_tissue_symbolic, t=IntSlider(min=0, max=T, step=1, value=0))

#### Tinkering/delete: Imshow testing (data rotation, x vs y)

In [None]:
mx = 6
my = 10
x_vec = np.arange(mx * my)
x_arr = x_vec.reshape((mx, my))

In [None]:
print(x_arr)

In [None]:
im = plt.imshow(x_arr)
plt.colorbar(im)
for i in range(mx):
    for j in range(my):
        cell_idx = j + i * my
        plt.text(j, i, cell_idx, fontsize=13, ha='center', va='center', c='white', 
                    path_effects=[patheffects.withStroke(linewidth=2, foreground='black')])
        #plt.text(j, i, cell_idx, fontsize=12, ha='center', c='k',
        #            path_effects=[patheffects.withStroke(linewidth=2, foreground='white')])