In [26]:
import re
from typing import List

In [19]:
s_example = """xmul(2,4)%&mul[3,7]!@^do_not_mul(5,5)+mul(32,64]then(mul(11,8)mul(8,5))"""

with open('input.txt', 'r') as f:
    s_input = f.read()[:-1]

# Part 1

## With regex

In [20]:
def mult_from_match(s: str) -> int:
    """
    Input is match from regex query, of the form mul(a,b) where a and b are integers, 
    each with length 1, 2 or 3
    Return a*b
    """
    
    s1, s2 = s[4:-1].split(',')
    return int(s1) * int(s2)
    
def compute_sum(s: str) -> int:
    regex_matches = re.findall(r"mul\(\d{1,3},\d{1,3}\)", s)
    return sum([mult_from_match(s) for s in regex_matches])

In [21]:
compute_sum(s_example)

161

In [22]:
compute_sum(s_input)

165225049

## Without regex

In [30]:
def check_digits(s: str) -> bool:
    """
    Check whether all characters in the string are digits, and length is appropriate  
    """
    
    return s.isdigit() and (1 <= len(s) <= 3)
    
def valid_between(s: str) -> bool:
    """
    Check if string has the form a,b where a and b are integers with 1 to 3 digits
    """
    
    L = len(s)
    i = 1
    for i in range(1, L):
        if s[i] == ',': #found comma
            return check_digits(s[:i]) and check_digits(s[i+1:])
    return False

def findall(s: str) -> List[str]:
    res = []
    L = len(s)
    i = 0
    while i < L-7:
        if s[i:i+4] == 'mul(': #found mul(
            closing_par_found = False
            for j in range(i+7, i+12): #look for closing parenthesis
                if s[j] == ')': #found closing parenthesis
                    closing_par_found = True
                    if valid_between(s[i+4:j]): #valid digits and comma inbetween
                        res.append(s[i:j+1])
                    i = j+1 #valid or not, can throw away the entire substring mul(...)
                    break #end search
            if not closing_par_found: #closing parenthesis not found where it should be
                i += 4 #can throw away substring mul(
        else: #mul( not found yet
            i += 1 #look one index further
    return res

def compute_sum2(s):
    matches = findall(s)
    return sum([mult_from_match(match) for match in matches])

In [31]:
compute_sum2(s_example)

161

In [32]:
compute_sum2(s_input)

165225049

# Part 2

In [33]:
def findall_corrupted(s):
    res = []
    L = len(s)
    i = 0
    while i < L-7:
        if s[i:i+4] == 'mul(': #found mul(
            j = i+7 #closing parenthesis is beyond i+7 inclusive
            while j <= i+11 and s[j] != ')':  #look for closing parenthesis, which is at most at index i+11
                j += 1
            if j <= i+11 and s[j] == ')': #if found closing parenthesis
                if valid_between(s[i+4:j]): #check for validity of digits and comma
                    res.append(s[i:j+1])
                else: #invalid, but may contain a do() or don't()
                    if 'do()' in s[i+4:j+1]: #found do(), which may repeat twice, but OK
                        res.append('do')
                    elif "don't()" in s[i+4:j+1]: #found don't()
                        res.append('dont')
                i = j+1 #valid or not, can throw away the entire substring mul(...)
            if j == i+12: #closing parenthesis not found where it should be
                i += 4 #can throw away substring mul(
                
        elif s[i:i+4] == 'do()': #found do()
            res.append('do')
            i += 4 #look further
            
        elif s[i:i+7] == "don't()":
            res.append('dont')
            i += 7 #look further

        else:
            i += 1
    return res

def compute_sum_corrupted(s):
    res = 0
    dont_active = False
    matches = findall_corrupted(s)
    for match in matches:
        if match == "dont":
            dont_active = True
        elif dont_active and match == "do":
            dont_active = False
        elif not(dont_active) and match != "do":
            res += mult_from_match(match)
    return res

In [34]:
s_example_corrupted = "xmul(2,4)&mul[3,7]!^don't()_mul(5,5)+mul(32,64](mul(11,8)undo()?mul(8,5))"
compute_sum_corrupted(s_example_corrupted)

48

In [35]:
compute_sum_corrupted(s_input)

108830766