In [None]:
import os
import re

import anndata
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
from umap import UMAP
from sklearn.preprocessing import StandardScaler


In [None]:
import torch
from torchvision import datasets, transforms

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import violet
from violet.utils.attention import plot_image_attention, get_image_attention
from violet.utils.model import predict, load_pretrained_model
from violet.utils.dataloaders import get_dataloader, listfiles, image_classification_dataloaders
from violet.utils.analysis import plot_image_umap

In [None]:
img_dir = '/home/estorrs/violet/data/he/ffpe/tcia_ccrcc_normalized/'
weights = '/home/estorrs/violet/sandbox/dino_runs/he_ffpe_pda_xcit_p16/checkpoint.pth'

In [None]:
train_dataloader, val_dataloader = image_classification_dataloaders(img_dir)

In [None]:
torch.cuda.set_device(3)
torch.cuda.current_device()

In [None]:
model = load_pretrained_model(weights, model_name='xcit_small')
model = model.cuda()

In [None]:
fps = sorted(listfiles(img_dir, regex='.jpeg$'))
len(fps)

In [None]:
attn = get_image_attention(fps[0], model)
attn.shape

In [None]:
plot_image_attention(fps[50], model)

In [None]:
plot_image_attention(fps[50], model, display='mean')

In [None]:
plot_image_attention('/home/estorrs/violet/sandbox/tmp/tumor_boundary_4.png', model)

In [None]:
plot_image_attention('/home/estorrs/violet/sandbox/tmp/tumor_boundary_4.png', model, display='mean')

In [None]:
len(fps)

In [None]:
dataloader = get_dataloader(img_dir, batch_size=1024, shuffle=False)
embs = predict(dataloader, model)
embs.shape

In [None]:
np.save('/home/estorrs/violet/sandbox/results/xcit_dino_tcia_pda/embs.npy', embs)

In [None]:
embs = np.load('/home/estorrs/violet/sandbox/results/xcit_dino_tcia_pda/embs.npy')

In [None]:
# x = UMAP(n_components=2).fit_transform(embs)

In [None]:
adata = anndata.AnnData(X=embs)
adata.obs.index = [s.split('/')[-1].split('.')[0] for s, _ in dataloader.dataset.samples]
adata.obs['sample'] = ['_'.join(s.split('_')[:-1]) for s in adata.obs.index]
# adata.obsm['X_umap_emb'] = x
adata

In [None]:
adata.write_h5ad('/home/estorrs/violet/sandbox/results/xcit_dino_tcia_pda/emb_adata.h5ad')

In [None]:
adata = sc.read_h5adad('/home/estorrs/violet/sandbox/results/xcit_dino_tcia_pda/emb_adata.h5ad')

In [None]:
idxs = np.random.permutation(np.arange(adata.shape[0]))[:10000]
x_train = adata.X[idxs]
x_train.shape

In [None]:
u = UMAP(n_components=2)
u.fit(x_train)
adata.obsm['X_umap_emb'] = u.transform(adata.X)

In [None]:
np.random.choice(adata.obs.index, replace=False, size=10000)

In [None]:
sc.pl.embedding(adata[np.random.choice(np.arange(adata.shape[0]), replace=False, size=10000)], basis='X_umap_emb')

In [None]:
adata.obs['slide'] = [x.split('_')[0] for x in adata.obs.index]

In [None]:
len(sorted(set(adata.obs['slide'])))

In [None]:
samples = sorted(set(adata.obs['slide']))
samples = np.random.choice(samples, replace=False, size=20)
f = adata[[True if s in samples else False for s in adata.obs['slide']]]
samples

In [None]:

sc.pl.embedding(f, basis='X_umap_emb', color=['slide'])

In [None]:
fps = listfiles(img_dir, regex='.jpeg$')
sample_to_fp = {fp.split('/')[-1].split('.')[0]:fp for fp in fps}

In [None]:
n = 2000
idxs = np.random.permutation(f.obs.index.to_list())[:n]
filtered = f[idxs]
plot_image_umap(
    filtered.obsm['X_umap_emb'][:, 0].flatten(),
    filtered.obsm['X_umap_emb'][:, 1].flatten(),
    [sample_to_fp[s] for s in filtered.obs.index]
)

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.dpi'] = 300

In [None]:
from violet.utils.preprocessing import normalize_counts, get_svs_tile_shape, get_svs_array
scale = .1

In [None]:
s = 'C3N-00303-22'
img = get_svs_array(f'/data/tcia/PDA/{s}.svs', scale=scale)
plt.imshow(img)

In [None]:
s = 'C3L-03628-26'
img = get_svs_array(f'/data/tcia/PDA/{s}.svs', scale=scale)
plt.imshow(img)

In [None]:
s = 'C3L-01662-23'
img = get_svs_array(f'/data/tcia/PDA/{s}.svs', scale=scale)
plt.imshow(img)

In [None]:
adata

In [None]:
import hdbscan
from collections import Counter
from hdbscan import approximate_predict
# idxs = np.random.permutation(np.arange(adata.shape[0]))[:10000]

# clf = hdbscan.HDBSCAN(
#     min_samples=1,
#     min_cluster_size=100,
#     prediction_data=True,
# #     cluster_selection_epsilon=.1,
# )
# clf.fit(adata.obsm['X_umap_emb'][idxs])
# labels = approximate_predict(clf, adata.obsm['X_umap_emb'][idxs])
# sorted(set(labels[0])), Counter(labels[0])[-1]
# labels = clf.predict(adata.obsm['X_umap_emb'])

In [None]:
from sklearn.cluster import KMeans
idxs = np.random.permutation(np.arange(adata.shape[0]))[:10000]
n_clusters = 100
clf = KMeans(n_clusters=n_clusters)
# clf.fit(adata.obsm['X_umap_emb'][idxs])
# labels = clf.predict(adata.obsm['X_umap_emb'])
clf.fit(adata.X[idxs])
labels = clf.predict(adata.X)
Counter(labels).most_common()

In [None]:
adata.obs['cluster'] = [str(x) for x in labels]

In [None]:
f = adata[[True if s in samples else False for s in adata.obs['slide']]]
f.shape

In [None]:

sc.pl.embedding(f, basis='X_umap_emb', color=['slide'])

In [None]:

sc.pl.embedding(f, basis='X_umap_emb', color=['cluster'])

In [None]:
adata.obs['case_id'] = ['-'.join(x.split('-')[:2]) for x in adata.obs['slide']]
adata.obs

In [None]:
case_to_proportions = {}
clusters = sorted(set(adata.obs['cluster']))
for case in sorted(set(adata.obs['case_id'])):
    case_to_proportions[case] = {}
    filtered = adata[adata.obs['case_id']==case]
    counts = Counter(filtered.obs['cluster'])
    for c in clusters:
        case_to_proportions[case][c] = counts[c] / filtered.shape[0]
        

In [None]:
clinical = pd.read_csv('/data/tcia/clinical/PDA_cleaned.txt', sep='\t')
clinical

In [None]:

for c in clusters:
    clinical[c] = [case_to_proportions[x][c] for x in clinical['case_id']]

In [None]:
clinical

In [None]:
from lifelines import CoxPHFitter

In [None]:
source = clinical
source

In [None]:
umap_X = UMAP().fit_transform(source.iloc[:, 4:])
source['UMAP1'], source['UMAP2'] = umap_X[:, 0], umap_X[:, 1]
plt.scatter(umap_X[:, 0], umap_X[:, 1])

In [None]:
labels = KMeans(n_clusters=5).fit_predict(source[['UMAP1', 'UMAP2']])
source['classification'] = [str(x) for x in labels]
source

In [None]:
source[source['classification']=='4']

In [None]:
sns.scatterplot(data=source, x='UMAP1', y='UMAP2', hue='classification', hue_order=['0', '1', '2', '3', '4'])

In [None]:
sns.scatterplot(data=source, x='UMAP1', y='UMAP2', hue='9', hue_order=['0', '1', '2', '3', '4'])

In [None]:
from lifelines import KaplanMeierFitter
kmf = KaplanMeierFitter()
# ft = integrated.obs.copy()
ft = source
ft = ft[~pd.isnull(ft['classification'])]
T = ft['survival_time']
E = ft['status']
groups = ft[f'classification']

for cluster in sorted(set(ft['classification'])): 
    kmf.fit(T[groups==cluster], E[groups==cluster], label=cluster)
    ax = kmf.plot(show_censors=True, ci_show=False, )

In [None]:
cols = [str(x) for x in range(15)]
cols += ['classification']
source[cols]

In [None]:
source[cols].groupby('classification').mean()

In [None]:
adata.obs

In [None]:
s2i = {}
for i, (fp, _) in enumerate(dataloader.dataset.imgs):
    s = fp.split('/')[-1].split('.')[0]
    s2i[s] = fp
    
# dataloader.dataset.imgs

In [None]:
from einops import rearrange

In [None]:
idxs = np.random.choice(adata[adata.obs['cluster']=='9'].obs.index, size=10)
for i in idxs:
    img = default_loader(s2i[i])
    plt.imshow(img)
    plt.show()

In [None]:
from torchvision.datasets.folder import default_loader

In [None]:
len(set(adata.obs.index))

In [None]:
adata.obs.shape

In [None]:
n = 1000
idxs = np.random.permutation(adata[adata.obs['cluster']=='9'].obs.index.to_list())[:n]
pool = set(idxs)
idxs = [i for i, x in enumerate(adata.obs.index) if x in pool]
filtered = adata.obs.iloc[idxs]
# filtered
plot_image_umap(
    adata.obsm['X_umap_emb'][idxs, 0].flatten(),
    adata.obsm['X_umap_emb'][idxs, 1].flatten(),
    [s2i[s] for s in filtered.index]
)

In [None]:
# merge


In [None]:
kmf = KaplanMeierFitter()
# ft = integrated.obs.copy()
ft = p
ft = ft[~pd.isnull(ft['classification'])]
T = ft['survival_time']
E = ft['status']
groups = ft[f'classification']

for cluster in sorted(set(ft['classification'])): 
    kmf.fit(T[groups==cluster], E[groups==cluster], label=cluster)
    ax = kmf.plot(show_censors=True, ci_show=False, )

In [None]:
s = 'C3L-00017-21'
a = adata[adata.obs['slide']==s].copy()
# sc.pp.neighbors(a)
# sc.tl.umap(a)
# sc.tl.leiden(a)
# a

In [None]:
adata.obs

In [None]:
a

In [None]:
from violet.utils.analysis import display_2d_scatter
import matplotlib.pyplot as plt
display_2d_scatter(a.obs, 'cluster', hue_order=sorted(set(a.obs['cluster'])))
plt.tight_layout()

In [None]:
from violet.utils.preprocessing import normalize_counts, get_svs_tile_shape, get_svs_array
scale = .1
res = 55.
img = get_svs_array(f'/data/tcia/PDA/{s}.svs', scale=scale)
img.shape

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300
plt.imshow(img)

In [None]:
fmap = pd.read_csv('/home/estorrs/spatial-analysis/data/sample_map.txt', sep='\t', index_col=0)
fmap

In [None]:
fmap = fmap[fmap['tissue_type']=='ffpe']

In [None]:
from collections import Counter
counts = Counter(fmap['disease'])
xs, ys = zip(*counts.items())
pd.DataFrame.from_dict({'disease': xs, 'sample count': ys}).set_index('disease')

In [None]:
adata.obs['disease'] = [fmap.loc[x, 'disease'] for x in adata.obs['sample']]
sc.pl.embedding(adata, basis='X_umap_emb', color=['disease'])

In [None]:
# lets do co_met
keep = ['pdac']
co = adata[[True if x in keep else False for x in adata.obs['disease']]].copy()
co

In [None]:
var = sc.read_visium(fmap.iloc[0]['spaceranger_output']).var
var

In [None]:
obs = None
var = None
data = None
for i, row in fmap.iterrows():
    if row['disease'] in keep:
        a = sc.read_visium(row['spaceranger_output'])
        a.var_names_make_unique()
        a.var["mt"] = a.var_names.str.startswith("MT-")
        sc.pp.calculate_qc_metrics(a, qc_vars=["mt"], inplace=True)
        
        a.obs.index = [f'{i}_{x}' for x in a.obs.index]
        a.obs['sample'] = i
        
        if obs is None:
            obs = a.obs
        else:
            obs = pd.concat((obs, a.obs), axis=0)
            
        if var is None:
            var = a.var

        if data is None:
            data = a.X.toarray()
        else:
            data = np.concatenate((data, a.X.toarray()), axis=0)
obs.shape, var.shape, data.shape

In [None]:
adata_exp = anndata.AnnData(X=data)
adata_exp.var = var
adata_exp.obs = obs
adata_exp

In [None]:
sns.histplot(adata_exp.obs, x='total_counts', kde=False, hue='sample')

In [None]:
exclude_samples = ['HT213C1A4_U1', 'HT165C1A3', 'HT250C1', 'HT253C1T1']
adata_exp = adata_exp[[True if s not in exclude_samples else False
                      for s in adata_exp.obs['sample']]]

In [None]:
sc.pp.filter_cells(adata_exp, min_counts=1000)

sc.pp.normalize_total(adata_exp, inplace=True)
sc.pp.log1p(adata_exp)
sc.pp.highly_variable_genes(adata_exp, flavor="seurat", n_top_genes=2000)

In [None]:
for i, row in fmap.iterrows():
    if row['disease'] in keep:
        print(i)
        a = sc.read_visium(row['spaceranger_output'])
        a.var_names_make_unique()
        sc.pl.spatial(a)

In [None]:
s_id = 'HT112C1_U2'
ref = co[co.obs['sample']==s_id].copy()
sc.pp.pca(ref)
sc.pp.neighbors(ref)
sc.tl.umap(ref)
sc.tl.leiden(ref)
new = []
for s in sorted(set(co.obs['sample'])):
    small = co[co.obs['sample']==s].copy()
    if s!=s_id:
        sc.tl.ingest(small, ref, obs=['leiden'])
        new.append(small)
new.append(ref)
corrected = anndata.concat(new)
corrected

In [None]:
sc.pl.umap(corrected, color='leiden')

In [None]:
overlap = sorted(set(corrected.obs.index).intersection(set(adata_exp.obs.index)))
corrected_exp = adata_exp[overlap]

In [None]:
corrected_exp.obs['disease'] = [corrected.obs.loc[x, 'disease'] for x in corrected_exp.obs.index]
corrected_exp.obs['leiden'] = [corrected.obs.loc[x, 'leiden'] for x in corrected_exp.obs.index]
corrected_exp.obsm['X_umap'] = corrected[overlap].obsm['X_umap']
corrected_exp

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (5, 5)
plt.rcParams['figure.dpi'] = 180

In [None]:
sc.pl.umap(corrected, color=['sample', 'disease', 'leiden'], ncols=1, size=10)

In [None]:
fps = listfiles(img_dir, regex='.jpeg$')
sample_to_fp = {fp.split('/')[-1].split('.')[0]:fp for fp in fps}

In [None]:
n = 500
idxs = np.random.permutation(corrected_exp.obs.index.to_list())[:n]
filtered = corrected_exp[idxs]
plot_image_umap(
    filtered.obsm['X_umap'][:, 0].flatten(),
    filtered.obsm['X_umap'][:, 1].flatten(),
    [sample_to_fp[s] for s in filtered.obs.index]
)

In [None]:
genes = ['leiden', 'CD3G', 'IL7R', 'CD8A']
sc.pl.umap(corrected_exp, color=genes, size=10)

In [None]:
genes = ['leiden', 'EPCAM', 'CDH1']
sc.pl.umap(corrected_exp, color=genes, size=10)

In [None]:
# visualize
s_id = 'HT112C1_U2'
a = sc.read_visium(fmap.loc[s_id, 'spaceranger_output'])
a.var_names_make_unique()
a.obs.index = [f'{s_id}_{x}' for x in a.obs.index]
sc.pp.normalize_total(a, inplace=True)
sc.pp.log1p(a)
sc.pp.highly_variable_genes(a, flavor="seurat", n_top_genes=2000)
a

In [None]:
a.obs['leiden'] = [corrected.obs.loc[x, 'leiden'] for x in a.obs.index]
a.obsm['X_umap'] = corrected[a.obs.index].obsm['X_umap'].copy()

In [None]:
sc.pl.spatial(a, color='leiden')

In [None]:
sc.pl.spatial(a)

In [None]:
genes = ['leiden', 'CD3G', 'IL7R', 'CD8A']
sc.pl.spatial(a, color=genes)

In [None]:
genes = ['leiden', 'EPCAM', 'CDH1']
sc.pl.spatial(a, color=genes)

In [None]:
genes = ['leiden', 'CD3G', 'IL7R', 'CD8A']
sc.pl.umap(a, color=genes)

In [None]:
genes = ['leiden', 'EPCAM', 'CDH1']
sc.pl.umap(a, color=genes)

In [None]:
genes = ['leiden', 'AFP']
sc.pl.umap(a, color=genes)

In [None]:
# grab some tumor cells and plot attention
ts1 = np.random.permutation(a[a.obs['leiden']=='0'].obs.index)[:5]
for s in ts1:
    print(s)
    plot_image_attention(sample_to_fp[s], model)
    plt.show()

In [None]:
for s in ts1:
    print(s)
    plot_image_attention(sample_to_fp[s], model, overlay_only=True)
    plt.show()

In [None]:
ts2 = np.random.permutation(a[a.obs['leiden']=='1'].obs.index)[:5]

for s in ts2:
    print(s)
    plot_image_attention(sample_to_fp[s], model, overlay_only=True)
    plt.show()

In [None]:
ts2 = np.random.permutation(a[a.obs['leiden']=='9'].obs.index)[:5]

for s in ts2:
    print(s)
    plot_image_attention(sample_to_fp[s], model, overlay_only=True)
    plt.show()

In [None]:
ts2 = np.random.permutation(a[a.obs['leiden']=='2'].obs.index)[:5]

for s in ts2:
    print(s)
    plot_image_attention(sample_to_fp[s], model, overlay_only=True)
    plt.show()

In [None]:
roi_fp = '/home/estorrs/sandbox/co_immune_tumor_roi.png'
plot_image_attention(roi_fp, model, overlay_only=False)

In [None]:
roi_fp = '/home/estorrs/sandbox/co_immune_tumor_roi2.png'
plot_image_attention(roi_fp, model, overlay_only=False)

In [None]:
roi_fp = '/home/estorrs/sandbox/co_immune_tumor_roi3.png'
plot_image_attention(roi_fp, model, overlay_only=False)

In [None]:
roi_fp = '/home/estorrs/sandbox/co_immune_tumor_roi4.png'
plot_image_attention(roi_fp, model, overlay_only=False)

In [None]:
p = [[x*255 for x in pal] for pal in sns.color_palette()]

In [None]:
p

In [None]:
p