In [None]:
import scanpy as sc

counts = sc.datasets.pbmc3k().X.todense().T
%load_ext rpy2.ipython

In [None]:
%%R -i counts

library(tidyverse)
library(Seurat)
library(scDEED)

perplexity <- 40
data <- CreateSeuratObject(counts) |>
    FindVariableFeatures() |>
    NormalizeData() |>
    ScaleData() |>
    RunPCA() |>
    RunTSNE(perplexity=perplexity)

In [None]:
%%R

embeddings <- Embeddings(data, "tsne")
normalized_counts <- GetAssayData(data, layer = "scale.data") |>
    as.matrix() |>
    t()

In [None]:
import numpy as np
import rpy2.robjects

embeddings = np.array(rpy2.robjects.globalenv['embeddings'])
normalized_counts = np.array(rpy2.robjects.globalenv["normalized_counts"])

In [None]:
%%R

K <- 8
result <- scDEED(data, K = K, reduction.method = 'tsne', rerun = F, perplexity = perplexity)
dubious <- result$full_results |>
    filter(perplexity == perplexity) |>
    pull(dubious_cells) |>
    str_split(",")

dubious <- as.integer(dubious[[1]])

data <- FindNeighbors(data, features = VariableFeatures(data), k.param = 50)
G <- data@graphs$RNA_nn
N <- map(dubious, ~ which(G[.x, ] > 0)) |>
    set_names(dubious)

In [None]:
from distortions.geometry import Geometry, bind_metric, local_distortions

geom = Geometry("brute", laplacian_method="geometric", affinity_kwds={"radius": 20}, adjacency_kwds={"radius": 50}, laplacian_kwds={"scaling_epps": 5})
H, Hvv, Hs = local_distortions(embeddings, normalized_counts, geom)
embeddings = bind_metric(embeddings, Hvv, Hs)
N = rpy2.robjects.globalenv["N"]
N_dict = {int(key): list(val) for key, val in N.items()}

In [None]:
from distortions.visualization import dplot

plots = {}
plots["scdeed_distort"] = dplot(embeddings, width=440, height=440)\
    .mapping(x="embedding_0", y="embedding_1")\
    .geom_ellipse(radiusMax=10, radiusMin=1)\
    .inter_edge_link(N=N_dict, stroke="#F25E7A", highlightColor="#C83F58", strokeWidth=.4, highlightStrokeWidth=5, threshold=10, backgroundOpacity=0.5)

In [None]:
#[p.save(f"../paper/figures/{k}.svg") for k, p in plots.items()]

In [None]:
[display(p) for p in plots.values()]