In [None]:
import os
import re

import anndata
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
from umap import UMAP
from sklearn.preprocessing import StandardScaler


In [None]:
import torch
from torchvision import datasets, transforms

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import violet
from violet.utils.attention import plot_image_attention, get_image_attention
from violet.utils.model import predict, load_pretrained_model
from violet.utils.dataloaders import get_dataloader, listfiles, image_classification_dataloaders
from violet.utils.analysis import plot_image_umap

In [None]:
img_dir = '/home/estorrs/violet/data/st/human_he_06252021'
weights = '/home/estorrs/violet/sandbox/dino_runs/he_st_xcit_p16/checkpoint0400.pth'

In [None]:
train_dataloader, val_dataloader = image_classification_dataloaders(img_dir)

In [None]:
model = load_pretrained_model(weights, model_name='xcit_small')
model = model.cuda()

In [None]:
fps = sorted(listfiles(img_dir, regex='.jpeg$'))
len(fps)

In [None]:
attn = get_image_attention(fps[0], model)
attn.shape

In [None]:
plot_image_attention(fps[2], model)

In [None]:
plot_image_attention(fps[0], model, display='mean')

In [None]:
dataloader = get_dataloader(img_dir, batch_size=1024, shuffle=False)
embs = predict(dataloader, model)
embs.shape

In [None]:
x = UMAP(n_components=2).fit_transform(embs)

In [None]:
adata = anndata.AnnData(X=embs)
adata.obs.index = [s.split('/')[-1].split('.')[0] for s, _ in dataloader.dataset.samples]
adata.obs['sample'] = ['_'.join(s.split('_')[:-1]) for s in adata.obs.index]
adata.obsm['X_umap_emb'] = x
adata

In [None]:
sc.pl.embedding(adata, basis='X_umap_emb', color=['sample'])

In [None]:
fmap = pd.read_csv('/home/estorrs/spatial-analysis/data/sample_map.txt', sep='\t', index_col=0)
fmap

In [None]:
from collections import Counter
counts = Counter([d for d in fmap['disease'] if 'mouse' not in d if 'pdx' not in d])
xs, ys = zip(*counts.items())
pd.DataFrame.from_dict({'disease': xs, 'sample count': ys}).set_index('disease')

In [None]:
len([d for d in fmap['disease'] if 'mouse' not in d if 'pdx' not in d])

In [None]:
adata.obs['disease'] = [fmap.loc[x, 'disease'] for x in adata.obs['sample']]
sc.pl.embedding(adata, basis='X_umap_emb', color=['disease'])

In [None]:
# lets do co_met
keep = ['co', 'co_met']
co = adata[[True if x in keep else False for x in adata.obs['disease']]].copy()
co

In [None]:
var = sc.read_visium(fmap.iloc[0]['spaceranger_output']).var
var

In [None]:
obs = None
var = None
data = None
for i, row in fmap.iterrows():
    if row['disease'] in keep:
        a = sc.read_visium(row['spaceranger_output'])
        a.var_names_make_unique()
        a.var["mt"] = a.var_names.str.startswith("MT-")
        sc.pp.calculate_qc_metrics(a, qc_vars=["mt"], inplace=True)
        
        a.obs.index = [f'{i}_{x}' for x in a.obs.index]
        a.obs['sample'] = i
        
        if obs is None:
            obs = a.obs
        else:
            obs = pd.concat((obs, a.obs), axis=0)
            
        if var is None:
            var = a.var

        if data is None:
            data = a.X.toarray()
        else:
            data = np.concatenate((data, a.X.toarray()), axis=0)
obs.shape, var.shape, data.shape

In [None]:
adata_exp = anndata.AnnData(X=data)
adata_exp.var = var
adata_exp.obs = obs
adata_exp

In [None]:
sns.histplot(adata_exp.obs, x='total_counts', kde=False, hue='sample')

In [None]:
exclude_samples = ['HT213C1A4_U1', 'HT165C1A3', 'HT250C1', 'HT253C1T1']
adata_exp = adata_exp[[True if s not in exclude_samples else False
                      for s in adata_exp.obs['sample']]]

In [None]:
sc.pp.filter_cells(adata_exp, min_counts=1000)

sc.pp.normalize_total(adata_exp, inplace=True)
sc.pp.log1p(adata_exp)
sc.pp.highly_variable_genes(adata_exp, flavor="seurat", n_top_genes=2000)

In [None]:
for i, row in fmap.iterrows():
    if row['disease'] in keep:
        print(i)
        a = sc.read_visium(row['spaceranger_output'])
        a.var_names_make_unique()
        sc.pl.spatial(a)

In [None]:
s_id = 'HT112C1_U2'
ref = co[co.obs['sample']==s_id].copy()
sc.pp.pca(ref)
sc.pp.neighbors(ref)
sc.tl.umap(ref)
sc.tl.leiden(ref)
new = []
for s in sorted(set(co.obs['sample'])):
    small = co[co.obs['sample']==s].copy()
    if s!=s_id:
        sc.tl.ingest(small, ref, obs=['leiden'])
        new.append(small)
new.append(ref)
corrected = anndata.concat(new)
corrected

In [None]:
sc.pl.umap(corrected, color='leiden')

In [None]:
overlap = sorted(set(corrected.obs.index).intersection(set(adata_exp.obs.index)))
corrected_exp = adata_exp[overlap]

In [None]:
corrected_exp.obs['disease'] = [corrected.obs.loc[x, 'disease'] for x in corrected_exp.obs.index]
corrected_exp.obs['leiden'] = [corrected.obs.loc[x, 'leiden'] for x in corrected_exp.obs.index]
corrected_exp.obsm['X_umap'] = corrected[overlap].obsm['X_umap']
corrected_exp

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (5, 5)
plt.rcParams['figure.dpi'] = 180

In [None]:
sc.pl.umap(corrected, color=['sample', 'disease', 'leiden'], ncols=1, size=10)

In [None]:
fps = listfiles(img_dir, regex='.jpeg$')
sample_to_fp = {fp.split('/')[-1].split('.')[0]:fp for fp in fps}

In [None]:
n = 500
idxs = np.random.permutation(corrected_exp.obs.index.to_list())[:n]
filtered = corrected_exp[idxs]
plot_image_umap(
    filtered.obsm['X_umap'][:, 0].flatten(),
    filtered.obsm['X_umap'][:, 1].flatten(),
    [sample_to_fp[s] for s in filtered.obs.index]
)

In [None]:
genes = ['leiden', 'CD3G', 'IL7R', 'CD8A']
sc.pl.umap(corrected_exp, color=genes, size=10)

In [None]:
genes = ['leiden', 'EPCAM', 'CDH1']
sc.pl.umap(corrected_exp, color=genes, size=10)

In [None]:
# visualize
s_id = 'HT112C1_U2'
a = sc.read_visium(fmap.loc[s_id, 'spaceranger_output'])
a.var_names_make_unique()
a.obs.index = [f'{s_id}_{x}' for x in a.obs.index]
sc.pp.normalize_total(a, inplace=True)
sc.pp.log1p(a)
sc.pp.highly_variable_genes(a, flavor="seurat", n_top_genes=2000)
a

In [None]:
a.obs['leiden'] = [corrected.obs.loc[x, 'leiden'] for x in a.obs.index]
a.obsm['X_umap'] = corrected[a.obs.index].obsm['X_umap'].copy()

In [None]:
sc.pl.spatial(a, color='leiden')

In [None]:
sc.pl.spatial(a)

In [None]:
genes = ['leiden', 'CD3G', 'IL7R', 'CD8A']
sc.pl.spatial(a, color=genes)

In [None]:
genes = ['leiden', 'EPCAM', 'CDH1']
sc.pl.spatial(a, color=genes)

In [None]:
genes = ['leiden', 'CD3G', 'IL7R', 'CD8A']
sc.pl.umap(a, color=genes)

In [None]:
genes = ['leiden', 'EPCAM', 'CDH1']
sc.pl.umap(a, color=genes)

In [None]:
genes = ['leiden', 'AFP']
sc.pl.umap(a, color=genes)

In [None]:
# grab some tumor cells and plot attention
ts1 = np.random.permutation(a[a.obs['leiden']=='0'].obs.index)[:5]
for s in ts1:
    print(s)
    plot_image_attention(sample_to_fp[s], model)
    plt.show()

In [None]:
for s in ts1:
    print(s)
    plot_image_attention(sample_to_fp[s], model, overlay_only=True)
    plt.show()

In [None]:
ts2 = np.random.permutation(a[a.obs['leiden']=='1'].obs.index)[:5]

for s in ts2:
    print(s)
    plot_image_attention(sample_to_fp[s], model, overlay_only=True)
    plt.show()

In [None]:
ts2 = np.random.permutation(a[a.obs['leiden']=='9'].obs.index)[:5]

for s in ts2:
    print(s)
    plot_image_attention(sample_to_fp[s], model, overlay_only=True)
    plt.show()

In [None]:
ts2 = np.random.permutation(a[a.obs['leiden']=='2'].obs.index)[:5]

for s in ts2:
    print(s)
    plot_image_attention(sample_to_fp[s], model, overlay_only=True)
    plt.show()

In [None]:
roi_fp = '/home/estorrs/sandbox/co_immune_tumor_roi.png'
plot_image_attention(roi_fp, model, overlay_only=False)

In [None]:
roi_fp = '/home/estorrs/sandbox/co_immune_tumor_roi2.png'
plot_image_attention(roi_fp, model, overlay_only=False)

In [None]:
roi_fp = '/home/estorrs/sandbox/co_immune_tumor_roi3.png'
plot_image_attention(roi_fp, model, overlay_only=False)

In [None]:
roi_fp = '/home/estorrs/sandbox/co_immune_tumor_roi4.png'
plot_image_attention(roi_fp, model, overlay_only=False)

In [None]:
p = [[x*255 for x in pal] for pal in sns.color_palette()]

In [None]:
p

In [None]:
p