<a href="https://colab.research.google.com/github/domschl/ALU_Net/blob/main/ALU_Net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
""" A neural net that tries to become an ALU (arithmetic logic unit) """

In [None]:
import sys
import os
import random
import numpy as np
import tensorflow as tf
import keras
from keras import layers, regularizers, callbacks
from tensorflow.python.client import device_lib

In [None]:
%load_ext tensorboard

try:
    %tensorflow_version 2.x
except:
    pass

try: # Colab instance?
    from google.colab import drive
    is_colab = True
except: # Not? ignore.
    is_colab = False
    pass

In [None]:
# Hardware check:

is_tpu = False
is_gpu = False
tpu_is_init = False

for hw in ["CPU", "GPU", "TPU"]:
    hw_list=tf.config.experimental.list_logical_devices(hw)
    if len(hw_list)>0:
        if hw=='TPU':
            is_tpu=True
        if hw=='GPU':
            is_gpu=True
    print(f"{hw}: {hw_list}") 

if is_colab:
    if not is_tpu:
        try:
            TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
            tf.config.experimental_connect_to_host(TPU_ADDRESS)
            is_tpu = True
            print(f"TPU available at {TPU_ADDRESS}")
        except:
            print("No TPU available")
    else:
        print(f"TPU available, already connected to {TPU_ADDRESS}")

if not is_tpu:
    if not is_gpu:
        print("WARNING: You have neither TPU nor GPU, this is going to be very slow!")
    else:
        print("GPU available")
else:
    tf.compat.v1.disable_eager_execution()
    print("TPU: eager execution disabled!")

In [None]:
class GenSamplesALU():
    """ Generate training data for all ALU operations """
    # The ALU takes two integers and applies one of the supported
    # model_ops. Eg op1=123, op2=100, op='-' -> result 23
    # The net is supposed to learn to 'calculate' the results for
    # arbitrary op1, op2 (positive integers, 0..32767) and 
    # the twelve supported ops 

    def __init__(self):
        self.model_ops = ["+", "-", "*", "/", "%",
                          "AND", "OR", "XOR", ">", "<", "=", "!="]
        # Probabilites for creating a sample for each of the ops, (Will be
        # reweighted on checks to generate for samples for 'difficult' ops):
        self.model_dis = [10, 10, 10, 10, 10, 10,   10,  10,   10, 10, 10, 10]
        self.model_funcs = [self.add_smpl, self.diff_smpl, self.mult_smpl,
                            self.div_smpl, self.mod_smpl, self.and_smpl,
                            self.bor_smpl, self.xor_smpl, self.greater_smpl,
                            self.lesser_smpl, self.eq_smpl, self.neq_smpl]
        self.bit_count = 15
        self.all_bits_one = 0x7fffffff
        self.true_vect = self.all_bits_one
        self.false_vect = 0

    @staticmethod
    def int_to_binary_vect(num_int, num_bits=8):
        """ get a binary encoded vector of n of bit-lenght nm """
        num_vect = np.zeros(num_bits, dtype=np.float32)
        for i in range(0, num_bits):
            if num_int & (2**i) != 0:
                num_vect[i] = 1.0
        return num_vect

    @staticmethod
    def get_random_bits(bits):
        """ get bits random int 0...2**bits-1 """
        return random.randint(0, 2**bits-1)

    def op_string_to_index(self, op_string):
        """ transform op_string (e.g. '+' -> 0) into corresponding index """
        for i in range(0, len(self.model_ops)):
            if self.model_ops[i] == op_string:
                return i
        return -1

    def get_data_point(self, equal_distrib=False, short_math=False):
        """ Get a random example for on ALU operation for training """
        result = -1
        op1 = self.get_random_bits(self.bit_count)
        op2 = self.get_random_bits(self.bit_count)
        if equal_distrib:
            op_index = random.randint(0, len(self.model_ops)-1)
        else: # make 'difficult' ops more present in training samples:
            rx = 0
            for md in self.model_dis:
                rx += md
            rrx = random.randint(0, rx)
            rx = 0
            op_index = 0
            for op_index in range(0, len(self.model_ops)):
                rx += self.model_dis[op_index]
                if rx > rrx:
                    break
        return self.encode_op(op1, op2, op_index, short_math)

    def encode_op(self, op1, op2, op_index, short_math=False):
        """ turn two ints and operation into training data """
        result = self.model_funcs[op_index](op1, op2, short_math)
        sym = f"{op1}{self.model_ops[op_index]}{op2}={result}"
        inp = np.concatenate(
            [self.int_to_binary_vect(op1, num_bits=16),
             self.int_to_binary_vect(op_index, num_bits=4),
             self.int_to_binary_vect(op2, num_bits=16)])
        oup = self.int_to_binary_vect(result, num_bits=32)
        return inp, oup, result, op_index, sym

    @staticmethod
    def add_smpl(op1, op2, _):
        """ addition training example """
        result = op1+op2
        return result

    @staticmethod
    def diff_smpl(op1, op2, _):
        """ subtraction training example """
        if op2 > op1:
            op2, op1 = op1, op2
        result = op1-op2
        return result

    @staticmethod
    def mult_smpl(op1, op2, short_math=False):
        """ multiplication training example """
        if short_math:
            op1 = op1 % 1000
            op2 = op2 % 1000
        result = op1*op2
        return result

    def div_smpl(self, op1, op2, _):
        """ integer division training example """
        while op2 == 0:
            op2 = self.get_random_bits(self.bit_count)
        if op1 < op2 and random.randint(0, 2) != 0:
            if op1 != 0:
                op1, op2 = op2, op1
        result = op1//op2
        return result

    def mod_smpl(self, op1, op2, _):
        """ modulo (remainder) training example """
        while op2 == 0:
            op2 = self.get_random_bits(self.bit_count)
        if op1 < op2 and random.randint(0, 2) != 0:
            if op1 != 0:
                op1, op2 = op2, op1
        result = op1 % op2
        return result

    @staticmethod
    def and_smpl(op1, op2, _):
        """ bitwise AND training example """
        result = op1 & op2
        return result

    @staticmethod
    def bor_smpl(op1, op2, _):
        """ bitwise OR training example """
        result = op1 | op2
        return result

    @staticmethod
    def xor_smpl(op1, op2, _):
        """ bitwise XOR training example """
        result = op1 ^ op2
        return result

    def greater_smpl(self, op1, op2, _):
        """ integer comparisation > training example """
        if op1 > op2:
            result = self.true_vect
        else:
            result = self.false_vect
        return result

    def lesser_smpl(self, op1, op2, _):
        """ integer comparisation < training example """
        if op1 < op2:
            result = self.true_vect
        else:
            result = self.false_vect
        return result

    def eq_smpl(self, op1, op2, _):
        """ integer comparisation == training example """
        if random.randint(0, 1) == 0:  # create more cases
            op2 = op1
        if op1 == op2:
            result = self.true_vect
        else:
            result = self.false_vect
        return result

    def neq_smpl(self, op1, op2, _):
        """ integer comparisation != training example """
        if random.randint(0, 1) == 0:  # create more cases
            op2 = op1
        if op1 != op2:
            result = self.true_vect
        else:
            result = self.false_vect
        return result

    def create_data_point(self, op1, op2, op_string):
        """ create training data from given ints op1, op2 and op_string """
        op_index = self.op_string_to_index(op_string)
        if op_index == -1:
            print(f"Invalid operation {op_string}")
            return np.array([]), np.array([]), -1, -1, None
        return self.encode_op(op1, op2, op_index)

    def create_training_data(self, samples=10000, short_math=False):
        """ create a number of training samples """
        x, y, _, _, _ = self.get_data_point()
        dpx = np.zeros((samples, len(x)), dtype=np.float32)
        dpy = np.zeros((samples, len(y)), dtype=np.float32)
        print(f"Creating {samples} data points (. = 1000 progress)")
        for i in range(0, samples):
            if i%100000 == 0:
                print(f"{i:>10} ", end="")
            if (i+1) % 1000 == 0:
                print(".", end="")
                sys.stdout.flush()
                if (i+1) % 100000 == 0:
                    print()
            x, y, _, _, _ = self.get_data_point(
                equal_distrib=False, short_math=short_math)
            dpx[i, :] = x
            dpy[i, :] = y
        print()
        return dpx, dpy

    def create_dataset(self, samples=10000, batch_size=2000, short_math=False):
        x, Y = self.create_training_data(samples=samples, short_math=short_math)
        shuffle_buffer=10000
        dataset=tf.data.Dataset.from_tensor_slices((x, Y)).cache()
        dataset=dataset.shuffle(shuffle_buffer, reshuffle_each_iteration=True)
        dataset=dataset.repeat() # Mandatory for Keras for now
        dataset=dataset.batch(batch_size, drop_remainder=True) # drop_remainder is important on TPU, batch size must be fixed
        dataset=dataset.prefetch(-1) # fetch next batches while training on the current one (-1: autotune prefetch buffer size)
        return dataset

    @staticmethod
    def decode_results(result_int_vects):
        """ take an array of 32-float results from neural net and convert to ints """
        result_vect_ints = []
        for vect in result_int_vects:
            if (len(vect) != 32):
                print(f"Ignoring unexpected vector of length {len(vect)}")
            else:
                int_result = 0
                for i in range(0, 32):
                    if vect[i] > 0.5:
                        int_result += 2**i
                result_vect_ints.append(int_result)
        return result_vect_ints

    def check_results(self, model, samples=1000, short_math=False, verbose=False):
        """ Run a number of tests on trained model """
        ok = 0
        err = 0
        operr = [0]*len(self.model_ops)
        opok = [0]*len(self.model_ops)
        for _ in range(0, samples):
            x, _, z, op, s = self.get_data_point(
                equal_distrib=True, short_math=short_math)
            res = self.decode_results(model.predict(np.array([x])))
            if res[0] == z:
                ok += 1
                opok[op] += 1
                r = "OK"
            else:
                err += 1
                operr[op] += 1
                r = "Error"
            if verbose is True:
                print(f"{s} == {res[0]}: {r}")
                print(bin(res[0]))
                print(bin(z))
        opsum = ok+err
        if opsum == 0:
            opsum = 1
        print(f"Ok: {ok}, Error: {err} -> {ok/opsum*100.0}%")
        print("")
        for i in range(0, len(self.model_ops)):
            opsum = opok[i]+operr[i]
            if opsum == 0:
                opsum = 1
            # modify the distribution of training-data generated to favour
            # ops with bad test results, so that more training data is
            # generated on difficult cases:
            self.model_dis[i] = int(operr[i]/opsum*100)+10
            print(
                f"OP{self.model_ops[i]}: Ok: {opok[i]}, Error: {operr[i]}", end="")
            print(f" -> {opok[i]/opsum*100.0}%")
        print("Change probability for ops in new training data:")
        print(f"Ops:    {self.model_ops}")
        print(f"Weight: {self.model_dis}")

In [None]:
def create_load_model(model_file='math_model'):
    """ Create of load a model """
    if model_file is None or not os.path.exists(model_file) or is_tpu is True:
        regu1 = 1e-8
        regu2 = 1e-8
        neurons = 368
        inputs = keras.Input(shape=(36,))  # depends on encoding of op-code!

        shaper = layers.Reshape(target_shape=(36, 1,), input_shape=(36,))
        rinp = shaper(inputs)  # x0)
        d1 = layers.Conv1D(filters=48, kernel_size=6, kernel_regularizer=regularizers.l2(
            regu1), activation="relu")
        x1 = d1(rinp)
        d2 = layers.Conv1D(filters=64, kernel_size=6, kernel_regularizer=regularizers.l2(
            regu1), activation="relu")
        x2 = d2(x1)
        d3 = layers.Conv1D(filters=128, kernel_size=6, kernel_regularizer=regularizers.l2(
            regu1), activation="relu")
        x3 = d3(x2)
        d4 = layers.Conv1D(filters=128, kernel_size=6, kernel_regularizer=regularizers.l2(
            regu1), activation="relu")
        x4 = d4(x3)
        d5 = layers.Conv1D(filters=128, kernel_size=6, kernel_regularizer=regularizers.l2(
            regu1), activation="relu")
        x5 = d5(x4)
        d6 = layers.Conv1D(filters=128, kernel_size=6, kernel_regularizer=regularizers.l2(
            regu1), activation="relu")
        x6 = d6(x5)
        d7 = layers.Conv1D(filters=128, kernel_size=6, kernel_regularizer=regularizers.l2(
            regu1), activation="relu")
        xcvl = d7(x6)
        flatter = layers.Flatten()
        xf = flatter(xcvl)
        de1 = layers.Dense(neurons, kernel_regularizer=regularizers.l2(
            regu2), activation="relu")
        xe1 = de1(xf)

        df1 = layers.Dense(neurons, kernel_regularizer=regularizers.l2(
            regu2), activation="relu")
        xf1 = df1(inputs)
        df2 = layers.Dense(neurons, kernel_regularizer=regularizers.l2(
            regu2), activation="relu")
        xf2 = df2(xf1)
        df3 = layers.Dense(neurons, kernel_regularizer=regularizers.l2(
            regu2), activation="relu")
        xf3 = df3(xf2)

        con = layers.Concatenate()
        xcon = con([xe1, xf3])
        dc1 = layers.Dense(neurons, kernel_regularizer=regularizers.l2(
            regu2), activation="relu")
        xc1 = dc1(xcon)

        de2 = layers.Dense(32, activation="sigmoid")
        outputs = de2(xc1)
        model = keras.Model(inputs=inputs, outputs=outputs, name="maths")
        # , metrics=["accuracy"])
        model.compile(loss="mean_squared_error", optimizer="adam")
        print("Compiling new model")
        if is_tpu is True:
            pass
            # if model_file is not None and os.path.exists(model_file):
            #    print("Injecting saved weights into TPU model, loading...")
            #    temp_model=tf.keras.models.load_model(model_file)
            #    print("Injecting...")
            #    model.set_weights(temp_model.get_weights())
            #    print("Updated TPU weights from saved model")
    else:
        print("Loading standard-format model")
        model = tf.keras.models.load_model(model_file)
        print("Continuing training from existing model")
    model.summary()
    return model

def get_model(model_file='math_model', on_tpu=False):
    if is_tpu is True and on_tpu is True:
        tpu_is_init=False
        if tpu_is_init is False:
            cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=TPU_ADDRESS)
            # tf.config.experimental_connect_to_cluster(cluster_resolver) # eager mode only! not TPU!
            tf.tpu.experimental.initialize_tpu_system(cluster_resolver)
            tpu_strategy = tf.distribute.TPUStrategy(cluster_resolver)    
            tpu_is_init=True
        with tpu_strategy.scope():
            print("Creating TPU-scope model")
            return create_load_model(model_file=model_file)
    else:
        print("Creating standard-scope model")
        return create_load_model(model_file=model_file)

In [None]:
def math_train(model, dataset, batch_size=8192, epochs=5000, steps_per_epoch=2000):
    """ Training loop """
    interrupted = False
    tensorboard_callback = callbacks.TensorBoard(
        log_dir="./logs",
        histogram_freq=1,
        write_images=1,
        update_freq='batch')
    try:
        model.fit(dataset, epochs=epochs, steps_per_epoch=steps_per_epoch, # , batch_size=batch_size
                  verbose=1, callbacks=[tensorboard_callback])  # validation_split=0.03 (not datasets!)
    except KeyboardInterrupt:
        print("")
        print("")
        print("---------INTERRUPT----------")
        print("")
        print("Training interrupted")
        interrupted = True
    except Exception as e:
        print(f"Exception {e}")
    finally:
        return interrupted

In [None]:
%tensorboard --logdir logs

In [None]:
save_model = True
if is_colab:
    mountpoint='/content/drive'
    root_path='/content/drive/My Drive'
    if not os.path.exists(root_path):
        drive.mount(mountpoint)
    if not os.path.exists(root_path):
        print(f"Something went wrong with Google Drive access. Cannot save model to {root_path}")
        save_model = False
else:
    root_path='.'

if save_model:
    project_path=os.path.join(root_path,"Colab Notebooks/ALU_Net")
    model_file=os.path.join(project_path,'math_model')
else:
    model_file=None
print(f"Saving model to: {model_file}")

In [None]:
BATCH_SIZE = 2000
SAMPLES = 500000
EPOCHS_PER_MASTER_CYCLE = 50
MASTER_CYCLES = 100
STEPS_PER_EPOCH = SAMPLES // BATCH_SIZE
REWEIGHT_SIZE = 1024

In [None]:
# Initialize model(s)
math_data = GenSamplesALU()
if is_tpu:
    # Generate a second CPU model for testing:
    test_model = get_model(model_file=None, on_tpu=False)
math_model = get_model(model_file=model_file, on_tpu=is_tpu)

In [None]:
# Training
for _ in range(0, MASTER_CYCLES):
    dataset = math_data.create_dataset(
        samples=SAMPLES, batch_size=BATCH_SIZE, short_math=False)
    interrupted = math_train(math_model, dataset, epochs=EPOCHS_PER_MASTER_CYCLE, steps_per_epoch=STEPS_PER_EPOCH)
    if is_tpu:
        print("Injecting weights into test_model:")
        test_model.set_weights(math_model.get_weights())
        # print("Saving test-model")
        # test_model.save(model_file)
        print("Done")
        math_data.check_results(test_model, samples=REWEIGHT_SIZE, short_math=False, verbose=False)
    else:
        print("Saving math-model")
        math_model.save(model_file)
        print("Done")
        math_data.check_results(math_model, samples=REWEIGHT_SIZE, short_math=False, verbose=False)
    if interrupted:
        break