In [None]:
import os
import re
import yaml

import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt

from PIL import Image
from numpy import asarray

In [None]:
clustering = 'VAE9_ROT_VIG18'

In [None]:
# Paths and input

# Read single-cell sample for VAE analysis
main = pd.read_csv(os.path.join(os.getcwd(), 'input/main.csv'))

markers = pd.read_csv(os.path.join(os.getcwd(), 'input/CRC-097_mcmicro_markers.csv'))

dna1 = markers['marker_name'][markers['channel_number'] == 1][0]
dna_moniker = str(re.search(r'[^\W\d]+', dna1).group())

cylinter_config_path = '/Volumes/T7 Shield/cylinter_input/clean_quant/config.yml'
with open(os.path.join(os.getcwd(), 'input/CRC-97_cylinter_config.yml')) as f:
    config = yaml.safe_load(f)
markers_to_exclude = config['channelExclusionsClustering']

abx_channels = [
    i for i in main.columns if i in list(markers['marker_name'])
    if dna_moniker not in i if i not in markers_to_exclude
]

out = os.path.join(os.getcwd(), f'output/{clustering}_umap_channels')
if not os.path.exists(out):
    os.makedirs(out)

In [None]:
# Generate individual channel PNGs
plt.rcParams['font.family'] = 'Arial'

for ch in abx_channels:
    print(ch)
    fig, ax = plt.subplots()
    ax.scatter(
        main[f'{clustering}_emb1'], main[f'{clustering}_emb2'], c=main[ch],
        linewidth=0.1, s=144000 / len(main)) 
    ax.set_axisbelow(True)
    ax.grid(True, lw=1)
    plt.tick_params(labelsize=10)
    plt.savefig(os.path.join(out, f'{ch}.png'), bbox_inches='tight', dpi=800)
    plt.show()
    plt.close('all')

In [None]:
# Generate single facetgrid showing all channels
df_melt = (
    main[abx_channels + ['VAE20_ROT_VIG40_emb1', 'VAE20_ROT_VIG40_emb2']]
    .sample(frac=0.1)
    .reset_index(drop=True)
    .melt(id_vars=['VAE20_ROT_VIG40_emb1', 'VAE20_ROT_VIG40_emb2'], var_name='abx')
)

ch_dict = {
    'anti_CD3': 'CD3\u03B5', 'anti_CD45RO': 'CD45RO', 'Keratin_570': 'Keratin', 'aSMA_660': '\u03B1SMA',
    'CD4_488': 'CD4', 'CD45_PE': 'CD45', 'PD1_647': 'PD1', 'CD20_488': 'CD20', 'CD68_555': 'CD68',
    'CD8a_660': 'CD8\u03B1', 'CD163_488': 'CD163', 'FOXP3_570': 'FOXP3', 'PDL1_647': 'PDL1',
    'Ecad_488': 'ECAD', 'Vimentin_555': 'Vimentin', 'CDX2_647': 'CDX2', 'LaminABC_488': 'LaminABC',
    'Desmin_555': 'Desmin', 'CD31_647': 'CD31', 'PCNA_488': 'PCNA', 'CollagenIV_647': 'CollagenIV'
}

long_table = pd.DataFrame(columns=['channel', 'image'])
for e, file in enumerate(os.listdir(out)):
    if not file.startswith(('.', 'combined.png')):
        channel_name = ch_dict[file.split('.png')[0]]
        img = Image.open(os.path.join(out, file))
        arr = asarray(img)
        long_table.loc[e, 'channel'] = channel_name
        long_table.loc[e, 'image'] = arr

g = sns.FacetGrid(long_table, col='channel', col_wrap=4, sharex=False, sharey=False)

g.map(
    lambda image, **kwargs: (plt.imshow(image.values[0]), plt.grid(False)), 'image'
)

for ax in g.axes.flatten():
    ax.axis('off')

g.set_titles(col_template="{col_name}", fontweight='normal', size=10)
g.fig.tight_layout()
plt.subplots_adjust(hspace=-0.1)
plt.savefig(
    os.path.join(out, 'combined.png'), bbox_inches='tight', dpi=800
)
plt.show()
plt.close('all')