In [None]:
gsharebits = 0b11
gsharesize = 4
class Instruction:
    def __init__(self, instruction):
        self.instructionVal = bin(instruction)
        self.opcode = instruction >> 26 & 0b111111
        self.rs = instruction >> 21 & 0b11111
        self.rt = instruction >> 16 & 0b11111
        self.rd = instruction >> 11 & 0b11111
        self.shamt = instruction >> 6 & 0b11111
        self.funct = instruction & 0b111111
        self.imm = instruction & 0xFFFF
        self.cycles = [0] * 5 #creates an array for every stage to record at which stage was an instruction in every cycle
                                #for example, if an instruction was at Decode stage, then self.cycles[0] = 0 and self.cycoles[1] = 1 and so on... a dictionary might be better.


In [2]:
class PipelineState:
    def __init__(self):
        self.cycle = 0
        self.ifidFlush = 0
        self.idexFlush = 0
        self.branchflushcount = 0
        self.branchInstruction  = 0
        self.histReg = 0
        self.pduStates = [0b11]*gsharesize
        self.ifidStall = 0
        self.pc = 0
        self.registers = [0] *32
        self.memory = [0] * 1024 #can be changed to any size depending on the memory size
        

In [3]:
class PipelineRegister:
    def __init__(self):
        self.flush = 0
        self.instruction = 0
        self.instructionVal = 0
        self.prediction = 0
        self.stateIndex = 0
        self.last = 0
        self.data = 0 # any data, depending on the stage. signals are interpreted directly w/o a CU so no need for multiple signals 
        self.done = 0
        self.flush = 0
        self.pc = 0 

In [4]:
instructions = []
with open("imem.txt", 'r') as inst:
    for line in inst:
        parts = line.split(":")
        if len(parts)>1:
            instructions.append(int(parts[1].strip().rstrip(";"),2))
    

In [5]:
IF_ID = PipelineRegister()
ID_EX = PipelineRegister()
EX_MEM = PipelineRegister()
MEM_WB = PipelineRegister()
state = PipelineState()
state.memory[0] = 0x5
state.memory[1] = 0x7
state.memory[2] = 0x2
state.memory[3] = 0xF
state.memory[4] = 0xA
state.memory[5] = 0x10
state.memory[6] = 0x30
state.memory[7] = 0x1
state.memory[8] = 0xFF
state.memory[9] = 0x55
opcodes = {
    0x0 : 'rtype', 0x8: 'addi', 0xd :'ori', 0xe: 'xori', 0xc:'andi', 0xa:'slti',
    0x23: 'lw', 0x2b : 'sw', 0x4:'beq', 0x2:'j', 0x3 : 'jal', 0x5:'bne'
      
}
notWrites = [0x2b, 0x4, 0x2, 0x5]
rtypes = {
    0x20: 'add', 0x22:'sub', 0x24:'and', 0x25: 'or', 0x2a:'slt', 0x26: 'xor', 
    0x27:'nor', 0x0:'sll', 0x2:'srl', 0x8:'jr'
    
}
def signed(data):
        if data & 0x8000:
            return data - (1 << 16)  
        else: return data

In [6]:
def forwarding_unit(EX_MEM, MEM_WB, ID_EX, state):
    if not EX_MEM.done and MEM_WB.done and ID_EX.done:
        writeex = 1 if EX_MEM.instruction.opcode not in notWrites else 0
        writewb = 1 if MEM_WB.instruction.opcode not in notWrites else 0
        if writeex and EX_MEM.destreg == ID_EX.rs and EX_MEM.destreg != 0:
            op1 = EX_MEM.data 
        elif writewb and MEM_WB.destreg == ID_EX.rs and MEM_WB.destreg !=0:
            op1 = MEM_WB.data
        else: op1 = state.registers[ID_EX.instruction.rs]

        if writeex and EX_MEM.destreg == ID_EX.rt and MEM_WB.destreg !=0:
            op2 = EX_MEM.data
        elif writewb and MEM_WB.destreg == ID_EX.rt:
            op2 = MEM_WB.data
        else: op2 = state.registers[ID_EX.instruction.rt]
        return op1, op2
    else: return state.registers[ID_EX.instruction.rs], state.registers[ID_EX.instruction.rt]



In [7]:


def hazard_unit(ID_EX, state, EX_MEM, MEM_WB, IF_ID):
    ldHazard = False
    if IF_ID.done and ID_EX.done:
        memRead = opcodes[ID_EX.instruction.opcode] == 'lw'
        if memRead and ((IF_ID.rs == ID_EX.destreg) or (opcodes[IF_ID.instruction.opcode] == 'rtype' and (IF_ID.rt == ID_EX.destreg))):
            ldHazard = True
        else: ldHazard = False
    return ldHazard
    
             
    
    
    

In [8]:
def fetch(state, instructions, IF_ID):
    IF_ID.rs = 0
    IF_ID.rt = 0
    IF_ID.last = 0
    if IF_ID.done:
        return
    elif state.ifidFlush:
        IF_ID.instruction = Instruction(0)
        IF_ID.instructionVal = 0
        inst = Instruction(0)
    else:
        if state.pc < len(instructions):
            inst = Instruction(instructions[state.pc])
        else: 
            IF_ID.last = 1
            inst = Instruction(0)
        IF_ID.thisPC = state.pc
        inst.cycles[0] = 1
        IF_ID.instruction = inst
        IF_ID.instructionVal = inst.instructionVal
        stateIndex = ((state.histReg & gsharebits) ^ (state.pc & gsharebits))
        IF_ID.stateIndex = stateIndex
        prediction = ((state.pduStates[stateIndex] >> 1) &0b1)
        if ((opcodes[inst.opcode] == 'beq' or opcodes[inst.opcode] =='bne' )):
            state.branchInstruction +=1
            IF_ID.prediction = prediction
            print("\nTAKEN" if prediction else "\n NOT TAKEN")
        IF_ID.rs = inst.rs
        IF_ID.rt = inst.rt
        IF_ID.pcp1 = state.pc +1
        if ((opcodes[inst.opcode] == 'beq' or opcodes[inst.opcode] =='bne' ) and prediction):
            state.pc = state.pc +1+ signed(inst.imm )& 0b1111111111
        elif opcodes[inst.opcode] == 'j' or opcodes[inst.opcode] =='jal':
            state.pc = inst.imm &0b1111111111
        elif opcodes[inst.opcode] == 'rtype' and rtypes[inst.funct] == 'jr':
            state.pc = state.registers[inst.rs] & 0b1111111111
        else: 
             state.pc += 1
            
    IF_ID.nextpc = state.pc
    IF_ID.done = 1
        

In [9]:
def decode(state, IF_ID, ID_EX):
    ID_EX.readData1 = 0
    ID_EX.readData2 = 0
    ID_EX.imm = 0
    ID_EX.last =0
    ID_EX.rs = 0
    ID_EX.rt = 0
    ID_EX.rd = 0
    ID_EX.thisPC = 0
    ID_EX.pcp1 = 0
    ID_EX.destreg = 0
    ID_EX.instruction = Instruction(0)
    if not IF_ID.done: #checks if IFID is fetching another instruction or not
        return
    
    inst = IF_ID.instruction
    inst.cycles[0] = 0
    inst.cycles[1] = 1
    inst.cycles[2] = 0
    inst.cycles[3] = 0
    if state.idexFlush:
        ID_EX.instruction = Instruction(0)
        ID_EX.instructionVal = 0
        inst = Instruction(0)
    
    if opcodes[inst.opcode] == 'rtype':
        ID_EX.readData1 = state.registers[inst.rs]
        ID_EX.readData2 = state.registers[inst.rt]
    else:
        ID_EX.imm = signed(inst.imm)
        ID_EX.readData1 = state.registers[inst.rs]
    ID_EX.instruction = inst
    ID_EX.prediction = IF_ID.prediction
    ID_EX.stateIndex = IF_ID.stateIndex
    ID_EX.instructionVal = inst.instructionVal
    ID_EX.rs = inst.rs
    ID_EX.rt = inst.rt
    validRt = opcodes[inst.opcode] == 'rtype' or opcodes[inst.opcode] =='bne' or opcodes[inst.opcode] == 'beq'
    ID_EX.destreg = inst.rd if validRt else inst.rt
    ID_EX.rd = inst.rd
    ID_EX.pcp1 = IF_ID.pcp1
    ID_EX.last = IF_ID.last
    ID_EX.thisPC = IF_ID.thisPC
    ID_EX.done = 1 #IDEX cannot decode and change its value until execute stage is done
    IF_ID.done = 0 # IFID is now allowed to fetch another instruction
    
    

In [10]:
def execute(state,ID_EX, EX_MEM):
    if not ID_EX.done: #checks if IDEX is decoding another instruction or not
        return
    EX_MEM.last =0
    EX_MEM.forwardBres = 0
    EX_MEM.imm = 0
    EX_MEM.readData1 = 0
    EX_MEM.instruction = Instruction(0)
    inst = ID_EX.instruction
    inst.cycles[0] = 0
    inst.cycles[1] = 0
    inst.cycles[2] = 1
    inst.cycles[3] = 0
    validRt = opcodes[inst.opcode] == 'rtype' or opcodes[inst.opcode] =='bne' or opcodes[inst.opcode] == 'beq'
    op1, op2 = forwarding_unit(EX_MEM=EX_MEM, MEM_WB=MEM_WB, ID_EX=ID_EX, state= state)
    if not validRt: op2 = ID_EX.imm
    if opcodes[inst.opcode] == 'rtype':
        if rtypes[inst.funct] =='add':
            EX_MEM.data = op1 + op2
        elif rtypes[inst.funct] =='sub':
            EX_MEM.data = op1 - op2
        elif rtypes[inst.funct] =='and':
            EX_MEM.data = op1 & op2
        elif rtypes[inst.funct] =='or':
            EX_MEM.data = op1 | op2
        elif rtypes[inst.funct] =='slt':
            EX_MEM.data = 1 if op1<op2 else 0
        elif rtypes[inst.funct] =='xor':
            EX_MEM.data = op1 ^ op2
        elif rtypes[inst.funct] =='nor':
            EX_MEM.data = ~(op1 | op2)
        elif rtypes[inst.funct] =='sll':
            EX_MEM.data = op1 <<inst.shamt
        elif rtypes[inst.funct] =='srl':
            EX_MEM.data = op1 >> inst.shamt
    elif opcodes[inst.opcode] == 'addi':
        EX_MEM.data = op1+op2
    elif opcodes[inst.opcode] == 'ori':
        EX_MEM.data = op1|op2
    elif opcodes[inst.opcode] == 'xori':
        EX_MEM.data = op1^op2
    elif opcodes[inst.opcode] == 'andi':
        EX_MEM.data = op1&op2
    elif opcodes[inst.opcode] == 'slti':
        EX_MEM.data = 1 if op1<op2 else 0
    elif opcodes[inst.opcode] == 'sw': EX_MEM.data = state.registers[inst.rt]
    else: EX_MEM.data = 0
    if(opcodes[inst.opcode] == 'bne'):
        if op1 == op2:
            taken = 0
        else: taken = 1
    elif(opcodes[inst.opcode] == 'beq'):
        if op1 == op2:
            taken = 1
        else: taken = 0
    else: taken = 0
    if((opcodes[inst.opcode] == 'bne') or (opcodes[inst.opcode] == 'beq')):
        state.histReg = (state.histReg << 1) + 1 if taken else  state.histReg << 1 
        if(taken):
            if state.pduStates[ID_EX.stateIndex] == 0b00:
                state.pduStates[ID_EX.stateIndex] == 0b01
            elif state.pduStates[ID_EX.stateIndex] == 0b01:
                state.pduStates[ID_EX.stateIndex] = 0b10
            elif state.pduStates[ID_EX.stateIndex] == 0b10:
                state.pduStates[ID_EX.stateIndex] = 0b11
            elif state.pduStates[ID_EX.stateIndex] == 0b11:
                state.pduStates[ID_EX.stateIndex] = 0b11
        else:
            if state.pduStates[ID_EX.stateIndex] == 0b00:
                state.pduStates[ID_EX.stateIndex] == 0b00
            elif state.pduStates[ID_EX.stateIndex] == 0b01:
                state.pduStates[ID_EX.stateIndex] = 0b00
            elif state.pduStates[ID_EX.stateIndex] == 0b10:
                state.pduStates[ID_EX.stateIndex] = 0b01
            elif state.pduStates[ID_EX.stateIndex] == 0b11:
                state.pduStates[ID_EX.stateIndex] = 0b10
    if opcodes[inst.opcode] == 'bne' or opcodes[inst.opcode] == 'beq':
        branchHazard = True if ID_EX.prediction ^ taken else False
        if branchHazard:
            if taken:
                state.pc = ID_EX.thisPC +1+ signed(inst.imm )& 0b1111111111
            else:
                state.pc = ID_EX.pcp1
            state.ifidFlush = 1
            state.idexFlush = 1
            print("flush detected")
            state.branchflushcount +=1
        else:
            state.ifidFlush = 0
            state.idexFlush = 0
    EX_MEM.taken = taken                
    EX_MEM.instruction = inst
    EX_MEM.instructionVal = inst.instructionVal
    EX_MEM.imm = ID_EX.imm
    EX_MEM.readData1 = ID_EX.readData1
    EX_MEM.rs = inst.rs
    EX_MEM.rt = inst.rt
    EX_MEM.rd = inst.rd
    EX_MEM.op1 = op1
    EX_MEM.op2 = op2
    EX_MEM.last = ID_EX.last
    EX_MEM.thisPC = ID_EX.thisPC
    EX_MEM.destreg = inst.rd if validRt else inst.rt
    EX_MEM.pcp1 = ID_EX.pcp1
    EX_MEM.done = 1
    ID_EX.done = 0
        
        
            
        
            
    

In [11]:
def memory (state, EX_MEM, MEM_WB):
    MEM_WB.last =0
    if not EX_MEM.done: #checks if exeucte is executing or not
        return
    inst = EX_MEM.instruction
    inst.cycles[0] = 0
    inst.cycles[1] = 0
    inst.cycles[2] = 0
    inst.cycles[3] = 1
    MEM_WB.data = EX_MEM.data
    if opcodes[inst.opcode] == 'lw':
         MEM_WB.data = state.memory[EX_MEM.op1 + EX_MEM.imm]
    elif opcodes[inst.opcode] == 'sw':
        state.memory[EX_MEM.op1 + EX_MEM.imm] = MEM_WB.data
    
    MEM_WB.instruction = inst
    MEM_WB.instructionVal = inst.instructionVal
    MEM_WB.rs = inst.rs
    MEM_WB.rt = inst.rt
    MEM_WB.rd = inst.rd
    MEM_WB.last = EX_MEM.last
    MEM_WB.thisPC = EX_MEM.thisPC
    MEM_WB.destreg = EX_MEM.destreg
    MEM_WB.done = 1
    EX_MEM.done =0
    

In [12]:
opcodes = {
    0x0 : 'rtype', 0x8: 'addi', 0xd :'ori', 0xe: 'xori', 0xc:'andi', 0xa:'slti',
    0x23: 'lw', 0x2b : 'sw', 0x4:'beq', 0x2:'j', 0x3 : 'jal', 0x5:'bne'
      
}


def writeBack(state, MEM_WB):
    if not MEM_WB.done:
        return
    inst = MEM_WB.instruction
    inst.cycles[0] = 0
    inst.cycles[1] = 0
    inst.cycles[2] = 0
    inst.cycles[3] = 0
    inst.cycles[4] = 1
    if opcodes[inst.opcode] != 'sw' and opcodes[inst.opcode] !='j' and opcodes[inst.opcode] !='jal' and opcodes[inst.opcode] !='bne' and opcodes[inst.opcode] !='beq':
        state.registers[inst.rd if opcodes[inst.opcode] == 'rtype' else inst.rt] = MEM_WB.data
    MEM_WB.done = 0
    
    
    

In [13]:
def printState(stages):
    for stage, reg in stages.items():
            print(f"{stage}: \n")
            for signal, value in vars(reg).items():
                print(f"{signal}: {value}")
            print("-"*20)

In [14]:

cycle = 0
countofstalls = 0
cycles = 5000
stages = {'IF_ID':IF_ID, 'ID_EX': ID_EX, 'EX_MEM': EX_MEM, 'MEM_WB': MEM_WB }
while (cycle < cycles):
    print(f"\n\n{'-' *50}")
    print(f"Cycle {cycle}")
    print(f"{'-'*50}")
    print(f"PC: {state.pc}")
    if MEM_WB.last:break
    writeBack(state, MEM_WB)
    memory(state, EX_MEM, MEM_WB)
    ldHazard = hazard_unit(ID_EX, state, EX_MEM, MEM_WB, IF_ID)
    if ldHazard:
        state.pc = IF_ID.thisPC
        state.idexFlush = 1
        print("load hazard detected")
        countofstalls +=1
    else:
        state.ifidFlush = 0
        state.idexFlush = 0
    execute(state, ID_EX, EX_MEM)
    decode(state, IF_ID, ID_EX)
    fetch(state, instructions, IF_ID)
    
    #pipeline register states
    printState(stages)
    cycle +=1
state.pc = 0


def print_registers(registers):
    print("\nFinal Register Values:")
    print(f"{'-' * 50}")
    for i, value in enumerate(registers):
        print(f"R{i:2}: {value}")
    print(f"{'-' * 50}\n")

print_registers(state.registers[0:32])

def printMemory(memory):
    print ("\nFinal mem values:")
    print(f"{"-"*50}")
    for i in range(15):
        print(f"M{i:2}: {memory[i]}")
    print(f"{"-"*50}")
printMemory(state.memory)
print("cycle count: ", cycle)
print("ld hazard count: ", countofstalls)
print("branch flush count: ", state.branchflushcount)
print("branch instruction count: ", state.branchInstruction )

    


            
    



--------------------------------------------------
Cycle 0
--------------------------------------------------
PC: 0
IF_ID: 

flush: 0
instruction: <__main__.Instruction object at 0x000002BAD9B431D0>
instructionVal: 0b110100000000100000000000000000
prediction: 0
stateIndex: 0
last: 0
data: 0
done: 1
pc: 0
rs: 0
rt: 2
thisPC: 0
pcp1: 1
nextpc: 1
--------------------
ID_EX: 

flush: 0
instruction: <__main__.Instruction object at 0x000002BAD9A1B7A0>
instructionVal: 0
prediction: 0
stateIndex: 0
last: 0
data: 0
done: 0
pc: 0
readData1: 0
readData2: 0
imm: 0
rs: 0
rt: 0
rd: 0
thisPC: 0
pcp1: 0
destreg: 0
--------------------
EX_MEM: 

flush: 0
instruction: 0
instructionVal: 0
prediction: 0
stateIndex: 0
last: 0
data: 0
done: 0
pc: 0
--------------------
MEM_WB: 

flush: 0
instruction: 0
instructionVal: 0
prediction: 0
stateIndex: 0
last: 0
data: 0
done: 0
pc: 0
--------------------


--------------------------------------------------
Cycle 1
------------------------------------------------