In [1]:
import warnings

warnings.filterwarnings("ignore")

import copy
import json
import os
import re
import sys
from collections import Counter
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from rich import print
from tqdm import tqdm

In [2]:
def read_file_to_str_li(fp, print_exp=True):
    with open(fp, "r") as f:
        lines = f.read().split("\n")
    if print_exp:
        print(f"Read from {fp}:")
        print(f"First line: {lines[0]} | Last line: {lines[-1]}")
        print("-" * 6)

    return lines


# define the function blocks
def convert_to_int(input_str):
    if input_str == "" or input_str == " ":
        return None
    return int(input_str)


def convert_to_str(input_str):
    return str(input_str)


# map the inputs to the function blocks
converts = {
    "i": convert_to_int,
    "s": convert_to_str,
}


def convert_str_li_to_other_li(
    str_li, pattern="i", per_letter=False, sep=" ", start_row=0, end_row=None
):
    """ Convert a list of string to a list of other types
    
    pattern: a list of types for one item. 
        'i' for int, 's' for string
        'si' means: convert the 1st item to string, the rest to integer
        If separated items are more than pattern items,
        use the last one from the parttern.
    if per_letter=True, ignore sep and separate item per letter
    """
    target_str_li = str_li[start_row:end_row]
    # find max item num
    max_item_num = 1
    if per_letter:
        max_item_num = max([len(s) for s in target_str_li])
    else:
        max_item_num = max([len(s.split(sep)) for s in target_str_li])

    # extend the pattern to the max itme num
    pattern = (
        pattern + f"{pattern[-1]}" * (max_item_num - len(pattern))
        if max_item_num > len(pattern)
        else pattern
    )

    # convert
    if per_letter:
        return [
            [converts[pattern[idx]](item) for idx, item in enumerate(s)]
            for s in target_str_li
        ]
    else:
        if sep == " ":
            return [
                [converts[pattern[idx]](item) for idx, item in enumerate(s.split())]
                for s in target_str_li
            ]
        else:
            return [
                [converts[pattern[idx]](item) for idx, item in enumerate(s.split(sep))]
                for s in target_str_li
            ]

In [3]:
fp = "input.txt"
lines = read_file_to_str_li(fp)

print("Convert to:")

# head = convert_str_li_to_other_li(
#     lines, pattern="s", per_letter=True, sep=",", start_row=0, end_row=1
# )

# print(f"Head:\n{head}")
# print(f"First line: {head[0]}")
# print(f"Last line: {head[-1]}")

# data = convert_str_li_to_other_li(
#     lines, pattern="s", per_letter=True, sep=" ", start_row=None, end_row=None
# )
data = convert_str_li_to_other_li(
    lines, pattern="s", per_letter=False, sep=" ", start_row=None, end_row=None
)
# data = convert_str_li_to_other_li(
#     lines, pattern="i", per_letter=False, sep=",", start_row=None, end_row=None
# )
# data = convert_str_li_to_other_li(
#     lines, pattern="s", per_letter=False, sep=" -> ", start_row=None, end_row=None
# )
# data = convert_str_li_to_other_li(
#     lines, pattern="i", per_letter=False, sep=" ", start_row=2, end_row=None
# )
# data = convert_str_li_to_other_li(
#     lines, pattern="si", per_letter=False, sep=" ", start_row=0, end_row=None
# )
# data = convert_str_li_to_other_li(
#     lines, pattern="i", per_letter=True, sep=" ", start_row=0, end_row=None
# )

print(f"First line: {data[0]}")
print(f"Last line: {data[-1]}")
print("-" * 6)

In [4]:
def find_first_explode(snailfish_number):
    for i in range(2):
        for j in range(2):
            for k in range(2):
                for l in range(2):
                    try:
                        if not isinstance(snailfish_number[i][j][k][l], int):
                            return (i, j, k, l), snailfish_number[i][j][k][l]
                    except:
                        pass
    return None, None


def find_reg_num(snailfish_number):
    pointers = []
    reg_mask = []
    for i in range(2):
        try:
            if isinstance(snailfish_number[i], int):
                reg_mask.append(True)
            else:
                reg_mask.append(False)
            pointers.append((i))
        except:
            pass
        for j in range(2):
            try:
                if isinstance(snailfish_number[i][j], int):
                    reg_mask.append(True)
                else:
                    reg_mask.append(False)
                pointers.append((i, j))
            except:
                pass
            for k in range(2):
                try:
                    if isinstance(snailfish_number[i][j][k], int):
                        reg_mask.append(True)
                    else:
                        reg_mask.append(False)
                    pointers.append((i, j, k))
                except:
                    pass
                for l in range(2):
                    try:
                        if isinstance(snailfish_number[i][j][k][l], int):
                            reg_mask.append(True)
                        else:
                            reg_mask.append(False)
                        pointers.append((i, j, k, l))
                    except:
                        pass
                    for m in range(2):
                        try:
                            if isinstance(snailfish_number[i][j][k][l][m], int):
                                reg_mask.append(True)
                            else:
                                reg_mask.append(False)
                            pointers.append((i, j, k, l, m))
                        except:
                            pass
    return reg_mask, pointers


def find_reg_pos(snailfish_number, number_pos):
    reg_mask, pointers = find_reg_num(snailfish_number)
    idx = pointers.index(number_pos)
    if np.any(reg_mask[:idx]):
        left_reg_pos = pointers[
            len(reg_mask[:idx]) - 1 - reg_mask[:idx][::-1].index(True)
        ]
    else:
        left_reg_pos = None

    idx += 3  # Need to skip 2
    if np.any(reg_mask[idx:]):
        right_reg_pos = pointers[idx + reg_mask[idx:].index(True)]
    else:
        right_reg_pos = None
    return left_reg_pos, right_reg_pos


def update_reg_num(snailfish_number, number_pos, add_num):
    pointer = snailfish_number
    for i in range(len(number_pos) - 1):
        pointer = pointer[number_pos[i]]
    pointer[number_pos[-1]] += add_num


def replace_by_zero(snailfish_number, number_pos):
    pointer = snailfish_number
    for i in range(len(number_pos) - 1):
        pointer = pointer[number_pos[i]]
    pointer[number_pos[-1]] = 0


def update_exp(snailfish_number):
    while True:
        exp_pos, add_to_reg = find_first_explode(snailfish_number)
        if exp_pos is None:
            break
        reg_pos = find_reg_pos(snailfish_number, exp_pos)
        for i in range(2):
            if reg_pos[i]:
                update_reg_num(snailfish_number, reg_pos[i], add_to_reg[i])
        replace_by_zero(snailfish_number, exp_pos)


def spilt_to_pair(num):
    num = int(num)
    return [num // 2, num // 2 + num % 2]


def split_snailfish_number(snailfish_number):
    for i in range(2):
        try:
            if isinstance(snailfish_number[i], int) and (snailfish_number[i] > 9):
                snailfish_number[i] = spilt_to_pair(snailfish_number[i])
                return True
        except:
            pass
        for j in range(2):
            try:
                if isinstance(snailfish_number[i][j], int) and (
                    snailfish_number[i][j] > 9
                ):
                    snailfish_number[i][j] = spilt_to_pair(snailfish_number[i][j])
                    return True
            except:
                pass
            for k in range(2):
                try:
                    if isinstance(snailfish_number[i][j][k], int) and (
                        snailfish_number[i][j][k] > 9
                    ):
                        snailfish_number[i][j][k] = spilt_to_pair(
                            snailfish_number[i][j][k]
                        )
                        return True
                except:
                    pass
                for l in range(2):
                    try:
                        if isinstance(snailfish_number[i][j][k][l], int) and (
                            snailfish_number[i][j][k][l] > 9
                        ):
                            snailfish_number[i][j][k][l] = spilt_to_pair(
                                snailfish_number[i][j][k][l]
                            )
                            return True
                    except:
                        pass
    return False


def add_snailfish_number(lef_snailfish_number, right_snailfish_number):
    new_snailfish_number = [
        copy.deepcopy(lef_snailfish_number),
        copy.deepcopy(right_snailfish_number),
    ]
    update_exp(new_snailfish_number)
    while split_snailfish_number(new_snailfish_number):
        update_exp(new_snailfish_number)
    return new_snailfish_number


def sum_mag(mag_li):
    return (mag_li[0] if isinstance(mag_li[0], int) else sum_mag(mag_li[0])) * 3 + (
        mag_li[1] if isinstance(mag_li[1], int) else sum_mag(mag_li[1])
    ) * 2


snailfish_number_li = [json.loads(num[0]) for num in data]

snailfish_number = snailfish_number_li[0]
for i in range(len(snailfish_number_li) - 1):
    snailfish_number = add_snailfish_number(
        snailfish_number, snailfish_number_li[i + 1]
    )
print(f"Answer to Q1: {sum_mag(snailfish_number)}")
print(
    f"Answer to Q2: {np.max([sum_mag(add_snailfish_number(snailfish_number_li[i],snailfish_number_li[j])) for i in range(len(snailfish_number_li)) for j in range(len(snailfish_number_li)) if i!=j])}"
)