In [None]:
import os
import json
import re

import numpy as np
import pandas as pd
import torch
import scanpy as sc

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from violet.utils.dataloaders import listfiles
from violet.utils.st import predict_he_tiles, predict_visium, predict_svs
from violet.utils.preprocessing import normalize_counts, get_svs_tile_shape, get_svs_array
from violet.utils.analysis import display_predictions, display_2d_scatter

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 180

In [None]:
img_dir = '/data/tcia/PDA_preprocessed_small_raw/'
weights = '/home/estorrs/violet/sandbox/runs/pdac_ffpe_tcia_raw_augmented_20samples/checkpoints/final.pth'
summary = '/home/estorrs/violet/sandbox/runs/pdac_ffpe_tcia_raw_augmented_20samples/summary.json'

###### for visium

In [None]:
def get_target_df(folder):
    target_df = None
    fps = listfiles(folder, regex=r'_sp.h5ad$')
    for fp in fps:
        a = sc.read_h5ad(fp)
        sample = fp.split('/')[-1].split('.')[0].split('_sp')[0]
        df = a.obsm['tangram_ct_pred']
        df = pd.DataFrame(data=df.values / np.max(df.values, axis=0),
                         columns=df.columns, index=df.index)
        df.index = [f'{sample}_{x}' for x in df.index]
        
        if target_df is None:
            target_df = df
        else:
            target_df = pd.concat((target_df, df))
    return target_df

In [None]:
target_df = get_target_df('/home/estorrs/tangram_annotation/results/pdac_ffpe/')
target_df.columns = [f'tangram_{c}' for c in target_df.columns]
target_df

In [None]:
filemap = pd.read_csv('/home/estorrs/spatial-analysis/data/sample_map.txt', sep='\t')
filemap = filemap[~pd.isnull(filemap['highres_image'])]
filemap = filemap.set_index('sample_id')
filemap

In [None]:
tups = [(i, row['spaceranger_output'], row['highres_image'])
        for i, row in filemap.iterrows()
        if row['disease'] == 'pdac' and row['tissue_type'] == 'ffpe']
tups

In [None]:
# ??predict_visium

In [None]:
adata_map = {}
for s_id, s, h in tups:
    a = predict_visium(s, h, weights, summary, tmp_dir='../sandbox/tmp', )
    adata_map[s_id] = a
adata_map.keys()

In [None]:
import json
summ = json.load(open(summary))
summ

In [None]:
sample = 'HT270P1_S1H1Fs5U1'
adata = adata_map[sample]
sample

In [None]:
adata

In [None]:
t = target_df[[True if sample in x else False for x in target_df.index]].copy()
t.index = [x.split('_')[-1] for x in t.index]
adata.obs = pd.merge(adata.obs, t, left_index=True, right_index=True)

In [None]:
adata.obs

In [None]:
adata.X

In [None]:
sc.pl.spatial(adata)

In [None]:
sc.pl.spatial(adata, color=['KRT18', 'CDH1', 'EPCAM', 'predicted_Malignant'], vmin=0., wspace=.0)

In [None]:
sc.pl.spatial(adata, color=['PTPRC', 'CD3E', 'CD8A', 'CD4'], vmin=0., wspace=.0)

In [None]:
sc.pl.spatial(adata, color=['predicted_CD8 T cell', 'predicted_CD4 T cell', 'predicted_Treg', 'predicted_B cell'],
             wspace=0)

In [None]:
sc.pl.spatial(adata, color=['tangram_CD8 T cell', 'tangram_CD4 T cell', 'tangram_Treg', 'tangram_B cell'], wspace=0)

In [None]:
sample = 'HT264P1_S1H2Fs1_U1'
adata = adata_map[sample]
sample

In [None]:
sc.pl.spatial(adata)

In [None]:
sc.pl.spatial(adata, color=['KRT18', 'CDH1', 'EPCAM', 'predicted_Malignant'], vmin=0., )

In [None]:
sc.pl.spatial(adata, color=['PTPRC', 'CD3E', 'CD8A', 'CD4'], vmin=0., wspace=.001)

In [None]:
sc.pl.spatial(adata, color=['predicted_CD8 T cell', 'predicted_CD4 T cell', 'predicted_Treg', 'predicted_B cell'])

###### for svs

In [None]:
svs_fp = '/data/tcia/PDA/C3L-00401-22.svs'

In [None]:
df, imgs, img_ids = predict_svs(svs_fp, weights, summary, tmp_dir='../sandbox/tmp', return_tiles=True)
df

In [None]:
# scale = .1

# img = get_svs_array(svs_fp, scale=scale)

# (n_rows, n_cols), tile_size = get_svs_tile_shape(svs_fp, resolution=res)
# row_offset = img.shape[0] % n_rows
# col_offset = img.shape[1] % n_cols

In [None]:
# display_predictions(img, df, tile_size, 'CA9', scale,
# #                    row_offset=row_offset, col_offset=col_offset,
#                    alpha=1., s=.05)

In [None]:
import matplotlib.pyplot as plt

In [None]:
scale = .1
img = get_svs_array(svs_fp, scale=scale)
plt.imshow(img)

In [None]:
from matplotlib import cm
display_2d_scatter(df, 'Malignant')

In [None]:
import matplotlib.pyplot as plt
for c in df.columns:
    print(c)
    display_2d_scatter(df, c, legend=True)
    plt.show()

In [None]:
idxs = np.argsort(df['Malignant'].to_numpy())[-n:]
df.iloc[idxs]

In [None]:
def plot_top_tiles(pred_df, cell_type, n=10):
    m = pred_df.iloc[np.flip(np.argsort(pred_df[cell_type].to_numpy())[-n:])]
    for img_id in m.index:
        print(img_id, df.loc[img_id, cell_type])
        plt.imshow(imgs[img_ids.index(img_id)])
        plt.show()

In [None]:
plot_top_tiles(df, 'Malignant')

In [None]:
plot_top_tiles(df, 'CD4 T cell')

In [None]:
df