In [1]:
import itertools
import numpy as np
import random
from copy import deepcopy

In [2]:
puzzle_input = open('inputs/24').read().strip()

In [3]:
ex_model_num = 13579246899999
ex_model_num = 93539246899999
model_digits = [int(i) for i in ' '.join(str(ex_model_num)).split(' ')]

In [4]:
model_digits

[9, 3, 5, 3, 9, 2, 4, 6, 8, 9, 9, 9, 9, 9]

In [5]:
puzzle_input.splitlines()[:20]

['inp w',
 'mul x 0',
 'add x z',
 'mod x 26',
 'div z 1',
 'add x 12',
 'eql x w',
 'eql x 0',
 'mul y 0',
 'add y 25',
 'mul y x',
 'add y 1',
 'mul z y',
 'mul y 0',
 'add y w',
 'add y 7',
 'mul y x',
 'add z y',
 'inp w',
 'mul x 0']

(Program authors should be especially cautious; attempting to execute div with b=0 or attempting to execute mod with a<0 or b<=0 will cause the program to crash and might even damage the ALU. These operations are never intended in any serious ALU program.)

In [6]:
len(puzzle_input.splitlines())

252

Var = d1 | d2 ...  | d14
where d_i is single digit positive number

Expr = Int | Var | Opr Expr Expr

Opr Expr Expr = Add Expr Expr | Mul Expr Expr | Div Expr Expr | Mod Expr Expr | Eql Expr Expr

Add Int Int
Add (Int 0) x = x
Add x (Int 0) = x 

Div (Int x) (Int y) = Int (x // y) 
Div e (Int 1) = e
Div e y = Div e y

In [7]:
op_lookup = {
    'add': lambda x, y: x + y,
    'mul': lambda x, y: x * y,
    # 'div': lambda x, y: x // y,
    'div': lambda x, y: int(x / y),
    'mod': lambda x, y: x % y,
    'eql': lambda x, y: int(x == y),
}

In [8]:
op_str_lookup = {
    'add': "+",
    'mul': "*",
    'div': "//",
    'mod': "%",
    'eql': "=="
}

In [9]:
from dataclasses import dataclass

class Expression:
    ...

@dataclass
class MyInt(Expression):
    n: int
        
    def __repr__(self):
        return str(self.n)
    
    def __eq__(self, other):
        return type(self) == type(other) and self.n == other.n

@dataclass
class Digit(Expression):
    digit_num: int
        
    def __repr__(self):
        return f'd{self.digit_num}'
    
    def __eq__(self, other):
        return type(self) == type(other) and self.digit_num == other.digit_num

@dataclass
class Opr(Expression):
    opr: str
    left: Expression
    right: Expression
        
    def __repr__(self):
        return f"({self.left.__repr__()} {op_str_lookup[self.opr]} {self.right.__repr__()})"
    
    def __eq__(self, other):
        return type(self) == type(other) and self.opr == other.opr and self.left == other.left and self.right == other.right

In [10]:
def reduce(op: Opr) -> Expression:
    match op:
        case Opr(op_name, MyInt(x), MyInt(y)):
            return MyInt(op_lookup[op_name](x, y))

        case Opr('add', MyInt(0), y):
            return y
        
        case Opr('add', x, MyInt(0)):
            return x
        
        # (d3 + 2) + 12)
        case Opr('add', Opr('add', e, MyInt(j)), MyInt(k)):
            return Opr('add', e, MyInt(j+k))
        
        # ???????
        # ((d1 * 26) + 182) + (d2 + 8)
        # => (d1*26) + d2 + 190
        case Opr('add', Opr('add', a, MyInt(b)), Opr('add', c, MyInt(d))):
            # is this extra reduction necessary here? hmm
            return Opr('add', reduce(Opr('add', a, c)), MyInt(b+d))

        case Opr('mul', x, MyInt(1)):
            return x
        case Opr('mul', MyInt(1), y):
            return y
        case Opr('mul', x, MyInt(0)):
            return MyInt(0)
        case Opr('mul', MyInt(0), y):
            return MyInt(0)

        # distributive
        # case Opr('mul', Opr('add', a, b), c):
        #      return reduce(Opr('add', reduce(Opr('mul', a, c)), reduce(Opr('mul', b, c))))
        
        # (d1*2)*2 = d1 * 2 
        case Opr('mul', Opr('mul', Digit(d), MyInt(x)), MyInt(y)):
            return Opr('mul', Digit(d), MyInt(x*y))

        case Opr('div', x, MyInt(1)):
            return x
        
        # ((((((((d1 + 7) * 26) + (d2 + 8)) * 26) + (d3 + 2)) * 26) + (d4 + 11)) // 26)
        # check to see if the thing on left is multiple of 26
        case Opr("div", Opr('add', Opr('mul', a, MyInt(j)), b), MyInt(k)):
            print('whacked out case div')
            if j == k:
                return reduce(Opr("add", a, reduce(Opr('div', b, MyInt(j)))))
            
            return op
        
        # ((d4 + 11) // 26)
        case Opr("div", Opr("add", Digit(d), MyInt(k)), MyInt(n)):
            if 10 + k < 26:
                return MyInt(0)
            
            return op
        
        case Opr('mod', MyInt(0), y):
            return MyInt(0)

        case Opr('mod', Digit(d), MyInt(y)):
            if y >= 10:
                return Digit(d)

            return op
        
        # (a + b) % n = (a % n + b % n) % n
        # actually.. this will cause infinite recursion. need something more specific
        # case Opr('mod', Opr('add', a, b), n):
        #    return reduce(Opr('mod', reduce(Opr('add', reduce(Opr('mod', a, n)), reduce(Opr('mod', b, n)))), n))
        
        
        # (((d1 + 7) * 26) + (d2 + 8)) % 26
        
        case Opr("mod", Opr('add', Opr('mul', a, MyInt(j)), b), MyInt(k)):
            print('whacked out case mod')
            if j == k:
                return reduce(Opr('mod', b, MyInt(k)))
            
            return op
        
        
        # (a % n) % n = a % n
        case Opr('mod', Opr('mod', a, n), k):
            print('mod collapse')
            if n == k:
                # need recursion?
                return Opr('mod', a, n)
            
            return op
        
        # (d1 * 17576) % 26)
        case Opr('mod', Opr('mul', e, MyInt(x)), MyInt(n)):
            print("new case")
            if x >= n:
                # reduces necessary?
                return reduce(Opr('mod', reduce(Opr('mul', e, MyInt(x % n))), MyInt(n)))
            
            return op
        
        
        case Opr('mod', Opr('mul', a, b), e):
            print('reducing mod to 0')
            if b == e:
                return MyInt(0)
            
            return op
        

        # (d4 + 11) % 26
        # 9 + j
        case Opr('mod', Opr('add', Digit(d), MyInt(j)), MyInt(k)):
            if 9 + j < k:
                return Opr('add', Digit(d), MyInt(j))
            
            return op

        case Opr('eql', Digit(d), MyInt(y)):
            if y <= 0 or y >= 10:
                return MyInt(0)

            return op

        case Opr('eql', MyInt(x), Digit(d)):
            if x <= 0 or x >= 10:
                return MyInt(0)
            
            return op

        # ((a % b) + 20) == d1
        case Opr('eql', Opr('add', Opr('mod', a, b), MyInt(x)), Digit(d)):
            if x >= 10:
                return MyInt(0)

            return op
        
        
        # (d1 + 20) == d2
        # smallest valid configuration is
        # (d1 + 8) == d2 where d1=1 and d2=9
        case Opr('eql', Opr('add', Digit(d1), MyInt(x)), Digit(d2)):
            if x > 8:
                return MyInt(0)

            return op


        case _:
            return op

In [11]:
test_input = '''inp z
inp x
mul z 3
eql z x'''

test_input2 = '''inp w
add z w
mod z 2
div w 2
add y w
mod y 2
div w 2
add x w
mod x 2
div w 2
mod w 2'''

In [12]:
ins = puzzle_input.splitlines()

In [13]:
ins[-20:]

['mul y x',
 'add z y',
 'inp w',
 'mul x 0',
 'add x z',
 'mod x 26',
 'div z 26',
 'add x -11',
 'eql x w',
 'eql x 0',
 'mul y 0',
 'add y 25',
 'mul y x',
 'add y 1',
 'mul z y',
 'mul y 0',
 'add y w',
 'add y 5',
 'mul y x',
 'add z y']

In [14]:
ins[-18]

'inp w'

general_expressions = []

expressions = {
    'x': MyInt(0),
    'y': MyInt(0),
    'z': MyInt(0),
    'w': MyInt(0)
}

digit_counter = 0

for i, instruction in enumerate(puzzle_input.splitlines()[:100]):
    print(instruction)
    print(i)
    instr, *rest = instruction.split()

    if instr == 'inp':
        letter = rest[0]
        digit_counter += 1
        expressions[letter] = Digit(digit_counter)
    else:
        var, other = rest
        
        left = expressions[var]
        right = expressions[other] if other in expressions else MyInt(int(other))        
        
        op = Opr(instr, left, right)
        
        result = reduce(op)
        
        expressions[var] = result
    
    general_expressions.append(deepcopy(expressions))
    
    print(expressions)

for _ in range(1000):
    model_digits = [random.randint(1, 9) for _ in range(14)]
    
    d1 = model_digits[0]
    d2 = model_digits[1]
    d3 = model_digits[2]
    d4 = model_digits[3]
    d5 = model_digits[4]
    d6 = model_digits[5]
    d7 = model_digits[6]
    d8 = model_digits[7]
    d9 = model_digits[8]
    d10 = model_digits[9]
    d11 = model_digits[10]
    d12 = model_digits[11]
    d13 = model_digits[12]
    d14 = model_digits[13]
    
    expressions = {
        'x': MyInt(0),
        'y': MyInt(0),
        'z': MyInt(0),
        'w': MyInt(0)
    }

    digit_counter = 0
    
    expressions_memory = []

    for i, instruction in enumerate(puzzle_input.splitlines()):
        instr, *rest = instruction.split()

        if instr == 'inp':
            letter = rest[0]
            digit_counter += 1
            expressions[letter] = MyInt(model_digits[digit_counter-1])
        else:
            var, other = rest

            left = expressions[var]
            right = expressions[other] if other in expressions else MyInt(int(other))

            op = Opr(instr, left, right)

            result = reduce(op)

            expressions[var] = result
            
        # check consistency with general case

        for k in expressions:
            if not expressions[k].n == eval(str(general_expressions[i][k])):
                print(instruction)
                
                print(k)
                print(model_digits)
                print(f"step: {i}")
                
                print("Previous concrete couple steps:")
                print(expressions_memory[-2])
                print(expressions_memory[-1])
                print("Current step:")
                print(expressions)
                
                print("Previous general couple steps:")
                print(general_expressions[i-2])
                print(general_expressions[i-1])
                
                print("Failing step:")
                print(general_expressions[i])

                print(eval(str(general_expressions[i][k])))
                assert False
                    
        expressions_memory.append(deepcopy(expressions))

In [None]:
d1=9
d2=9
d3=9
d4=1
d5=9
d6=9
d7=9
d8=7
d9=1
d10=8
d11=6
d12=5
d13=9
d14=6

eval(str(general_expressions[-1]['z']))