In [None]:
import os
import json
import re

import numpy as np
import pandas as pd
import torch
import scanpy as sc
import matplotlib.pyplot as plt
import matplotlib as mpl

from torchvision.datasets.folder import default_loader

In [None]:
plt.rcParams['figure.dpi'] = 200
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

In [None]:
((7000000 / 10000) * 40) / 60 / 60

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
from violet.utils.dataloaders import listfiles
from violet.utils.st import predict_he_tiles, predict_visium, predict_svs, load_trained_st_regressor
from violet.utils.preprocessing import normalize_counts, get_svs_tile_shape, get_svs_array
from violet.utils.analysis import display_predictions
from violet.utils.attention import plot_attention_rollup

In [None]:
img_dir = '/home/estorrs/violet/data/st/human_he_06252021'
weights = '/home/estorrs/violet/sandbox/runs/test_run_on_human_he_06252021/checkpoints/final.pth'
summary = '/home/estorrs/violet/sandbox/runs/test_run_on_human_he_06252021/summary.json'

In [None]:
filemap = pd.read_csv('/home/estorrs/spatial-analysis/data/sample_map.txt', sep='\t')
filemap = filemap[~pd.isnull(filemap['highres_image'])]
filemap = filemap.set_index('sample_id')
filemap

In [None]:
s = json.load(open(summary))
s

In [None]:
# do only validation samples
samples = s['dataset']['val_dataset']['samples']

In [None]:
xs = np.random.rand(5, 10)
ys = np.random.rand(1, 10)
xs / ys

In [None]:
xs.shape, ys.shape

In [None]:
def standardize_predictions(adata):
    # scale adata predictions and adata spot expression to between 0-1
    # make sure predictions are cut at zero
    for c in adata.obs.columns:
        if 'predicted_' in c:
            adata.obs[c] = [0. if x < 0. else x for x in adata.obs[c]]
            m = np.max(adata.obs[c])
            adata.obs[c] = [x / m for x in adata.obs[c]]
    adata.X = adata.X.toarray() / np.max(adata.X.toarray(), axis=0)
    
    return adata
    
    

In [None]:
adata_map = {}
for i, row in filemap.iterrows():
    if i in samples:
        print(i)
        adata = predict_visium(row['spaceranger_output'], row['highres_image'],
                               weights, summary, tmp_dir='../sandbox/tmp')
        adata = normalize_counts(adata)
        adata = standardize_predictions(adata)
        adata_map[i] = adata

In [None]:
for sample, adata in adata_map.items():
    sc.pl.spatial(adata)
    plt.show()

In [None]:
for sample, adata in adata_map.items():
    print(sample)
    markers = ['ESR1', 'PGR', 'ERBB2', 'MKI67']
    markers += [f'predicted_{m}' for m in markers]
    sc.pl.spatial(adata, color=markers, ncols=4, return_fig=True, vmin=0.)
    plt.show()

In [None]:
for sample, adata in adata_map.items():
    print(sample)
    markers = ['EPCAM', 'CDH1']
    markers += [f'predicted_{m}' for m in markers]
    sc.pl.spatial(adata, color=markers, ncols=2, return_fig=True, vmin=0.)
    plt.show()

In [None]:
for sample, adata in adata_map.items():
    print(sample)
    markers = ['CD3G', 'CD4', 'IL7R', 'CD8A']
    markers += [f'predicted_{m}' for m in markers]
    sc.pl.spatial(adata, color=markers, ncols=4, return_fig=True, vmin=0.)
    plt.show()

In [None]:
for sample, adata in adata_map.items():
    print(sample)
    markers = ['BGN', 'FAP', 'SPARC']
    markers += [f'predicted_{m}' for m in markers]
    sc.pl.spatial(adata, color=markers, ncols=3, return_fig=True, vmin=0.)
    plt.show()

In [None]:
for sample, adata in adata_map.items():
    print(sample)
    markers = ['ITGAX', 'LYZ', 'CD68', 'CD14']
    markers += [f'predicted_{m}' for m in markers]
    sc.pl.spatial(adata, color=markers, ncols=4, return_fig=True, vmin=0.)
    plt.show()

In [None]:
for sample, adata in adata_map.items():
    print(sample)
    markers = ['SDC1', 'PECAM1']
    markers += [f'predicted_{m}' for m in markers]
    sc.pl.spatial(adata, color=markers, ncols=2, return_fig=True, vmin=0.)
    plt.show()

In [None]:
a = adata_map['HT206B1_H8_U2']
adata.obs.sort_values('predicted_EPCAM')

In [None]:
meta = json.load(open(summary))
regressor = load_trained_st_regressor(weights, meta)

In [None]:
tile_fps = sorted(listfiles(img_dir, regex='HT206B1_H8_U2'))
tile_fps

In [None]:
barcode = 'AAACCCGAACGAAATC-1'
fp = [fp for fp in tile_fps if barcode in fp][0]
fp

In [None]:
default_loader(fp)

In [None]:
plot_attention_rollup(fp, regressor.vit)