In [2]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
from time import time
import itertools
import random

In [12]:
def create_instance(n, has_prior=False):
    # set up a problem instance
    if not has_prior:
        s = np.random.binomial(1,0.5,size=n)
    else:
        prior_probs = np.random.choice((1/3,2/3), size=n)
        s = np.random.binomial(1, prior_probs, size=n)
        prior = np.zeros(n, dtype='int')
        prior[prior_probs == 2/3] = 1
    counter = np.zeros(n, dtype='int')
    counter[0] = s[0]
    for i in range(1,n):
        counter[i] = counter[i-1] + s[i]
    z = np.random.randint(0,2,size=n)
    a = counter + z
    
    if has_prior:
        return a, s, z, counter, prior
    else:
        return a, s, z, counter
        
def fixed_counts(a):
    # non-exhaustive linear-time solving for when we know the true count exactly
    n = len(a)
    count = -1*np.ones(n, dtype='int')
    
    if a[0] == 0:
        count[0] = 0
    elif a[0] == 2:
        count[0] = 1

    for i in range(1,n):
        if a[i] == a[i-1] + 2:
            count[i] = a[i] - 1
            count[i-1] = a[i-1] 
        elif a[i] == a[i-1] - 1:
            count[i] = a[i]
            count[i-1] = a[i]

    passes = 0
    modified = -1
    while passes < 10 and modified != 0:
        modified = 0
        passes += 1
        for i in range(1,n):
            if count[i] == -1:
                if a[i] == count[i-1]:
                    count[i] = count[i-1]
                    modified += 1
                if a[i] == count[i-1] + 2:
                    count[i] = count[i-1] + 1
                    modified += 1
        for i in range(n-1):
            if count[i] == -1:
                if a[i] == count[i+1] + 1:
                    count[i] = count[i+1]
                    modified += 1

    # this closes any size-one holes
    for i in range(1,n-1):
        if (count[i]==-1) and (count[i+1]!=-1) and (count[i-1]!=-1):
            if count[i+1]==count[i-1]:
                count[i] = count[i-1]
            if count[i+1]==count[i-1]+2:
                count[i] = count[i-1] + 1

    return count

def find_gaps(partial_counts):
    """find stretches in which we don't know the counts"""
    n = len(partial_counts)
    indices = []
    # first, check the start
    i = 1
    if partial_counts[0] == -1:
        for i in range(1,n):
            if partial_counts[i] != -1:
                indices.append([0,i])
                break
                
    # now for general
    on = False
    for j in range(i,n):
        if not on and partial_counts[j] == -1:
            on = True
            start = j
        if on and partial_counts[j] != -1:
            on = False
            indices.append([start,j])
        if on and j == n-1:
            indices.append([start,n])
            break  # we're at the end
            
    indices.sort(key= lambda x: x[0]-x[1]) # sort, shortest first
    
    return indices

def guess_single_holes(counts):
    """just do a random guess on the size-one gaps, rather than feed to solver"""
    n = len(counts)
    if counts[0]==-1 and counts[1]!=-1:
        counts[0] = 0
    if counts[-1]==-1 and counts[-2]!=-1:
        counts[-1] = counts[-2] 
    
    for i in range(1,n-1):
        if (counts[i]==-1) and (counts[i+1]!=-1) and (counts[i-1]!=-1):
            counts[i] = counts[i-1]
    return counts

def solver(obs, prior, total, maxtime):
    # brute force search for a consistent explanation of a subarray
    # this is minimum working code, folks
    n = len(obs)

    start = time()
    
    # these keep track of the best we've seen so far
    best_diff = float('Inf')
    best_count = None
    highest_prior = 0
    best_match = None
    
    if prior is None:
        if total >= 2:  
            nonzeros = itertools.chain(itertools.combinations(range(n), total),
                                        itertools.combinations(range(n), total-1))
            for indices in nonzeros:
                x_hat = np.zeros(n, dtype='int')
                x_hat[np.array(indices)] = 1
                count_hat = np.cumsum(x_hat)
                diff = obs - count_hat
                deviation = max(np.max(diff),np.max(-diff)+1)
                if deviation <= 1:
                    return count_hat
                else:
                    if deviation < best_diff:
                        best_count = count_hat
                if time() - start > maxtime:
                    break
        else:  # we're solve without a fixed upper total, via permutation
            bitstrings = itertools.product(range(2), repeat=n)
            for bits in bitstrings:
                x_hat = np.array(bits)
                count_hat = np.cumsum(x_hat)
                diff = obs - count_hat
                deviation = max(np.max(diff),np.max(-diff)+1)
                if deviation <= 1:
                    return count_hat
                if time() - start > maxtime:
                    break
    else: # have a prior
        if total >= 2:  
            # first, just check if the max-prior is possible
            count_hat = np.cumsum(prior)
            diff = obs - count_hat
            deviation = max(np.max(diff),np.max(-diff)+1)
            if deviation <= 1 and abs(count_hat[-1] - total) <= 1:
                return count_hat
            
            # otherwise, brute-force as normal
            nonzeros = itertools.chain(itertools.combinations(range(n), total),
                                        itertools.combinations(range(n), total-1))
            for indices in nonzeros:
                x_hat = np.zeros(n, dtype='int')
                x_hat[np.array(indices)] = 1
                count_hat = np.cumsum(x_hat)
                diff = obs - count_hat
                deviation = max(np.max(diff),np.max(-diff)+1)
                match_prior = np.sum(x_hat == prior)
                if deviation <= 1:
                    if match_prior == n-1:  # because we know matching the prior wasn't feasible
                        return count_hat
                    else:
                        if match_prior > highest_prior:
                            highest_prior = match_prior
                            best_match = count_hat 
                else:
                    if deviation < best_diff:
                        best_diff = deviation
                        best_count = count_hat
                if time() - start > maxtime:
                    break
                    
        else:  # we're solving without an upper number, i.e. at the end
            # check if the max-prior is feasible
            count_hat = np.cumsum(prior)
            diff = obs - count_hat
            deviation = max(np.max(diff),np.max(-diff)+1)
            if deviation <= 1:
                return count_hat
            
            bitstrings = itertools.product(range(2), repeat=n)
            for bits in bitstrings:
                x_hat = np.array(bits)
                count_hat = np.cumsum(x_hat)
                diff = obs - count_hat
                deviation = max(np.max(diff),np.max(-diff)+1)
                match_prior = np.sum(x_hat == prior)
                if deviation <= 1:
                    if match_prior == n-1:
                        return count_hat
                    else:
                        if match_prior > highest_prior:
                            highest_prior = match_prior
                            best_match = count_hat
                else:
                    if deviation < best_diff:
                        best_diff = deviation
                        best_count = count_hat
                if time() - start > maxtime:
                    break
    # breaking takes us here, return the best guess we have so far

    if best_match is not None:
        return best_match
    elif best_count is not None:
        return best_count
    else:
        return count_hat
            
def recover(a, prior=None):
    """base recovery"""
    n = len(a)
    
    partial_counts = fixed_counts(a)
    counts = guess_single_holes(partial_counts)
    
    indices = find_gaps(counts)
    # iterate over the short gaps
    for index in indices:
        start, end = index[0], index[1]
        if start == 0:
            lo = 0
        else:
            lo = counts[start-1]
        if end == n:
            hi = -1
        else:
            hi = counts[end] - lo

        obs = a[start:end] - lo
        
        if n == 50000:
            if prior is not None:
                counts[start:end] = solver(obs, prior[start:end], hi, 0.1) + lo
            else:
                counts[start:end] = solver(obs, prior=None, total=hi, maxtime=0.1) + lo            
        else:
            if prior is not None:
                counts[start:end] = solver(obs, prior[start:end], hi, 1) + lo
            else:
                counts[start:end] = solver(obs, prior=None, total=hi, maxtime=1) + lo            
                
    s_hat = counts_to_bits(counts)
    return s_hat

def counts_to_bits(counts):
    """turn a guessed count vector into bits"""
    n = len(counts)
    x_hat = np.zeros(n, dtype='int')
    x_hat[0] = counts[0]
    for i in range(1,n):
        x_hat[i] = counts[i] - counts[i-1]
    return x_hat

In [13]:
# No prior, 20 runs, should take ~10 seconds
np.random.seed(591)
n = 1000
runs = 20

accs = np.zeros(runs)

start = time()
for run in range(runs):
    a, true_s, true_z, true_counts = create_instance(n, has_prior=False)

    s_hat = recover(a, prior=None)

    accs[run] = 1 - np.sum(np.abs(s_hat - true_s)) / n

end = time()
print('n:', n)
print('indep. runs:', runs)
print('total time:', end-start)
print('avg accuracy:', np.sum(accs)/runs)
print('std. dev.:', np.std(accs))

n: 1000
indep. runs: 20
total time: 9.772082805633545
avg accuracy: 0.7779
std. dev.: 0.012161003248087719


In [14]:
# with prior knowledge, 20 runs, should take ~30 seconds
np.random.seed(591)
n = 1000
runs = 20

accs = np.zeros(runs)

start = time()
for run in range(runs):
    a, true_s, true_z, true_counts, prior = create_instance(n, has_prior=True)

    s_hat = recover(a, prior=prior)

    accs[run] = 1 - np.sum(np.abs(s_hat - true_s)) / n

end = time()
print('n:', n)
print('indep. runs:', runs)
print('total time:', end-start)
print('avg accuracy:', np.sum(accs)/runs)
print('std. dev.:', np.std(accs))

n: 1000
indep. runs: 20
total time: 27.760793685913086
avg accuracy: 0.8137500000000001
std. dev.: 0.01654652531500193
