In [None]:
import pandas as pd
import numpy as np
import sklearn
import scipy

import anndata as ad
import scanpy as sc

from dask import delayed
from dask.distributed import Client, LocalCluster

import os, binascii

In [None]:
! aws s3 cp s3://saturn-kaggle-datasets/open-problems-single-cell-perturbations-optional/train_or_control_bulk_by_cell_type_adata.h5ad --no-sign-request .
! aws s3 cp s3://saturn-kaggle-datasets/open-problems-single-cell-perturbations-optional/lincs_id_compound_mapping.parquet --no-sign-request .

## Loading expression data

Here we load expression data (long format) and converting it into an AnnData object (wide sparse format).

You'll need to increase your instance RAM to at least 64 GB.

In [None]:
data_dir = '/home/jovyan/kaggle/input/open-problems-single-cell-perturbations'
adata_train_df = pd.read_parquet(os.path.join(data_dir, 'adata_train.parquet'))
adata_obs_meta_df = pd.read_csv(os.path.join(data_dir, 'adata_obs_meta.csv'))

adata_train_df['obs_id'] = adata_train_df['obs_id'].astype('category')
adata_train_df['gene'] = adata_train_df['gene'].astype('category')

obs_ids = adata_train_df['obs_id'].unique()
obs_id_map = dict(zip(obs_ids, range(len(obs_ids))))

genes = adata_train_df['gene'].unique()
gene_map = dict(zip(genes, range(len(genes))))

adata_train_df['obs_index'] = adata_train_df['obs_id'].map(obs_id_map)
adata_train_df['gene_index'] = adata_train_df['gene'].map(gene_map)

normalized_counts_values = adata_train_df['normalized_count'].to_numpy()
counts_values = adata_train_df['count'].to_numpy()

row_indices = adata_train_df['obs_index'].to_numpy()
col_indices = adata_train_df['gene_index'].to_numpy()

counts = scipy.sparse.csr_matrix((counts_values, (row_indices, col_indices)))

obs_df = pd.Series(obs_ids, name='obs_id').to_frame()
var_df = pd.Series(genes, name='gene').to_frame()

obs_df = obs_df.set_index('obs_id')
var_df = var_df.set_index('gene')

obs_df.index = obs_df.index.astype('str')
var_df.index = var_df.index.astype('str')

counts_adata = ad.AnnData(
    X=counts,
    obs=obs_df,
    var=var_df,
    dtype=np.uint32,
)

index_ordering_before_join = counts_adata.obs.index
counts_adata.obs = counts_adata.obs.join(adata_obs_meta_df.set_index('obs_id'))
index_ordering_after_join = counts_adata.obs.index
assert (index_ordering_before_join == index_ordering_after_join).all()

## Pseudobulking counts by cell type

In [None]:
from scipy import sparse

def sum_by(adata: ad.AnnData, col: str) -> ad.AnnData:
    """
    Adapted from this forum post: 
    https://discourse.scverse.org/t/group-sum-rows-based-on-jobs-feature/371/4
    """
    
    assert pd.api.types.is_categorical_dtype(adata.obs[col])

    # sum `.X` entries for each unique value in `col`
    cat = adata.obs[col].values
    indicator = sparse.coo_matrix(
        (
            np.broadcast_to(True, adata.n_obs),
            (cat.codes, np.arange(adata.n_obs))
        ),
        shape=(len(cat.categories), adata.n_obs),
    )
    sum_adata = ad.AnnData(
        indicator @ adata.X,
        var=adata.var,
        obs=pd.DataFrame(index=cat.categories),
        dtype=adata.X.dtype,
    )
    
    # copy over `.obs` values that have a one-to-one-mapping with `.obs[col]`
    obs_cols = adata.obs.columns
    obs_cols = list(set(adata.obs.columns) - set([col]))
    
    one_to_one_mapped_obs_cols = []
    nunique_in_col = adata.obs[col].nunique()
    for other_col in obs_cols:
        if len(adata.obs[[col, other_col]].drop_duplicates()) == nunique_in_col:
            one_to_one_mapped_obs_cols.append(other_col)

    joining_df = adata.obs[[col] + one_to_one_mapped_obs_cols].drop_duplicates().set_index(col)
    assert (sum_adata.obs.index == sum_adata.obs.join(joining_df).index).all()
    sum_adata.obs = sum_adata.obs.join(joining_df)
    sum_adata.obs.index.name = col
    sum_adata.obs = sum_adata.obs.reset_index()
    sum_adata.obs.index = sum_adata.obs.index.astype('str')

    return sum_adata

In [None]:
counts_adata.obs['plate_well_cell_type'] = counts_adata.obs['plate_name'].astype('str') \
    + '_' + counts_adata.obs['well'].astype('str') \
    + '_' + counts_adata.obs['cell_type'].astype('str')
counts_adata.obs['plate_well_cell_type'] = counts_adata.obs['plate_well_cell_type'].astype('category')

bulk_adata = sum_by(counts_adata, 'plate_well_cell_type')
bulk_adata.obs = bulk_adata.obs.drop(columns=['plate_well_cell_type'])
bulk_adata.X = np.array(bulk_adata.X.todense())
bulk_adata.X = bulk_adata.X.astype('float64')
bulk_adata = bulk_adata.copy()

In [None]:
plate_reordering = {
    'plate_0': 'plate_1',
    'plate_1': 'plate_2',
    'plate_2': 'plate_3',
    'plate_3': 'plate_0',
    'plate_4': 'plate_4',
    'plate_5': 'plate_5',
}

## Loading pseudobulked counts from correctly filtered AnnData (@Daniel: delete this section once the fixed RNA expression AnnData is uploaded to Kaggle)

In [None]:
fixed_bulk_adata = sc.read_h5ad('train_or_control_bulk_by_cell_type_adata.h5ad')
fixed_bulk_adata.X = fixed_bulk_adata.layers['counts']

In [None]:
new_plate_names = bulk_adata.obs['plate_name'].sort_values().unique()
original_plate_names = fixed_bulk_adata.obs['plate_name'].sort_values().unique()
plate_name_map = dict(zip(original_plate_names, new_plate_names))

fixed_bulk_adata.obs['plate_name'] = fixed_bulk_adata.obs['plate_name'].map(plate_name_map)
fixed_bulk_adata.obs['plate_name'] = fixed_bulk_adata.obs['plate_name'].map(plate_reordering)

In [None]:
new_donor_ids = bulk_adata.obs['donor_id'].sort_values().unique()
original_donor_ids = fixed_bulk_adata.obs['raw_cell_id'].sort_values().unique()
donor_id_map = dict(zip(original_donor_ids, new_donor_ids))

fixed_bulk_adata.obs['donor_id'] = fixed_bulk_adata.obs['raw_cell_id'].map(donor_id_map)

In [None]:
lincs_id_mapping_df = pd.read_parquet('lincs_id_compound_mapping.parquet')

In [None]:
compound_id_to_sm_lincs_id = lincs_id_mapping_df.set_index('compound_id')['sm_lincs_id'].to_dict()
compound_id_to_sm_name = lincs_id_mapping_df.set_index('compound_id')['sm_name'].to_dict()
compound_id_to_smiles = lincs_id_mapping_df.set_index('compound_id')['smiles'].to_dict()

fixed_bulk_adata.obs['sm_lincs_id'] = \
    fixed_bulk_adata.obs['compound_id'].map(compound_id_to_sm_lincs_id)
fixed_bulk_adata.obs['sm_name'] = \
    fixed_bulk_adata.obs['compound_id'].map(compound_id_to_sm_name)
fixed_bulk_adata.obs['SMILES'] = \
    fixed_bulk_adata.obs['compound_id'].map(compound_id_to_smiles)

In [None]:
sorted_index = fixed_bulk_adata.obs.sort_values(['plate_name', 'sm_name', 'cell_type']).index
fixed_bulk_adata = fixed_bulk_adata[sorted_index].copy()

In [None]:
sorted_index = bulk_adata.obs.sort_values(['plate_name', 'sm_name', 'cell_type']).index
bulk_adata = bulk_adata[sorted_index].copy()

If the new expression data is uploaded to Kaggle (or placed in `'/home/jovyan/kaggle/input/open-problems-multimodal-2023'` in the correct format), the assertion below should evaluate to True.

In [None]:
bulk_adata.obs['plate_name'] = bulk_adata.obs['plate_name'].map(plate_reordering)

In [None]:
# assert np.allclose(fixed_bulk_adata.X, bulk_adata.X)

In [None]:
bulk_adata = fixed_bulk_adata.copy()

## Running Limma

In [None]:
# bulk_adata.obs['plate_name'] = bulk_adata.obs['plate_name'].map(plate_reordering)

In [None]:
de_pert_cols = [
    'sm_name',
    'sm_lincs_id',
    'SMILES',
    'dose_uM',
    'timepoint_hr',
    'cell_type',
]

control_compound = 'Dimethyl Sulfoxide'

In [None]:
import limma_utils
import imp
imp.reload(limma_utils)


In [None]:
!mkdir -p output

In [None]:
def _run_limma_for_cell_type(bulk_adata):
    import limma_utils
    bulk_adata = bulk_adata.copy()
    
    compound_name_col = de_pert_cols[0]
    
    # limma doesn't like dashes etc. in the compound names
    rpert_mapping = bulk_adata.obs[compound_name_col].drop_duplicates() \
        .reset_index(drop=True).reset_index() \
        .set_index(compound_name_col)['index'].to_dict()
    
    bulk_adata.obs['Rpert'] = bulk_adata.obs.apply(
        lambda row: rpert_mapping[row[compound_name_col]], 
        axis='columns',
    ).astype('str')

    compound_name_to_Rpert = bulk_adata.obs.set_index(compound_name_col)['Rpert'].to_dict()
    ref_pert = compound_name_to_Rpert[control_compound]
            
    random_string = binascii.b2a_hex(os.urandom(15)).decode()
    
    
    limma_utils.limma_fit(
        bulk_adata, 
        design='~0+Rpert+donor_id+plate_name+row',
        output_path=f'output/{random_string}_limma.rds',
        plot_output_path=f'output/{random_string}_voom',
        exec_path='limma_fit.r',
        verbose=True,
    )

    pert_de_dfs = []
    


    for pert in bulk_adata.obs['Rpert'].unique():
        if pert == ref_pert:
            continue

        pert_de_df = limma_utils.limma_contrast(
            fit_path=f'output/{random_string}_limma.rds',
            contrast='Rpert'+pert+'-Rpert'+ref_pert,
            exec_path='limma_contrast.r',
        )

        pert_de_df['Rpert'] = pert

        pert_obs = bulk_adata.obs[bulk_adata.obs['Rpert'].eq(pert)]
        for col in de_pert_cols:
            pert_de_df[col] = pert_obs[col].unique()[0]
        pert_de_dfs.append(pert_de_df)

    de_df = pd.concat(pert_de_dfs, axis=0)

    try:
        os.remove(f'output/{random_string}_limma.rds')
        os.remove(f'output/{random_string}_voom')
    except FileNotFoundError:
        pass
    
    return de_df

run_limma_for_cell_type = delayed(_run_limma_for_cell_type)

In [None]:
%%capture

cluster = LocalCluster(
    n_workers=6,
    processes=True,
    threads_per_worker=1,
    memory_limit='20GB',
)

c = Client(cluster)

In [None]:
%%capture

cell_types = bulk_adata.obs['cell_type'].unique()
de_dfs = []

for cell_type in cell_types:
    cell_type_selection = bulk_adata.obs['cell_type'].eq(cell_type)
    cell_type_bulk_adata = bulk_adata[cell_type_selection].copy()
    
    de_df = run_limma_for_cell_type(cell_type_bulk_adata)
    
    de_dfs.append(de_df)

de_dfs = c.compute(de_dfs, sync=True)
de_df = pd.concat(de_dfs)

## Converting DataFrame to Anndata

In [None]:
def convert_de_df_to_anndata(de_df, pert_cols, de_sig_cutoff):
    de_df = de_df.copy()
    zero_pval_selection = de_df['P.Value'].eq(0)
    de_df.loc[zero_pval_selection, 'P.Value'] = np.finfo(np.float64).eps

    de_df['sign_log10_pval'] = np.sign(de_df['logFC']) * -np.log10(de_df['P.Value'])
    de_df['is_de'] = de_df['P.Value'].lt(de_sig_cutoff)
    de_df['is_de_adj'] = de_df['adj.P.Val'].lt(de_sig_cutoff)

    de_feature_dfs = {}
    for feature in ['is_de', 'is_de_adj', 'sign_log10_pval', 'logFC', 'P.Value', 'adj.P.Val']:
        df = de_df.reset_index().pivot_table(
            index=['gene'], 
            columns=pert_cols,
            values=[feature],
            dropna=True,
        )
        de_feature_dfs[feature] = df

    de_adata = ad.AnnData(de_feature_dfs['sign_log10_pval'].T, dtype=np.float64)
    de_adata.obs = de_adata.obs.reset_index()
    de_adata.obs = de_adata.obs.drop(columns=['level_0'])
    de_adata.obs.index = de_adata.obs.index.astype('string')

    de_adata.layers['is_de'] = de_feature_dfs['is_de'].to_numpy().T
    de_adata.layers['is_de_adj'] = de_feature_dfs['is_de_adj'].to_numpy().T
    de_adata.layers['logFC'] = de_feature_dfs['logFC'].to_numpy().T
    de_adata.layers['P.Value'] = de_feature_dfs['P.Value'].to_numpy().T
    de_adata.layers['adj.P.Val'] = de_feature_dfs['adj.P.Val'].to_numpy().T
    
    return de_adata

In [None]:
de_adata = convert_de_df_to_anndata(de_df, de_pert_cols, 0.05)

In [None]:
bulk_adata.obs

In [None]:
kaggle_train_de_df = pd.read_parquet(os.path.join(data_dir, 'de_train.parquet'))
kaggle_train_de_df = kaggle_train_de_df.set_index(list(kaggle_train_de_df.columns[:5]))

kaggle_train_de_adata = ad.AnnData(kaggle_train_de_df)
kaggle_train_de_adata.obs = kaggle_train_de_adata.obs.reset_index()
kaggle_train_de_adata.obs.index = kaggle_train_de_adata.obs.index.astype('str')

In [None]:
sorting_index = kaggle_train_de_adata.obs.sort_values(['sm_name', 'cell_type']).index
kaggle_train_de_adata = kaggle_train_de_adata[sorting_index].copy()

In [None]:
de_adata.obs.index = de_adata.obs.index.astype('str')

sorting_index = de_adata.obs.sort_values(['sm_name', 'cell_type']).index
de_adata = de_adata[sorting_index].copy()

In [None]:
de_adata.obs

In [None]:
kaggle_train_de_adata

In [None]:
(de_adata.var.index == kaggle_train_de_adata.var.index).all()

In [None]:
import seaborn as sns

In [None]:
sns.scatterplot(
    x=kaggle_train_de_adata.X[100],
    y=de_adata.X[100],
)