In [None]:
import os
import re

import anndata
import numpy as np
import pandas as pd
import seaborn as sns
import scanpy as sc
from umap import UMAP
from sklearn.preprocessing import StandardScaler


In [None]:
import torch

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
import violet
from violet.utils.attention import plot_image_attention, plot_multichannel_attention, get_multichannel_images_attentions
from violet.utils.model import predict
from violet.utils.dataloaders import multichannel_image_dataloader, listfiles, dino_he_transform
from violet.utils.analysis import plot_image_umap, display_2d_scatter
from violet.utils.model import load_pretrained_model
from violet.utils.multichannel import create_pseudocolor_image, retile_multichannel_image

In [None]:
torch.cuda.set_device(3)
torch.cuda.current_device()

In [None]:
img_dir = '/home/estorrs/violet/data/codex/pdac_codex_256res/'
weights = '/home/estorrs/violet/sandbox/dino_runs/codex_multiplex_pdac_xcit_p8_256res/checkpoint.pth'

In [None]:
dataloader = multichannel_image_dataloader(img_dir, shuffle=False, pad=True)

In [None]:
model = load_pretrained_model(weights, in_chans=len(dataloader.dataset.channels), model_name='xcit_small',
                             patch_size=8)

In [None]:
dataloader.dataset.channels

In [None]:
len(dataloader.dataset.samples)

In [None]:
b = next(iter(dataloader))
b.shape

In [None]:
sns.color_palette()

In [None]:
import matplotlib.pyplot as plt
pseudos = []
for i, x in enumerate(b):
    print(i)
    pseudo = create_pseudocolor_image(x, dataloader.dataset, ['DAPI', 'CD20', 'CD8', 'pancytok', 'CD31'])
    plt.imshow(pseudo)
    plt.show()
    pseudos.append(pseudo)

In [None]:
model = model.cuda()

In [None]:
test, pseudo = b[7], pseudos[7]
# test, pseudo = dataloader.dataset[14216], create_pseudocolor_image(dataloader.dataset[14216], dataloader.dataset, ['Histone H3', 'CD20', 'CD3', 'pan Cytokeratin'])

In [None]:
plot_multichannel_attention(test, pseudo, model)

In [None]:
plot_multichannel_attention(test, pseudo, model, display='mean')

In [None]:
embs = predict(dataloader, model)

In [None]:
embs.shape

In [None]:
x = UMAP(n_components=2).fit_transform(embs)

In [None]:
adata = anndata.AnnData(X=embs)
adata.obs.index = [s for s in dataloader.dataset.samples]
adata.obs['sample'] = ['_'.join(s.split('_')[:-2]) for s in adata.obs.index]
adata.obsm['X_umap_emb'] = x
adata

In [None]:
adata.write_h5ad('../sandbox/results/xcit_dino_multichannel_codex_inhouse/emb_adata.h5ad')

In [None]:
adata.obs.index

In [None]:
n = 200
idxs = np.random.permutation(adata.obs.index.to_list())[:n]
filtered = adata[idxs]
imgs = [dataloader.dataset[np.where(dataloader.dataset.samples==s)[0][0]] for s in filtered.obs.index]
pseudos = [create_pseudocolor_image(x, dataloader.dataset, ['DAPI', 'CD20', 'CD3e', 'pancytok'])
          for x in imgs]
plot_image_umap(
    filtered.obsm['X_umap_emb'][:, 0].flatten(),
    filtered.obsm['X_umap_emb'][:, 1].flatten(),
    pseudos
)

In [None]:
attns = get_multichannel_images_attentions(dataloader, model)
attns.shape

In [None]:
np.save('../sandbox/results/xcit_dino_multichannel_codex_inhouse/cls_attns.npy', attns)

In [None]:
np.sqrt(784)

In [None]:
# 28 x 28 to 224 x 224
from skimage.transform import resize
from scipy.stats import pearsonr, spearmanr

In [None]:
# head_to_channel = {}
head_to_channel = {}
for i in range(len(dataloader)):
    img = dataloader.dataset[i].numpy()
    for h in range(attns.shape[1]):
        attn = attns[i, h, 1:].reshape(int(np.sqrt(attns.shape[2] - 1)), int(np.sqrt(attns.shape[2] - 1)))
        attn = resize(attn, img.shape[1:])
        head_to_channel[h] = {}
        for j, c in enumerate(dataloader.dataset.channels):
            
            if c not in head_to_channel[h]:
                head_to_channel[h][c] = {'image': img[j, :, :].flatten(), 'attention': attn.flatten()}
            else:
                head_to_channel[h][c]['image'] = np.concatenate((head_to_channel[h][c]['image'], img[j, :, :].flatten()))
                head_to_channel[h][c]['attention'] = np.concatenate((head_to_channel[h][c]['attention'], attn.flatten()))
        
#             corr, p = pearsonr(img[j, :, :].flatten(), attn.flatten())
#             head_to_channel[h][c] = {
#                 'r2': corr,
#                 'p-value': p,
#                 'image': img[j, :, :],
#                 'attention': attn
#             }
        
        
    
    

In [None]:
for h, v in head_to_channel.items():
    for c, d in v.items():
        corr, p = spearmanr(d['image'],d['attention'])
        d['r'] = corr
        d['p-value'] = p

In [None]:
for h, v in head_to_channel.items():
    for c, d in v.items():
        print(c, d['r'])

In [None]:
channel = 'pancytok'
for h, v in head_to_channel.items():
    print(h, v[channel]['r'])

In [None]:
channel = 'pancytok'
for h, v in head_to_channel.items():
    print(h, v[channel]['r'])
    plt.scatter(v[channel]['image'], v[channel]['attention'])
    plt.show()
    sns.displot(v[channel]['image'])
    plt.show()

In [None]:
channel = 'Ki67'
for h, v in head_to_channel.items():
    print(h, v[channel]['r'])
    print(np.argmax(v[channel]['image']))
    plt.scatter(v[channel]['image'], v[channel]['attention'])
    plt.show()

In [None]:
import umap
import hdbscan

In [None]:
# (n, h, p)

In [None]:
# cluster based on attn feature maps
head_to_results = []
for h in range(attns.shape[1]):
    print(h)
    
    m = attns[:, h, 1:]
    
    clusterable_embedding = umap.UMAP(
        n_neighbors=30,
        min_dist=0.0,
        n_components=2,
        random_state=42,
    ).fit_transform(m)
    
    labels = hdbscan.HDBSCAN(
        min_samples=10,
        min_cluster_size=100,
    ).fit_predict(clusterable_embedding)
    
    a = anndata.AnnData(X=m)
    a.obs.index = list(dataloader.dataset.samples)
    a.obsm['X_umap_emb'] = clusterable_embedding
    a.obs['hdbscan_cluster'] = labels
    
    sc.pp.neighbors(a)
    sc.tl.umap(a)
    sc.tl.leiden(a)
    
    head_to_results.append(a)

In [None]:
for h in range(len(head_to_results)):
    print(h)
    a = head_to_results[h]

    a.obs['sample'] = [re.sub(r'^(BaselTMA_SP[0-9]+).*(X[0-9]+Y[0-9]+).*$', r'\1_\2', x) for x in a.obs.index]
    combined = pd.merge(basel, a.obs, left_index=True, right_on='sample')

In [None]:
basel

In [None]:
from lifelines import KaplanMeierFitter, CoxPHFitter
for h in range(len(head_to_results)):
    print(h)
    a = head_to_results[h]

    a.obs['sample'] = [re.sub(r'^(BaselTMA_SP[0-9]+).*(X[0-9]+Y[0-9]+).*$', r'\1_\2', x) for x in a.obs.index]
    combined = pd.merge(basel, a.obs, left_index=True, right_on='sample')
    
    
    kmf = KaplanMeierFitter()
    ft = combined
    ft = ft[~pd.isnull(ft['leiden'])]
    T = ft['survival_time']
    E = ft['status']
    groups = ft[f'leiden']

    for cluster in sorted(set(ft['leiden'])): 
        kmf.fit(T[groups==cluster], E[groups==cluster], label=cluster)
        ax = kmf.plot(show_censors=True, ci_show=False, )
    plt.show()

In [None]:
from lifelines import KaplanMeierFitter, CoxPHFitter
for h in range(len(head_to_results)):
    print(h)
    a = head_to_results[h]

    a.obs['sample'] = [re.sub(r'^(BaselTMA_SP[0-9]+).*(X[0-9]+Y[0-9]+).*$', r'\1_\2', x) for x in a.obs.index]
    combined = pd.merge(basel, a.obs, left_index=True, right_on='sample')
    
    
    kmf = KaplanMeierFitter()
    ft = combined
    ft = ft[~pd.isnull(ft['hdbscan_cluster'])]
    T = ft['survival_time']
    E = ft['status']
    groups = ft[f'hdbscan_cluster']

    for cluster in sorted(set(ft['hdbscan_cluster'])): 
        kmf.fit(T[groups==cluster], E[groups==cluster], label=cluster)
        ax = kmf.plot(show_censors=True, ci_show=False, )
    plt.show()

In [None]:
for h in range(len(head_to_results)):
    print(h)
    a = head_to_results[h]

    a.obs['sample'] = [re.sub(r'^(BaselTMA_SP[0-9]+).*(X[0-9]+Y[0-9]+).*$', r'\1_\2', x) for x in a.obs.index]
#     a.obs.index = a.obs['sample'].to_list()
#     combined = pd.merge(basel, a.obs, left_index=True, right_index=True)
#     print(combined.shape)
    b = basel.drop_duplicates()
    a.obs['survival_time'] = [b.loc[x, 'survival_time'] if x in b.index else np.nan
                              for x in a.obs['sample']]
    
    sc.pl.embedding(a[~pd.isnull(a.obs['survival_time'])], basis='X_umap_emb', color=['hdbscan_cluster', 'survival_time'])
    plt.show()

In [None]:
for h in range(len(head_to_results)):
    print(h)
    a = head_to_results[h]

    a.obs['sample'] = [re.sub(r'^(BaselTMA_SP[0-9]+).*(X[0-9]+Y[0-9]+).*$', r'\1_\2', x) for x in a.obs.index]
#     a.obs.index = a.obs['sample'].to_list()
#     combined = pd.merge(basel, a.obs, left_index=True, right_index=True)
#     print(combined.shape)
    b = basel.drop_duplicates()
    a.obs['survival_time'] = [b.loc[x, 'survival_time'] if x in b.index else np.nan
                              for x in a.obs['sample']]
    
    sc.pl.umap(a[~pd.isnull(a.obs['survival_time'])], color=['leiden', 'survival_time'])
    plt.show()

In [None]:
idxs = [i for s, c in zip(head_to_results[12].obs.index, head_to_results[0].obs['leiden']) if c==7]
idxs

In [None]:

test, pseudo = b[36], pseudos[36]
test, pseudo = dataloader.dataset[14216], create_pseudocolor_image(dataloader.dataset[14216], dataloader.dataset, ['Histone H3', 'CD20', 'CD3', 'pan Cytokeratin'])

In [None]:
plot_multichannel_attention(test, pseudo, model)

In [None]:
dl = multichannel_image_dataloader(img_dir, shuffle=True, pad=False)

In [None]:
import matplotlib.pyplot as plt
idxs = np.random.permutation(np.arange(len(dl.dataset.samples)))
for i in idxs:
    x = dl.dataset[i]
    print(i, dl.dataset.samples[i])
    pseudo = create_pseudocolor_image(x, dl.dataset, ['Histone H3', 'CD20', 'CD3', 'pan Cytokeratin'])
    plt.imshow(pseudo)
    plt.show()

In [None]:
1408, 14216, 10347

In [None]:
dl.dataset.samples[14216]

In [None]:
## clustring feature heads

In [None]:
sns.color_palette()

In [None]:
sc.pp.neighbors(adata, use_rep='X')
sc.tl.leiden(adata, resolution=.1)
sc.tl.umap(adata, )

In [None]:
# adata.uns.pop('leiden_colors')

In [None]:
sc.pl.embedding(adata, basis='X_umap_emb', color='leiden')

In [None]:
sc.pl.umap(adata)

In [None]:
attns = get_multichannel_images_attentions(dataloader, model)
attns.shape

In [None]:
np.save('../sandbox/results/xcit_dino_multichannel_bodenmiller/cls_attns.npy', attns)

In [None]:
# remove cls self attention
attns = attns[:, :, 1:]

In [None]:
sums = np.sum(attns, axis=-1)
sums.shape

In [None]:
basel = pd.read_csv('/data/multiplex/bodenmiller_2019/Data_publication/BaselTMA/Basel_PatientMetadata.csv')
basel

In [None]:
basel = basel[['core', 'Patientstatus', 'OSmonth', 'Subtype']]
basel['status'] = [1 if 'alive' in x else 2 for x in basel['Patientstatus']]
basel['survival_time'] = basel['OSmonth']
basel = basel.set_index('core')
basel

In [None]:
basel.index = [re.sub(r'^(Basel.*SP[0-9]+).*(X[0-9]+Y[0-9]+).*$', r'\1_\2', x) for x in basel.index]

In [None]:
list(basel.index)

In [None]:
{re.sub(r'^(Basel.*SP[0-9]+).*(X[0-9]+Y[0-9]+).*$', r'\1_\2', x) for x in basel.index}

In [None]:
len({re.sub(r'^(Basel.*SP[0-9]+).*(X[0-9]+Y[0-9]+).*$', r'\1_\2', x) for x in basel.index})

In [None]:
list(basel.index)

In [None]:
set([x.split('.')[0] for x in df.index])

In [None]:
df = pd.DataFrame(data=sums, index=adata.obs.index, columns=[f'embedding_{x}' for x in range(sums.shape[1])])
df

In [None]:
df = pd.DataFrame(data=sums, index=adata.obs.index, columns=[f'embedding_{x}' for x in range(sums.shape[1])])
df['sample'] = [re.sub(r'^(BaselTMA_SP[0-9]+).*(X[0-9]+Y[0-9]+).*$', r'\1_\2', x) for x in df.index]
df = df.groupby('sample').mean()
df

In [None]:
combined = pd.merge(df, basel, left_index=True, right_index=True, )
combined = combined[[c for c in combined.columns if c!='Patientstatus' if c!='OSmonth' if c!='Subtype']]
combined

In [None]:
cph = CoxPHFitter()
cph.fit(combined, duration_col='survival_time', event_col='status')

cph.print_summary()  # access the individual results using cph.summary

In [None]:
sns.regplot(data=combined, x='survival_time', y='embedding_0')

In [None]:
sns.regplot(data=combined, x='survival_time', y='embedding_8')

In [None]:
p = combined.copy()
p['classification'] = ['1' if x < .9995 else '0' for x in p['embedding_8']]

In [None]:
from lifelines import KaplanMeierFitter, CoxPHFitter

In [None]:
kmf = KaplanMeierFitter()
# ft = integrated.obs.copy()
ft = basel
ft = ft[~pd.isnull(ft['Subtype'])]
T = ft['survival_time']
E = ft['status']
groups = ft[f'Subtype']

for cluster in sorted(set(ft['Subtype'])): 
    kmf.fit(T[groups==cluster], E[groups==cluster], label=cluster)
    ax = kmf.plot(show_censors=True, ci_show=False, )

In [None]:
kmf = KaplanMeierFitter()
# ft = integrated.obs.copy()
ft = p
ft = ft[~pd.isnull(ft['classification'])]
T = ft['survival_time']
E = ft['status']
groups = ft[f'classification']

for cluster in sorted(set(ft['classification'])): 
    kmf.fit(T[groups==cluster], E[groups==cluster], label=cluster)
    ax = kmf.plot(show_censors=True, ci_show=False, )

In [None]:
display_2d_scatter(adata.obs, 'leiden', hue_order=sorted(set(adata.obs['leiden'])))
plt.tight_layout()

In [None]:
samples = [(i, x) for i, x in enumerate(dl.dataset.samples)
          if 'ZTMA208_slide_28.23kx22.4ky_7000x7000_5_20171115_96_1_Ay12x8_283_a0_full' in x]
samples

In [None]:
pseudos = [create_pseudocolor_image(dl.dataset[x], dl.dataset, ['Histone H3', 'CD20', 'CD3', 'pan Cytokeratin'])
           for x, _ in samples]

In [None]:
retiled = retile_multichannel_image([x for _, x in samples], pseudos)

In [None]:
plt.imshow(retiled)

In [None]:
dapis = [np.sum(dataloader.dataset[x][-7].numpy())
         for x in range(len(dataloader.dataset.samples))]

In [None]:
adata.obs['dapi'] = dapis
display_2d_scatter(adata.obs, 'dapi', legend=True)

In [None]:
adata.obs['highlight'] = ['yes' if d>100 else 'no' for d in dapis]
display_2d_scatter(adata.obs, 'highlight', legend=True)

In [None]:
dataloader.dataset.channels