## part 1 ##

In [None]:
ex = '''939
7,13,x,x,59,x,31,19'''.strip().split('\n')
ex

In [2]:
def parse_lines(lines):
    t = int(lines[0])
    buses = [int(bus) for bus in lines[1].split(',') if bus != 'x']
    return t, buses
parse_lines(ex)

(939, [7, 13, 59, 31, 19])

In [3]:
def next_departure(t, bus):
    t_since_last = t % bus
    prev = t - (t % bus)
    if prev == t:
        return t
    return prev + bus

In [4]:
def earliest_bus(t, buses):
    next_deps = [next_departure(t, bus) for bus in buses]
    earliest = min(next_deps)
    minidx = next_deps.index(earliest)
    bus = buses[minidx]
    return earliest, bus

In [5]:
ext, exbuses = parse_lines(ex)
exdepar, exbus = earliest_bus(ext, exbuses)
print(exdepar, exbus, exbus*(exdepar - ext))

944 59 295


In [6]:
with open('inputs/day13.input') as fp:
    data = fp.read().strip().split('\n')
datat, databuses = parse_lines(data)
datat, databuses

(1000067, [17, 37, 439, 29, 13, 23, 787, 41, 19])

In [7]:
datadepar, databus = earliest_bus(datat, databuses)
print(datadepar, databus, databus*(datadepar - datat))

1000072 41 205


## part 2 ##

In [8]:
def parse_lines2(lines):
    items = [s for s in lines[1].split(',')]
    buses = [int(i) for i in items if i != 'x']
    intervals = [items.index(i) for i in items if i != 'x']
    return buses, intervals

In [9]:
exbuses, exintervals = parse_lines2(ex)
exbuses, exintervals

([7, 13, 59, 31, 19], [0, 1, 4, 6, 7])

In [10]:
def calc_intervals(t, buses):
    return [next_departure(t, bus)-t for bus in buses]

In [11]:
calc_intervals(1068781, exbuses)

[0, 1, 4, 6, 7]

In [12]:
def find_time(buses, intervals):
    maxbus = max(buses)
    maxidx = buses.index(maxbus)
    t = maxbus - intervals[maxidx]
    while True:
        if intervals == calc_intervals(t, buses):
            return t
        t += maxbus

In [13]:
find_time(exbuses, exintervals)

1068781

In [14]:
ex2 = '111\n1789,37,47,1889\n'.strip().split('\n')
ex2buses, ex2intervals = parse_lines2(ex2)
find_time(ex2buses, ex2intervals)

1202161486

In [15]:
databuses, dataintervals = parse_lines2(data)
print(databuses, dataintervals)

[17, 37, 439, 29, 13, 23, 787, 41, 19] [0, 11, 17, 19, 30, 40, 48, 58, 67]


The naive approach takes far too long. Stealing from nedbat (https://github.com/nedbat/adventofcode2020/blob/main/helpers.py), who says "chinese remainder theorem" is trending on Google, and "modular multiplicative inverse" is the key wikipedia page: 

In [16]:
import functools
import math

def product(nums):
    return functools.reduce((lambda a, b: a * b), nums)

def lcm2(a, b):
    return int(a * b / math.gcd(a, b))

def lcm(nums):
    return functools.reduce(lcm2, nums)

def modular_inverse(a, m):
    m0 = m
    y = 0
    x = 1

    if m == 1:
        return 0

    while a > 1:
        # q is quotient
        if m == 0:
            return None
        q = a // m

        t = m

        # m is remainder now, process
        # same as Euclid's algo
        m = a % m
        a = t
        t = y

        # Update x and y
        y = x - q * y
        x = t

    # Make x positive
    if x < 0:
        x += m0

    return x


In [17]:
def part2(bus_ids):
    base_mods = [(b, m) for m, b in enumerate(bus_ids) if b is not None]
    bases = [b for b, _ in base_mods]
    p = product(bases)
    x = 0
    for b, m in base_mods:
        pp = p // b
        x += modular_inverse(pp, b) * pp * (b - m)
    return x % lcm(bases)

In [18]:
part2([7,13,None,None,59,None,31,19])

1068781

In [19]:
bus_ids = []
for item in data[1].strip().split(','):
    if item == 'x':
        bus_ids.append(None)
    else:
        bus_ids.append(int(item))


In [20]:
part2(bus_ids)

803025030761664

In [21]:
base_mods = [(b, m) for m, b in enumerate(bus_ids) if b is not None]
bases = [b for b, _ in base_mods]
p = product(bases)

In [22]:
base_mods

[(17, 0),
 (37, 11),
 (439, 17),
 (29, 19),
 (13, 30),
 (23, 40),
 (787, 48),
 (41, 58),
 (19, 67)]

In [23]:
p

1467900241541773

In [24]:
lcm(bases)

1467900241541773

In [25]:
x = 0
for b, m in base_mods:
    pp = p // b
    x += modular_inverse(pp, b) * pp * (b - m)
    print(pp, x)

86347073031869 23486403864668368
39672979501129 33801378534961908
3343736313307 42267718880255232
50617249708337 44292408868588712
112915403195521 25096790325350142
63821749632251 12077153400370938
1865184550879 76860608406051245
35802444915653 59210003062634316
77257907449567 18417827929262940


In [26]:
x % p

803025030761664