In [1]:
#!/usr/bin/env python
# coding: utf-8
from __future__ import print_function
from collections import defaultdict
import pandas as pd
import time
import numpy as np


########################################################################
## helper functions
#######################################################################


def read_edges(edgefile):
    f = open(edgefile)
    ans = []
    for line in f:
        line = tuple(sorted(map(lambda x: int(x), line.strip().split("\t"))))
        ans.append(line)
    return ans


def get_nodes(edges):
    nodes = set()
    for e in edges:
        nodes.update(e)
    return sorted(nodes)


def get_inhs(nodes, df):
    seqs = []
    inh_vals = []
    nodeptr = 0
    l = len(nodes)
    for line in df:
        line = line.split('\t')
        node_ind = int(line[0])  # data specific, as I know thatthte first line is ind
        if nodeptr < l and node_ind == nodes[nodeptr]:
            nodeptr += 1
            inh_vals.append(float(line[inh_ind]))
    assert l == len(inh_vals)
    df.close()
    return inh_vals


def binarize(p):
    return [node_inhclass[n] for n in p]


def last_continuous_ind(p, num):
    ptr = 0
    while ptr < len(p) and p[ptr] == num:
        ptr += 1
    return ptr - 1


def first_ind(p, num):
    return p.index(num)


def monotone_increase(p):
    return all([p[i] >= p[i - 1] for i in range(1, len(p))])


def monotone_decrease(p):
    return all([p[i] <= p[i - 1] for i in range(1, len(p))])


def percent(x, t):
    ans = x * 100.0 / t if t != 0 else 0
    return "{0:.2f}".format(ans)


def analyze(d):
    total = 0
    totals = defaultdict(int)
    saperates = []
    monotones = ["inc", "dec", "monotone"]
    for k, v in d.items():
        v = int(v)
        tokens = k.split("_")
        num, meaning = int(tokens[0]), tokens[-1]
        key = "_".join(tokens[1:])
        while len(saperates) < num + 1:
            saperates.append(defaultdict(int))
        saperates[num][key] = v
        if meaning not in monotones:
            total += v
        else:
            totals[num] += v
    assert sum(totals.values()) == total
    print(inh, " total paths analyzed:", total)
    all_keys = ["inc", "dec", "not_monotone", 'all_zeros', 'all_ones', "zero_one", "one_zero", "spikes"]
    df = pd.DataFrame(columns=all_keys + ["total"], index=range(max(totals.keys()) + 1))

    for i in range(max(totals.keys()) + 1):
        d = saperates[i]
        t = totals[i]
        vals = [percent(d[k], t) for k in all_keys] + [t]
        df.iloc[i] = vals
    print(df)

def var(s):
    return round(np.var(s),4)

In [1]:
########################################################################
## global variables
#######################################################################
inh = 'TPV'
#vecstring = 'dist_vec'
vecstring = 'count_vec'
num_nbhrs = 400
resistance_level = 3
root_threshold = 2  # Threshold for which nodes wihh <= this number of mutations are selected for spanning trees

path = '/home/dshah8/Documents/Summer19/Harrison/data_ten_chunks/'  # where split data lies
vecname = 'distvec' if vecstring == 'dist_vec' else 'countvec'
ff = vecname + "_data/"
folder = ff + str(
    inh) + '_random_' + vecname + '_spanning_trees/'  # where spanning trees for the splits life, filtered by inhibitor != NA
spfolder = str(inh) + '_random_shortest_paths/'  # where we will be storing shortest paths
root_threshold = 2  # Threshold for which nodes wihh <= this number of mutations are selected for spanning trees


def get_split_data_file(splitnum):
    return path + 'PI_DataSet_6_19_random_split_' + str(splitnum) + '.txt'

def get_shortest_path_folder():
    return path + folder + spfolder


def get_shortest_paths_file(splitnum):
    edgefile = get_spanning_trees_file(splitnum)
    spfile = edgefile.split(".")[0].split("/")[-1] + "_upto" + str(
        root_threshold) + "mutsroots_shortestpaths_to_leaves.txt"
    return  get_shortest_path_folder() + spfile


def get_spanning_trees_file(splitnum):
    return path + folder + 'PI_DataSet_6_19_random_split_' + str(splitnum) + '_' + str(
        num_nbhrs - 1) + 'nn_dist_vec_' + str(
        inh) + 'filtered_spanning_tree_edges.txt'

def get_writefilename():
    return "_".join([get_shortest_path_folder(),'all',vecname,inh,"random_shortest_paths_stats.txt"])
#######################################################################

In [3]:
def list_to_csv(l):
    return ",".join([str(x) for x in l])

In [4]:
writefilename = get_writefilename()
cols = ["path_id","inh_type","vec_type","root_mutation","path_length","path_type","fraction_above_inh_threshold","path_variance"]
header = ",".join(cols)
wf = open(writefilename,"w")
wf.write(header)

In [5]:
for c in cols:
    print(c)

path_id
inh_type
vec_type
root_mutation
path_length
path_type
fraction_above_inh_threshold
path_variance


In [6]:
path_id = -1
inh_type = inh
vec_type = vecname[0]
for splitnum in range(10):
    t0 = time.time()
    print(str(inh), " split ", str(splitnum))
    datafile = get_split_data_file(splitnum)
    df = open(datafile)
    header = next(df).strip().split('\t')
    inh_ind = header.index(inh)
    seq_start = header.index('P1')
    

    edgefile = get_spanning_trees_file(splitnum)
    edges = read_edges(edgefile)
    nodes = get_nodes(edges)
    inh_vals = get_inhs(nodes, df)
    node_ind = {str(v): i for i, v in enumerate(nodes)}
    node_inhclass = {str(n): int(inh_vals[node_ind[str(n)]] > resistance_level) for n in nodes}

    spfile = get_shortest_paths_file(splitnum)
    print(vecname)
    f = open(spfile)
    mutation = ''
    for line in f:
        if "mutation" in line:
            if mutation:
                print('mutation:', mutation)
            mutation = str(line.strip().split(":")[1])
        else:
            path_id += 1
            root_mutation = mutation
            line = line.strip().split(",")
            # find if the given seq is monotone in inh vals
            vals = [inh_vals[node_ind[n]] for n in line]
            binary = binarize(line)
            path_length = len(line)
            s = sum(binary)
            fraction_above_inh_threshold = round(float(s)/path_length,4)
            
            path_variance = round(var(vals),4)
            
            if s == path_length:
                path_type = "above"
                  
            elif s == 0:
                path_type = "below"
                
            else:
                last_cont_zero = last_continuous_ind(binary, 0)
                first_one = first_ind(binary, 1)
                last_cont_one = last_continuous_ind(binary, 1)
                first_zero = first_ind(binary, 0)
                
                if last_cont_zero != -1 and last_cont_zero + 1 == first_one and s == len(binary[first_one:]):
                    path_type = "gains"
                    path_variance = round(var(vals[first_one:]),4)
                    
                elif last_cont_one != 1 and last_cont_one + 1 == first_zero and s == len(binary[:first_zero]):
                    path_type = "looses"
                    path_variance = round(var(vals[:first_zero]),4)
        
                else:
                    path_type = 'spikes'
                    path_variance = round(var([v for v in vals if v > 3.0]),4)
            token = list_to_csv([path_id,inh_type,vec_type,root_mutation,path_length,path_type,fraction_above_inh_threshold,path_variance])
            token = '\n'+token
            # print (token)
            wf.write(token)
    print('mutation,count', mutation)
    f.close()
    print(time.time() - t0)
    t0 = time.time()
wf.close()

#analyze(inh_stats)


TPV  split  0
countvec
mutation: 1
mutation,count 2
28.9205198288
TPV  split  1
countvec
mutation: 1
mutation,count 2
31.3950850964
TPV  split  2
countvec
mutation: 1
mutation,count 2
37.1343250275
TPV  split  3
countvec
mutation: 0
mutation: 1
mutation,count 2
44.747410059
TPV  split  4
countvec
mutation: 1
mutation,count 2
27.6428639889
TPV  split  5
countvec
mutation: 0
mutation: 1
mutation,count 2
39.9254789352
TPV  split  6
countvec
mutation: 1
mutation,count 2
45.9823000431
TPV  split  7
countvec
mutation: 0
mutation: 1
mutation,count 2
43.0079228878
TPV  split  8
countvec
mutation: 0
mutation: 1
mutation,count 2
35.6183829308
TPV  split  9
countvec
mutation: 0
mutation: 1
mutation,count 2
32.9058411121


In [7]:
vecname

'countvec'

In [8]:
telly_count = -2
rf = open(writefilename)
for _ in rf:
    telly_count+=1
rf.close()

In [9]:
assert telly_count == path_id