In [1]:
import math
import weakref
import random

In [2]:
class Word:
    def __init__(self, wordLength, x, isRegister=True):
        self.wordLength = wordLength
        self.maxValue = 2 ** wordLength - 1
        self.isRegister = isRegister
        
        self._x = x
        self.checkError()
        
    def __del__(self):
        if self.isRegister:
            Machine.deallocateReg(self)
        
    def checkError(self):
        assert self._x <= self.maxValue and self._x >= 0, f"Tried to set value of {self._x} on the memory with {self.wordLength}-bit word length."
            
    def get(self):
        if not self.isRegister:
            Machine.totalExecution += 1
        return self._x
    
    def set(self, x):
        if not self.isRegister:
            Machine.totalExecution += 1
        self._x = x
        self.checkError()
        
    def lg(self):
        v = self.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(int(math.log(v, 2)))
    
    def pow(self, x):
        nx = Reg(x.get())
        ans = Reg(1)
        while nx >= 0:
            ans *= self
            nx -= 1
        return ans
    
    def __add__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v + rv)
    def __iadd__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(v + rv)
        return self
        
    def __sub__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v - rv)
    def __isub__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(v - rv)
        return self
        
    def  __mul__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v * rv)
    def __imul__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(v * rv)
        return self
        
    def __truediv__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v // rv)
    def __itruediv__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(v // rv)
        return self
    
    # modulo operation can be done by true division and subtraction in constant time
    # so here I simplify it to use the modulo operation provided by Python
    def __mod__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v % rv)
    def __imod__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(v % rv)
        return self
        
    def __and__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v & rv)
    def __iand__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(r & rv)
        return self
    
    def __or__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v | rv)
    def __ior__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(v | rv)
        return self
    
    def __xor__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v ^ rv)
    def __ixor__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(v ^ rv)
        return self
    
    def __lshift__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v << rv)
    def __ilshift__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(v << rv)
        return self
    
    def __rshift__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return Machine.allocateReg(v >> rv)
    def __irshift__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        self.set(v >> rv)
        return self
    
    def __eq__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return v == rv
    
    def __lt__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return v < rv
    
    def __le__(self, rhs):
        v = self.get()
        rv = rhs.get()
        Machine.totalExecution += 1
        return v <= rv
        
    
    def __hash__(self):
        return hash(id(self))
    
    def __str__(self):
        return str(self._x)
        
    
    def _randomize(self):
        self._x = random.randint(0, self.maxValue)
        
    def _getBit(self, i):
        Machine.totalExecution += 1
        return (self._x >> (self.wordLength - i - 1)) & 1
    
    def lgceil(self):
        return (self - Reg(1)).lg() + Reg(1)
    
    def divceil(self, rhs):
        return (self - Reg(1)) / rhs + Reg(1)
    
    def getOnes(self):
        crt = Reg(self.get())
        res = Reg(0)
        while (crt > Reg(0)):
            res += crt & Reg(1)
            crt >>= Reg(1)
        return res

In [3]:
class WordRAMMemory:
    def __init__(self, size, wordLength):
        self.size = size
        self.wordLength = wordLength
        self.memSize = size * wordLength
        
        self._d = [Word(self.wordLength, 0, False) for _ in range(self.size)]
    
    def __getitem__(self, i):
        assert isinstance(i, Word), "Subscript type should be `Word`!"
        x = i.get()
        assert x < self.size, f"Tried to access at address {x} while the size of memory is {self.size}."
        return self._d[x]
    
    def __setitem__(self, i, word):
        assert isinstance(i, Word), "Subscript type should be `Word`!"
        assert isinstance(word, Word), "Value type should be `Word`!"
        self._d[i.get()].set(word.get())
    
    def _randomize(self):
        for word in self._d:
            word._randomize()
            
    def _getBit(self, i):
        blockId = i // self.wordLength
        bitIdInBlock = i - blockId * self.wordLength
        return self._d[blockId]._getBit(bitIdInBlock)
    
    def _getMultiBits(self, start, end):
        assert self.wordLength == 1, "Can only obtain multiple bits at once in raw WordRAM memorys"
        assert end._x - start._x < Machine.wordLength, "Try to access too much bits at once!"
        res = Reg(0)
        Machine.totalExecution += 1
        for i in range(start.get(), end.get()):
            res <<= Reg(1)
            if (i < self.size):
                res += self._d[i]
        return res
    
    def clear(self):
        i = Reg(0)
        while i < Reg(self.size):
            self[i] = Reg(0)
            
    def clone(self):
        new = Mem(self.size, self.wordLength)
        i = Reg(0)
        while i < Reg(self.size):
            new[i] = self[i]
        return new
        
    
    def __str__(self):
        return ", ".join(map(lambda w: str(w._x) ,self._d))

In [4]:
class WordRAMMachine:
    def __init__(self, wordLength, registerNum):
        self.wordLength = wordLength
        
        self.registerNum = registerNum
        self.crtRegisters = weakref.WeakSet()
        
        self.totalMemSize = 0
        self.totalExecution = 0
        
    def allocateMem(self, size, wordLength):
        size = size.get()
        wordLength = wordLength.get()
        if (wordLength > self.wordLength):
            print(f"Tried to allocate memory with word length of {wordLength} bits on the machine with {self.wordLength}-bit word length.")
            return None
        mem = WordRAMMemory(size, wordLength)
        self.totalMemSize += size * wordLength
        return mem
    
    def allocateReg(self, x=0):
        assert len(self.crtRegisters) < self.registerNum, "Too many registers!"
        reg = Word(self.wordLength, x)
        self.crtRegisters.add(reg)
        return reg
    
    def deallocateReg(self, reg):
        self.crtRegisters.remove(reg)

In [5]:
Machine = None

def CreateMachine(wordLength, registerNum):
    global Machine
    Machine = WordRAMMachine(wordLength, registerNum)
    
def GetMachine():
    return Machine
    
def Reg(x):
    return Machine.allocateReg(x)

def Mem(size, wordLength=None):
    if wordLength = None:
        wordLength = Reg(Machine.wordLength)
    return Machine.allocateMem(size, wordLength)