In [None]:
import numpy as np
import pandas as pd
import ms_feature_validation as mfv
import bokeh.plotting
bokeh.plotting.output_notebook()
import matplotlib.pyplot as plt

In [None]:
data = mfv.fileio.read_progenesis("../examples/SuerosRCC_ESi_neg_default_SepOct2017.csv")

# adding order and batch information
temp = pd.Series(data=data.sample_metadata.index.str.split("_"),
                 index=data.sample_metadata.index)
order = temp.apply(lambda x: x[-1]).astype(int)
dates = temp.apply(lambda x: x[1])
dates_to_batch = dict(zip(dates.unique(), range(1, dates.size + 1)))
batch = (temp.apply(lambda x: dates_to_batch[x[1]])).astype(int)

def convert_to_global_run_order(order, batch):
    max_order = order.groupby(batch).max()
    max_order[0] = 0
    max_order = max_order.sort_index()
    max_order = max_order.cumsum()
    global_run_order = order + batch.apply(lambda x: max_order[x - 1])
    return global_run_order

data.order = convert_to_global_run_order(order, batch)
data.batch = batch
data.id = data.sample_metadata.index

# setup sample types
sample_mapping = {"qc": ["QC"],
                  "blank": ["SV"],
                  "sample": ["CS", "EI", "EII", "EIII", "EIV", "EI2", "EII2",
                             "EIII2", "EIV2", "Crb", "Pa"]}
data.mapping = sample_mapping

In [None]:
# cr = mfv.filter.ClassRemover(["Sarc", "B", "O", "Onc", "EII2", "Pa"])
# bc = mfv.filter.BlankCorrector()
# pf = mfv.filter.PrevalenceFilter(lb=0.8)
# vf = mfv.filter.VariationFilter(ub=0.25)
batch_checker = mfv.filter.BatchSchemeChecker()
batch_checker.process(data)
batch_prevalence = mfv.filter.BatchPrevalenceChecker()
batch_prevalence.process(data)
# pipeline = mfv.filter.Pipeline([cr, bc, pf, vf])
# pipeline.process(data)

In [None]:
data.data_matrix.shape

In [None]:
data.plot.pca_scores(color_by="type");

In [None]:
data.plot.pca_loadings()

In [None]:
sample_classes = data.mapping["sample"]
qc_classes = ["QC"]

def get_ext_per_batch(order, batch, classes, class_list, ext):
    mask = classes.isin(class_list)
    func = {"min": lambda x: x.min(), "max": lambda x: x.max()}
    func = func[ext]
    ext_order = (order[mask]
                 .groupby([classes[mask], batch[mask]])
                 .pipe(func)
                 .reset_index()
                 .pivot(index=classes.name, columns=batch.name, values=order.name)
                 .pipe(func) # min used to convert to pandas Series.
                 .astype(int)) 
    return ext_order


def check_qc_prevalence(data_matrix, order, batch, classes, qc_classes, sample_classes,
                        thresh=0, min_n_qc=4):

    min_qc_order = get_ext_per_batch(data.order, data.batch,
                                     data.classes, qc_classes, "min")
    min_sample_order = get_ext_per_batch(data.order, data.batch,
                                         data.classes, sample_classes, "min")
    max_qc_order = get_ext_per_batch(data.order, data.batch,
                                     data.classes, qc_classes, "max")
    max_sample_order = get_ext_per_batch(data.order, data.batch,
                                         data.classes, sample_classes, "max")
    batches = batch.unique()
    valid_features = data_matrix.columns
    qc_mask = classes.isin(qc_classes)
    qc_per_batch = classes[qc_mask].groupby(batch[qc_mask]).count()

    for k_batch in batches:
        # feature check is done for each batch in three parts:
        # | start block | middle block      | end block |
        #   q   q         ssss q ssss q ssss  q  q
        #  where q is a qc sample and s is a biological sample
        # in the start block, a feature is valid if is detected
        # in at least one sample of the block
        # in the middle block, a feature is valid if the number
        # of qc samples where the feature was detected is greater
        # than the total number of qc samples in the block minus the
        # n_missing parameter
        # in the end block the same strategy applied in the start
        # block is used.
        # A feature is considered valid only if is valid in the totallity
        # of the batches
        # start block check
        start_block_qc_samples = (order[(order >= min_qc_order[k_batch])
                                        & (order < min_sample_order[k_batch])
                                        & data.classes.isin(qc_classes)]
                                  .index)
        start_block_valid_features = (data.data_matrix.loc[start_block_qc_samples] > thresh).any()
        start_block_valid_features = start_block_valid_features[start_block_valid_features].index
        
        valid_features = valid_features.intersection(start_block_valid_features)
    
        
        # middle block check
        middle_block_qc_samples = (order[(order > min_sample_order[k_batch])
                                         & (order < max_sample_order[k_batch])
                                         & data.classes.isin(qc_classes)]
                                   .index)
        middle_block_valid_features = ((data.data_matrix
                                        .loc[middle_block_qc_samples] > thresh)
                                       .sum() >= min_n_qc)
        middle_block_valid_features = middle_block_valid_features[middle_block_valid_features].index
        valid_features = valid_features.intersection(middle_block_valid_features)
        
        # end block check
        end_block_qc_samples = (order[(order > max_sample_order[k_batch])
                                      & (order <= max_qc_order[k_batch])
                                      & data.classes.isin(qc_classes)]
                                .index)
        end_block_valid_features = (data.data_matrix.loc[end_block_qc_samples] > thresh).any()
        end_block_valid_features = end_block_valid_features[end_block_valid_features].index
        valid_features = valid_features.intersection(end_block_valid_features)
        
    
    invalid_features = data_matrix.columns.difference(valid_features)
    
    return invalid_features

In [None]:
%%time
a = check_qc_prevalence(data.data_matrix, data.order, data.batch, data.classes, qc_classes, sample_classes)

In [None]:
qc_cv_per_batch = (data.data_matrix
                   .groupby(data.classes)
                   .filter(lambda x: x.name in ["QC"])
                   .groupby(data.batch)
                   .apply(mfv.utils.cv))

In [None]:
qc_cv_per_batch

In [None]:
qc_cv_per_batch.idxmax().value_counts()

In [None]:
# fig, ax = plt.subplots(figsize=(12, 8))
data.data_matrix.groupby([data.batch, data.classes]).apply(mfv.utils.cv) #.min().hist(bins=100, ax=ax)
# ax.set_xlim(0, 1)

In [None]:
data.sample_metadata = data.sample_metadata.sort_values("order")
data.data_matrix = data.data_matrix.loc[data.sample_metadata.index, :]

In [None]:
from statsmodels.nonparametric.smoothers_lowess import lowess
from scipy.interpolate import CubicSpline

def _coov_loess(x, y, interpolator, frac=None) -> tuple:
    """
    Helper function for batch_correction. Computes loess correction with LOOCV.

    Parameters
    ----------
    x: pd.Series
        Feature intensities
    frac: float, optional
        fraction of sample to use in LOESS correction. If None, determines the
        best value using LOOCV.
    interpolator = callable
        interpolator function used to predict new values.
    Returns
    -------
    frac: float.
        Best frac found by LOOCV
    corrected: pd.Series
        LOESS corrected data
    """
    if frac is None:
        # valid frac values, from 4/N to 1/N, where N is the number of corrector
        # samples.
        frac_list = [k / x.size for k in range(4, x.size + 1)]
        rms = np.inf    # initial value for root mean square error
        best_frac = 1
        for frac in frac_list:
            curr_rms = 0
            for loocv_index in x.index[1:-1]:
                y_temp = y.drop(loocv_index)
                x_temp = x.drop(loocv_index)
                y_loess = lowess(y_temp, x_temp, return_sorted=False, frac=frac)
                interp = interpolator(x_temp, y_loess)
                curr_rms += (y[loocv_index] - interp(x[loocv_index])) ** 2
            if rms > curr_rms:
                best_frac = frac
                rms = curr_rms
        frac = best_frac
    return lowess(y, x, return_sorted=False, frac=frac)

def batch_corrector_func(df_batch, order, classes, frac, interpolator, qc_classes, sample_classes):
    qc_index = classes.isin(qc_classes)
    qc_index = qc_index[qc_index].index
    sample_index = classes.isin(sample_classes)
    sample_index = sample_index[sample_index].index
    df_batch.loc[sample_index, :] = (df_batch
                                     .apply(lambda x: loess_interp(x, order, qc_index,
                                                                   sample_index, frac, interpolator)))
    return df_batch
    

def loess_interp(ft_data, order, qc_index, sample_index, frac, interpolator):
    qc_loess = _coov_loess(order[qc_index], ft_data[qc_index], interpolator, frac=frac)
    interp = interpolator(order[qc_index], qc_loess)
    ft_data[sample_index] = interp(order[sample_index])
    return ft_data

def intebatch_correction(df_group, order, batch, classes, qc_index, sample_indx, frac, intepolator):
    batch_order = order[batch == df_group.name]
    batch_classes = classes[batch == df_group.name]
    



In [None]:
%%time
df_batch = data.data_matrix[data.batch == 1]
order_batch = data.order[data.batch == 1]
classes_batch = data.classes[data.batch == 1]
qc_classes = ["QC"]
sample_classes = ['CS', 'EI', 'EII', 'EIII', 'EIV', 'EI2', 'EII2', 'EIII2', 'EIV2', 'Crb', 'Pa', "QC"]
batch_corrector_func(df_batch, order_batch, classes_batch, None, CubicSpline, qc_classes, sample_classes)

In [None]:
y = data.data_matrix.loc[(data.batch == 1) & (data.classes == "QC"), "0.04_304.9041n"]
x = data.order[y.index]
y_loess = _coov_loess(x, y, interpolator=CubicSpline)

fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(x, y)
ax.plot(x, y_loess)

In [None]:
a = df_batch[classes_batch == "QC"]
acv = a.std() / a.mean()
b = data.data_matrix[(data.batch == 1) & (data.classes == "QC")]
bcv = b.std() / b.mean()
bcv = bcv[acv.index]
acv = acv[acv > bcv]

In [None]:
a = df_batch.loc[classes_batch == "QC", acv.index]
a.mean().hist()