In [83]:
import pandas as pd
import numpy as np
import itertools as it
import networkx as nx
from collections import Counter, defaultdict
from networkx.algorithms.approximation import min_weighted_vertex_cover
import random

In [2]:
df = pd.read_csv('example/data1.SC', sep='\t')

In [3]:
df.drop(columns=df.columns[0], inplace=True)
df.replace('?', 0, inplace=True)

In [10]:
A = df.to_numpy(dtype=np.bool)
m, n = A.shape

In [36]:
A.shape

(20, 20)

In [67]:
def make_graph(A):
    """Given a matrix A of mutations and cell samples, return a  graph G where nodes are columns,
    edges are column pairs with three-gamete rule violations and edge weights are the number of bit
    flips required to fix all violations between said column pair.
    """
    """Find a PP by reducing to weighted vertex cover where vertices are columns
    of A and vertex weights are the number of 0s in the corresponding column.
    """
    m, n = A.shape
    row_types_per_col_pair = {(p,q): defaultdict(list) for p,q in it.combinations(range(n), 2)}
    edge_list = []
    for p,q in it.combinations(range(n), 2):
        # Count the number of 01s, 10s, and 11s
        for i in range(m):
            if A[i, p] and not A[i, q]:
                row_types_per_col_pair[p,q][0, 1].append(i)
            elif not A[i, p] and A[i, q]:
                row_types_per_col_pair[p,q][1, 0].append(i)
            elif A[i, p] and A[i, q]:
                row_types_per_col_pair[p,q][1, 1].append(i)
        # Tally up the number of flips required to fix (p,q) violations
        pq_pair_counts = row_types_per_col_pair[p,q]
        count11 = len(pq_pair_counts[1, 1])
        count01 = len(pq_pair_counts[0, 1])
        count10 = len(pq_pair_counts[1, 0])
        if count11 and count01 and count10:
            edge_list.append((p, q, {'weight': min(count01, count10)}))
    G = nx.Graph(edge_list)
    nx.set_node_attributes(G, {p: n - np.sum(A[:,p]) for p in range(n)}, 'weight')
    print(G.nodes(data='weight'))
    return G, row_types_per_col_pair


def randPP(A):
    pass

In [68]:
def weighted_vertex_cover_pp(A):
    """Return a perfect phylogeny of matrix A by solving a related weighted
    vertex cover instance.
    """
    G, violations = make_graph(A)
    vc = min_weighted_vertex_cover(G, weight='weight')
    flipped_bits = sum(G.nodes[v]['weight'] for v in vc)
    flippable_bits = m*n - np.sum(A)
    print(f'Vertex cover deleted {len(vc)} / {len(G.nodes)} = {len(vc) / len(G.nodes):.2%} nodes')
    print(f'It flipped {flipped_bits} / {flippable_bits} = {flipped_bits/flippable_bits:.2%} bits')
weighted_vertex_cover_pp(A)

[(0, np.int64(2)), (1, np.int64(6)), (2, np.int64(3)), (3, np.int64(5)), (4, np.int64(5)), (5, np.int64(3)), (7, np.int64(6)), (8, np.int64(4)), (9, np.int64(8)), (11, np.int64(5)), (12, np.int64(4)), (13, np.int64(1)), (14, np.int64(5)), (15, np.int64(4)), (16, np.int64(4)), (17, np.int64(3)), (18, np.int64(4)), (19, np.int64(1)), (6, np.int64(3)), (10, np.int64(4))]
Vertex cover deleted 18 / 20 = 90.00% nodes
It flipped 75 / 80 = 93.75% bits


In [97]:
# Instead try the randomized generalization of vertex cover that samples greedily
def randomized_pp(A, shuffle=False):
    A = np.array(A, dtype=np.bool)
    m, n = A.shape
    num_flipped_bits = 0
    flippable_bits = m * n - np.sum(A)
    has_violations = True
    while has_violations:
        iterator = it.combinations(range(n), 2)
        if shuffle:
            iterator = list(iterator)
            random.shuffle(iterator)
        for p,q in iterator:
            ixs01s, ixs10s = [], []
            has11 = False 
            for i in range(m):
                has11 = has11 or (A[i, p] and A[i,q])
                if not A[i,p] and A[i,q]:
                    ixs01s.append(i)
                elif A[i,p] and not A[i,q]:
                    ixs10s.append(i)
            # A violation was found
            if has11 and ixs01s and ixs10s:
                prob_flip_01s = len(ixs10s) / (len(ixs01s) + len(ixs10s))
                if np.random.rand() < prob_flip_01s:
                    num_flipped_bits += len(ixs01s)
                    for r in ixs01s:
                        A[r,p] = 1
                else:
                    num_flipped_bits += len(ixs10s)
                    for r in ixs10s:
                        A[r,q] = 1
                continue
        # reaching here means no violations remain
        has_violations = False
    # print(f'Vertex cover flipped {num_flipped_bits} / {flippable_bits} = {num_flipped_bits/flippable_bits:.2%} bits')
    return A, num_flipped_bits / flippable_bits

        


In [98]:
np.random.seed(11)
bit_flip_pcts = []
for _ in range(100):
    pp, num_flipped_bits = randomized_pp(A)
    bit_flip_pcts.append(num_flipped_bits)
print(f'On average, flipped {np.mean(bit_flip_pcts):.2%} bits.')
print(f'Best performance: {min(bit_flip_pcts):.2%}')

On average, flipped 92.72% bits.
Best performance: 81.25%


In [101]:
np.random.seed(11)
bit_flip_pcts = []
least_bit_flips = float('inf')
best_pp = None
for _ in range(100):
    pp, num_flipped_bits = randomized_pp(A, shuffle=True)
    bit_flip_pcts.append(num_flipped_bits)
    if num_flipped_bits < least_bit_flips:
        least_bit_flips = num_flipped_bits
        best_pp = pp
print(f'On average, flipped {np.mean(bit_flip_pcts):.2%} bits.')
print(f'Best performance: {min(bit_flip_pcts):.2%}')
print('best pp:\n', np.astype(best_pp, np.int8))

On average, flipped 90.50% bits.
Best performance: 81.25%
best pp:
 [[1 1 1 1 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1]
 [1 0 1 1 1 1 1 1 1 0 1 0 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 0 1 1 1 1 1 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
 [1 0 1 1 0 1 1 1 1 0 1 0 1 1 0 1 1 1 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]]


In [None]:
# TODO: Instead try sampling with probability proportional to the number of violations removed