Skip to content

Commit

Permalink
Outsource clustering to new class
Browse files Browse the repository at this point in the history
  • Loading branch information
iosonofabio committed May 21, 2020
1 parent 7fe6370 commit 22e0ce8
Showing 1 changed file with 8 additions and 54 deletions.
62 changes: 8 additions & 54 deletions northstar/averages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from anndata import AnnData
import leidenalg
from .fetch_atlas import AtlasFetcher
from .cluster_with_annotations import ClusterWithAnnotations



Expand Down Expand Up @@ -643,60 +644,13 @@ def cluster_graph(self):
size N - n_fixed with the atlas cell types of all cells from the
new dataset.
'''
opt = leidenalg.Optimiser()

matrix = self.matrix
sizes = self.sizes
n_atlas = self.n_atlas
clustering_metric = self.clustering_metric
resolution_parameter = self.resolution_parameter
g = self.graph

N, L = matrix.shape
n_fixede = int(np.sum(sizes[:n_atlas]))
Ne = int(np.sum(sizes))

# NOTE: initial membership is singletons except for atlas nodes, which
# get the membership they have.
initial_membership = []
for isi in range(N):
if isi < n_atlas:
for ii in range(int(self.sizes[isi])):
initial_membership.append(isi)
else:
initial_membership.append(isi)

if len(initial_membership) != Ne:
raise ValueError('initial_membership list has wrong length!')

# Compute communities with semi-supervised Leiden
if clustering_metric == 'cpm':
partition = leidenalg.CPMVertexPartition(
g,
resolution_parameter=resolution_parameter,
initial_membership=initial_membership,
)
elif clustering_metric == 'modularity':
partition = leidenalg.ModularityVertexPartition(
g,
resolution_parameter=resolution_parameter,
initial_membership=initial_membership,
)
else:
raise ValueError(
'clustering_metric not understood: {:}'.format(clustering_metric))

fixed_nodes = [int(i < n_fixede) for i in range(Ne)]
opt.optimise_partition(partition, fixed_nodes=fixed_nodes)
membership = partition.membership[n_fixede:]

# Convert the known cell types
lstring = len(max(self.cell_types_atlas, key=len))
self.membership = np.array(
[str(x) for x in membership],
dtype='U{:}'.format(lstring))
for i, ct in enumerate(self.cell_types_atlas):
self.membership[self.membership == str(i)] = ct
clu = ClusterWithAnnotations(
self.graph,
self.cell_types_atlas_extended,
resolution_parameter=self.resolution_parameter,
metric=self.clustering_metric,
)
self.membership = clu.fit_transform()

def estimate_closest_atlas_cell_type(self):
'''Estimate atlas cell type closest to each new cluster'''
Expand Down

0 comments on commit 22e0ce8

Please sign in to comment.