In [None]:
import numpy
import collections
import sys
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas
import numpy
import numpy as np
import scanpy as sc

sc.settings.verbosity = 3
sc.logging.print_header()
sc.settings.set_figure_params(dpi=80,facecolor='white')

from genevector.data import GeneVectorDataset
from genevector.model import GeneVector
from genevector.embedding import GeneEmbedding, CellEmbedding

In [None]:
adata = sc.read("PBMC.h5ad")

In [None]:
dataset = GeneVectorDataset(adata)

In [None]:
cmps = GeneVector(dataset,
                  output_file="pbmc.vec",
                  initial_lr=0.15,
                  emb_dimension=100)

In [None]:
for _ in range(25):
    cmps.train(20)
    embed = GeneEmbedding("pbmc.vec", dataset, vector="average")
    print("Similarity to CD8A:")
    print(embed.compute_similarities("CD8A")[:10])

In [None]:
cembed = CellEmbedding(dataset, embed)
adata = cembed.get_adata()
sc.pl.umap(adata,palette="Dark2",color=["sample"], title=["Uncorrected"],add_outline=True)

In [None]:
cembed = CellEmbedding(dataset, embed)
cembed.batch_correct(column="sample")
adata = cembed.get_adata()
sc.pl.umap(adata,palette="Dark2",color=["sample","celltype"], wspace=0.6,
                                 title=["Corrected","Cell Type"],add_outline=True,size=15)

In [None]:
df = embed.compute_similarities("CD8A").head(10)
fig,ax = plt.subplots(1,1,figsize=(3,3))
sns.barplot(data=df,y="Gene",x="Similarity",palette="Dark2",ax=ax)

In [None]:
gdata = embed.get_adata()
metagenes = embed.get_metagenes(gdata)

In [None]:
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
embed.score_metagenes(adata, metagenes)

In [None]:
embed.plot_metagenes_scores(adata,metagenes,"detailed_celltype")

In [None]:
for clust, genes in metagenes.items():
    if "CD8A" in genes:
        print(clust,genes)
        embed.plot_cluster(gdata, cluster=clust, title="CD8 T MG")

In [None]:
isg_sig = None
for clust, genes in metagenes.items():
    if "IFIT1" in genes:
        isg_sig = clust
        print(clust,genes)
        embed.plot_cluster(gdata, cluster=clust, title="ISG MG")

In [None]:
def get_predictive_genes(self, adata, label, n_genes=10):
    vectors = dict()
    mapped_components = dict(zip(list(self.data.keys()),self.matrix))
    comps = collections.defaultdict(list)
    for bc,x in zip(adata.obs.index,adata.obs[label]):
        comps[x].append(mapped_components[bc])
    mean_vecs = []
    for x, vec in comps.items():
        ovecs = []
        vec = numpy.average(vec,axis=0)
        for oph, ovec in comps.items():
            ovecs.append(numpy.average(ovec,axis=0))
        aovec = numpy.median(ovecs,axis=0)
        vector = numpy.subtract(vec,aovec)
        vector = numpy.subtract(vector,self.dataset_vector)
        vectors[x] = vector
    markers = dict()
    for x, mvec in vectors.items():
        ct_sig = embed.get_similar_genes(mvec)[:n_genes]["Gene"].tolist()
        markers[x] = ct_sig
    return markers
markers = get_predictive_genes(cembed,adata,"celltype")
pandas.DataFrame.from_dict(markers).T

In [None]:
annotated_adata = cembed.phenotype_probability(adata,markers)

In [None]:
prob_cols = [x for x in annotated_adata.obs.columns.tolist() if "Pseudo-probability" in x]
sc.pl.umap(annotated_adata,color=prob_cols,size=25)

In [None]:
sc.pl.umap(annotated_adata,color="genevector",size=25,add_outline=True)

In [None]:
import pickle
pickle.dump(dict(dataset.mi_scores),open("mk.pkl","wb"))