In [1]:
import numpy as np

In [2]:
def get_initial_poss(n_cells, eq_cells):
    return np.array([[*['']*i, '=', *['']*(n_cells-1-i)] for i in eq_cells])

def expand_poss(poss, col_num, explist):
    newvals = np.where(poss[:, col_num:col_num+1] == '',
                       np.expand_dims(explist, 0),
                       np.column_stack((poss[:, col_num:col_num+1], 
                                        np.full((poss.shape[0], 
                                                 explist.shape[0]-1), ''))))    
    poss_out = np.repeat(poss, np.sum(newvals != '', axis=1), axis=0)
    poss_out[:, col_num] = newvals.reshape(-1)[newvals.reshape(-1) != '']
    return poss_out

def get_filters(poss, col_num):
    filters = [np.ones((poss.shape[0]), dtype=bool)]
    if col_num > 0:
        # Back to back operators
        filters += [~(is_operator(poss, col_num) & is_operator(poss, col_num-1))]
    if 1 <= col_num <= 6:
        # Number starting with zero
        filters += [~((poss[:, col_num] == '0') & is_operator(poss, col_num-1))]
    if col_num in[3, 4, 5]:
        # Operator before equals
        filters += [~(is_operator(poss, col_num) & (poss[:, col_num+1] == '='))]
        # No operator on LHS
        filters += [~((poss[:, col_num+1] == '=') 
                      & (~np.any(np.column_stack([is_operator(poss, c) 
                                                  for c in range(col_num+1)]), axis=1)))]
    return np.all(np.column_stack(filters), axis=1)

def is_operator(poss, col_num):
    return (((poss[:, col_num] != '') 
             & (poss[:, col_num].view(np.uint32) < ord('0')))
            | (poss[:, col_num].view(np.uint32) > ord('9')))

def complete_rhs(poss, col_num):
    lhs_comp_ix = np.arange(poss.shape[0])[poss[:, col_num+1] == '=']

    lhs = poss[lhs_comp_ix][:, :col_num+1].reshape(-1).view(f'<U{col_num+1}')
    lhs_result = np.round(np.array([eval(i) for i in lhs]), 6)
    targ_res_len = poss.shape[1] - col_num - 2

    filt = np.all(np.column_stack((lhs_result >= 0,
                                   np.mod(lhs_result, 1) == 0,
                                   np.maximum(np.floor(np.log10(np.maximum(lhs_result, 1e-6)))+1, 1)
                                       == targ_res_len)),
                  axis=1)
    poss[lhs_comp_ix[filt], col_num+2:] = lhs_result[filt].astype(int)\
                                                          .astype(f'<U{targ_res_len}')\
                                                          .view('<U1')\
                                                          .reshape(-1, targ_res_len)
    return np.delete(poss, lhs_comp_ix[~filt], axis=0)

def get_all_poss():
    poss = get_initial_poss(8, [4, 5, 6])
    explist = np.arange(1, 10).astype('<U1')
    for col_num in range(6):
        poss = expand_poss(poss, col_num, explist)
        poss = poss[get_filters(poss, col_num)]
        if col_num == 0:
            explist = np.append(explist, list('0+-*/'))
        elif col_num >= 3:
            poss = complete_rhs(poss, col_num)
    return poss

In [3]:
poss = get_all_poss()
poss.shape[0], poss[:10]

(17723,
 array([['1', '2', '*', '9', '=', '1', '0', '8'],
        ['1', '3', '*', '8', '=', '1', '0', '4'],
        ['1', '3', '*', '9', '=', '1', '1', '7'],
        ['1', '4', '*', '8', '=', '1', '1', '2'],
        ['1', '4', '*', '9', '=', '1', '2', '6'],
        ['1', '5', '*', '7', '=', '1', '0', '5'],
        ['1', '5', '*', '8', '=', '1', '2', '0'],
        ['1', '5', '*', '9', '=', '1', '3', '5'],
        ['1', '6', '*', '7', '=', '1', '1', '2'],
        ['1', '6', '*', '8', '=', '1', '2', '8']], dtype='<U1'))

In [166]:
i = 5
ix[i] = np.arange(poss.shape[0])[poss[:, i] == '=']
used = np.column_stack([np.where(np.any(poss[ix[i]] == c, axis=1), c, ' ') 
                        for c in '*+-/0123456789'])
sortorder = np.argsort(used.reshape(-1).view('<U14'))

In [167]:
tv1 = poss[ix[i][used.reshape(-1).view('<U14') == ' +   12345   9']].reshape(-1).view('<U8')
tv1

array(['14+25=39', '15+24=39', '24+15=39', '25+14=39', '41+52=93',
       '42+51=93', '51+42=93', '52+41=93'], dtype='<U8')

In [169]:
np.sum(poss[ix[5]].reshape(-1).view('<U8') == '14+25=39')

np.int64(1)

In [187]:
combos.reshape(-1).view('<U8').reshape(-1, 2) 

(14594820, 2)

In [203]:
n_elim[tv1] == 13

array([False, False, False, ..., False, False, False], shape=(1388,))

In [206]:
combos[tv1].reshape(-1).view('<U8').reshape(-1, 2)[n_elim[tv1] == 13]

array([['14+25=39', '56-7*8=0'],
       ['14+25=39', '56/7-8=0'],
       ['14+25=39', '60-7*8=4'],
       ['14+25=39', '60/4-8=7'],
       ['14+25=39', '70-8*8=6'],
       ['14+25=39', '90/6-8=7']], dtype='<U8')

In [196]:
tv1 = np.any(combos.reshape(-1).view('<U8').reshape(-1, 2) == '14+25=39', axis=1)

tv1[n_elim[tv1] == 13]


IndexError: boolean index did not match indexed array along axis 0; size of axis is 14594820 but size of corresponding boolean axis is 1388

In [178]:
tv1 = np.repeat(poss[ix[i]], ix[j].shape[0], axis=0).reshape(-1).view('<U8')
np.arange(tv1.shape[0])[tv1 == '14+25=39']

array([444160, 444161, 444162, ..., 445545, 445546, 445547], shape=(1388,))

In [182]:
np.repeat(poss[ix[i]], ix[j].shape[0], axis=0).shape

(14594820, 8)

In [184]:
combos.shape

(14594820, 16)

In [None]:
np.sum(np.repeat(poss[ix[i]], ix[j].shape[0], axis=0).reshape(-1).view('<U8')  == '14+25=39')

In [176]:
i, j

(5, 6)

In [183]:
# ix = dict()
# for i in range(5, 7):
#     ix[i] = np.arange(poss.shape[0])[poss[:, i] == '=']
#     used = np.column_stack([np.where(np.any(poss[ix[i]] == c, axis=1), c, ' ') 
#                             for c in '*+-/0123456789'])
#     sortorder = np.argsort(used.reshape(-1).view('<U14'))
#     ix[i] = ix[i][np.any(used[sortorder]
#                          != np.append([[' ']*14], 
#                                       used[sortorder][:-1], 
#                                       axis=0), 
#                          axis=1)[np.argsort(sortorder)]]

i = 5
j = 6

combos = np.column_stack((np.repeat(poss[ix[i]], ix[j].shape[0], axis=0),
                          np.tile(poss[ix[j]], (ix[i].shape[0], 1))))
n_elim = np.sum(np.column_stack([np.any(combos == c, axis=1) for c in '*+-/0123456789']), axis=1)
combos[n_elim == 14].reshape(-1).view('<U8').reshape(-1, 2)

array([['1+5*9=46', '30/2-7=8'],
       ['1+6*8=49', '7-20/5=3'],
       ['1+7*9=64', '8-30/5=2'],
       ['1+8*6=49', '7-20/5=3'],
       ['1+8*9=73', '40/5-2=6'],
       ['1+9*5=46', '30/2-7=8'],
       ['1+9*7=64', '8-30/5=2'],
       ['1+9*8=73', '40/5-2=6'],
       ['2+5*9=47', '18/6-3=0'],
       ['2+9*5=47', '18/6-3=0'],
       ['2*3+4=10', '96/8-5=7'],
       ['2*3+9=15', '60/4-8=7'],
       ['2*4+5=13', '90/6-8=7'],
       ['2*7+4=18', '60/5-9=3'],
       ['2*8-6=10', '35/7+4=9'],
       ['3+5*9=48', '10-6/2=7'],
       ['3+6*7=45', '2-18/9=0'],
       ['3+6*9=57', '10-4/2=8'],
       ['3+7*6=45', '2-18/9=0'],
       ['3+7*8=59', '6-20/4=1'],
       ['3+8*7=59', '6-20/4=1'],
       ['3+8*9=75', '6-20/4=1'],
       ['3+9*5=48', '10-6/2=7'],
       ['3+9*6=57', '10-4/2=8'],
       ['3+9*8=75', '6-20/4=1'],
       ['3*2+4=10', '96/8-5=7'],
       ['3*2+9=15', '60/4-8=7'],
       ['3*4+8=20', '9-56/7=1'],
       ['3*4+9=21', '56/7-8=0'],
       ['3*6-8=10', '45/9+2=7'],
       ['3

In [108]:
filt = np.all(np.column_stack([
                  poss[:, 5] == '=',
                  np.all(poss != '5', axis=1),
                  np.all(poss != '*', axis=1),
                  np.all(poss != '0', axis=1),
                  np.all(poss != '/', axis=1),
                  np.all(poss != '2', axis=1),
                  np.all(poss != '-', axis=1),
                  np.all(poss != '7', axis=1),
                  poss[:, 0] != '1',
                  np.sum(poss == '1', axis=1) >= 1,
                  poss[:, 1] != '+',
                  np.sum(poss == '+', axis=1) >= 1,
                  poss[:, 4] != '9',
                  np.sum(poss == '9', axis=1) >= 1,
                  poss[:, 6] != '4',
                  np.sum(poss == '4', axis=1) >= 1,
                  poss[:, 7] != '6',
                  np.sum(poss == '6', axis=1) >= 1,
                  poss[:, 0] != '3',
                  np.sum(poss == '3', axis=1) >= 1,
                  poss[:, 7] != '8',
                  np.sum(poss == '8', axis=1) >= 1,
              ]),
              axis=1)

poss[filt]

array([['6', '9', '+', '1', '4', '=', '8', '3']], dtype='<U1')

In [124]:
def poss_after_2(fb, g2='1+5*9=4630/2-7=8'):
    filt = []
    for i in range(16):
        if fb[i] == '2':
            filt += [poss[:, i % 8] == g2[i]]
        elif fb[i] == '0':
            filt += [np.all(poss != g2[i], axis=1)]
        else:
            filt += [poss[:, i % 8] != g2[i]]
            filt += [np.any(poss == g2[i], axis=1)]    
    filt = np.all(np.column_stack(filt), axis=1)
    return poss[filt].reshape(-1).view('<U8')

In [129]:
poss_after_2('1010020100001111')

array(['68-11=57', '68-17=51', '68-51=17', '68-57=11', '85-18=67',
       '85-68=17', '86-11=75', '86-15=71', '86-71=15', '86-75=11'],
      dtype='<U8')

In [130]:
poss_after_2('1000220110011011')

array(['82-19=63', '82-69=13'], dtype='<U8')

In [131]:
poss_after_2('0012022010000010')

array(['5*3*3=45'], dtype='<U8')

In [132]:
poss_after_2('1000120020011012')

array(['39-11=28'], dtype='<U8')

In [503]:
# words = poss.reshape(-1).view('<U8').tolist()

# From https://github.com/pedrokkrause/Nerdle-Equations/blob/main/wordle.py:
from math import log2
from tqdm import tqdm
from collections import defaultdict
from time import sleep
from copy import deepcopy

def genmaskW(expression,correct):
    mask = ['N']*len(expression)
    count = defaultdict(lambda: 0)
    for i, x in enumerate(expression):
        if x == correct[i]:
            mask[i] = '2'
            count[x] += 1
    for i,x in enumerate(expression):
        if x not in correct:
            mask[i] = '0'
        elif x != correct[i]:
            if count[x] < correct.count(x):
                mask[i] = '1'
                count[x] += 1
            else:
                mask[i] = '0'
    return(''.join(mask))

def checkmaskW(mask,expression,correct):
    if mask == genmaskW(expression,correct):
        return(True)
    else:
        return(False)

def allmasksW(expression,dictionary):
    masks = defaultdict(lambda: 0)
    for expression2 in dictionary:
        mask = genmaskW(expression,expression2)
        masks[mask] += 1
    return(masks)

def filtermaskW(mask,expression,dictionary):
    local = [x for x in dictionary if checkmaskW(mask,expression,x)]
    return(local)

def searchW(dictionary,possible=None):
    if possible==None:
        possible = dictionary
    if len(possible) == 1 or len(dictionary) == 1:
        return([possible[0],0,possible[0] in dictionary])
    best = None
    expected = 0
    length = len(possible)
    for expression in tqdm(dictionary):
        localexp = 0
        allmaskss = allmasksW(expression,possible)
        for mask in allmaskss:
            prob = allmaskss[mask]/length
            localexp += -prob*log2(prob)
        if localexp > expected or (localexp >= expected and expression in possible):
            best = expression
            expected = localexp
            if len(possible) > 10000:
                print([best,expected,best in possible])
    return([best,expected,best in possible])

possiblewords = deepcopy(words)
print("Number of possible words:",len(possiblewords))
best = ['48-32=16']
while True:
    ans1 = input("What you inserted (enter 'b' if it was the best choice): ")
    ans2 = input("Mask: ")
    if ans1 == 'b':
        ans1 = best[0]
    possiblewords = filtermaskW(ans2,ans1,possiblewords)
    print("Number of possible words:",len(possiblewords))
    best = searchW(words,possiblewords)
    print("Best choice:",best[0])
    print("Average information:",best[1])
    if len(possiblewords) < 7:
        print("Possible words:",possiblewords)
    if ans2 == '22222222':
        break
    print("============")

Number of possible words: 17723


What you inserted (enter 'b' if it was the best choice):  14+25=39
Mask:  20011110


Number of possible words: 20


100%|██████████| 17723/17723 [00:00<00:00, 18198.02it/s]


Best choice: 13-5-6=2
Average information: 4.321928094887363


What you inserted (enter 'b' if it was the best choice):  b
Mask:  22220021


Number of possible words: 1
Best choice: 13-5*2=3
Average information: 0
Possible words: ['13-5*2=3']


What you inserted (enter 'b' if it was the best choice):  b
Mask:  22222222


Number of possible words: 1
Best choice: 13-5*2=3
Average information: 0
Possible words: ['13-5*2=3']
