<a href="https://colab.research.google.com/github/jiaozihan/compiler/blob/main/take_home_task.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
# Fun project

# Unusual allowances:

- You should copy these files to your computer and edit in your own environment
  so that you can use the trace functionality. If the CodeSignal platform said
  anything about not doing this, ignore it. We'll ignore any warnings the
  platform gives due to this allowance.
- You're welcome to use any online resources including language models. In fact
  you're encouraged to install Github Copilot, it'll likely only help a little.
  CodeSignal makes you promise that you'll only use language reference
  resources, there's no way to disable that promise, you're welcome to use
  non-language resources.
- You may use any libraries you want while debugging but your uploaded solution
  must use only the standard library.

# Rules

- You may not consult with anyone else while you're doing this problem.
- Please don't tell anyone else about the details of this problem other than the
  vague description communicated before you started.

# Task

- Optimize the kernel (in KernelBuilder.build_kernel) as much as possible in the
  available time.
    - The machine parameters are balanced so that how impressed we are by you
      doing an optimization is roughly proportional to its speedup factor.
      Although it's not perfect.
    - The exception to the above is doing basic multicore, which we recommend
      you start with as a warmup.
    - The input used in test_kernel_cycles is what your performance will be
      scored on.

# Tips

- You'll be evaluated primarily on the speed of your correct kernel submissions.
  You're encouraged to improve the debugging/trace tooling present in this file,
  but only implement tooling and write nice code to the extent that in your
  judgement it'll best help you achieve a fast and correct solution.
- Try to keep a copy of your first correct kernel somewhere so you can compare
  against it when debugging.
- Modifying the simulator is one of the most powerful debugging tools available,
  you should take advantage of it. But your solution will be tested with an
  unmodified simulator so don't change the behavior or API.

We recommend you read through problem.py next.
"""

import random
import unittest

from problem import *


class KernelBuilder:
    def __init__(self):
        self.instrs = []
        self.labels = {}
        self.scratch = {}
        self.scratch_debug = {}
        self.scratch_ptr = 0
        self.const_map = {}

    def debug_info(self):
        # Hint: This isn't consumed anywhere, but you should probably use it in some way for debugging
        return DebugInfo(scratch_map=self.scratch_debug)

    def build(self, slots: list[tuple[Engine, tuple]], vliw: bool = False):
        # Simple slot packing that just uses one slot per instruction bundle
        # print("slots: ", slots)
        instrs = []
        for engine, slot in slots:
            instrs.append({engine: [slot]})
        return instrs

    def add(self, engine, slot):
        self.instrs.append({engine: [slot]})

    def label(self, name):
        self.labels[name] = len(self.instrs)

    def alloc_scratch(self, name=None, length=1):
        addr = self.scratch_ptr
        if name is not None:
            self.scratch[name] = addr
            self.scratch_debug[addr] = (name, length)
        self.scratch_ptr += length
        assert self.scratch_ptr <= SCRATCH_SIZE, "Out of scratch space"
        return addr

    def scratch_const(self, val, name=None):
        if val not in self.const_map:
            addr = self.alloc_scratch(name)
            self.add("load", ("const", addr, val))
            self.const_map[val] = addr
        return self.const_map[val]

    def for_loop(self, iter_addr, limit_addr, body: list[Instruction]):
        """
        A for loop that runs len times
        """
        loop_cond = self.alloc_scratch()
        one_constant = self.scratch_const(1)
        start_addr = len(self.instrs)
        prologue_len, epilogue_len = 3, 1
        end_addr = start_addr + prologue_len + len(body) + epilogue_len
        instrs = [
            {"alu": [("+", iter_addr, one_constant, iter_addr)]},
            {"alu": [("<", loop_cond, limit_addr, iter_addr)]},
            {"flow": [("cond_jump", loop_cond, end_addr)]},
        ]
        instrs.extend(body)
        instrs.append({"flow": [("jump", start_addr)]})
        return instrs

    def build_simple_test(self):
        """
        A simple test program that just counts to 10
        """
        # Scratch space addresses
        one_constant = self.scratch_const(1)
        accum = self.alloc_scratch("accum")
        iter_addr = self.alloc_scratch("iter")
        limit_addr = self.alloc_scratch("limit")
        self.add("load", ("const", limit_addr, 10))
        body = [("alu", ("+", accum, accum, one_constant))]
        self.instrs.extend(self.for_loop(iter_addr, limit_addr, self.build(body)))

    def build_hash_vectorized(self, v_val_hash_addr, v_tmp1, v_tmp2):
        """
        Vectorized version of the hash function.

        Args:
            v_val_hash_addr: Scratch address of the *input vector* (and final output).
            v_tmp1:          Scratch address of a temporary vector.
            v_tmp2:          Scratch address of another temporary vector.
        """
        slots = []
        v_const = self.alloc_scratch("v_const", VLEN) # Allocate a temporary vector for constants

        for op1, val1, op2, op3, val3 in HASH_STAGES:
            # Broadcast the constant val1 to the v_const vector
            slots.append(("load", ("const", v_tmp1, val1)))
            slots.append(("valu", ("vbroadcast", v_const, v_tmp1)))

            if op1 == '+':
                slots.append(("valu", ("+", v_tmp1, v_val_hash_addr, v_const)))
            elif op1 == '^':
                slots.append(("valu", ("^", v_tmp1, v_val_hash_addr, v_const)))

            # Broadcast val3
            slots.append(("load", ("const", v_tmp2, val3))) # Load immediate
            slots.append(("valu", ("vbroadcast", v_const, v_tmp2))) # vbroadcast

            if op3 == "<<":
                slots.append(("valu", ("<<", v_tmp2, v_val_hash_addr, v_const)))
            elif op3 == ">>":
                slots.append(("valu", (">>", v_tmp2, v_val_hash_addr, v_const)))
            if op2 == '+':
                slots.append(("valu", ("+", v_val_hash_addr, v_tmp1, v_tmp2)))
            elif op2 == '^':
                slots.append(("valu", ("^", v_val_hash_addr, v_tmp1, v_tmp2)))

        return slots

    def build_hash(self, val_hash_addr, tmp1, tmp2):
        slots = []

        for op1, val1, op2, op3, val3 in HASH_STAGES:
            slots.append(("alu", (op1, tmp1, val_hash_addr, self.scratch_const(val1))))
            slots.append(("alu", (op3, tmp2, val_hash_addr, self.scratch_const(val3))))
            slots.append(("alu", (op2, val_hash_addr, tmp1, tmp2)))

        return slots

    def build_kernel(self, forest_height: int, n_nodes: int, batch_size: int):
        """
        Like reference_kernel2 but building actual instructions, vectorized.
        Just the simplest implementation and non-overlapping scheduling possible.

        batch_size is guaranteed to be a multiple of VLEN*N_CORES*16
        """
        tmp1 = self.alloc_scratch("tmp1")
        tmp2 = self.alloc_scratch("tmp2")
        # Scratch space addresses
        init_vars = [
            "rounds",
            "n_nodes",
            "batch_size",
            "forest_height",
            "forest_values_p",
            "inp_indices_p",
            "inp_values_p",
        ]
        for v in init_vars:
            self.alloc_scratch(v, 1)
        for i, v in enumerate(init_vars):
            # initialize a space in mem
            self.add("load", ("const", tmp1, i))
            # move to scratch space
            self.add("load", ("load", self.scratch[v], tmp1))

        # create consts in mem so it comes handy later
        zero_const = self.scratch_const(0)
        one_const = self.scratch_const(1)
        two_const = self.scratch_const(2)

        # Create a vector of twos for the modulo operation and conditional set operations
        # Why can't this be avoided? I hoped there would be a better instruction for this so we can save space?
        # Not sure if it works in the same way in real GPUs.
        v_zero = self.alloc_scratch("v_zero", VLEN)
        v_ones = self.alloc_scratch("v_ones", VLEN)
        v_twos = self.alloc_scratch("v_twos", VLEN)

        self.add("valu", ("vbroadcast", v_zero, zero_const))
        self.add("valu", ("vbroadcast", v_ones, one_const))
        self.add("valu", ("vbroadcast", v_twos, two_const))


        vector_length_const = self.scratch_const(VLEN)
        n_cores_const = self.scratch_const(N_CORES)

        # Get core ID
        self.add("flow", ("coreid", tmp1))

        # Calculate items per core
        items_per_core = self.alloc_scratch("items_per_core")
        self.add("alu", ("//", items_per_core, self.scratch["batch_size"], n_cores_const))

        # Calculate vectors per core
        vectors_per_core = self.alloc_scratch("vectors_per_core")
        self.add("alu", ("//", vectors_per_core, items_per_core, vector_length_const))

        # Calculate core offset (in items)
        core_offset = self.alloc_scratch("core_offset")
        self.add("alu", ("*", core_offset, tmp1, items_per_core))

        # Pause
        self.add("flow", ("pause",))
        self.add("debug", ("comment", "Starting loop after init"))

        # inint things
        body = []
        v_idx = self.alloc_scratch("v_idx", VLEN)  # Vector of indices
        v_val = self.alloc_scratch("v_val", VLEN)  # Vector of values
        v_node_val = self.alloc_scratch("v_node_val", VLEN) # Vector of node values
        v_tmp = self.alloc_scratch("v_tmp", VLEN) # Temporary vector
        v_tmp2 = self.alloc_scratch("v_tmp2", VLEN) # Temporary vector
        v_offset = self.alloc_scratch("v_offset", VLEN) # Vector of offsets within a core's chunk

        # print("body:", body)

        # Create a vector of offsets
        for i in range(VLEN):
            body.append(("load", ("const", v_offset + i, i)))

        # Assuming batch_size is a multiple of N_CORES*VLEN*16, my code can't handle the remainder

        for i in range(batch_size // (N_CORES * VLEN)):
            # Calculate the base address for this vector within the core's chunk
            body.append(("alu", ("*", tmp1, self.scratch_const(i * VLEN), one_const))) # i*VLEN
            body.append(("alu", ("+", tmp1, core_offset, tmp1))) # core_offset + i*VLEN

            # Load a vector of indices:  inp_indices[core_offset + i*VLEN : core_offset + i*VLEN + VLEN]
            body.append(("alu", ("+", tmp2, self.scratch["inp_indices_p"], tmp1))) # Calculate address
            body.append(("load", ("vload", v_idx, tmp2))) # Vector load


            # Load corresponding node values: forest_values[v_idx[0]], forest_values[v_idx[1]], ...
            for idx_in_vector in range(VLEN):
                body.append(("alu", ("+", v_tmp+idx_in_vector, self.scratch["forest_values_p"], v_idx+idx_in_vector))) # v_tmp now has addresses
            for idx_in_vector in range(VLEN):
                body.append(("load", ("load", v_node_val+idx_in_vector, v_tmp+idx_in_vector)))
            # Load a vector of values:  inp_values[core_offset + i*VLEN : core_offset + i*VLEN + VLEN]
            body.append(("alu", ("+", tmp2, self.scratch["inp_values_p"], tmp1)))  # Calculate address
            body.append(("load", ("vload", v_val, tmp2)))  # Vector load

            # Vectorized XOR: v_val = v_val ^ v_node_val
            for idx_in_vector in range(VLEN):
                body.append(("alu", ("^", v_val+idx_in_vector, v_val+idx_in_vector, v_node_val+idx_in_vector)))
            for idx_in_vector in range(VLEN):
                body.extend(self.build_hash(v_val+idx_in_vector, v_tmp, v_tmp2))

            # Vectorized hash.
            #body.extend(self.build_hash_vectorized(v_val, v_tmp, tmp2))
            # body.extend(self.build_hash(v_val, v_tmp, tmp2))
            # body.extend(self.build_hash(v_val+1, v_tmp+1, tmp2+1))
            body.append(("debug", ("comment", "after hashing")))
            # Vectorized calculation of next index
            body.append(("valu", ("%", v_tmp, v_val, v_twos)))  # v_tmp = v_val % 2
            body.append(("valu", ("==", v_tmp, v_tmp, v_zero)))      # v_tmp = (v_tmp == 0) ? 1 : 0
            body.append(("flow", ("vselect", v_tmp, v_tmp, v_ones, v_twos))) # v_tmp = (val % 2 == 0) ? 1 : 2
            body.append(("debug", ("comment", "debbbbbuggg")))
            body.append(("valu", ("*", v_idx, v_idx, v_twos)))  # v_idx = 2 * v_idx
            body.append(("valu", ("+", v_idx, v_idx, v_tmp)))    # v_idx = v_idx + v_tmp

            body.append(("debug", ("comment", "comparing to n_nodes")))
            # Vectorized wraparound
            for idx_in_vector in range(VLEN):
                body.append(("alu", ("<", v_tmp + idx_in_vector, v_idx + idx_in_vector, self.scratch["n_nodes"]))) #v_tmp = (v_idx<n_nodes)
            body.append(("flow", ("vselect", v_idx, v_tmp, v_idx, v_zero))) # v_idx = v_tmp?v_idx:zero_const

            # Store updated vector of indices
            body.append(("alu", ("+", tmp2, self.scratch["inp_indices_p"], tmp1)))
            body.append(("store", ("vstore", tmp2, v_idx)))

            # Store updated vector of values
            body.append(("alu", ("+", tmp2, self.scratch["inp_values_p"], tmp1)))
            body.append(("store", ("vstore", tmp2, v_val)))  # Store the hashed values

        body_instrs = self.build(body)
        body_instrs.append({"flow": [("pause",)]})
        height_i = self.alloc_scratch("height_i")
        loop = self.for_loop(height_i, self.scratch["rounds"], body_instrs)
        self.instrs.extend(loop)


def do_kernel_test(
    forest_height: int,
    rounds: int,
    batch_size: int,
    seed: int = 123,
    trace: bool = False,
    prints: bool = True,
):
    print(f"{forest_height=}, {rounds=}, {batch_size=}")
    random.seed(seed)
    forest = Tree.generate(forest_height)
    inp = Input.generate(forest, batch_size, rounds)
    mem = build_mem_image(forest, inp)

    print("mem:", mem)

    kb = KernelBuilder()
    kb.build_kernel(forest.height, len(forest.values), len(inp.indices))
    # print(kb.instrs)

    machine = Machine(mem, kb.instrs, kb.debug_info(), n_cores=N_CORES, trace=trace)
    machine.prints = prints

    for i, ref_mem in enumerate(reference_kernel2(mem)):
        machine.run()
        inp_values_p = ref_mem[6]
        print("input values position: ", inp_values_p)
        if True:
            print("hello:", machine.mem[inp_values_p : inp_values_p + len(inp.values)])
            print("hi:", ref_mem[inp_values_p : inp_values_p + len(inp.values)])
        print("new mem:", machine.mem)
        assert (
            machine.mem[inp_values_p : inp_values_p + len(inp.values)]
            == ref_mem[inp_values_p : inp_values_p + len(inp.values)]
        ), f"Incorrect result on round {i}"
        inp_indices_p = ref_mem[5]
        if True:
            print("hey:", machine.mem[inp_indices_p : inp_indices_p + len(inp.indices)])
            print("there:", ref_mem[inp_indices_p : inp_indices_p + len(inp.indices)])
        # Updating these in memory isn't required, but you can enable this check for debugging
        #assert machine.mem[inp_indices_p:inp_indices_p+len(inp.indices)] == ref_mem[inp_indices_p:inp_indices_p+len(inp.indices)]

    print("CYCLES: ", machine.cycle)
    return machine.cycle


class Tests(unittest.TestCase):
    def test_ref_kernels(self):
        """
        Test the reference kernels against each other
        """
        random.seed(123)
        for i in range(10):
            f = Tree.generate(4)
            inp = Input.generate(f, 10, 6)
            mem = build_mem_image(f, inp)
            reference_kernel(f, inp)
            for _ in reference_kernel2(mem):
                pass
            assert inp.indices == mem[mem[5] : mem[5] + len(inp.indices)]
            assert inp.values == mem[mem[6] : mem[6] + len(inp.values)]

    def test_simple(self):
        # Test the kernel builder
        kb = KernelBuilder()
        kb.build_simple_test()
        # print(kb.instrs)

        # Test the machine
        mem = [0] * 10
        machine = Machine(mem, kb.instrs, kb.debug_info())
        machine.prints = True
        machine.run()
        # print(machine.cores[0])
        assert machine.cores[0].scratch[1] == 10

    def test_kernel_trace(self):
        # Tiny example for correctness debugging
        # do_kernel_test(3, 1, 1, trace=True, prints=True)
        # Full-scale example for performance testing
        do_kernel_test(3, 3, 16, trace=True, prints=False)

    def test_kernel_correctness(self):
        # Technically passing this test is not required for submission, see submission_tests.py for the actual correctness test
        # Feel free not to run this yourself if your compiler is slow at it
        for batch in range(1, 3):
            for forest_height in range(3):
                do_kernel_test(
                    forest_height + 2, forest_height + 4, batch * 16 * VLEN * N_CORES
                )

    def test_kernel_cycles(self):
        do_kernel_test(10, 16, 1024)


# To run all the tests:
#    python perf_takehome.py
# To run a specific test:
#    python perf_takehome.py Tests.test_kernel_cycles
# To view a hot-reloading trace of all the instructions:  **Recommended debug loop**
#    python perf_takehome.py Tests.test_kernel_trace
# Then run `python watch_trace.py` in another tab, it'll open a browser tab, then click "Open Perfetto"
# You can then keep that open and re-run the test to see a new trace.

# To test the actual submission tests that CodeSignal will run:
#    python tests/submission_tests.py

if __name__ == "__main__":
    unittest.main()
