# Manually explore Arevalo2023 pipeline results

In [1]:
import pandas as pd
import numpy as np
# import pyarrow.parquet as pq
from tqdm.contrib.concurrent import thread_map
import anndata as ad
from scvi.external import SysVI

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
import pyarrow

### io.py

In [None]:
def to_anndata(parquet_path):
    meta, feats, features = split_parquet(parquet_path)
    meta.index = meta.index.astype(str)
    adata = ad.AnnData(feats, meta)
    adata.var_names = features
    return adata


def split_parquet(dframe_path,
                  features=None) -> tuple[pd.DataFrame, np.ndarray, list[str]]:
    dframe = pd.read_parquet(dframe_path)
    if features is None:
        features = find_feat_cols(dframe)
    vals = np.empty((len(dframe), len(features)), dtype=np.float32)
    for i, c in enumerate(features):
        vals[:, i] = dframe[c]
    meta = dframe[find_meta_cols(dframe)].copy()
    return meta, vals, features


def merge_parquet(meta, vals, features, output_path) -> None:
    '''Save the data in a parquet file resetting the index'''
    dframe = pd.DataFrame(vals, columns=features)
    for c in meta:
        dframe[c] = meta[c].reset_index(drop=True)
    dframe.to_parquet(output_path)


def get_num_rows(path) -> int:
    '''Count the number of rows in a parquet file'''
    with pq.ParquetFile(path) as file:
        return file.metadata.num_rows


def prealloc_params(sources, plate_types):
    '''
    Get a list of paths to the parquet files and the corresponding slices
    for further concatenation
    '''
    meta = load_metadata(sources, plate_types)
    paths = (meta[['Metadata_Source', 'Metadata_Batch',
                   'Metadata_Plate']].drop_duplicates().apply(build_path,
                                                              axis=1)).values

    counts = thread_map(get_num_rows, paths, leave=False, desc='counts')
    slices = np.zeros((len(paths), 2), dtype=int)
    slices[:, 1] = np.cumsum(counts)
    slices[1:, 0] = slices[:-1, 1]
    return paths, slices


def load_data(sources, plate_types):
    '''Load all plates given the params'''
    paths, slices = prealloc_params(sources, plate_types)
    total = slices[-1, 1]

    with pq.ParquetFile(paths[0]) as f:
        meta_cols = find_meta_cols(f.schema.names)
        feat_cols = find_feat_cols(f.schema.names)
    meta = np.empty([total, len(meta_cols)], dtype='|S128')
    feats = np.empty([total, len(feat_cols)], dtype=np.float32)

    def read_parquet(params):
        path, start, end = params
        df = pd.read_parquet(path)
        meta[start:end] = df[meta_cols].values
        feats[start:end] = df[feat_cols].values

    params = np.concatenate([paths[:, None], slices], axis=1)
    thread_map(read_parquet, params)

    meta = pd.DataFrame(data=meta.astype(str),
                        columns=meta_cols,
                        dtype='category')
    dframe = pd.DataFrame(columns=feat_cols, data=feats)
    for col in meta_cols:
        dframe[col] = meta[col]
    return dframe


def add_pert_type(meta: pd.DataFrame, col: str = 'Metadata_PertType'):
    meta[col] = 'trt'
    meta.loc[~meta['Metadata_JCP2022'].str.startswith('JCP'), col] = 'poscon'
    meta.loc[meta['Metadata_JCP2022'] == 'DMSO', col] = 'negcon'
    meta[col] = meta[col].astype('category')


def add_row_col(meta: pd.DataFrame):
    '''Add Metadata_Row and Metadata_Column to the DataFrame'''
    well_regex = r'^(?P<row>[a-zA-Z]{1,2})(?P<column>[0-9]{1,2})$'
    position = meta['Metadata_Well'].str.extract(well_regex)
    meta['Metadata_Row'] = position['row'].astype('category')
    meta['Metadata_Column'] = position['column'].astype('category')

def add_microscopy_info(meta: pd.DataFrame):
    configs = meta['Metadata_Source'].map(MICRO_CONFIG).astype('category')
    meta['Metadata_Microscope'] = configs

def write_parquet(sources, plate_types, output_file):
    '''Write the parquet dataset given the params'''
    dframe = load_data(sources, plate_types)
    # Efficient merge
    meta = load_metadata(sources, plate_types)
    add_pert_type(meta)
    add_row_col(meta)
    add_microscopy_info(meta)
    foreign_key = ['Metadata_Source', 'Metadata_Plate', 'Metadata_Well']
    meta = dframe[foreign_key].merge(meta, on=foreign_key, how='left')
    for c in meta:
        dframe[c] = meta[c].astype('category')
    # Dropping samples with no metadata
    dframe.dropna(subset=['Metadata_JCP2022'], inplace=True)
    dframe.reset_index(drop=True, inplace=True)
    dframe.to_parquet(output_file)


## metadata.py

In [None]:
"""
Functions to load metadata information
"""
import logging
from collections.abc import Iterable

import pandas as pd

logger = logging.getLogger(__name__)

MAPPER = {
    "JCP2022_085227": "Aloxistatin",
    "JCP2022_037716": "AMG900",
    "JCP2022_025848": "Dexamethasone",
    "JCP2022_046054": "FK-866",
    "JCP2022_035095": "LY2109761",
    "JCP2022_064022": "NVS-PAK1-1",
    "JCP2022_050797": "Quinidine",
    "JCP2022_012818": "TC-S-7004",
    "JCP2022_033924": "DMSO",
    "JCP2022_999999": "UNTREATED",
    "JCP2022_UNKNOWN": "UNKNOWN",
    "JCP2022_900001": "BAD CONSTRUCT",
}

MICRO_CONFIG = pd.read_csv(
    "https://raw.githubusercontent.com/jump-cellpainting/datasets/181fa0dc96b0d68511b437cf75a712ec782576aa/metadata/microscope_config.csv"
)
MICRO_CONFIG["Metadata_Source"] = "source_" + MICRO_CONFIG["Metadata_Source"].astype(
    str
)
MICRO_CONFIG = MICRO_CONFIG.set_index("Metadata_Source")["Metadata_Microscope_Name"]


def find_feat_cols(cols: Iterable[str]):
    """Find column names for features"""
    feat_cols = [c for c in cols if not c.startswith("Meta")]
    return feat_cols


def find_meta_cols(cols: Iterable[str]):
    """Find column names for metadata"""
    meta_cols = [c for c in cols if c.startswith("Meta")]
    return meta_cols


def get_source_4_plate_redlist(plate_types: list[str]):
    """Get set of plate_id's  that should be not considered in the analysis"""
    # https://github.com/jump-cellpainting/jump-orf-analysis/issues/1#issuecomment-921888625
    # Low concentration plates
    redlist = set(["BR00127147", "BR00127148", "BR00127145", "BR00127146"])
    # https://github.com/jump-cellpainting/aws/issues/70#issuecomment-1182444836
    redlist.add("BR00123528A")

    metadata = pd.read_csv("inputs/experiment-metadata.tsv", sep="\t")
    if "ORF" in plate_types:
        # filter ORF plates.
        query = 'Batch=="Batch12"'
        bad_plates = set(metadata.query(query).Assay_Plate_Barcode)
        redlist |= bad_plates

    if "TARGET2" in plate_types:
        # filter TARGET2 plates
        query = 'Anomaly!="none"'
        bad_plates = set(metadata.query(query).Assay_Plate_Barcode)
        redlist |= bad_plates
    return redlist


SOURCE3_BATCH_REDLIST = {
    "CP_32_all_Phenix1",
    "CP_33_all_Phenix1",
    "CP_34_mix_Phenix1",
    "CP_35_all_Phenix1",
    "CP_36_all_Phenix1",
    "CP59",
    "CP60",
}


def build_path(row: pd.Series) -> str:
    """Create the path to the parquet file"""
    template = (
        "./inputs/{Metadata_Source}/workspace/profiles/"
        "{Metadata_Batch}/{Metadata_Plate}/{Metadata_Plate}.parquet"
    )
    return template.format(**row.to_dict())


def get_plate_metadata(sources: list[str], plate_types: list[str]) -> pd.DataFrame:
    """Create filtered metadata DataFrame"""
    plate_metadata = pd.read_csv("./inputs/metadata/plate.csv.gz")
    # Filter plates from source_4
    if "source_4" in sources:
        redlist = get_source_4_plate_redlist(plate_types)
        plate_metadata = plate_metadata[~plate_metadata["Metadata_Plate"].isin(redlist)]

    # Filter plates from source_3 batches without DMSO
    plate_metadata = plate_metadata[
        (~plate_metadata["Metadata_Batch"].isin(SOURCE3_BATCH_REDLIST))
        | (plate_metadata["Metadata_PlateType"] == "TARGET2")
    ]

    plate_metadata = plate_metadata[plate_metadata["Metadata_Source"].isin(sources)]
    plate_metadata = plate_metadata[
        plate_metadata["Metadata_PlateType"].isin(plate_types)
    ]
    return plate_metadata


def get_well_metadata(plate_types: list[str]):
    """Load well metadata"""
    well_metadata = pd.read_csv("./inputs/metadata/well.csv.gz")
    if "ORF" in plate_types:
        orf_metadata = pd.read_csv("./inputs/metadata/orf.csv.gz")
        well_metadata = well_metadata.merge(orf_metadata, how="inner")
    # Use readable names for controls and non-treatment codes
    well_metadata["Metadata_JCP2022"] = well_metadata["Metadata_JCP2022"].apply(
        lambda x: MAPPER.get(x, x)
    )
    # Filter out wells
    well_metadata = well_metadata[
        ~well_metadata["Metadata_JCP2022"].isin(
            ["UNTREATED", "UNKNOWN", "BAD CONSTRUCT"]
        )
    ]

    return well_metadata


def load_metadata(sources: list[str], plate_types: list[str]):
    """Load metadata only"""
    plate = get_plate_metadata(sources, plate_types)
    well = get_well_metadata(plate_types)
    meta = well.merge(plate, on=["Metadata_Source", "Metadata_Plate"])
    return meta


In [4]:
batch_key = ["Metadata_Batch", "Metadata_Plate"]
label_key = "Metadata_JCP2022",

n_latent = 30

adata = to_anndata("../outputs/scenario_1/mad_int_featselect.parquet")
meta = adata.obs.reset_index(drop=True).copy()

if isinstance(batch_key, list):
    actual_batch_key = batch_key[0]
    categorical_covariate_keys = batch_key[1:]
else:
    batch_key = batch_key
    categorical_covariate_keys = []

SysVI.setup_anndata(adata, batch_key=actual_batch_key, categorical_covariate_keys=categorical_covariate_keys)
vae = SysVI(adata, prior="standard_normal")
# vae.view_anndata_setup(adata=adata)
vae.train()


[34mINFO    [0m Using column names from columns of adata.obsm[1m[[0m[32m'system'[0m[1m][0m                                                   
[34mINFO    [0m Using column names from columns of adata.obsm[1m[[0m[32m'covariates'[0m[1m][0m                                               
[34mINFO    [0m The model has been initialized                                                                            


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
SLURM auto-requeueing enabled. Setting signal handlers.


Epoch 320/400:  80%|███████▉  | 319/400 [1:24:33<15:18, 11.34s/it, v_num=1, loss_train=-152]

In [None]:
import torch

if torch.cuda.is_available():
    print("GPU is available")
else:
    print("GPU is not available")


GPU is not available


In [None]:
# min_value = adata.X.min()
# adata.X -= min_value

if isinstance(batch_key, list):
    actual_batch_key = batch_key[0]
else:
    categorical_covariate_keys = batch_key[1:]

SysVI.setup_anndata(adata, batch_key=batch_key, categorical_covariate_keys=categorical_covariate_keys)
vae = SysVI(adata, n_layers=2, n_latent=n_latent, prior="standard_normal")


vals = vae.get_latent_representation()
features = [f'sysvi_{i}' for i in range(vals.shape[1])]