# SCAFOLD Function
SCAFOLD identifies local structures ('clusters') in spatial data based on user defined values

#### Inputs: 
1. adata = anndata object containing x,y coordinates of cells under adata.obsm['spatial'] and cell type annotations under adata.obs['ct']
2. k = number of nearest neighbors for knn
3. r = distance cutoff between two cells
4. valuelist = list of cell types of interest

#### Output: 
Adds cluster assignments to adata.obs['cluster']

In [1]:
import numpy as np
import pandas as pd
import scanpy as sc
import pickle
from anndata import AnnData
import warnings
import squidpy as sq
from scipy.cluster.hierarchy import DisjointSet
import copy
from sklearn.neighbors import NearestNeighbors
warnings.filterwarnings("ignore")

In [2]:
def runScafold(adata, k, r, valuelist):
    filtered_neighs = makeneighborgraph(adata, k, r, valuelist)
    joinsets(adata, filtered_neighs)
    
def makeneighborgraph(adata, k, r, valuelist):
    knn = NearestNeighbors(n_neighbors=k)
    knn.fit(adata.obsm['spatial'])
    neighbors = knn.kneighbors()

    keepidx = [i for i,c in enumerate(adata.obs['ct']) if c in valuelist] 
    map_idxtoct = dict(map(lambda i,j : (i,j) , keepidx,[adata.obs['ct'][x] for x in keepidx])) 
    
    filtered_neighs = dict() 
    for idx in keepidx:
        filtered_neighs[idx] = set([j for i,j in enumerate(neighbors[1][idx]) if (j in map_idxtoct) and (neighbors[0][idx][i] < r)]) # filter neighbors by cell type and radius

    return filtered_neighs

def joinsets(adata, filtered_neighs): # call union find
    edges = set()
    nodes = set()
    for k in filtered_neighs:
        nodes.add(k)
        for v in filtered_neighs[k]:
            edges.add((k,v))
            
    disjoint_set = DisjointSet(nodes)
    for a,b in edges:
        disjoint_set.merge(a, b)

    clusters = dict() # key is cell; value is cluster assignment (root cell)
    for s in disjoint_set.subsets():
        for c in s:
            clusters[c] = disjoint_set.__getitem__(c)

    adata.obs['cluster'] = range(adata.shape[0])
    adata.obs = adata.obs.replace({"cluster": clusters})



In [3]:
adata = pickle.load(open('/data1/greenbab/users/zhangb2/xenium_files/ND83_toy.pickle','rb'))

In [4]:
adata.obs

Unnamed: 0,cell_id,ct
aaaaffnp-1-0,aaaaffnp-1,Megakaryocyte
aaaaojnj-1-0,aaaaojnj-1,Erythroid
aaaapmda-1-0,aaaapmda-1,Maturing/Mature Myeloid
aaabbcdp-1-0,aaabbcdp-1,T Cell
aaabkmdn-1-0,aaabkmdn-1,Erythroid
...,...,...
oiliadop-1-0,oiliadop-1,Erythroid
oiligppf-1-0,oiligppf-1,Erythroid
oiliholf-1-0,oiliholf-1,Macrophage
oilihpgc-1-0,oilihpgc-1,Erythroid


In [5]:
adata.obsm['spatial']

array([[ 1735.67320709,  5907.53197209],
       [ 8052.32246051,  5706.30418283],
       [ 8289.39414416, 11536.8386676 ],
       ...,
       [ 2520.78167927,  8581.99538172],
       [ 7978.56171806,  6760.57096967],
       [ 8298.63895903,  6458.61162463]])

In [6]:
runScafold(adata, 10, 20, ['T Cell', 'B Cell', 'pDC', 'NK']) # to find lymphoid aggregates

In [7]:
# Size of clusters
adata.obs['cluster'].value_counts()

102831    196
69558     172
33605     118
34416     107
100988     50
         ... 
57144       1
57145       1
57146       1
57147       1
171041      1
Name: cluster, Length: 163912, dtype: int64

In [8]:
# Cell types within largest cluster
adata.obs['ct'][adata.obs['cluster'] == 102831].value_counts()

T Cell                     111
B Cell                      51
pDC                         32
NK                           2
Megakaryocyte                0
VSMC                         0
Plasma Cell                  0
Osteoblast                   0
Mesenchymal/Stromal          0
Adipocyte                    0
Mast                         0
Macrophage                   0
HSPC                         0
Erythroid                    0
Endothelial                  0
Early Myeloid                0
Maturing/Mature Myeloid      0
Name: ct, dtype: int64