In [1]:
from enum import Enum

# RISCV-32I 6 种类型
class OpCode(Enum):
    R = "0110011"
    I_JALR = "1100111"
    I_CALC = "0010011"
    I_LOAD = "0000011"
    S = "0100011"
    B = "1100011"
    U_LUI = "0110111"
    U_AUIPC = "0010111"
    J = "1101111"


# 部分 I 型指令
class IFunct3(Enum):
    ADDI = "000"
    SLTI = "010"
    SLTIU = "011"
    XORI = "100"
    ORI = "110"
    ANDI = "111"
    SLLI = "001"
    SRLI = "101"
    SRAI = "101"
    JALR = "000"
    LB = "000"
    LH = "001"
    LW = "010"
    LBU = "100"
    LHU = "101"


# 部分 S 型指令
class SFunct3(Enum):
    SB = "000"  # 本次需要
    SH = "001"
    SW = "010"


# B 型指令
class BFunct3(Enum):
    BEQ = "000"
    BNE = "001"
    BLT = "100"
    BGE = "101"
    BLTU = "110"
    BGEU = "111"


# 部分 R 型指令
class RFunct3(Enum):
    ADD = "000"  # 本次需要
    SUB = "000"
    SLL = "001"
    SLT = "010"
    SLTU = "011"
    XOR = "100"
    SRL = "101"
    SRA = "101"
    OR = "110"
    AND = "111"


class RFunct7(Enum):
    ADD = "0000000"
    SUB = "0100000"
    SRA = "0100000"


class InstructionInfo:
    opcode: OpCode
    rs1: int
    rs2: int
    rd: int
    funct3: Enum
    funct7: Enum
    imm: int


class PipeReg:
    rs1: int
    rs2: int
    value: int
    mem_value: int


class Instruction:
    def __init__(self, isa: "ISA") -> None:
        self.isa = isa
        self.pc_inc = True
        # print(self.__class__.__name__)

    def stage_ex(self):
        """
        EX-执行.对指令的各种操作数进行运算

        IF ID 阶段操作相对固定, EX MEM WB 阶段需要根据具体指令在做调整

        单指令可以继承 Instruction 类并重写此方法
        """

    def stage_mem(self):
        """
        MEM-存储器访问.将数据写入存储器或从存储器中读出数据

        单指令可以继承 Instruction 类并重写此方法
        """

    def stage_wb(self):
        """
        WB-写回.将指令运算结果存入指定的寄存器

        单指令可以继承 Instruction 类并重写此方法
        """
        if self.pc_inc:
            self.isa.pc += 4


class ISA:
    """
    基础处理器架构
    """

    def __init__(self) -> None:
        # 基础配置信息
        register_number = 32
        memory_range = 0x200

        self.pc = 0
        self.registers = [0] * register_number  # 寄存器组
        self.memory = [0] * memory_range  # 内存
        self.instruction: Instruction = None  # 当前指令
        self.instruction_info = InstructionInfo()  # 当前指令的信息拆分
        self.pipeline_register = PipeReg()

    def load_instructions(self, instructions, pc=0x100):
        self.pc = pc
        # 小端存储
        for inst in instructions:
            instruction_str = format(inst, "032b")
            self.memory[pc + 3] = int(instruction_str[:8], 2)
            self.memory[pc + 2] = int(instruction_str[8:16], 2)
            self.memory[pc + 1] = int(instruction_str[16:24], 2)
            self.memory[pc] = int(instruction_str[24:], 2)
            pc += 4

    def run(self):
        while True:
            self.stage_if()
            if self.pc == -1:
                break
            self.stage_id()
            self.stage_ex()
            self.stage_mem()
            self.stage_wb()

    def stage_if(self):
        """
        IF-取指令.根据PC中的地址在指令存储器中取出一条指令
        """
        # 小端取数
        self.instruction = ""
        self.instruction += format(self.memory[self.pc + 3], "08b")
        self.instruction += format(self.memory[self.pc + 2], "08b")
        self.instruction += format(self.memory[self.pc + 1], "08b")
        self.instruction += format(self.memory[self.pc], "08b")

        # 全 0 默认运行完所有指令, 退出
        if int(self.instruction) == 0:
            self.pc = -1

    def stage_id(self):
        """
        ID-译码 解析指令并读取寄存器的值"""
        raise NotImplementedError("should implement stage ID")

    def stage_ex(self):
        """
        EX-执行.对指令的各种操作数进行运算
        """
        self.instruction.stage_ex()

    def stage_mem(self):
        """
        MEM-存储器访问.将数据写入存储器或从存储器中读出数据
        """
        self.instruction.stage_mem()

    def stage_wb(self):
        """
        WB-写回.将指令运算结果存入指定的寄存器
        """
        self.instruction.stage_wb()

    def show_info(self, info=None):
        mem_range = 5
        register_range = 15

        if info is not None:
            print(info)
        print("#" * 20)
        for i in range(mem_range):
            print(f"mem[{i}] = {self.memory[i]}")
        print("#" * 20)
        for i in range(register_range):
            print(f"r{i} = {self.registers[i]}")
        print("#" * 20)

    def binary_str(self, imm: str):
        if imm[0] == "1":
            inverted_str = "".join("1" if bit == "0" else "0" for bit in imm)
            abs_value = int(inverted_str, 2) + 1
            return -abs_value
        else:
            return int(imm, 2)


In [2]:
# RISCV 32I 指令集介绍可以参考
# https://www.sunnychen.top/archives/riscvbasic

class R_ADD(Instruction):
    def stage_ex(self):
        self.isa.pipeline_register.value = (
            self.isa.pipeline_register.rs1 + self.isa.pipeline_register.rs2
        )

    def stage_wb(self):
        self.isa.registers[self.isa.instruction_info.rd] = self.isa.pipeline_register.value
        return super().stage_wb()


class R_SUB(Instruction):
    ...


class R_SLL(Instruction):
    ...


class R_SLT(Instruction):
    ...


class R_SLTU(Instruction):
    ...


class R_XOR(Instruction):
    def stage_ex(self):
        self.isa.pipeline_register.value = (
            self.isa.pipeline_register.rs1 ^ self.isa.pipeline_register.rs2
        )

    def stage_wb(self):
        self.isa.registers[self.isa.instruction_info.rd] = self.isa.pipeline_register.value
        return super().stage_wb()


class R_SRL(Instruction):
    ...


class R_SRA(Instruction):
    ...


class R_OR(Instruction):
    ...


class R_AND(Instruction):
    ...


class I_ADDI(Instruction):
    def stage_ex(self):
        self.isa.pipeline_register.value = (
            self.isa.pipeline_register.rs1 + self.isa.instruction_info.imm
        )

    def stage_wb(self):
        self.isa.registers[self.isa.instruction_info.rd] = self.isa.pipeline_register.value
        return super().stage_wb()


class I_SLTI(Instruction):
    ...


class I_SLTIU(Instruction):
    ...


class I_XORI(Instruction):
    ...


class I_ORI(Instruction):
    ...


class I_ANDI(Instruction):
    ...


class I_SLLI(Instruction):
    ...


class I_SRLI(Instruction):
    ...


class I_SRAI(Instruction):
    ...


class I_JALR(Instruction):
    ...


class I_LB(Instruction):
    def stage_ex(self):
        self.isa.pipeline_register.value = (
            self.isa.pipeline_register.rs1 + self.isa.instruction_info.imm
        )

    def stage_mem(self):
        self.isa.pipeline_register.mem_value = self.isa.memory[self.isa.pipeline_register.value]

    def stage_wb(self):
        self.isa.registers[self.isa.instruction_info.rd] = self.isa.pipeline_register.mem_value
        return super().stage_wb()


class I_LH(Instruction):
    ...


class I_LW(Instruction):
    ...


class I_LBU(Instruction):
    ...


class I_LHU(Instruction):
    ...


class S_SB(Instruction):
    def stage_ex(self):
        self.isa.pipeline_register.value = (
            self.isa.pipeline_register.rs1 + self.isa.instruction_info.imm
        )

    def stage_wb(self):
        self.isa.memory[self.isa.pipeline_register.value] = self.isa.pipeline_register.rs2 & 0xFF
        return super().stage_wb()


class S_SH(Instruction):
    ...


class S_SW(Instruction):
    ...


class B_BEQ(Instruction):
    ...


class B_BNE(Instruction):
    def stage_ex(self):
        if self.isa.pipeline_register.rs1 != self.isa.pipeline_register.rs2:
            self.isa.pc = self.isa.pc + self.isa.instruction_info.imm
            self.pc_inc = False


class B_BLT(Instruction):
    ...


class B_BGE(Instruction):
    ...


class B_BLTU(Instruction):
    ...


class B_BGEU(Instruction):
    ...


class U_LUI(Instruction):
    ...


class U_AUIPC(Instruction):
    ...


class J_JAL(Instruction):
    def stage_ex(self):
        self.isa.pc += self.isa.instruction_info.imm
        self.pc_inc = False


In [3]:

class RISCV32(ISA):
    """
    RISCV 32I 单周期五阶段
    """

    def stage_id(self):
        opcode = self.instruction[-7:]
        opcode_type = OpCode(opcode)
        self.instruction_info.opcode = opcode_type
        if opcode_type == OpCode.R:
            self.instruction_info.funct7 = RFunct7(self.instruction[:7])
            self.instruction_info.rs2 = int(self.instruction[7:12], 2)
            self.instruction_info.rs1 = int(self.instruction[12:17], 2)
            self.instruction_info.funct3 = RFunct3(self.instruction[17:20])
            self.instruction_info.rd = int(self.instruction[20:25], 2)
            self.instruction_info.imm = None
        elif opcode_type in (OpCode.I_LOAD, OpCode.I_CALC, OpCode.I_JALR):
            self.instruction_info.funct7 = None
            self.instruction_info.rs2 = None
            self.instruction_info.rs1 = int(self.instruction[12:17], 2)
            self.instruction_info.funct3 = IFunct3(self.instruction[17:20])
            self.instruction_info.rd = int(self.instruction[20:25], 2)
            self.instruction_info.imm = self.binary_str(self.instruction[:12])
        elif opcode_type == OpCode.S:
            self.instruction_info.funct7 = None
            self.instruction_info.rs2 = int(self.instruction[7:12], 2)
            self.instruction_info.rs1 = int(self.instruction[12:17], 2)
            self.instruction_info.funct3 = SFunct3(self.instruction[17:20])
            self.instruction_info.rd = None
            self.instruction_info.imm = self.binary_str(self.instruction[:7] + self.instruction[20:25])
        elif opcode_type == OpCode.B:
            self.instruction_info.funct7 = None
            self.instruction_info.rs2 = int(self.instruction[7:12], 2)
            self.instruction_info.rs1 = int(self.instruction[12:17], 2)
            self.instruction_info.funct3 = BFunct3(self.instruction[17:20])
            self.instruction_info.rd = None
            self.instruction_info.imm = self.binary_str(
                self.instruction[0]
                + self.instruction[24]
                + self.instruction[1:6]
                + self.instruction[20:24]
                + "0"
            )

        elif opcode_type in (OpCode.U_AUIPC, OpCode.U_LUI):
            self.instruction_info.funct7 = None
            self.instruction_info.rs2 = None
            self.instruction_info.rs1 = None
            self.instruction_info.funct3 = None
            self.instruction_info.rd = int(self.instruction[20:25], 2)
            self.instruction_info.imm = self.binary_str(self.instruction[:20] + "0" * 12)
        elif opcode_type == OpCode.J:
            self.instruction_info.funct7 = None
            self.instruction_info.rs2 = None
            self.instruction_info.rs1 = None
            self.instruction_info.funct3 = None
            self.instruction_info.rd = int(self.instruction[20:25], 2)
            self.instruction_info.imm = self.binary_str(
                self.instruction[0]
                + self.instruction[11:18]
                + self.instruction[10]
                + self.instruction[1:11]
                + "0"
            )
        else:
            raise ValueError("unknown opcode type")

        # 如果指令中包含 rs1 rs2, 则读取对应寄存器的值
        if self.instruction_info.rs1 is not None:
            self.pipeline_register.rs1 = self.registers[self.instruction_info.rs1]

        if self.instruction_info.rs2 is not None:
            self.pipeline_register.rs2 = self.registers[self.instruction_info.rs2]

        self.match_instruction()

    def match_instruction(self):
        """
        通过 ISA 在 ID 阶段解析指令得到的信息, 定位找到具体的指令
        """
        RISCV_32I_instructions = {
            OpCode.R: {
                RFunct3.ADD: R_ADD,
                # RFunct3.SUB: R_SUB,
                RFunct3.SLL: R_SLL,
                RFunct3.SLT: R_SLT,
                RFunct3.SLTU: R_SLTU,
                RFunct3.XOR: R_XOR,
                RFunct3.SRL: R_SRL,
                RFunct3.SRA: R_SRA,
                RFunct3.OR: R_OR,
                RFunct3.AND: R_AND,
            },
            OpCode.I_CALC: {
                IFunct3.ADDI: I_ADDI,
                IFunct3.SLTI: I_SLTI,
                IFunct3.SLTIU: I_SLTIU,
                IFunct3.XORI: I_XORI,
                IFunct3.ORI: I_ORI,
                IFunct3.ANDI: I_ANDI,
                IFunct3.SLLI: I_SLLI,
                IFunct3.SRLI: I_SRLI,
                IFunct3.SRAI: I_SRAI,
            },
            OpCode.I_JALR: {IFunct3.JALR: I_JALR},
            OpCode.I_LOAD: {
                IFunct3.LB: I_LB,
                IFunct3.LH: I_LH,
                IFunct3.LW: I_LW,
                IFunct3.LBU: I_LBU,
                IFunct3.LHU: I_LHU,
            },
            OpCode.S: {SFunct3.SB: S_SB, SFunct3.SH: S_SH, SFunct3.SW: S_SW},
            OpCode.B: {
                BFunct3.BEQ: B_BEQ,
                BFunct3.BNE: B_BNE,
                BFunct3.BLT: B_BLT,
                BFunct3.BGE: B_BGE,
                BFunct3.BLTU: B_BLTU,
                BFunct3.BGEU: B_BGEU,
            },
        }

        if (
            self.instruction_info.funct3 == RFunct3.SUB
            and self.instruction_info.funct7 == RFunct7.SUB
        ):
            self.instruction = R_SUB(self)
        elif (
            self.instruction_info.funct3 == RFunct3.SRA
            and self.instruction_info.funct7 == RFunct7.SRA
        ):
            self.instruction = R_SRA(self)
        elif self.instruction_info.opcode == OpCode.U_AUIPC:
            self.instruction = U_AUIPC(self)
        elif self.instruction_info.opcode == OpCode.U_LUI:
            self.instruction = U_LUI(self)
        elif self.instruction_info.opcode == OpCode.J:
            self.instruction = J_JAL(self)
        else:
            self.instruction = RISCV_32I_instructions[self.instruction_info.opcode][
                self.instruction_info.funct3
            ](self)


def main():
    # 汇编代码见 example.S

    #     xor a0, a0, a0
    #     lb a1, 0(a0)
    #     lb a2, 1(a0)
    # L1:
    #     addi a1, a1, 1
    #     addi a2, a2, 3
    #     bne a1, a2, L1
    #     jal a4, L2
    # L2:
    #     sb a2, 3(a0)

    # 编译为 32 位 RISCV 目标文件

    # riscv64-linux-gnu-gcc -march=rv32i -mabi=ilp32 -c example.S -o example.o
    # riscv64-linux-gnu-objdump example.o -d

    instructions = [
        0x00A54533,
        0x00050583,
        0x00150603,
        0x00360613,
        0x00158593,
        0xFEC59CE3,
        0x0040076F,
        0x00C501A3,
    ]

    isa = RISCV32()
    isa.memory[0] = 20
    isa.memory[1] = 0
    isa.show_info("before")

    isa.load_instructions(instructions)
    isa.run()
    isa.show_info("after")


if __name__ == "__main__":
    main()


before
####################
mem[0] = 20
mem[1] = 0
mem[2] = 0
mem[3] = 0
mem[4] = 0
####################
r0 = 0
r1 = 0
r2 = 0
r3 = 0
r4 = 0
r5 = 0
r6 = 0
r7 = 0
r8 = 0
r9 = 0
r10 = 0
r11 = 0
r12 = 0
r13 = 0
r14 = 0
####################
after
####################
mem[0] = 20
mem[1] = 0
mem[2] = 0
mem[3] = 30
mem[4] = 0
####################
r0 = 0
r1 = 0
r2 = 0
r3 = 0
r4 = 0
r5 = 0
r6 = 0
r7 = 0
r8 = 0
r9 = 0
r10 = 0
r11 = 30
r12 = 30
r13 = 0
r14 = 0
####################
