In [9]:
import numpy as np
from scipy.stats import dirichlet, beta, uniform, gamma, bernoulli
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import normalize

In [1]:
## helper function to make the plot nice
def reorder(q):
    N, K = q.shape
    A = [None] * K
    for k in range(K):
        A[k] = np.array([i for i in range(N) if k==np.argmax(q[i,:])])
        index_sorted = np.argsort(q[A[k], k])[::-1]
        A[k] = A[k][index_sorted]
    return np.hstack(A)

In [10]:
def simulate_data(N=400, L=800, Q0=[.25,.25,.25,.25], η=0, r=.025):
    """
    simulate haploid data with the scheme in HDPStructure 2019
    Q0: global admixture proportion
    r: recombination rate, will sample recombination hot spots (linkage)
    by choosing randomly (L * r) point from 1 to L-1 
    η: noise
    """
    q = np.vstack([dirichlet(Q0).rvs()[0] for i in range(N)])
    ## q[i,:] is admixture proportion of individual i
    K = len(Q0)
    num_hotspot = int(L * r)
    HOT = np.random.choice(a = L-1, size = num_hotspot, replace = False)
    z = np.zeros((N, L))
    ## sample z
    for i in range(N):
        z[i, 0] = int(np.random.choice(a = K, p = q[i, :]))
        for l in range(1, L):
            if l in HOT:
                s = bernoulli(0.99).rvs()
            else:
                lamb = uniform(0.01, 0.5).rvs()
                s = bernoulli(1-lamb).rvs()
            if s==1:
                z[i, l] = z[i, l-1]
            else:
                z[i, l] = int(np.random.choice(a = K, p = q[i, :]))
    h = bernoulli(.25).rvs(size=(K, L))  ## latent haplotype
    
    ## sample X
    X = np.zeros((N, L))
    for i in range(N):
        for l in range(L):
            if uniform().rvs() < η:
                X[i,l] = bernoulli(0.5).rvs()
            else:
                X[i,l] = h[int(z[i,l]), l]
    
    return X, q

In [3]:
def bar_plot(q, sorted_ind, colors, switch_label=None):
    q_sorted = q[sorted_ind,:]
    if switch_label!=None:
        q_sorted = q_sorted[:, switch_label]
    ax = plt.axes()
    xdata = np.arange(N)
    bottom = np.zeros(N, dtype=float)
    for i, ydata in enumerate(q_sorted.T):
        color = colors[i]
        ax.bar(xdata, ydata, bottom=bottom, color=color, width=1.0, linewidth=0)
        bottom += ydata
    return ax

In [9]:
def several_bar_plot(qs, sorted_ind, colors, figsize, titles):
    J = len(qs)
    fig, ax = plt.subplots(J, figsize=figsize)
    for j in range(J):
        q_sorted = qs[j][sorted_ind,:]
        xdata = np.arange(N)
        bottom = np.zeros(N, dtype=float)
        for i, ydata in enumerate(q_sorted.T):
            color = colors[i]
            ax[j].bar(xdata, ydata, bottom=bottom, color=color, width=1.0, linewidth=0)
            bottom += ydata
        ax[j].set_title(titles[j])