In [None]:
import os
import re
import io
import sys
import glob
import enum
import json
import itertools
import psycopg2
import requests
import skimage
import seaborn as sns

import numpy as np
import pandas as pd
import sqlalchemy as db
from matplotlib import pyplot as plt

In [None]:
# database URL from a notebook on ESS
url = 'postgresql://postgres:password@keith-oc-db:5432/opencelldb'

In [None]:
# database URL from a remote notebook
url = 'postgresql://postgres:password@cap.czbiohub.org:5434/opencelldb'

In [None]:
engine = db.create_engine(url)

In [None]:
df = pd.read_sql(
    '''
    select cell_line_id, well_id, plate_design_id as plate_id, target_name, categories
    from cell_line line
    left join crispr_design cd on cd.id = line.crispr_design_id
    left join cell_line_annotation ant on ant.cell_line_id = line.id;
    ''',
    engine
)

df = df.dropna(axis=0, how='any')
df = df.explode('categories')
df.rename(columns={'categories': 'category'}, inplace=True)

In [None]:
# parse the grade from the category
df['grade'] = df.category.apply(
    lambda s: s[-1] if not pd.isna(s) and s[-1] in ['1', '2', '3'] else 'none'
)

# remove the grade from the category names
df['category'] = df.category.apply(
    lambda s: re.sub('_[1,2,3]$', '', s) if not pd.isna(s) else None
)

In [None]:
# retain only the grade-2 or grade-3 annotations 
# (which are necessarily localization annotations)
df = df.loc[df.grade.isin(['2', '3'])]

In [None]:
categories_to_plot = [
    'nucleoplasm',
    'nuclear_membrane',
    'nuclear_punctae',
    'chromatin',
    'nucleolus_fc_dfc',
    'nucleolus_gc',
    'cytoplasmic',
    'cytoskeleton',
    'centrosome',
    'focal_adhesions',
    'membrane',
    'er',
    'vesicles',
    'mitochondria',
]

In [None]:
all_possible_pairs = list(itertools.combinations(categories_to_plot, 2))

# initialize a dataframe of pairwise counts
pairwise_counts = pd.DataFrame(columns=categories_to_plot, index=categories_to_plot)
pairwise_counts.loc[:] = 0

In [None]:
# explicitly count the pairs of categories
grouped = df.groupby('cell_line_id')

for cell_line_id in df.cell_line_id.unique():
    categories = grouped.get_group(cell_line_id).category.tolist()
    for row_category in pairwise_counts.index:
            for col_category in pairwise_counts.columns:
                if col_category in categories and row_category in categories:
                    pairwise_counts.at[row_category, col_category] += 1

In [None]:
# normalize each row by the frequency of its category
counts = df.groupby('category').cell_line_id.count()
for category in pairwise_counts:
    pairwise_counts.loc[category] /= counts[category]

In [None]:
plt.figure(figsize=(10, 10))

sns.heatmap(
    pairwise_counts.astype(float),
    cmap='YlGnBu', 
    vmax=None, 
    square=True, 
    linewidths=.5,
    annot=True,
    fmt='0.2f'
)