In [None]:
import os
import json
import re

import numpy as np
import pandas as pd
import torch
import scanpy as sc

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from violet.utils.dataloaders import listfiles
from violet.utils.st import predict_he_tiles, predict_visium, predict_svs
from violet.utils.preprocessing import normalize_counts, get_svs_tile_shape, get_svs_array
from violet.utils.analysis import display_predictions, display_2d_scatter

In [None]:
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 300

In [None]:
img_dir = '/home/estorrs/violet/data/st/ccrcc_ffpe_08032021_normalized/'
weights = '/data/violet/sandbox/runs/he_ffpe_ccrcc_xcit_p8_pda_start/checkpoints/final.pth'
summary = '/data/violet/sandbox/runs/he_ffpe_ccrcc_xcit_p8_pda_start/summary.json'

###### for visium

In [None]:
filemap = pd.read_csv('/home/estorrs/spatial-analysis/data/sample_map.txt', sep='\t')
filemap = filemap[~pd.isnull(filemap['highres_image'])]
filemap = filemap.set_index('sample_id')
filemap

In [None]:
tups = [(i, row['spaceranger_output'], row['highres_image'])
        for i, row in filemap.iterrows()
        if row['disease'] == 'ccrcc' and row['tissue_type'] == 'ffpe']
tups

In [None]:
adata_map = {}
for s_id, s, h in tups:
    a = predict_visium(s, h, weights, summary, tmp_dir='../sandbox/tmp', ref='ccrcc')
#     a = normalize_counts(a)
    adata_map[s_id] = a
adata_map.keys()

In [None]:
sample = 'HT282N1_S1H3Fs4U1'
adata = adata_map[sample]
sample

In [None]:
adata

In [None]:
adata.obs

In [None]:
adata.X

In [None]:
sc.pl.spatial(adata)

In [None]:
sc.pl.spatial(adata, color=['CA9', 'predicted_CA9'], vmin=0.)

In [None]:
sc.pl.spatial(adata, color=['EPCAM', 'predicted_EPCAM'], vmin=0.)

In [None]:
sc.pl.spatial(adata, color=['KRT18', 'predicted_KRT18'], vmin=0.)

In [None]:
sc.pl.spatial(adata, color=['PTPRC', 'predicted_PTPRC'], vmin=0.)

In [None]:
sc.pl.spatial(adata, color=['CD8A', 'predicted_CD8A'], vmin=0.)

In [None]:
sc.pl.spatial(adata, color=['CD3E', 'predicted_CD3E'])

In [None]:
sc.pl.spatial(adata, color=['CD3G', 'predicted_CD3G'])

In [None]:
sample = 'HT293N1_S1H3Fs1U1'
adata = adata_map[sample]
sample

In [None]:
sc.pl.spatial(adata)

In [None]:
sc.pl.spatial(adata, color=['CA9', 'predicted_CA9'], vmin=0.)

In [None]:
sc.pl.spatial(adata, color=['EPCAM', 'predicted_EPCAM'], vmin=0.)

In [None]:
sc.pl.spatial(adata, color=['PTPRC', 'predicted_PTPRC'], vmin=0.)

In [None]:
sc.pl.spatial(adata, color=['CDH1', 'predicted_CDH1'], vmin=0.)

###### for svs

In [None]:
svs_fp = '/data/tcia/CCRCC/C3L-00610-21.svs'

In [None]:
df = predict_svs(svs_fp, weights, summary, tmp_dir='../sandbox/tmp')
df

In [None]:
scale = .1
# res = json.load(open(summary))['dataset']['resolution']
res = 55.
img = get_svs_array(svs_fp, scale=scale)

(n_rows, n_cols), tile_size = get_svs_tile_shape(svs_fp, resolution=res)
row_offset = img.shape[0] % n_rows
col_offset = img.shape[1] % n_cols

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.dpi'] = 300

In [None]:
display_predictions(img, df, tile_size, 'CA9', scale,
#                    row_offset=row_offset, col_offset=col_offset,
                   alpha=1., s=.05)

In [None]:
from matplotlib import cm
display_2d_scatter(df, 'CA9')

In [None]:
for c in df.columns:
    print(c)
    display_2d_scatter(df, c)
    plt.show()

In [None]:
plt.imshow(img)