# Experiment on a simulated data set with 100 taxa


In what follows, we demonstrate the efficiency of adaptive (multistep) LASSO at recovering sparsity on a fixed phylogenetic tree topology. To reduce parameterization, we use JC model for evolution. 

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('../')

In [3]:
import numpy as np
import phyloinfer as pinf
import matplotlib.pyplot as plt
from time import clock
import pandas as pd
import random
%matplotlib inline

In [4]:
from model import PHY
from optimizer import fista, adaLasso, raxml
from utils import readTree, detection

In [5]:
# load data
true_tree = readTree('../data/simulation/true_tree_100tips_final.newick', tree_format=3)[0]
pinf.tree.init(true_tree, name='interior')
data, taxa = pinf.data.loadData('../data/simulation/simulated_100tips_final.fasta', data_type='fasta')

In [6]:
true_branch = pinf.branch.get(true_tree)
shrunken_idx = np.where(true_branch==0)[0]

## Set up the model. 

In [7]:
pden = np.array([0.25, 0.25, 0.25, 0.25])
D, U, U_inv, rate_matrix = pinf.rateM.decompJC()

model = PHY(pden, ('JC',1.0), data)

## Run adaptive LASSO

In [None]:
# set up some hyperparameters
lam = 1e-06
beta = 0.5
gamma_ada_penalized = 1.0
wts = np.ones(2*model.ntips-3)

gamma_phy_list = [0, 10, 20, 30, 40, 50]
brlen_ada_list = []
p_ada_list = []
ada_time_phy_list = []
n_zeros_ada_list = []

# set random starting branch lengths
brlen_init= np.random.exponential(scale=0.1, size=2*model.ntips-3)

for gamma_phy in gamma_phy_list:
    start = clock()
    brlen_ada_lasso, objval_ll_ada_lasso, objval_lasso_ada_lasso, n_zeros_ada_lasso, lam_tuned_ada_lasso = adaLasso(model, true_tree, brlen_init, lam, gamma=gamma_phy, beta=beta, prox='l1', msteps=4, gamma_ada_penalized=gamma_ada_penalized, sparsity_monitor=True)
    
    brlen_ada_list.append(brlen_ada_lasso)
    p_ada_list.append(objval_lasso_ada_lasso[-1])
    ada_time_phy_list.append(clock() - start)
    n_zeros_ada_list.append(n_zeros_ada_lasso)
    
    print "\nlambda = {}; step size: {}; elasped time: {:.04f} second".format(gamma_phy, lam_tuned_ada_lasso, ada_time_phy_list[-1])
    
    plt.plot(objval_ll_ada_lasso, label="LL")
    plt.plot(objval_lasso_ada_lasso, label="LL+Penalty")
    plt.legend(loc=4)
    plt.show()

## Result Analysis

In [9]:
brlen_ada_list_transpose = zip(*brlen_ada_list)

In [None]:
# miss detection and false alarm for adaLasso and threshold
res = np.empty((5*4,7))
for i, brlen_ada_cycle in enumerate(brlen_ada_list_transpose):
    print "cycle: {}\n".format(i+1)
    for j, brlen_ada in enumerate(brlen_ada_cycle[1:6]):
        miss_zeros, false_alarm = detection(shrunken_idx, brlen_ada)
        res[j+5*i,:4] = [miss_zeros, false_alarm, gamma_phy_list[j+1], i+1]
        print "lambda: {}; # miss: {}; # false alarm: {}".format(gamma_phy_list[j+1], miss_zeros, false_alarm)    
    print ""
    
print '------------------------------------------------'


brlen_orginal = brlen_ada_list[0][-1]
thresholds = sorted([brlen for brlen in brlen_orginal if brlen < 0.01 and brlen > 0.0])
thr_nbin = 6
thresholds_quant = np.percentile(thresholds, q=np.linspace(0,100,thr_nbin))
for j, quant in enumerate(thresholds_quant[1:6]):
    brlen_threshold = np.copy(brlen_orginal)
    brlen_threshold[brlen_threshold<quant] = 0.0
    miss_zeros, false_alarm = detection(shrunken_idx, brlen_threshold)
    for i in range(4):
        res[j+5*i,4:] = [miss_zeros, false_alarm, quant]
    print "threshold: {:.04f}; # miss: {}; # false alarm {}".format(quant, miss_zeros, false_alarm)

In [11]:
df = pd.DataFrame(res, columns=['miss_lasso', 'false_lasso', 'lambda', 'cycle', 'miss_detection', 'false_alarm', 'threshold'])

In [12]:
red_color_bar = ['#fdcab5', '#fc8a6a', '#f14432', '#bc141a']
colors = {i+1.0:color for i, color in enumerate(red_color_bar)}
grouped = df.groupby('cycle')

In [None]:
# miss detection vs false alarm

fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
df.plot.scatter(x='miss_detection', y='false_alarm', color='blue', marker='s', s=40, label="thresholding", ax=ax)
for name, group in grouped:
    group.plot.scatter(x='miss_lasso', y='false_lasso', color=colors[name], marker='*', s=100, label="adaLASSO, cycle {}".format(int(name)), ax=ax)
ax.legend(loc='best', fontsize=16)
ax.set_xlabel('miss detection', fontsize=18)
ax.set_ylabel('false alarm', fontsize=18)
ax.xaxis.set_tick_params(labelsize=14)
ax.yaxis.set_tick_params(labelsize=14)
ax.set_ylim(bottom=-1, top=15)

plt.show()

In [None]:
# identified zero branches for multistep adaptive phylogenetic LASSO of different cycles.

fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
for i in range(1,6):
    n_zeros_ada = n_zeros_ada_list[i]
    ax.plot(range(1,5), n_zeros_ada, '^--', markersize=10,label=r'$\lambda={}$'.format(gamma_phy_list[i]),alpha=0.5)

ax.set_xlabel('cycles', fontsize=18)
ax.set_ylabel('# zero branches', fontsize=18)
ax.xaxis.set_tick_params(labelsize=14)
ax.xaxis.set_ticks([1,2,3,4])
ax.xaxis.set_ticklabels([1,2,3,4])
ax.yaxis.set_tick_params(labelsize=14)
ax.legend(loc='best', fontsize=16)

plt.show()