# Imports and Utilities

In [None]:
import numpy as np
import bisect
import re
from collections import Counter, defaultdict, deque
import itertools
import functools


In [None]:
def read_input(day, map_f=lambda x:x.strip(), newline='\n'):
    with open(f"input{day}.txt", 'r') as f:
        return list(map(map_f, f.read().strip().split(newline)))

def binary_search(a, x, low=0, hi=None):
    if hi is None:
        hi = len(a)
    'Locate the leftmost value exactly equal to x'
    i = bisect.bisect_left(a, x, low, hi)
    if i != len(a) and a[i] == x:
        return i
    return -1

# Day 1

In [None]:
target_sum = 2020
numbers_count = Counter(read_input(1, lambda x: int(x)))

def find_pair(numbers_count, target_sum):
    for n, count in numbers_count.items():
        left = target_sum - n
        if left == n and count >= 2:
            print(f"Numbers found: {n},{left}\tProduct: {left*n}")
            return
        elif left in numbers_count:
            print(f"Numbers found: {n},{left}\tProduct: {left*n}")
            return
    print(f"Pair with sum equal to {target_sum} not found.")


def find_triple(numbers_count, target_sum):
    for n1, count1 in numbers_count.items():
        for n2, count2 in numbers_count.items():
            if n1 == n2 and count1 <= 1:
                continue
            left = target_sum - n1 - n2
            if left < 0:
                continue
            if left == n1 == n2 and count1 >= 3:
                print(f"Numbers found: {n1},{n2},{left}\tProduct: {left*n1*n2}")
                return
            elif left in numbers_count:
                print(f"Numbers found: {n1},{n2},{left}\tProduct: {left*n1*n2}")
                return
    print(f"Triple with sum equal to {target_sum} not found.")

find_pair(numbers_count, target_sum)
find_triple(numbers_count, target_sum)

# Day 2

In [None]:
password_list = read_input(2, lambda x: re.match(r"(\d+)\-(\d+) ([a-z]): ([a-z]+)", x).groups())

def count_valid_passwords(password_list):
    valid_m1 = 0
    valid_m2 = 0
    for l, h, c, password in password_list:
        l = int(l)
        h = int(h)
        counter = Counter(password)
        if l <= counter[c] <= h:
            valid_m1 += 1
            
        is_l = password[l-1] == c
        is_h = password[h-1] == c

        if is_l ^ is_h:
            valid_m2 += 1

    print(f"Valid passwords according to first method: {valid_m1}")
    print(f"Valid passwords according to second method: {valid_m2}")

count_valid_passwords(password_list)

# Day 3

In [None]:
map_t = read_input(3)

def count_trees_on_slope(map_t, slope):
    d_j, d_i = slope
    rows = len(map_t)
    cols = len(map_t[0])
    i, j = 0, 0
    trees = 0
    while True:
        i += d_i
        j = (j + d_j) % cols
        if i >= rows:
            break
        if map_t[i][j] == '#':
            trees += 1
    return trees

trees_31 = count_trees_on_slope(map_t, (3,1))
print(f"Trees count: {trees_31}")

slopes = [(1, 1), (3, 1), (5, 1), (7, 1), (1, 2)]
trees_product = 1
for slope in slopes:
    trees_product *= count_trees_on_slope(map_t, slope)
print(f"Trees product for different slopes: {trees_product}")

# Day 4

In [None]:
passports = []
for pair_list in read_input(4, map_f = lambda x: re.split('\n| ', x), newline='\n\n'):
    pp = {}
    for pair in pair_list:
        k, v = pair.split(':')
        pp[k] = v
    passports.append(pp)

def check_height(s):
    if match:=re.match(r"^([0-9]+)cm$", s):
        if 150 <= int(match.groups()[0]) <= 193:
            return True
        else:
            return False
    elif match:=re.match(r"^([0-9]+)in$", s):
        if 59 <= int(match.groups()[0]) <= 76:
            return True
        else:
            return False
    else:
        return False
        
keys_and_checks = [
    ('byr', lambda x: re.match(r"^[0-9]{4}$", x) and 1920 <= int(x) <= 2002), 
    ('iyr', lambda x: re.match(r"^[0-9]{4}$", x) and 2010 <= int(x) <= 2020), 
    ('eyr', lambda x: re.match(r"^[0-9]{4}$", x) and 2020 <= int(x) <= 2030),
    ('hgt', check_height),
    ('hcl', lambda x: re.match(r"^\#[0-9a-f]{6}$", x)), 
    ('ecl', lambda x: any([x == s for s in ["amb", "blu", "brn", "gry", "grn", "hzl", "oth"]])), 
    ('pid', lambda x: re.match(r"^[0-9]{9}$", x))
]

def count_valid_passports(passports, necessary_keys):
    valid_m1 = 0
    valid_m2 = 0
    for passport in passports:
        if not all([key in passport for key, _ in necessary_keys]):
            continue
        valid_m1 += 1
        if all([f(passport[key]) for key, f in necessary_keys]):
            valid_m2 += 1
    return valid_m1, valid_m2

valid_m1, valid_m2 = count_valid_passports(passports, keys_and_checks)
print(f"Valid passports: {valid_m1}, {valid_m2}")

# Day 5

In [None]:
rows = 128
cols = 8

def decode_row(s):
    assert(len(s) == 7)
    s = s.replace('B', '1')
    s = s.replace('F', '0')
    return int(s, 2)

def decode_col(s):
    assert(len(s) == 3)
    s = s.replace('L', '0')
    s = s.replace('R', '1')
    return int(s, 2)

def decode_seat(s):
    row = decode_row(s[:7])
    col = decode_col(s[7:])
    return (row, col)

def seat2id(x):
    return x[0] * cols + x[1]

def id2seat(seat_id):
    return (seat_id // cols, seat_id % cols)

seats = set(read_input(5, map_f=decode_seat))

max_id_seat = max(seats, key=seat2id)
print(f"Max seat id: {seat2id(max_id_seat)}")

for i in range(1, rows - 1):
    for j in range(cols):
        if (i, j) not in seats:
            seat_id = seat2id((i, j))
            seat_p = id2seat(seat_id + 1)
            seat_m = id2seat(seat_id - 1)
            if seat_p in seats and seat_m in seats:
                print(f"My seat id: {seat_id}")

# Day 6    

In [None]:
group_answers = read_input(6, newline='\n\n', map_f=lambda x: re.split(r"\n", x))

def count_answers_anyone(group_answers):
    count = 0
    for group in group_answers:
        count += len(set("".join(group)))
    return count

def count_answers_everyone(group_answers):
    count = 0
    for group in group_answers:
        group_counter = Counter("".join(group))
        for k, v in group_counter.items():
            if v == len(group):
                count += 1
    return count

print(f"Count of questions anyone in group answered: {count_answers_anyone(group_answers)}")
print(f"Count of questions everyone in group answered: {count_answers_everyone(group_answers)}")

# Day 7

In [None]:
rules_list = read_input(7, map_f=lambda x: re.match(r"([a-z ]+) bags contain ([0-9a-z, ]*)", x).groups())

color2content = defaultdict(list)
color2parents = defaultdict(list)

for bag, content in rules_list:
    bag = bag.strip()
    for count, color in re.findall(r"([0-9]+)([a-z ]+) bag", content):
        count = int(count)
        color = color.strip()
        color2content[bag] += [(count, color)]
        color2parents[color] += [bag]

def count_bags_containing_color(color, color2parents):
    visited = set()
    opened_list = list(color2parents[color])
    while opened_list != []:
        current = opened_list.pop()
        if current in visited:
            continue
        visited.add(current)
        opened_list += color2parents[current]
    return len(visited)

def count_total_bags_contained(color, color2content):
    total = 0
    for count, c in color2content[color]:
        total += count + count * count_total_bags_contained(c, color2content)
    return total

print(f"Total number of bags that can eventually contain shiny gold bag: {count_bags_containing_color('shiny gold', color2parents)}")
print(f"Total number of bags shiny gold bag contains: {count_total_bags_contained('shiny gold', color2content)}")


# Day 8

In [None]:
instructions = read_input(8, map_f=lambda x: x.split())

def acc_before_instr_repeats(instructions):
    acc = 0
    pc = 0
    executed = set()
    while True:
        if pc in executed:
            return acc, -1
        if pc >= len(instructions):
            return acc, +1
        op, arg = instructions[pc]
        executed.add(pc)
        if op == "acc":
            acc += int(arg)
            pc += 1
        elif op == "jmp":
            pc += int(arg)
        else:
            pc += 1

def fix_code(instructions):
    for i, (op, arg) in enumerate(instructions):
        if op == "jmp":
            instructions[i] = ("nop", arg)
            acc, ret = acc_before_instr_repeats(instructions)
            if ret == 1:
                return acc
            instructions[i] = (op, arg)
        elif op == "nop":
            instructions[i] = ("jmp", arg)
            acc, ret = acc_before_instr_repeats(instructions)
            if ret == 1:
                return acc
            instructions[i] = (op, arg)

print(f"Acc before instruction repeats: {acc_before_instr_repeats(instructions)[0]}")
print(f"Acc after fixing the code: {fix_code(instructions)}")

# Day 9

In [None]:
numbers = read_input(9, map_f=int)

def sum_exists(num_list, target):
    num_count = Counter(num_list)
    for n1 in num_list:
        n2 = target - n1
        if n1 == n2 and num_count[n1] >= 2:
            return True
        elif n2 in num_count:
            return True
    return False

def find_first_invalid_num(num_list, preamble_len):
    for i, n in zip(range(preamble_len, len(numbers)), numbers[preamble_len:]):
        if not sum_exists(numbers[i-preamble_len: i], n):
            return n

def find_contigous_sum(num_list, target):
    l = 0
    r = 1
    sum_val = num_list[l]
    while True:
        if sum_val == invalid_num:
            return l, r, min(numbers[l:r]) + max(numbers[l:r])
        elif sum_val < invalid_num:
            r += 1
            sum_val += num_list[r - 1]
        elif sum_val > invalid_num:
            sum_val -= num_list[l]
            l += 1

invalid_num = find_first_invalid_num(numbers, 25)
print(f"First invalid number: {invalid_num}")
left, right, min_max_sum = find_contigous_sum(numbers, invalid_num)
print(f"Min max sum: {min_max_sum}")

# Day 10

In [None]:
adapters = sorted(read_input(10, int))

def count_jolt_differences(adapters):
    diff = Counter(np.array(adapters + [adapters[-1] + 3]) - np.array([0] + adapters))
    return diff[1] * diff[3]

def count_diff_arangements(adapters):
    counts = np.zeros(adapters[-1] + 1)
    counts[0] = 1
    for adapter in adapters:
        counts[adapter] = np.sum(counts[max(0, adapter-3):adapter])
    return counts[-1]

print(f"Multiplied count of differences equal to 3 and 1: {count_jolt_differences(adapters)}")
print(f"Number of different arrangements: {count_diff_arangements(adapters)}")

# Day 11

In [None]:
initial_map = np.array(read_input(11, map_f = list))

DIRECTIONS = list(itertools.product(range(-1, 2), range(-1, 2)))
DIRECTIONS.remove((0, 0))

def count_occupied_at_equilibrium(current, neighbor_f, occupied_tresh):
    rows, cols = current.shape
    while True:
        new = current.copy()
        for i in range(rows):
            for j in range(cols):
                if current[i, j] == '.':
                    continue
                neighborhood = neighbor_f(current, i, j)
                occupied = np.count_nonzero(neighborhood == '#')
                if (current[i, j] == 'L') and (occupied == 0):
                    new[i, j] = '#'
                if (current[i, j] == '#') and (occupied >= occupied_tresh):
                    new[i, j] = 'L'
        if(np.all(current == new)):
            return np.count_nonzero(current == '#')
        current = new

def find_neighbors1(mat, i, j):
    rows, cols = mat.shape
    neighborhood = {}
    for d_i, d_j in DIRECTIONS:
        n_i, n_j = i + d_i, j + d_j
        if n_i >= rows or n_i < 0 or n_j >= cols or n_j < 0:
            neighborhood[(d_i, d_j)] = '.'
        else:
            neighborhood[(d_i, d_j)] = mat[n_i, n_j]
    return np.array(list(neighborhood.values()))

def find_neighbors2(mat, i, j):
    rows, cols = mat.shape
    neighborhood = {}
    for k in range(1, max(rows, cols)):
        if len(neighborhood.keys()) == len(DIRECTIONS):
            break
        for d_i, d_j in DIRECTIONS:
            if (d_i, d_j) in neighborhood: 
                continue
            n_i, n_j = i + k * d_i, j + k * d_j
            if n_i >= rows or n_i < 0 or n_j >= cols or n_j < 0:
                neighborhood[(d_i, d_j)] = '.'
            elif mat[n_i, n_j] == 'L' or mat[n_i, n_j] == '#':
                neighborhood[(d_i, d_j)] = mat[n_i, n_j]
    
    neighbors = list(neighborhood.values())
    assert len(neighbors) == 8
    return np.array(list(neighborhood.values()))

print("Occupied space at equilibrium first part: ", count_occupied_at_equilibrium(initial_map.copy(), find_neighbors1, 4))
print("Occupied space at equilibrium second part: ", count_occupied_at_equilibrium(initial_map.copy(), find_neighbors2, 5))

# Day 12

In [None]:
instructions = read_input(12, lambda x: (x[0], int(x[1:])))

directions_dict = {command: np.array(d) for command, d in [
    ('N', [0, 1]),
    ('S', [0, -1]),
    ('E', [1, 0]),
    ('W', [-1, 0])
]}

def rotate(x, theta):
    theta = np.deg2rad(theta)
    c, s = np.cos(theta), np.sin(theta)
    R = np.array(((c, -s), (s, c)))
    return np.round(x @ R.T)

def find_final_position(instructions):
    pos = np.array([0., 0])
    direction = np.array([1., 0])
    for command, n in instructions:
        if command == 'L':
            direction = rotate(direction, n)
        elif command == 'R':
            direction = rotate(direction, -n)
        elif command == 'F':
            pos += n * direction
        else:
            pos += n * directions_dict[command]
    return pos

def find_final_position_following_waypoint(instructions):
    w_pos = np.array([10., 1])
    s_pos = np.array([0., 0])

    for i, (command, n) in enumerate(instructions):
        d = (w_pos - s_pos)
        if command == 'L':
            w_pos = rotate(d, n) + s_pos
        elif command == 'R':
            w_pos = rotate(d, -n) + s_pos
        elif command == 'F':
            s_pos += n * d
            w_pos = s_pos + d
        else:
            w_pos += n * directions_dict[command]
    return s_pos
    
pos = find_final_position(instructions)
print(f"Final position following rules from part 1: {pos}; Manhattan distance: {np.abs(pos).sum()}")
pos = find_final_position_following_waypoint(instructions)
print(f"Final position following rules from part 2: {pos}; Manhattan distance: {np.abs(pos).sum()}")


# Day 13

In [None]:
earliest_timestamp, buses = read_input(13)
earliest_timestamp = int(earliest_timestamp)

def egcd(a, b):
    q = 0
    r_old = a
    r = b
    s_old, s = 1, 0
    t_old, t = 0, 1
    
    while r != 0:
        q = r_old // r
        r_old, r = r, r_old % r
        s_old, s = s, s_old - s * q
        t_old, t = t, t_old - t * q
    return r_old, s_old, t_old
        
def invmod(a, m):
    g, x, y = egcd(a, m)
    assert(g == 1)
    return x % m
    
def chinese_remainder(divs, rems):
    N = 1
    for d in divs:
        N *= d
    total_sum = 0
    for i, (d_i, r_i) in enumerate(zip(divs, rems)):        
        rest = N // d_i
        total_sum += rest * r_i * invmod(rest, d_i)
    return total_sum % N

def find_next_departure(t, buses):    
    known_buses = [int(x) for x in buses.split(',') if x != 'x']
    next_departure = [(earliest_timestamp // x + 1) * x for x in known_buses]
    index = np.argmin(next_departure)
    return next_departure[index], known_buses[index]

def find_earliest_timestamp_offsets(buses):
    buses = buses.split(',')
    dt_max = len(buses) - 1
    n, a = [], []
    for dt, bus_id in enumerate(buses):
        if bus_id == 'x':
            continue
        n += [int(bus_id)]
        a += [dt_max - dt]
    return chinese_remainder(n, a) - dt_max

ts, bus_id = find_next_departure(earliest_timestamp, buses)

print(f"Bus with id {bus_id} departs at {ts}. Waiting time multiplied by id: {(ts - earliest_timestamp) * bus_id}")
print(f"Earliest timestamp such that all of the listed bus IDs\n" 
      f"depart at offsets matching their positions in the list:", find_earliest_timestamp_offsets(buses))

# Day 14

In [None]:
instructions = read_input(14)

def mask_num(mask, n):
    mask = np.array(list(mask))
    n_bin = np.array(list(np.binary_repr(n, 36)))
    n_bin[mask == "1"] = "1"
    n_bin[mask == "0"] = "0"
    
    return int("0b" + "".join(n_bin), 2)

def get_all_addr(mask, addr):
    mask = np.array(list(mask))
    addr_bin = np.array(list(np.binary_repr(addr, 36)))
    addr_bin[mask == "1"] = "1"
    X_count = np.count_nonzero(mask == "X")
    if X_count == 0:
        return [addr_bin]
    
    result = []
    for comb in itertools.product("01", repeat=X_count):
        addr_bin[mask == "X"] = np.array(comb)
        result.append(int("0b" + "".join(addr_bin), 2))
    return result

def memory_sum_part1(instructions):
    mask = "X" * 36
    mem = {}
    for ins in instructions:
        if ins.startswith("mask"):
            mask = ins[7:]
        else:
            addr, val = re.match(r"mem\[([0-9]+)\] = ([0-9]+)" ,ins).groups()
            mem[addr] = mask_num(mask, int(val))
    return sum(mem.values())

def memory_sum_part2(instructions):
# Inefficient solution. Works only for small number of Xs in mask
    mask = "X" * 36
    mem = {}
    for ins in instructions:
        if ins.startswith("mask"):
            mask = ins[7:]
        else:
            addr, val = re.match(r"mem\[([0-9]+)\] = ([0-9]+)" ,ins).groups()
            for addr in get_all_addr(mask, int(addr)):
                mem[addr] = int(val)

    return sum(mem.values())

print(f"Memory sum the first part: {memory_sum_part1(instructions)}")
print(f"Memory sum the second part: {memory_sum_part2(instructions)}")

# Day 15

In [None]:
input_numbers = [0,13,1,8,6,15]

def play_game_2020num(nums, ret_iter=2020):
    nums_d = defaultdict(list)
    for i, n in enumerate(input_numbers):
        nums_d[n] += [i + 1]
    last_n = input_numbers[-1]
    i = i + 2
    while True:
        if len(nums_d[last_n]) == 1:
            n = 0
        else:
            n = i - nums_d[last_n][-2] - 1
        nums_d[n] += [i]
        last_n = n
        if i == ret_iter:
            return n
        i += 1
print("2020th number: ", play_game_2020num(input_numbers))
print("30000000th number: ", play_game_2020num(input_numbers, 30000000))


# Day 16

In [None]:
lines = read_input(16)
fields = defaultdict(list)

for i, line in enumerate(lines):
    if line == '':
        break
    name, x1, x2, y1, y2 = re.match('([a-z ]+): ([0-9]+)-([0-9]+) or ([0-9]+)-([0-9]+)', line).groups()
    fields[name] = [[int(x1), int(x2)], [int(y1), int(y2)]]

my_ticket = [int(x) for x in lines[i+2].split(",")]
tickets = []
for line in lines[i+5:]:
    tickets.append([int(x) for x in line.split(",")])

functools.reduce(lambda a, b: a + b, fields.keys())

def is_element_of_any_interval(intervals, n):
    return any([i[0] <= n <= i[1] for i in intervals])
    
def check_tickets(fields, tickets):
    total_error = 0
    all_intervals = functools.reduce(lambda a, b: a + b, fields.values())
    valid_tickets = []
    for ticket in tickets:
        valid = True
        for n in ticket:
            if not is_element_of_any_interval(all_intervals, n):
                total_error += n
                valid = False
        if valid:
            valid_tickets.append(ticket)
    return total_error, valid_tickets

def check_field(values, intervals):
    checks = []
    for x1, x2 in intervals:
        checks.append(np.logical_and(x1 <= values, values <= x2))
    checks = np.array(checks)
    return np.sum(checks, axis=0).all()

total_error, valid_tickets = check_tickets(fields, tickets)
print("Total error: ", total_error)
valid_tickets.append(my_ticket)

def identify_fields(fields, tickets):
    tickets = np.array(tickets)
    name2columns = defaultdict(set)
    for col in range(tickets.shape[1]):
        for name, intervals in fields.items():
            if check_field(tickets[:, col], intervals):
                name2columns[name].add(col)
    
    decoded = {}
    
    for i in range(len(fields.keys())):
        for field, columns in name2columns.items():
            if len(columns) == 1:
                c = columns.pop()
                decoded[field] = c
                columns.add(c)
                break
        for f, c in name2columns.items():
            name2columns[f] = name2columns[f] - columns
        del name2columns[field]

    return decoded

def departure_product(fields, tickets):
    decoded = identify_fields(fields, tickets)
    my_ticket = tickets[-1]
    mul = 1
    for name, c in decoded.items():
        if "departure" in name:
            mul *= my_ticket[c]
    return mul

print("Departure fields product: ", departure_product(fields, valid_tickets))            

# Day 17

In [None]:
initial_state = read_input(17)

def get_neighbors(pos):
    for d_pos in itertools.product([-1, 0, 1], repeat=len(pos)):
        if all([x == 0 for x in d_pos]):
            continue
        yield tuple([x + d_x for x, d_x in zip(pos, d_pos)])

def count_active_neighbors(active, pos):
    count = 0
    for n_pos in get_neighbors(pos):
        if n_pos in active:
            count += 1
    return count

def count_active(initial_state, cycles=6, dim=3):
    rows, cols = len(initial_state), len(initial_state[0])
    active = set()
    for i in range(len(initial_state)):
        for j in range(len(initial_state[0])):
            if initial_state[i][j] == '#':
                active.add(tuple([i, j] + [0] * (dim - 2)))
                
    for cycle in range(cycles):
        new_active = set()
        inactive_visited = set()
        for pos in active:
            active_n = count_active_neighbors(active, pos)
            if active_n in [2, 3]:
                new_active.add(pos)
            for n_pos in get_neighbors(pos):
                if n_pos not in inactive_visited and n_pos not in active:
                    active_n = count_active_neighbors(active, n_pos)
                    if active_n == 3:
                        new_active.add(n_pos)
                    inactive_visited.add(n_pos)
        active = new_active
    return len(active)

print(f"Cubes left active after 6 cycles in 3D space: {count_active(initial_state, 6, 3)}")
print(f"Cubes left active after 6 cycles in 4D space: {count_active(initial_state, 6, 4)}")

# Day 18

In [None]:
expressions = read_input(18, map_f=lambda s: s.replace("(", " ( ").replace(")", " ) ").split())

op_precedence1 = {"+": 1, 
                 "*": 1}

op_precedence2 = {"+": 2, 
                 "*": 1}
def is_int(s):
    try: 
        int(s)
        return True
    except ValueError:
        return False
    
def shunting_yard_parse(s, op_precedence):
    output_q = deque()
    operator_stack = []
    for token in s:
        if is_int(token):
            output_q.append(token)
        elif token == "(":
            operator_stack.append(token)
        elif token == ")":
            while operator_stack[-1] != "(":
                output_q.append(operator_stack.pop())
            if operator_stack[-1] == "(":
                operator_stack.pop()
        elif token in "+*":
            token_p = op_precedence[token]
            while len(operator_stack) != 0 and operator_stack[-1] != "(" \
                and op_precedence[operator_stack[-1]] >= token_p:
                output_q.append(operator_stack.pop())
            operator_stack.append(token)            
    while len(operator_stack) != 0:
        output_q.append(operator_stack.pop())
    return output_q

def evaluate_postfix(expression):
    stack = []
    for token in expression:
        if is_int(token):
            stack.append(int(token))
        elif token == "+":
            stack.append(stack.pop() + stack.pop())
        elif token == "*":
            stack.append(stack.pop() * stack.pop())
    return stack.pop()

def sum_of_all_results(expressions, op_precedence):
    total = 0
    for exp in expressions:
        total += evaluate_postfix(shunting_yard_parse(exp, op_precedence))
    return total

print(f"Sum of all results with equal precedence: {sum_of_all_results(expressions, op_precedence1)}")
print(f"Sum of all results with addition higher precedence: {sum_of_all_results(expressions, op_precedence2)}")


# Day 19

In [None]:
lines = read_input(19)

def unary_rule(c, s):
    if s == "":
        return [(False, s)]
    if s[0] == c:
        return [(True, s[1:])]
    else:
        return [(False, s)]

def and_rule(rule_ids, s):
    rets = [(True, s)]
    for rule_id in rule_ids:
        new_rets = []
        for check, new_s in rets:
            if check == True:
                new_rets.extend(rules[rule_id](new_s))
        rets = new_rets
    return rets

def or_and_rule(rule_ids, s):
    results = []
    for rules_set in rule_ids:
        rets = and_rule(rules_set, s)
        if rets != []:
            for check, new_s in rets:
                if not check:
                    continue
                results.append((True, new_s))
    if len(results) >= 1:
        return results
    return [(False, s)]

def load_rules_and_examples(lines):
    rules = {}
    for i, line in enumerate(lines):
        if line == "":
            break
        rule_id, rule = line.split(": ")
        if "a" in rule:
            rules[rule_id] = functools.partial(unary_rule, "a")
        elif "b" in rule:
            rules[rule_id] = functools.partial(unary_rule, "b")
        elif "|" in rule:
            or_and_rule_sets = [x.split() for x in rule.split(" | ")]
            rules[rule_id] = functools.partial(or_and_rule, or_and_rule_sets)
        else:
            rules[rule_id] = functools.partial(and_rule, rule.split())
    
    return rules, lines[i+1:]

def check_example(rules, ex):
    rets = rules["0"](ex)
    if rets != []:
        for check, s in rets:
            if check and s == "":
                return True
    return False

def count_matching_examples(rules, examples):
    count = 0
    for ex in examples:
        if check_example(rules, ex):
            count += 1
    return count

rules, examples = load_rules_and_examples(lines)

print("Part 1: number of messages completely match rule 0:",  count_matching_examples(rules, examples))

rules["8"] = functools.partial(or_and_rule, [["42"], ["42", "8"]])
rules["11"] = functools.partial(or_and_rule, [["42", "31"], ["42", "11", "31"]])


print("Part 2: number of messages completely match rule 0:",  count_matching_examples(rules, examples))

# Day 20

In [None]:
 class Tile:
        def __init__(self, id, data):
            self.data = data
            self.id = id
        
        def __str__(self):
            return "\n".join([" ".join(row) for row in self.data])
        
        def __repr__(self):
            return f"Tile({self.id},\n{str(self)})"
        
        def rotate_cw(self, times=1):
            self.data = np.rot90(self.data, -1 * times)
            return self
        
        def rotate_ccw(self, times=1):
            self.data = np.rot90(self.data, times)
            return self
        
        def flip_h(self):
            self.data = self.data[:, ::-1]
            return self
        
        def flip_v(self):
            self.data = self.data[::-1, :]
            return self
        
        def __str__(self):
            return "\n".join([" ".join(row) for row in self.data])
        
        def borders(self):
            return [self.left(), self.top(), self.right(), self.bottom()]
            
        def left(self):
            return "".join(self.data[:,0])
        
        def right(self):
            return "".join(self.data[:,-1])
        
        def top(self):
            return "".join(self.data[0,:])
        
        def bottom(self):
            return "".join(self.data[-1,:])
        
        def connect_left(self, tile):
            border_r = tile.right()
            if self.left() == border_r:
                return self
            elif self.left() == border_r[::-1]:
                return self.flip_v()
            elif self.top() == border_r:
                return self.rotate_ccw().flip_v()
            elif self.top() == border_r[::-1]:
                return self.rotate_ccw()
            elif self.right() == border_r:
                return self.flip_h()
            elif self.right() == border_r[::-1]:
                return self.rotate_ccw(2)
            elif self.bottom() == border_r:
                return self.rotate_cw()
            elif self.bottom() == border_r[::-1]:
                return self.rotate_cw().flip_v()
            else:
                return None
        
        def connect_top(self, tile):
            tile.rotate_ccw(1)
            res = self.connect_left(tile)
            tile.rotate_cw(1)
            if res is not None:
                res.rotate_cw(1)
                return res
            else:
                return None
        
        def border_counts(self, counts):
            return [max(counts[b], counts[b[::-1]]) for b in self.borders()]

def load_tiles(lines):
    tiles = []
    i = 0
    while i < len(lines):
        match = re.match(r"Tile ([0-9]+):", lines[i])
        if match is not None:
            tile_data = np.array([list(l) for l in lines[i+1:i+11]])
            assert(tile_data.shape == (10, 10))
            tile_id = int(match.groups(0)[0])
            tiles.append(Tile(tile_id, tile_data))
            i += 11
        i += 1

    borders_counter = Counter()
    for tile in tiles:
        for tb in tile.borders():
            if tb in borders_counter:
                borders_counter[tb] += 1
            elif tb[::-1] in borders_counter:
                borders_counter[tb[::-1]] += 1
            else:
                borders_counter[tb] += 1
    return tiles, borders_counter

def find_corner_tiles(tiles, borders_counter):
    corner_tiles = []
    for tile in tiles:
        tile_border_counts = []
        for tb in tile.borders():
            if tb in borders_counter:
                tile_border_counts += [borders_counter[tb]]
            elif tb[::-1] in borders_counter:
                tile_border_counts += [borders_counter[tb[::-1]]]
        if sum(tile_border_counts) == 6:
            corner_tiles.append(tile)
    return corner_tiles

def assemble_image(tiles, corner_tiles, borders_counter):
    rows = cols = int(np.sqrt(len(tiles)))
    while corner_tiles[0].border_counts(borders_counter) != [1, 1, 2, 2]:
            corner_tiles[0].rotate_ccw(1)
    image = [[None] * cols for i in range(rows)]
    image[0][0] = corner_tiles[0]
    tiles.remove(corner_tiles[0])
    for i in range(rows):
        for j in range(cols):
            if (i, j) == (0, 0):
                continue
            if i == 0:
                for tile in tiles:
                    if tile.connect_left(image[i][j - 1]) is not None:
                        image[i][j] = tile
                        break
                tiles.remove(tile)
            else:
                for tile in tiles:
                    if tile.connect_top(image[i - 1][j]) is not None:
                        image[i][j] = tile
                        break
                tiles.remove(tile)
    return Tile(42, np.vstack([np.hstack([t.data[1:-1, 1:-1] for t in row]) for row in image]))

def count_monsters(tile):
    def _count(tile):
        count = 0
        arr = tile.data
        r, c = arr.shape
        for i in range(r - 2):
            for j in range(18, c - 1):
                if arr[i, j] == '#' and (arr[i + 1, [j - 18, j - 13, j - 12, j - 7, j - 6, j - 1, j, j + 1]] == np.array(['#'] * 8)).all() and (arr[i+2, [j - 2, j - 5, j - 8, j - 11, j - 14, j - 17]] == np.array(['#'] * 6)).all():
                    count += 1
        return count
    for i in range(4):
        tile.rotate_cw()
        count = _count(tile)
        if count != 0:
            return count

lines = read_input(20)
tiles, borders_counter = load_tiles(lines)
corner_tiles = find_corner_tiles(tiles, borders_counter)

mul = 1
for t in corner_tiles:
    mul *= t.id
print(f"Product of ids of four corner tiles: {mul}")

image = assemble_image(tiles, corner_tiles, borders_counter)
monsters_count = count_monsters(image)
print(f"Monsters found: {monsters_count}")
print(f"Count of # not part of a monster: {np.sum(image.data == '#') - monsters_count * 15}")

# Day 21

In [None]:
lines = read_input(21)

def process_input21(lines):
    a2i = defaultdict(list)
    ingredients_counter = Counter()
    allergens_counter = Counter()
    for line in lines:
        ingredients, allergens = line[:-1].split(" (contains ")
        ingredients = ingredients.split()
        allergens = allergens.strip().split(", ")
        ingredients_counter.update(ingredients)
        allergens_counter.update(allergens)
        for a in allergens:
            a2i[a].append(set(ingredients))
    
    for k, v in a2i.items():
        a2i[k] = set(functools.reduce(lambda x, y: x & y, v))
    return ingredients_counter, allergens_counter, a2i

def map_allergen_to_ingredient(alergen2ingredients):
    ret_dict = {}
    for i in range(len(alergen2ingredients)):
        to_delete = []
        for k1, v1 in alergen2ingredients.items():
            if len(v1) == 1:
                for k2, v2 in alergen2ingredients.items():
                    if k1 != k2:
                        v2.difference_update(v1)
                to_delete.append(k1)
        for k in to_delete:
            ret_dict[k] = alergen2ingredients[k].pop()
            del alergen2ingredients[k]
    return ret_dict

def count_non_allergic_ingredients(ingredients_counter, allergic_ingredients):
    count = 0
    for k, c in ingredients_counter.items():
        if k not in allergic_ingredients:
            count += c
    return count

def determine_danger_list(a2i, i2a):
     return ",".join(sorted(a2i.values(), key= lambda x: i2a[x]))
    
ingredients_counter, allergens_counter, a2possible_ingredients = process_input21(lines)

a2i = map_allergen_to_ingredient(a2possible_ingredients)
i2a = {v: k for k, v in a2i.items()}

allergic_ingredients = a2i.values()

print("Occurences of non allergic ingredients: ", count_non_allergic_ingredients(ingredients_counter, allergic_ingredients))
print("Canonical dangerous ingredient list: ", determine_danger_list(a2i, i2a))