In [None]:
import os
import pickle
import pathlib
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.colors import TwoSlopeNorm
import seaborn as sns
import dask.array as da
from PIL import Image
from numpy import asarray

In [None]:
def ravel_data(channel_patches, max_pixels):
    total_pixels = channel_patches.size

    # flatten image patches for raveling efficiency
    data = channel_patches.reshape(
        -1, np.prod(channel_patches.shape[1:])
    ) 
    if total_pixels > max_pixels:
        num_patches = int(max_pixels / data.shape[1])
        patch_selections = np.random.choice(
            data.shape[0], num_patches, replace=False
        )
        data = data[patch_selections].ravel() 
    else:
        data = data[:].ravel()                      

    return data

In [None]:
def save_figs(dpi=300, format='pdf', out_dir=None, prefix=None, close=True):
    figs = [plt.figure(i) for i in plt.get_fignums()]
    if prefix is not None:
        for f in figs:
            if f._suptitle:
                f.suptitle(f'{prefix} {f._suptitle.get_text()}')
            else:
                f.suptitle(prefix)
    names = [f._suptitle.get_text() if f._suptitle else "" for f in figs]
    out_dir = pathlib.Path(out_dir)
    out_dir.mkdir(exist_ok=True, parents=True)

    for f, n, nm in zip(figs, plt.get_fignums(), names):
        f.savefig(out_dir / f'{n}-{nm}.{format}', dpi=dpi, bbox_inches='tight')
        if close:
            plt.close(f)

In [None]:
check = 'encode'

root = (
    '/n/scratch/users/g/gjb15/VAE9_VIG7_multi-tissue/test/combined/'
    f'6_latent_space_LD412/hist_check_{check}'
)

with open(os.path.join(
   '/n/scratch/users/g/gjb15/VAE9_VIG7_multi-tissue/test/combined/'
   '4_histogram_alignment', 'limits.pkl'), 'rb') as handle:
    limits = pickle.load(handle)

max_pixels = 2_000_000
X = np.load(os.path.join(root, 'X_transform.npy'))
y = pd.Series(np.load(os.path.join(root, 'labels.npy')))
main = pd.read_csv(
    '/n/scratch/users/g/gjb15/VAE9_VIG7_multi-tissue/test/combined/'
    '6_latent_space_LD412/main.csv'
)
samples = ['C9', 'CRC097', 'CRC102']
markers = {
    'anti_CD3': 0, 'anti_CD45RO': 1, 'Keratin_570': 2, 'aSMA_660': 3, 
    'CD4_488': 4, 'CD45_PE': 5, 'PD1_647': 6, 'CD20_488': 7, 'CD68_555': 8, 
    'CD8a_660': 9, 'CD163_488': 10, 'FOXP3_570': 11, 'PDL1_647': 12, 
    'Ecad_488': 13, 'Vimentin_555': 14, 'CDX2_647': 15, 'LaminABC_488': 16,
    'Desmin_555': 17, 'CD31_647': 18, 'PCNA_488': 19, 'CollagenIV_647': 20
}
antibody_abbrs = {
    'anti_CD3': 'CD3', 'anti_CD45RO': 'CD45RO', 'Keratin_570': 'Keratin', 
    'aSMA_660': 'aSMA', 'CD4_488': 'CD4', 'CD45_PE': 'CD45', 'PD1_647': 'PD1', 
    'CD20_488': 'CD20', 'CD68_555': 'CD68', 'CD8a_660': 'CD8a', 
    'CD163_488': 'CD163', 'FOXP3_570': 'FOXP3', 'PDL1_647': 'PDL1', 
    'Ecad_488': 'ECAD', 'Vimentin_555': 'Vimentin', 'CDX2_647': 'CDX2', 
    'LaminABC_488': 'LaminABC', 'Desmin_555': 'Desmin', 'CD31_647': 'CD31', 
    'PCNA_488': 'PCNA', 'CollagenIV_647': 'CollagenIV'
}

In [None]:
# vizualize histogram alignment

fig, axs = plt.subplots(
    3, np.ceil(len(markers.keys()) / 3).astype('int'),
    figsize=(15, 7)
)
prop_cycle = plt.rcParams['axes.prop_cycle']
colors = prop_cycle.by_key()['color']
for (marker, ch), ax in zip(markers.items(), axs.ravel()):
    handles = []
    if marker in markers.keys():
        print(f'Plotting channel {marker}')
        for sample, color in zip(samples, colors):
            idx = y.index[y == sample]
            channel_patches = X[idx, :, :, ch]
            raveled_data = ravel_data(
                channel_patches=channel_patches,
                max_pixels=max_pixels
            )
            raveled_data = raveled_data[raveled_data > 0]
            bins = np.linspace(
                *np.percentile(
                    raveled_data, [0.01, 99.99]), 200
            )
            counts, bin_edges = np.histogram(
                raveled_data, bins=bins
            )
            ax.step(
                bin_edges[:-1], counts, where='mid', 
                label=sample.split('-')[-1]
            )
            handle = mlines.Line2D(
                [], [], color=color, marker='s', markersize=8, 
                linestyle='None', label=sample.split('-')[-1]
            )
            handles.append(handle)
            fig.suptitle('processed_masked', fontsize=12)
            ax.set_title(antibody_abbrs[marker], fontsize=10)
            ax.tick_params(axis='both', which='major', labelsize=7)
        ax.legend(fontsize=8)
        ax.legend(
            handles=handles,
            fontsize=8,
            markerscale=1,
            loc='best',
            frameon=False,
            handletextpad=0.1
        )
fig.tight_layout()
save_figs(format='pdf', out_dir=root)
print()

In [None]:
# compute channel z-scores heatmap

main['call'] = main['cluster'].astype(str) + '_' + main['Sample']
X = da.from_array(X)

# compute median values of per patch, per channel pixel intensity medians
medians = da.median(X, axis=(1, 2)).compute()
clustermap_input = pd.DataFrame(
    columns=[i for i in markers.keys()], data=medians
)

clustermap_input['cluster'] = main['call']
clustermap_input = clustermap_input.groupby('cluster').median()

res = [i for i in clustermap_input.index if i.split('_')[1] in ['CRC097', 'CRC102']]
clustermap_input = clustermap_input.loc[res]

# compute per channel z-scores across clusters
clustermap_input = (
    (clustermap_input-clustermap_input.mean())/clustermap_input.std()
)
clustermap_input.fillna(0, inplace=True)

# zero-center colorbar
norm = TwoSlopeNorm(
    vcenter=0, vmin=clustermap_input.min().min(), 
    vmax=clustermap_input.max().max()
)

g = sns.clustermap(
        clustermap_input, cmap='coolwarm', standard_scale=None, 
        yticklabels=1, xticklabels=1, linewidths=0.1, linecolor='k', 
        cbar=True, norm=norm 
)
g.ax_heatmap.set_xticklabels(
    [antibody_abbrs[i.get_text()] for i 
     in g.ax_heatmap.get_xticklabels()], rotation=90
)
g.ax_heatmap.set_yticklabels(
    [i.get_text() for i in g.ax_heatmap.get_yticklabels()], 
    rotation=0, fontsize=5
)
g.ax_cbar.set_position([1.01, 0.75, 0.05, 0.2])
g.fig.suptitle('Channel z-scores', y=0.995, fontsize=10)
g.ax_heatmap.yaxis.set_tick_params(length=0.05, width=0.01)
plt.savefig(
    os.path.join(root, 'channel_z-scores.pdf'), 
    bbox_inches='tight'
)

In [None]:
# generate channel intensity plots

intensities = pd.DataFrame(
    columns=[i for i in markers.keys()], data=medians
)

# generate individual channel PNG image files
for name in markers.keys():
    if not os.path.exists(os.path.join(root, f'{antibody_abbrs[name]}.png')):
        
        fig, ax = plt.subplots()

        ax.scatter(
            main['emb1'],
            main['emb2'],
            c=intensities[name], linewidth=0.1,
            s=150000 / len(main),
        )

        ax.set_aspect('equal')
        ax.tick_params(labelsize=10)
        ax.grid(False)
        plt.savefig(
            os.path.join(root, f'{antibody_abbrs[name]}.png'), 
            bbox_inches='tight', dpi=800
        )
        plt.close('all')

# generate facetgrid showing all plots together
data_melt = (
    main[list(markers.keys()) + ['emb1', 'emb2']]
    .sample(frac=0.1)
    .reset_index(drop=True)
    .melt(id_vars=['emb1', 'emb2'], var_name='abx')
)

long_table = pd.DataFrame(columns=['channel', 'image'])
for e, file in enumerate(os.listdir(root)):
    if (file.endswith('.png') and not file.endswith('combined.png')):
        channel_name = file.split('.png')[0]
        img = Image.open(os.path.join(root, file))
        arr = asarray(img)
        long_table.loc[e, 'channel'] = channel_name
        long_table.loc[e, 'image'] = arr

long_table.sort_values(by='channel', inplace=True)

g = sns.FacetGrid(
    long_table, col='channel', col_wrap=7, sharex=False, sharey=False,
    height=3.0, aspect=2.0
)
g.map(
    lambda image, **kwargs: (plt.imshow(image.values[0]), plt.grid(False)), 
    'image'
)

for ax in g.axes.flatten():
    ax.axis('off')

g.set_titles(col_template="{col_name}", size=14, fontweight='normal', y=1.0)

g.fig.tight_layout()

plt.savefig(
    os.path.join(root, 'combined.png'), bbox_inches='tight', dpi=800
)
plt.show()
plt.close('all')

print()