In [2]:
import itertools
import re

In [3]:
line_re = re.compile(r'(\d+) ;(\d+).*')

In [257]:
def parse_known(s):
    parsed = [
        line_re.match(l.strip()).groups()
        for l in s.splitlines() if l.strip()
    ]
    return [(p[0], int(p[1])) for p in parsed]

def is_compatible(guess, known):
    return all((sum(guess[i] == prev_guess[i] for i in range(len(guess))) == correct)
           for (prev_guess, correct) in known)

In [258]:
short_known = parse_known("""
    90342 ;2 correct
    70794 ;0 correct
    39458 ;2 correct
    34109 ;1 correct
    51545 ;2 correct
    12531 ;1 correct
""")

In [259]:
short_known

[('90342', 2),
 ('70794', 0),
 ('39458', 2),
 ('34109', 1),
 ('51545', 2),
 ('12531', 1)]

In [26]:
is_compatible('39542', short_known)

True

In [243]:
import numpy as np
import util

def init(n):
    return np.full((n, 10), .1)
    
def normalize(v):
    s = sum(v)
    for i in range(len(v)):
        v[i] = float(v[i])/s

def update(s, x, d):
    if x == 0:
        for i,c in enumerate(s):
            n = int(c)
            d[i][n] = 0
        return
    for i,c in enumerate(s):
        n = int(c)
        if d[i][n] == 0:
            pass
        #elif d[i][n] == 1:
        #    d[i][n] == 0
        else:
            d[i][n] += x

def possibilities(d):
    guesses = []
    for probs in d:
        probs = sorted(enumerate(probs), key=lambda kv: kv[1], reverse=True)
        guesses.append([str(i) for (i, p) in probs if p > 1e-9])
    return (''.join(g) for g in itertools.product(*guesses))

def is_compatible(guess, known):
    #for prev, correct in known:
    #    print('>', prev, correct, sum(guess[i] == prev[i] for i in range(len(guess))))
    return all((sum(guess[i] == prev_guess[i] for i in range(len(guess))) == correct)
               for (prev_guess, correct) in known)

def common(a, b):
    return [1 if a[i] == b[i] else 0 for i in range(len(a))]

def solve(known):
    n = len(known[0][0])
    d = init(n)

    for g, c in known:
        d[range(len(g)), list(int(a) for a in g)] += c
        #d[] += c

    for (g1, c1), (g2, c2) in itertools.combinations(known, 2):
        #if c1 == 0 or c2 == 0: continue
        #print(g1, c1, g2, c2, sum(common(g1, g2)))
        for i, use in enumerate(common(g1, g2)):
            if not use: continue
            d[i, int(g1[i])] += c1 + c2

    for guess, correct in known:
        update(guess, correct, d)

    for i in range(len(d)):
        normalize(d[i])
        print(i, ':', ' | '.join(f'{i}:{p:.03f}' for i,p in enumerate(d[i])))

#     wrong = 0
#     for guess in possibilities(d):
#         if is_compatible(guess, known):
#             print('guess:', guess)
#             print('wrong #:', wrong)
#             break
#         else:
#             wrong += 1
#             if wrong % 1_000_000 == 0:
#                 print(wrong)
    def eliminate(guess, probs):
        probs[range(len(guess)), list(int(a) for a in guess)] = 0
        
    def scores(guess, probs):
        return [probs[i, int(g[i])] for i in range(len(g))]

    # Zero correct
    for g, c in known:
        if c == 0: eliminate(g, d)
    
    # One correct
    g, c = max(known, key=lambda kv: max(scores(kv[0], d)) if kv[1] == 1 else 0)
    for (s, i, n) in sorted(zip(scores(g, d), range(len(g)), g), reverse=True):
        probs = d.copy()
        print(f'guessing: {n}@{i} from {g} with prob {s}')
        # Eliminate the rest.
        for j, m in enumerate(g):
            if i == j: continue
            probs[j, int(m)] = 0
        break

    print('possibilities:', util.product(len(list(filter(lambda x: x > 0, p))) for p in d))
    
    print('possibilities:', util.product(len(list(filter(lambda x: x > 0, p))) for p in probs))
    for i in range(len(d)):
        print(i, ':', ' | '.join(f'{i}:{p:.03f}' for i,p in enumerate(probs[i])))

In [244]:
solve(long_known)

0 : 0:0.001 | 1:0.138 | 2:0.000 | 3:0.154 | 4:0.184 | 5:0.215 | 6:0.123 | 7:0.047 | 8:0.092 | 9:0.047
1 : 0:0.037 | 1:0.037 | 2:0.090 | 3:0.000 | 4:0.025 | 5:0.025 | 6:0.270 | 7:0.108 | 8:0.396 | 9:0.013
2 : 0:0.000 | 1:0.148 | 2:0.000 | 3:0.000 | 4:0.444 | 5:0.169 | 6:0.000 | 7:0.026 | 8:0.000 | 9:0.211
3 : 0:0.146 | 1:0.000 | 2:0.122 | 3:0.073 | 4:0.017 | 5:0.283 | 6:0.227 | 7:0.049 | 8:0.049 | 9:0.033
4 : 0:0.035 | 1:0.023 | 2:0.256 | 3:0.000 | 4:0.069 | 5:0.137 | 6:0.035 | 7:0.035 | 8:0.376 | 9:0.035
5 : 0:0.029 | 1:0.015 | 2:0.015 | 3:0.198 | 4:0.085 | 5:0.381 | 6:0.106 | 7:0.128 | 8:0.000 | 9:0.043
6 : 0:0.083 | 1:0.154 | 2:0.132 | 3:0.022 | 4:0.001 | 5:0.176 | 6:0.000 | 7:0.001 | 8:0.011 | 9:0.422
7 : 0:0.075 | 1:0.000 | 2:0.001 | 3:0.057 | 4:0.280 | 5:0.112 | 6:0.336 | 7:0.025 | 8:0.038 | 9:0.075
8 : 0:0.000 | 1:0.054 | 2:0.000 | 3:0.041 | 4:0.409 | 5:0.297 | 6:0.000 | 7:0.180 | 8:0.000 | 9:0.018
9 : 0:0.126 | 1:0.126 | 2:0.018 | 3:0.101 | 4:0.000 | 5:0.076 | 6:0.051 | 7:0.376 

In [187]:
solve(short_known)

0 : 0:0.005 | 1:0.106 | 2:0.005 | 3:0.457 | 4:0.005 | 5:0.206 | 6:0.005 | 7:0.000 | 8:0.005 | 9:0.206
1 : 0:0.000 | 1:0.318 | 2:0.163 | 3:0.008 | 4:0.163 | 5:0.008 | 6:0.008 | 7:0.008 | 8:0.008 | 9:0.318
2 : 0:0.005 | 1:0.106 | 2:0.005 | 3:0.206 | 4:0.206 | 5:0.457 | 6:0.005 | 7:0.000 | 8:0.005 | 9:0.005
3 : 0:0.100 | 1:0.005 | 2:0.005 | 3:0.100 | 4:0.579 | 5:0.196 | 6:0.005 | 7:0.005 | 8:0.005 | 9:0.000
4 : 0:0.006 | 1:0.124 | 2:0.243 | 3:0.006 | 4:0.000 | 5:0.243 | 6:0.006 | 7:0.006 | 8:0.243 | 9:0.124
possibilities: 59049


In [170]:
long_known = parse_known("""
    5616185650518293 ;2 correct
    3847439647293047 ;1 correct
    5855462940810587 ;3 correct
    9742855507068353 ;3 correct
    4296849643607543 ;3 correct
    3174248439465858 ;1 correct
    4513559094146117 ;2 correct
    7890971548908067 ;3 correct
    8157356344118483 ;1 correct
    2615250744386899 ;2 correct
    8690095851526254 ;3 correct
    6375711915077050 ;1 correct
    6913859173121360 ;1 correct
    6442889055042768 ;2 correct
    2321386104303845 ;0 correct
    2326509471271448 ;2 correct
    5251583379644322 ;2 correct
    1748270476758276 ;3 correct
    4895722652190306 ;1 correct
    3041631117224635 ;3 correct
    1841236454324589 ;3 correct
    2659862637316867 ;2 correct
""")

In [21]:
common('3847439647293047', '3174248439465858')

[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [739]:
import copy

def parse_known(s):
    parsed = [
        line_re.match(l.strip()).groups()
        for l in s.splitlines() if l.strip()
    ]
    return [(tuple(map(int, p[0])), int(p[1])) for p in parsed]


class Mastermind:
    def __init__(self, known):
        self.known = [(g, c, set()) for (g, c) in known]
        self.known.sort(key=lambda k: k[1])
        self.pegs = len(known[0][0])
        self.possibilities = np.full((self.pegs, 10), True, dtype=bool)
        for guess, correct in known:
            if correct == 0:
                self._eliminate_guess(guess)

    def __str__(self):
        lines = [''.join(str(ps.argmax()) if sum(ps) == 1 else '?' for ps in self.possibilities)]
        lines.append('possibilities:')
        for i, ps in enumerate(self.possibilities):
            ps_formatted = '|'.join(str(j) if p else 'X' for j,p in enumerate(ps))
            lines.append(f'  {i} : {ps_formatted} | c:{sum(ps)}')
        lines.append('known:')
        for g, c, a in self.known:
            line = []
            for i, m in enumerate(g):
                if i in a:
                    line.append(f'✓{m} ')
                elif not self.possibilities[i, m]:
                    line.append(f'!{m} ')
                else:
                    line.append(f' {m}?')
            lines.append(f'  {" ".join(line)} | c:{c} | a:{len(a)} | u:{self.num_unknowns(g)}' )
        return '\n'.join(lines)
    
    def _eliminate_guess(self, guess):
        self.possibilities[range(self.pegs), guess] = False
        
    def is_compatible(self, guess, correct):
        return sum(self.possibilities[range(self.pegs), guess]) >= correct
    
    def remaining(self):
        return self.possibilities.sum(axis=1)
    
    def num_unknowns(self, guess):
        return sum(self.unknowns(guess))
    
    def unknowns(self, guess):
        return self.possibilities[range(self.pegs), guess]
    
    def best_guess(self):
        #hit_rate = lambda k: (k[1]-len(k[2]))/self.num_unknowns(k[0]) if k[1] else 0
        #return max(self.known, key=hit_rate)[0]
        scores = []
        for g,c,a in self.known:
            remaining = c - len(a)
            unknowns = self.num_unknowns(g)
            if unknowns == remaining and remaining > 0:
                score = (-1, 0)
            elif remaining > 0:
                score = (remaining, unknowns)
            else:
                score = (100, 0)
            scores.append((g, score))
        #remaining = lambda k: (k[1]-len(k[2]), self.num_unknowns(k[0])) if (k[1]-len(k[2])) else (100, 100)
        return min(scores, key=lambda kv: kv[1])[0]
    
    def solvable(self):
        return (all(r > 0 for r in self.remaining()) and
                all(len(a) <= c and self.is_compatible(g, c) and self.num_unknowns(g) >= c for (g, c, a) in self.known))
    
    def solved(self):
        return all(r == 1 for r in self.remaining()) and all(len(a) == c for (g, c, a) in self.known)
    
    def support(self, pos, n):
        return sum(g[pos] == n for (g, _, _) in self.known)
    
    def assigned(self, pos, n):
        return self.possibilities[pos, n] and self.possibilities[pos].sum() == 1
    
    def eliminate(self, pos, n):
        self.possibilities[pos, n] = False
    
    def assign(self, pos, n):
        new = copy.deepcopy(self)
        new.possibilities[pos, :] = False
        new.possibilities[pos, n] = True
        for g, c, a in new.known:
            if g[pos] == n: 
                a.add(pos)

            if len(a) == c:
                # Eliminate row
                for pos2, m in enumerate(g):
                    if pos2 in a: continue
                    new.eliminate(pos2, m)
        return new
    

def possible_guesses(m, g):
    guesses = []
    for pos, can in enumerate(m.unknowns(g)):
        if not can: continue
        if m.assigned(pos, g[pos]): continue
        guesses.append((pos, g[pos]))
    guesses.sort(key=lambda k: m.support(*k))
    return guesses

In [740]:
def solve(known):
    m = Mastermind(known)
    guesses = possible_guesses(m, m.best_guess())
    history = []
    backtracked = 0
    while True:
        move = guesses.pop()
        #print(m.best_guess(), move)
        history.append((m, guesses, move))
        m = m.assign(*move)
        #print(m)
        if m.solved():
            print('solved!')
            print(m)
            break
        guesses = possible_guesses(m, m.best_guess())
        if not guesses or not m.solvable():
            while not guesses or not m.solvable():
                m, guesses, prev_move = history.pop()
                # Eliminate the previous possibility.
                #print('backtracking and eliminating', prev_move)
                m.eliminate(*prev_move)
                backtracked += 1
#         if backtracked > 2:
#             break

In [717]:
short_known = parse_known("""
    90342 ;2 correct
    70794 ;0 correct
    39458 ;2 correct
    34109 ;1 correct
    51545 ;2 correct
    12531 ;1 correct
""")

In [706]:
%%time
solve(short_known)

solved!
39542
possibilities:
  0 : X|X|X|3|X|X|X|X|X|X | c:1
  1 : X|X|X|X|X|X|X|X|X|9 | c:1
  2 : X|X|X|X|X|5|X|X|X|X | c:1
  3 : X|X|X|X|4|X|X|X|X|X | c:1
  4 : X|X|2|X|X|X|X|X|X|X | c:1
known:
  !9  !0  !3  ✓4  ✓2  | c:2 | a:2
  !7  !0  !7  !9  !4  | c:0 | a:0
  ✓3  ✓9  !4  !5  !8  | c:2 | a:2
  ✓3  !4  !1  !0  !9  | c:1 | a:1
  !5  !1  ✓5  ✓4  !5  | c:2 | a:2
  !1  !2  ✓5  !3  !1  | c:1 | a:1
CPU times: user 19 ms, sys: 95 µs, total: 19 ms
Wall time: 15.1 ms


In [745]:
long_known = parse_known("""
    5616185650518293 ;2 correct
    3847439647293047 ;1 correct
    5855462940810587 ;3 correct
    9742855507068353 ;3 correct
    4296849643607543 ;3 correct
    3174248439465858 ;1 correct
    4513559094146117 ;2 correct
    7890971548908067 ;3 correct
    8157356344118483 ;1 correct
    2615250744386899 ;2 correct
    8690095851526254 ;3 correct
    6375711915077050 ;1 correct
    6913859173121360 ;1 correct
    6442889055042768 ;2 correct
    2321386104303845 ;0 correct
    2326509471271448 ;2 correct
    5251583379644322 ;2 correct
    1748270476758276 ;3 correct
    4895722652190306 ;1 correct
    3041631117224635 ;3 correct
    1841236454324589 ;3 correct
    2659862637316867 ;2 correct
""")

In [743]:
long_known = parse_known("""
    5616185650518293 ;15 correct
    0123456789097654 ;1 correct
""")

In [746]:
%%time
solve(long_known)

KeyboardInterrupt: 