[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/geraschenko/examples/blob/master/Black_sock_problem.ipynb)

You have a bag of $n$ socks, $k$ of which are black. Suppose that if you pull out two random socks, there's a 1/2 probability they're both black.

* What can $n$ and $k$ be?
* Must $k$ be odd?
* What about for values other than 1/2?

In [0]:
from __future__ import absolute_import, division, print_function
import numpy as np
from scipy.linalg import hankel
import math
np.set_printoptions(precision=4, suppress=True)

In [0]:
def find_sock_values(numerator, denominator, max_k=1000000, verbose=False):
    '''Find numbers of socks with given probability of two black socks.
    
    Returns k's and n's (lists), so that the probability of picking two
    black socks from n socks (k of which are black) is
    numerator/denominator. Note: we use explicit numerator an denominator
    so that we don't run into floating point error issues.'''
    
    k_list, n_list = [], []
    # Our goal is k*(k-1)/(n*(n-1)) = numerator/denominator.
    for k in range(2, max_k+1):
        n = int((denominator*k*(k-1)/numerator)**0.5) + 1
        if numerator*n*(n-1) == denominator*k*(k-1):
            k_list.append(k)
            n_list.append(n)
            if verbose:
                print(k, n)
    return k_list, n_list

def check_sock_values(numerator, denominator, k_list, n_list):
    for k, n in zip(k_list, n_list):
        if numerator*n*(n-1) != denominator*k*(k-1):
            print('%d * %d * %d != %d * %d * %d (%d vs %d)' %
                  (numerator,n,n-1,denominator,k,k-1,
                   numerator*n*(n-1), denominator*k*(k-1)))
            return False
    return True

def find_recurrence(numbers):
    '''Tries to find a recurrence relation in a list of numbers.'''
    A = hankel(numbers)
    n = len(numbers)
    for k in range(2, math.floor(n/2)+1):
        # try to find a k-term recurrence
        x = np.linalg.solve(A[:k, :k], A[:k, k])
        # If we don't round, floats near an int sometimes drop by 1 when
        # they're converted to ints.
        x = np.array([round(y) for y in x], dtype=np.int)
        recurrence_error = np.matmul(A[:n-k-1,:k], x) - A[:n-k-1, k]
        if np.linalg.norm(recurrence_error) == 0.0:
            # We found a recurrence that works as far as we can tell.
            return np.hstack((x, [-1]))
    return None

def run_recurrence(recurrence, initial, max_length=20):
    k = len(recurrence) - 1
    if len(initial) < k:
        raise ValueError('Not enough initial values (%s) for recurrence '
                         '%s' % (str(initial), str(recurrence)))
    if recurrence[-1] != -1:
        raise ValueError('Recurrence doesn\'t end in -1.')
    result = list(initial)  # hacky way to make a copy
    while len(result) < max_length:
        new_val = 0
        for a, b in zip(result[-k:], recurrence[:-1]):
            # We explicitly turn things into python ints to avoid
            # integer overflow.
            new_val += int(a) * int(b)
        value = -new_val/recurrence[-1]
        assert value == int(value)
        result.append(int(value))
    return result

def run_analysis_for(numberator, denominator, **kwargs):
    k, n = find_sock_values(numberator, denominator, **kwargs)
    print('k:', k)
    print('n:', n)
    r = find_recurrence(k)
    if r is None:
        s = ' ... maybe increase max_k?' if k else ''
        k = np.array(k, dtype=np.int)
        s += '\nConsecutive ratios: %s' % (k[1:] / k[:-1])
        print('No recurrence found' + s)
        return
    print('Found recurrence ', r)
    k_big = run_recurrence(r, k)
    n_big = run_recurrence(r, n)
    if check_sock_values(numberator, denominator, k_big, n_big):
        print('It keeps working at least up to (k,n)=(%d, %d)'
              % (k_big[-1], n_big[-1]))
    else:
        print('It breaks after a while.')

In [3]:
for denominator in range(2, 10):
  for numerator in range(1, denominator):
    if math.gcd(numerator, denominator) != 1:
      continue
    print('\nanalyzing for probability %d/%d' % (numerator, denominator))
    run_analysis_for(numerator, denominator)


analyzing for probability 1/2
k: [3, 15, 85, 493, 2871, 16731, 97513, 568345]
n: [4, 21, 120, 697, 4060, 23661, 137904, 803761]
Found recurrence  [ 1 -7  7 -1]
It keeps working at least up to (k,n)=(873430010034205, 1235216565974041)

analyzing for probability 1/3
k: [2, 6, 21, 77, 286, 1066, 3977, 14841, 55386, 206702, 771421]
n: [3, 10, 36, 133, 495, 1846, 6888, 25705, 95931, 358018, 1336140]
Found recurrence  [ 1 -5  5 -1]
It keeps working at least up to (k,n)=(108347552061, 187663465045)

analyzing for probability 2/3
k: [5, 45, 441, 4361, 43165, 427285]
n: [6, 55, 540, 5341, 52866, 523315]
Found recurrence  [  1 -11  11  -1]
2 * 46801248083088248 * 46801248083088247 != 3 * 38213059042991848 * 38213059042991847 (4380713644269542722338833745242512 vs 4380713644269543035763911854389768)
It breaks after a while.

analyzing for probability 1/4
k: []
n: []
No recurrence found
Consecutive ratios: []

analyzing for probability 3/4
k: [7, 91, 1261, 17557, 244531]
n: [8, 105, 1456, 20273, 