In [1]:
def read_input(filename, split = False, convert_to_int = False, sep = '\n'):
    f = open(filename)
    raw = f.read()[:-1]
    f.close()
    data = raw
    if split:
        data = data.split(sep)
    if convert_to_int:
        data = [int(item) for item in data]
    return data

In [2]:
sample_data, data = read_input('sample.txt',True), read_input('input.txt',True)

In [16]:
def count_simple_digits(data):
    count = 0
    for line in data:
        for digit in line.split(' | ')[-1].split():
            if len(digit) in [2, 3, 4, 7]:
                count += 1
    return count

In [17]:
count_simple_digits(sample_data) == 26

True

In [18]:
count_simple_digits(data)

412

In [114]:
def str_diff(s1, s2):
    ans = ''
    for letter in s1:
        if letter not in s2:
            ans += letter
    for letter in s2:
        if letter not in s1:
            ans += letter
    return ans
    
def str_intersection(s1, s2):
    out = ''
    for c in s1:
        if c in s2 and not c in out:
            out += c
    return out

def str_equal(s1, s2):
    return ''.join(sorted(s1)) == ''.join(sorted(s2))

In [229]:
class LineDecoder:
    def __init__(self, line):
        self.digits = line.split(' | ')[0].split()
        self.number = line.split(' | ')[-1].split()
        self.wires_map = {let:None for let in 'abcdefg'}
        self.number_codes = {i:None for i in range(10)}
        self.output = None
        self.codes_to_numbers = None
    
    def remap(self):
        #find numbers 1,4,7,8
        for code in self.digits:
            if len(code) == 2:
                self.number_codes[1] = code
            elif len(code) == 3:
                self.number_codes[7] = code
            elif len(code) == 4:
                self.number_codes[4] = code
            elif len(code) == 7:
                self.number_codes[8] = code
        for i in [1,4,7,8]:
            self.digits.remove(self.number_codes[i])
        
        #deduct wire A as 7-1
        self.wires_map['a'] = str_diff(self.number_codes[7], self.number_codes[1])
        
        #F is diff b/w 8-1 and 6 (the only number that has diff == 1 with this formula)        
        reverse3 = str_diff(self.number_codes[8],self.number_codes[1])
        for code in self.digits:
            if len(str_diff(code, reverse3)) == 1:
                self.number_codes[6] = code
                self.wires_map['f'] = str_diff(code, reverse3)
                break
        self.digits.remove(self.number_codes[6])
        
        #C is 8-6
        self.wires_map['c'] = str_diff(self.number_codes[8], self.number_codes[6])
        
        #5 is the only remaining digit which has diff == 1 with 6
        for code in self.digits:
            if len(str_diff(code, self.number_codes[6])) == 1:
                self.number_codes[5] = code
                break
        self.digits.remove(self.number_codes[5])
        
        #0 and 9 are have diff == 1 with 8, but only diff with 0 is present in 4
        for code in self.digits:
            if len(str_diff(code, self.number_codes[8])) == 1:
                if str_diff(code, self.number_codes[8]) in self.number_codes[4]:
                    self.number_codes[0] = code
                else:
                    self.number_codes[9] = code
        for i in [0,9]:
            self.digits.remove(self.number_codes[i])
        
        #3 has diff == 1 with 9
        for code in self.digits:
            if len(str_diff(code, self.number_codes[9])) == 1:
                self.number_codes[3] = code
                break
        self.digits.remove(self.number_codes[3])
        
        #2 is the last number
        self.number_codes[2] = self.digits[0]
        self.digits.remove(self.number_codes[2])
        
        self.codes_to_numbers = {code:num for num,code in self.number_codes.items()}
        
    def decode_value(self, s):
        for code in self.codes_to_numbers:
            if str_equal(s, code):
                return self.codes_to_numbers[code]
        
        
    def decode_output(self):
        self.output = ''
        for digit in self.number:
            self.output += str(self.decode_value(digit))
        self.output = int(self.output)

In [230]:
a = LineDecoder('acedgfb cdfbe gcdfa fbcad dab cefabd cdfgeb eafb cagedb ab | cdfeb fcadb cdfeb cdbaf')

In [231]:
a.remap()

In [232]:
a.decode_output()

In [234]:
class SubmarineDisplays:
    def __init__(self, data):
        self.displays = [LineDecoder(line) for line in data]
        self.sum = 0
        
    def run(self):
        for display in self.displays:
            display.remap()
            display.decode_output()
            self.sum += display.output
        return self.sum

In [235]:
%%time
SubmarineDisplays(sample_data).run() == 61229

CPU times: user 730 µs, sys: 1 µs, total: 731 µs
Wall time: 733 µs


True

In [236]:
%%time
SubmarineDisplays(data).run()

CPU times: user 16.3 ms, sys: 309 µs, total: 16.6 ms
Wall time: 16.5 ms


978171