In [58]:
import pandas as pd
import numpy.linalg as la
import numpy as np
import patsy
import sys

In [59]:
def design_mat(mod, numerical_covariates, batch_levels):
    # require levels to make sure they are in the same order as we use in the
    # rest of the script.
    design = patsy.dmatrix("~ 0 + C(batch, levels=%s)" % str(batch_levels),
                                                  mod, return_type="dataframe")

    mod = mod.drop(["batch"], axis=1)
    numerical_covariates = list(numerical_covariates)
    sys.stderr.write("found %i batches\n" % design.shape[1])
    other_cols = [c for i, c in enumerate(mod.columns)
                  if not i in numerical_covariates]
    factor_matrix = mod[other_cols]
    design = pd.concat((design, factor_matrix), axis=1)
    if numerical_covariates is not None:
        sys.stderr.write("found %i numerical covariates...\n"
                            % len(numerical_covariates))
        for i, nC in enumerate(numerical_covariates):
            cname = mod.columns[nC]
            sys.stderr.write("\t{0}\n".format(cname))
            design[cname] = mod[mod.columns[nC]]
    sys.stderr.write("found %i categorical variables:" % len(other_cols))
    sys.stderr.write("\t" + ", ".join(other_cols) + '\n')
    return design

In [60]:
def aprior(gamma_hat):
    m = gamma_hat.mean()
    s2 = gamma_hat.var()
    return (2 * s2 +m**2) / np.float64(s2)

In [61]:
def bprior(gamma_hat):
    m = gamma_hat.mean()
    s2 = gamma_hat.var()
    return (m*s2+m**3) / np.float64(s2)

In [62]:
#Used to calculate gamma_star - appears in it_sol and it_sol_loc functions.
def postmean(g_hat, g_bar, n, d_star, t2):
    return (t2*n*g_hat+d_star * g_bar) / (t2*n+d_star)

In [63]:
def postvar(sum2, n, a, b):
    return (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)

In [64]:
def it_sol(sdat, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001):
    n = (1 - np.isnan(sdat)).sum(axis=1)
    g_old = g_hat.copy()
    d_old = d_hat.copy()

    change = 1
    count = 0
    while change > conv:
        #print g_hat.shape, g_bar.shape, t2.shape
        g_new = postmean(g_hat, g_bar, n, d_old, t2)
        sum2 = ((sdat - np.dot(g_new.reshape((g_new.shape[0], 1)), np.ones((1, sdat.shape[1])))) ** 2).sum(axis=1)
        d_new = postvar(sum2, n, a, b)
       
        change = max((abs(g_new - g_old) / g_old).max(), (abs(d_new - d_old) / d_old).max())
        g_old = g_new #.copy()
        d_old = d_new #.copy()
        count = count + 1
    adjust = (g_new, d_new)
    return adjust

In [65]:
def it_sol_loc(g_hat, g_bar, n, d_hat, t2):
    n = 1
    g_old = g_hat.copy()
    d_old = d_hat.copy()

    g_new = postmean(g_hat, g_bar, n, d_old, t2)
    g_old = g_new #.copy()
    adjust = (g_new)
    return adjust

In [66]:
def it_sol_scale(sdat, g_hat, d_hat, g_bar, t2, a, b, conv=0.0001):
    n = (1 - np.isnan(sdat)).sum(axis=1)
    g_old = g_hat.copy()
    d_old = d_hat.copy()

    change = 1
    count = 0
    while change > conv:
        #print g_hat.shape, g_bar.shape, t2.shape
        g_new = postmean(g_hat, g_bar, n, d_old, t2)
        sum2 = ((sdat - np.dot(g_new.reshape((g_new.shape[0], 1)), np.ones((1, sdat.shape[1])))) ** 2).sum(axis=1)
        d_new = postvar(sum2, n, a, b)
       
        change = (abs(d_new - d_old) / d_old).max()
        g_old = g_new #.copy()
        d_old = d_new #.copy()
        count = count + 1
    adjust = (d_new, g_new)
    return adjust

In [67]:
def combat(data, batch, model=None, numerical_covariates=None, mean_only=False, ref_batch=None):
    """Correct for batch effects in a dataset
    Parameters
    ----------
    data : pandas.DataFrame
        A (n_features, n_samples) dataframe of the expression to batch correct
    batch : List-like
        A column corresponding to the batches in the data, in the same order
        as the samples in ``data``
    model : patsy.design_info.DesignMatrix, optional
        A model matrix describing metadata on the samples which could be
        causing batch effects. If not provided, then will attempt to coarsely
        correct just from the information provided in ``batch``
    numerical_covariates : list-like
        List of covariates in the model which are numerical, rather than
        categorical
    mean_only : (T/F)
        Only adjusts for mean across the batches, forgoing scale (variance) adjustment. 
        Default = False
    ref_batch : int
        Adjusts all batches to the specified referance batch. Specified batch will not be adjusted.
        Default = None
        
    Returns
    -------
    corrected : pandas.DataFrame
        A (n_features, n_samples) dataframe of the batch-corrected data
    """
    if isinstance(numerical_covariates, str):
        numerical_covariates = [numerical_covariates]
    
    if numerical_covariates is None:
        numerical_covariates = []

    if model is not None and isinstance(model, pd.DataFrame):
        model["batch"] = list(batch)
    
    else:
        model = pd.DataFrame({'batch': batch})
        
    if ref_batch is not None and isinstance(ref_batch, int):
        ref = int(ref_batch)

    batch_items = model.groupby("batch").groups.items()
    batch_levels = [k for k, v in batch_items]
    batch_info = [v for k, v in batch_items]
    n_batch = len(batch_info)
    n_batches = np.array([len(v) for v in batch_info])
    n_array = sum(n_batches)
    
    drop_cols = [cname for cname, inter in  ((model == 1).all()).iteritems() if inter == True]
    drop_idxs = [list(model.columns).index(cdrop) for cdrop in drop_cols]
    model = model[[c for c in model.columns if not c in drop_cols]]
    numerical_covariates = [list(model.columns).index(c) if isinstance(c, str) else c
            for c in numerical_covariates if not c in drop_cols]
    
    design = design_mat(model, numerical_covariates, batch_levels)
    
    if ref_batch is not None:
        design.iloc[:,ref-1] = 1.
    
    sys.stderr.write("Standardizing Data across genes.\n")
    B_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T)
    
    if ref_batch is not None:
        grand_mean = B_hat[ref-1]
        refdat = dft.loc[dft['batch'] == ref].drop('batch', axis=1)
        reflist=refdat.index.tolist()
        desbatch=design.T[reflist].T
        var_pooled = np.dot(((refdat - np.dot(desbatch, B_hat))**2).T, np.ones((n_batches[ref-1], 1)) / n_batches[ref-1])
            
    else:
        grand_mean = np.dot((n_batches / n_array).T, B_hat[:n_batch,:])
        var_pooled = np.dot(((data - np.dot(design, B_hat).T)**2), np.ones((n_array, 1)) / n_array)

    stand_mean = np.dot(grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, n_array)))
    tmp = np.array(design.copy())
    tmp[:,:n_batch] = 0
    stand_mean  += np.dot(tmp, B_hat).T
    
    s_data = ((data - stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, n_array))))

    sys.stderr.write("Fitting L/S model and finding priors\n")
    batch_design = design[design.columns[:n_batch]]
    gamma_hat = np.dot(np.dot(la.inv(np.dot(batch_design.T, batch_design)), batch_design.T), s_data.T)

    delta_hat = []

    for i, batch_idxs in enumerate(batch_info):
        if mean_only is False:
            delta_hat.append(s_data[batch_idxs].var(axis=1))
        
        else:
            x=s_data[batch_idxs].var(axis=1)
            x[x<100]=1
            np.float64(delta_hat.append(x))
            
    gamma_bar = gamma_hat.mean(axis=1) 
    t2 = gamma_hat.var(axis=1)
    
    a_prior = list(map(aprior, delta_hat))
    b_prior = list(map(bprior, delta_hat))
    
    sys.stderr.write("Finding parametric adjustments\n")
    gamma_star, delta_star = [], []
    for i, batch_idxs in enumerate(batch_info):
    
        if mean_only is False:
            
            temp = it_sol(s_data[batch_idxs], gamma_hat[i],
            delta_hat[i], gamma_bar[i], t2[i], a_prior[i], b_prior[i])

            gamma_star.append(temp[0])
            delta_star.append(temp[1])
        
        else:
            
            temploc = it_sol_loc(gamma_hat[i], gamma_bar[i], 1, delta_hat[i], t2[i])

            gamma_star.append(temploc)
        
            tempscale = it_sol_scale(s_data[batch_idxs], gamma_hat[i],
                     delta_hat[i], gamma_bar[i], t2[i], a_prior[i], b_prior[i])
        
            y=tempscale[1]
            y[y<100]=1
            delta_star.append(y)
    
    sys.stdout.write("Adjusting data\n")
    bayesdata = s_data
    gamma_star = np.array(gamma_star)
    delta_star = np.array(delta_star)
    
    if ref_batch is not None:
        gamma_star[ref-1]=0
        delta_star[ref-1]=1

    for j, batch_idxs in enumerate(batch_info):

        dsq = np.sqrt(delta_star[j,:])
        dsq = dsq.reshape((len(dsq), 1))
        denom =  np.dot(dsq, np.ones((1, n_batches[j])))
        numer = np.array(bayesdata[batch_idxs] - np.dot(batch_design.ix[batch_idxs], gamma_star).T)

        bayesdata[batch_idxs] = numer / denom
   
    vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1))
    bayesdata = bayesdata * np.dot(vpsq, np.ones((1, n_array))) + stand_mean
   
    return bayesdata

In [68]:
#Import data.
pheno = pd.read_table('../Data/Python_Data/bladder-pheno.txt', index_col=0)
dat = pd.read_table('../Data/Python_Data/bladder-expr.txt', index_col=0)

In [69]:
dat.head()

Unnamed: 0,GSM71019.CEL,GSM71020.CEL,GSM71021.CEL,GSM71022.CEL,GSM71023.CEL,GSM71024.CEL,GSM71025.CEL,GSM71026.CEL,GSM71028.CEL,GSM71029.CEL,...,GSM71068.CEL,GSM71069.CEL,GSM71070.CEL,GSM71071.CEL,GSM71072.CEL,GSM71073.CEL,GSM71074.CEL,GSM71075.CEL,GSM71076.CEL,GSM71077.CEL
1007_s_at,10.11517,8.628044,8.779235,9.248569,10.256841,10.023133,9.108034,8.735616,9.803271,10.168602,...,10.582892,10.009028,9.912853,9.096809,9.011927,8.396062,8.903465,9.501538,9.540766,9.039143
1053_at,5.345168,5.063598,5.113116,5.17941,5.181383,5.248418,5.252312,5.220931,5.595771,5.02518,...,5.615926,5.151548,5.237126,5.093278,5.353248,5.214357,5.251383,5.223598,5.191392,5.23588
117_at,6.348024,6.663625,6.465892,6.116422,5.980457,5.796155,6.414849,6.846798,5.841478,6.352257,...,5.913488,5.904794,5.960948,6.394089,6.425034,6.37252,6.095344,5.811968,6.007461,6.314809
121_at,8.901739,9.439977,9.540738,9.254368,8.798086,8.00287,9.093704,9.263386,7.78924,9.834564,...,8.049998,8.407351,8.985741,8.817789,8.866083,8.704385,9.375736,8.580523,8.848099,9.663298
1255_g_at,3.967672,4.466027,4.144885,4.189338,4.078509,3.91974,4.40259,4.173666,3.590649,4.338196,...,3.775194,3.995371,4.32238,4.141255,3.997644,4.21936,4.454771,4.18831,4.284053,4.877523


In [70]:
#Run ComBat using only batch designation for adjustment, adjusting for both
#location and scale.
cmdf = combat(dat, pheno.batch, None, None, False, None)

found 5 batches
found 0 numerical covariates...
found 0 categorical variables:	
Standardizing Data across genes.
Fitting L/S model and finding priors
Finding parametric adjustments
  # This is added back by InteractiveShellApp.init_path()


Adjusting data


.ix is deprecated. Please use
.loc for label based indexing or
.iloc for positional indexing

See the documentation here:
http://pandas.pydata.org/pandas-docs/stable/indexing.html#deprecate_ix


In [71]:
cmdf.head()

Unnamed: 0,GSM71019.CEL,GSM71020.CEL,GSM71021.CEL,GSM71022.CEL,GSM71023.CEL,GSM71024.CEL,GSM71025.CEL,GSM71026.CEL,GSM71028.CEL,GSM71029.CEL,...,GSM71068.CEL,GSM71069.CEL,GSM71070.CEL,GSM71071.CEL,GSM71072.CEL,GSM71073.CEL,GSM71074.CEL,GSM71075.CEL,GSM71076.CEL,GSM71077.CEL
1007_s_at,10.086494,8.593022,8.73554,8.904954,10.279651,9.961009,9.045478,8.694423,9.899763,10.04521,...,10.560488,10.268861,10.140819,9.301039,9.229102,8.707159,9.137181,9.593213,9.645439,8.9776
1053_at,5.519204,5.113332,5.158352,5.26221,5.26527,5.3692,5.284904,5.256373,5.510408,5.078403,...,5.528647,5.30409,5.43777,5.055689,5.290943,5.165257,5.198763,5.416638,5.366329,5.435822
117_at,6.56174,6.432694,6.270199,6.220634,6.020383,5.748941,6.228252,6.583224,5.959971,6.176815,...,6.027851,6.024178,6.104956,6.480889,6.510059,6.460556,6.199277,5.890645,6.171866,6.613992
121_at,8.789979,9.223918,9.316884,9.194759,8.670996,7.758172,8.90443,9.060987,8.070841,9.587982,...,8.305276,8.055463,8.689644,8.995561,9.038979,8.893604,9.497184,8.245339,8.538726,9.43256
1255_g_at,3.899327,4.403982,4.092681,4.221117,4.060228,3.829744,4.342489,4.120581,3.7059,4.280069,...,3.870269,3.807746,4.059236,4.196309,4.068399,4.265876,4.475549,3.956128,4.02976,4.486175
