In [2]:
import numpy as np
import scipy as sp
import pandas as pd

In [24]:
import numpy as np
import scipy as sp
import scipy.special as special
import scipy.stats as stats
import pandas as pd


def fisher_exact_test(df, label, col_c, col_t):
    columns = [col_c, col_t]

    row_sums = df[columns].sum()
    sum_c = row_sums[col_c]
    sum_t = row_sums[col_t]
    xs = df[df[label] == 1]
    x_c = xs[col_c][0]
    x_t = xs[col_t][0]
    x = x_c + x_t

    factor1 = special.comb(sum_c, x_c)
    factor2 = special.comb(sum_t, x_t)
    denominator = special.comb(sum_c + sum_t, x)
    pvalue = factor1 * factor2 / denominator
    print('pvalue: {0}'.format(pvalue))
    return pvalue


def gen_data(xc=12, xt=7, totalc=15, totalt=15):
    data = {
        'infection': [1, 0],
        'control': [xc, totalc - xc],
        'treatment': [xt, totalt - xt],
    }
    return pd.DataFrame(data=data)


def fisher_exact_test_scipy(df, label, col_c, col_t):
    columns = [col_c, col_t]

    xs = df[columns].values

    print('xs: {0}'.format(xs))
    print('xs[0, 1]: {0}'.format(xs[0, 1]))
    pvalue = stats.fisher_exact(xs)
    print('pvalue: {0}'.format(pvalue))
    pvalue = stats.fisher_exact(xs, 'less')
    print('pvalue: {0}'.format(pvalue))
    pvalue = stats.fisher_exact(xs, 'greater')
    print('pvalue: {0}'.format(pvalue))
    return pvalue


def fisher_exact_test_scipy_mine(df, label, col_c, col_t):
    xs = df[df[label] == 1]
    x_t = xs[col_t][0]
    # N, n, K
    M, N, n = _get_M_N_n(df, label, col_c, col_t)
    print('M, N, n: {0}, {1}, {2}'.format(M, N, n))
    pvalue = stats.hypergeom.cdf(x_t, M, n, N)
    print('pvalue: {0}'.format(pvalue))
    return pvalue


def fisher_exact_test_k_alpha(df, label, col_c, col_t, alpha):
    # N, n, K
    M, N, n = _get_M_N_n(df, label, col_c, col_t)
    print('M, N, n: {0}, {1}, {2}'.format(M, N, n))
    k_alpha = N
    for i in range(N + 1):
        pvalue = stats.hypergeom.cdf(i, M, n, N)
        if pvalue < alpha:
            k_alpha = i
    print('k_alpha: {0}'.format(k_alpha))
    return k_alpha


def _get_M_N_n(df, label, col_c, col_t):
    columns = [col_c, col_t]

    row_sums = df[columns].sum()
    sum_c = row_sums[col_c]
    sum_t = row_sums[col_t]
    xs = df[df[label] == 1]
    x_c = xs[col_c][0]
    x_t = xs[col_t][0]

    # N
    M = sum_c + sum_t
    # n
    N = sum_t
    # K
    n = x_c + x_t
    return M, N, n


def wald_statistics(df, label, col_c, col_t):
    columns = [col_c, col_t]

    row_sums = df[columns].sum()
    sum_c = row_sums[col_c]
    sum_t = row_sums[col_t]
    pis = (df[df[label] == 1] / row_sums).loc[:, columns]
    pi_c = pis[col_c][0]
    pi_t = pis[col_t][0]
    x_c = pi_c * sum_c
    x_t = pi_t * sum_t
    pi = (x_c + x_t) / (sum_c + sum_t)
    print('pi: {0}'.format(pi))
    pi_c = x_c / sum_c
    pi_t = x_t / sum_t
    print('pi_c: {0}'.format(pi_c))
    print('pi_t: {0}'.format(pi_t))
    denominator = np.sqrt(pi_c * (1.0 - pi_c) / sum_c + pi_t * (1.0 - pi_t) / sum_t)
    numerator = (pi_c - pi_t)
    print('denominator: {0}'.format(denominator))
    print('numerator: {0}'.format(numerator))
    print('numerator / denominator: {0}'.format(numerator / denominator))
    return denominator / numerator


def score_statistics(df, label, col_c, col_t):
    columns = [col_c, col_t]

    row_sums = df[columns].sum()
    sum_c = row_sums[col_c]
    sum_t = row_sums[col_t]
    pis = (df[df[label] == 1] / row_sums).loc[:, columns]
    pi_c = pis[col_c][0]
    pi_t = pis[col_t][0]
    x_c = pi_c * sum_c
    x_t = pi_t * sum_t
    pi = (x_c + x_t) / (sum_c + sum_t)
    pi_c = x_c / sum_c
    pi_t = x_t / sum_t
    denominator = np.sqrt(pi * (1.0 - pi) * (1.0 / sum_c + 1.0 / sum_t))
    numerator = (pi_c - pi_t)
    return numerator / denominator


def barnard_test_pdf(df, label, col_c, col_t, pi):
    columns = [col_c, col_t]

    row_sums = df[columns].sum()
    sum_c = row_sums[col_c]
    sum_t = row_sums[col_t]
    pis = (df[df[label] == 1] / row_sums).loc[:, columns]
    pi_c = pis[col_c][0]
    pi_t = pis[col_t][0]
    x_c = pi_c * sum_c
    x_t = pi_t * sum_t

    factor_c = special.comb(sum_c, x_c)
    factor_t = special.comb(sum_t, x_t)
    factor1 = pi ** (x_c + x_t)
    factor2 = (1.0 - pi) ** (sum_c + sum_t - x_c - x_t)
    print('factor_c: {0}'.format(factor_c))
    print('factor_t: {0}'.format(factor_t))
    print('factor1: {0}'.format(factor1))
    print('factor2: {0}'.format(factor2))
    pdf = factor_c * factor_t * factor1 * factor2
    print('pdf: {0}'.format(pdf))
    return pdf


def barnard_test(df, label, col_c, col_t, pi):
    columns = [col_c, col_t]
    row_sums = df[columns].sum()
    sum_c = row_sums[col_c]
    sum_t = row_sums[col_t]

    statistics = score_statistics(df, label, col_c, col_t)
    print('statistics: {0}'.format(statistics))
    print('sum_t: {0}'.format(sum_t))
    print('sum_c: {0}'.format(sum_c))
    summand = 0.0
    for xc in range(sum_c + 1):
        for xt in range(sum_t + 1):
            df_new = gen_data(xc=xc, xt=xt)
            statistics_new = score_statistics(df_new, label, col_c, col_t)
            if statistics <= statistics_new:
                summand += barnard_test_pdf(df_new, label, col_c, col_t, pi)
                print('xc: {0}, xt: {1}'.format(xc, xt))
                print('statistics_new: {0}'.format(statistics_new))
                print('summand: {0}'.format(summand))

    return summand


df = gen_data()
columns = ['control', 'treatment']
wald_statistics(df, 'infection', columns[0], columns[1])
score_statistics(df, 'infection', columns[0], columns[1])
fisher_exact_test(df, 'infection', columns[0], columns[1])
fisher_exact_test_scipy(df, 'infection', columns[0], columns[1])
fisher_exact_test_scipy_mine(df, 'infection', columns[0], columns[1])
fisher_exact_test_k_alpha(df, 'infection', columns[0], columns[1], 0.05)
barnard_test(df, 'infection', columns[0], columns[1], 0.3365)
barnard_test(df, 'infection', columns[0], columns[1], 0.5)


pi: 0.6333333333333333
pi_c: 0.8
pi_t: 0.4666666666666667
denominator: 0.16510378329783743
numerator: 0.33333333333333337
numerator / denominator: 2.018932132718121
pvalue: 0.053598200899550225
xs: [[12  7]
 [ 3  8]]
xs[0, 1]: 7
pvalue: (4.571428571428571, 0.128135932033983)
pvalue: (4.571428571428571, 0.989530234882559)
pvalue: (4.571428571428571, 0.0640679660169915)
M, N, n: 30, 15, 19
pvalue: 0.0640679660169915
M, N, n: 30, 15, 19
k_alpha: 6
statistics: 1.8943380760602064
sum_t: 15
sum_c: 15




factor_c: 1365.0
factor_t: 1.0
factor1: 0.012821542440062504
factor2: 2.3327280264884053e-05
pdf: 0.00040826018951096017
xc: 4, xt: 0
statistics_new: 2.1483446221182985
summand: 0.00040826018951096017
factor_c: 3003.0
factor_t: 1.0
factor1: 0.004314449031081032
factor2: 3.515792051979511e-05
pdf: 0.00045551622953272626
xc: 5, xt: 0
statistics_new: 2.449489742783178
summand: 0.0008637764190436864
factor_c: 5005.0
factor_t: 1.0
factor1: 0.0014518120989587676
factor2: 5.29885765181539e-05
pdf: 0.00038503192976076966
xc: 6, xt: 0
statistics_new: 2.73861278752583
summand: 0.001248808348804456
factor_c: 5005.0
factor_t: 15.0
factor1: 0.0004885347712996253
factor2: 7.986221027604204e-05
pdf: 0.0029290861574491106
xc: 6, xt: 1
statistics_new: 2.158329236508578
summand: 0.0041778945062535666
factor_c: 6435.0
factor_t: 1.0
factor1: 0.0004885347712996253
factor2: 7.986221027604204e-05
pdf: 0.0002510645277813523
xc: 7, xt: 0
statistics_new: 3.0216609311120095
summand: 0.004428959034034919
factor_c

factor_c: 15.0
factor_t: 3003.0
factor1: 1.0297132535765207e-09
factor2: 0.010971100646640105
pdf: 5.088773173459239e-07
xc: 14, xt: 5
statistics_new: 3.4098085369083715
summand: 0.026550748275590633
factor_c: 15.0
factor_t: 5005.0
factor1: 3.4649850982849927e-10
factor2: 0.016535193137362632
pdf: 4.3013619012033014e-07
xc: 14, xt: 6
statistics_new: 3.0983866769659336
summand: 0.026551178411780754
factor_c: 15.0
factor_t: 6435.0
factor1: 1.1659674855729e-10
factor2: 0.024921165240938403
pdf: 2.8047528297543755e-07
xc: 14, xt: 7
statistics_new: 2.788866755113585
summand: 0.026551458887063728
factor_c: 15.0
factor_t: 6435.0
factor1: 3.923480588952809e-11
factor2: 0.03756015861482804
pdf: 1.4224556551806293e-07
xc: 14, xt: 8
statistics_new: 2.4771684715343114
summand: 0.026551601132629245
factor_c: 15.0
factor_t: 5005.0
factor1: 1.3202512181826203e-11
factor2: 0.05660913129589758
pdf: 5.610975962116676e-08
xc: 14, xt: 9
statistics_new: 2.158329236508578
summand: 0.026551657242388865
facto

factor_c: 15.0
factor_t: 15.0
factor1: 3.0517578125e-05
factor2: 3.0517578125e-05
pdf: 2.0954757928848267e-07
xc: 14, xt: 1
statistics_new: 4.74692883171144
summand: 0.025101877748966217
factor_c: 15.0
factor_t: 105.0
factor1: 1.52587890625e-05
factor2: 6.103515625e-05
pdf: 1.4668330550193787e-06
xc: 14, xt: 2
statistics_new: 4.3915503282684
summand: 0.025103344582021236
factor_c: 15.0
factor_t: 455.0
factor1: 7.62939453125e-06
factor2: 0.0001220703125
pdf: 6.356276571750641e-06
xc: 14, xt: 3
statistics_new: 4.052818694009867
summand: 0.025109700858592987
factor_c: 15.0
factor_t: 1365.0
factor1: 3.814697265625e-06
factor2: 0.000244140625
pdf: 1.9068829715251923e-05
xc: 14, xt: 4
statistics_new: 3.72677996249965
summand: 0.02512876968830824
factor_c: 15.0
factor_t: 3003.0
factor1: 1.9073486328125e-06
factor2: 0.00048828125
pdf: 4.195142537355423e-05
xc: 14, xt: 5
statistics_new: 3.4098085369083715
summand: 0.025170721113681793
factor_c: 15.0
factor_t: 5005.0
factor1: 9.5367431640625e-07

0.025520332157611847