diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index a6aa81be1e402d..529a5bf784f434 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -124,6 +124,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, addRegisterClass(RISCVVMVTs::vint64m4_t, &RISCV::VRM4RegClass); addRegisterClass(RISCVVMVTs::vint64m8_t, &RISCV::VRM8RegClass); + addRegisterClass(RISCVVMVTs::vfloat16mf4_t, &RISCV::VRRegClass); + addRegisterClass(RISCVVMVTs::vfloat16mf2_t, &RISCV::VRRegClass); + addRegisterClass(RISCVVMVTs::vfloat16m1_t, &RISCV::VRRegClass); + addRegisterClass(RISCVVMVTs::vfloat16m2_t, &RISCV::VRM2RegClass); + addRegisterClass(RISCVVMVTs::vfloat16m4_t, &RISCV::VRM4RegClass); + addRegisterClass(RISCVVMVTs::vfloat16m8_t, &RISCV::VRM8RegClass); + addRegisterClass(RISCVVMVTs::vfloat32mf2_t, &RISCV::VRRegClass); addRegisterClass(RISCVVMVTs::vfloat32m1_t, &RISCV::VRRegClass); addRegisterClass(RISCVVMVTs::vfloat32m2_t, &RISCV::VRM2RegClass); diff --git a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td index ea7357c9c073e4..b69cdde6c53269 100644 --- a/llvm/lib/Target/RISCV/RISCVRegisterInfo.td +++ b/llvm/lib/Target/RISCV/RISCVRegisterInfo.td @@ -314,6 +314,13 @@ defvar vint64m2_t = nxv2i64; defvar vint64m4_t = nxv4i64; defvar vint64m8_t = nxv8i64; +defvar vfloat16mf4_t = nxv1f16; +defvar vfloat16mf2_t = nxv2f16; +defvar vfloat16m1_t = nxv4f16; +defvar vfloat16m2_t = nxv8f16; +defvar vfloat16m4_t = nxv16f16; +defvar vfloat16m8_t = nxv32f16; + defvar vfloat32mf2_t = nxv1f32; defvar vfloat32m1_t = nxv2f32; defvar vfloat32m2_t = nxv4f32; @@ -391,6 +398,7 @@ class VReg regTypes, dag regList, int Vlmul> def VR : VReg<[vint8mf2_t, vint8mf4_t, vint8mf8_t, vint16mf2_t, vint16mf4_t, vint32mf2_t, vint8m1_t, vint16m1_t, vint32m1_t, vint64m1_t, + vfloat16mf4_t, vfloat16mf2_t, vfloat16m1_t, vfloat32mf2_t, vfloat32m1_t, vfloat64m1_t, vbool64_t, vbool32_t, vbool16_t, vbool8_t, vbool4_t, vbool2_t, vbool1_t], @@ -401,6 +409,7 @@ def VR : VReg<[vint8mf2_t, vint8mf4_t, vint8mf8_t, def VRNoV0 : VReg<[vint8mf2_t, vint8mf4_t, vint8mf8_t, vint16mf2_t, vint16mf4_t, vint32mf2_t, vint8m1_t, vint16m1_t, vint32m1_t, vint64m1_t, + vfloat16mf4_t, vfloat16mf2_t, vfloat16m1_t, vfloat32mf2_t, vfloat32m1_t, vfloat64m1_t, vbool64_t, vbool32_t, vbool16_t, vbool8_t, vbool4_t, vbool2_t, vbool1_t], @@ -409,29 +418,29 @@ def VRNoV0 : VReg<[vint8mf2_t, vint8mf4_t, vint8mf8_t, (sequence "V%u", 1, 7)), 1>; def VRM2 : VReg<[vint8m2_t, vint16m2_t, vint32m2_t, vint64m2_t, - vfloat32m2_t, vfloat64m2_t], + vfloat16m2_t, vfloat32m2_t, vfloat64m2_t], (add V26M2, V28M2, V30M2, V8M2, V10M2, V12M2, V14M2, V16M2, V18M2, V20M2, V22M2, V24M2, V0M2, V2M2, V4M2, V6M2), 2>; def VRM2NoV0 : VReg<[vint8m2_t, vint16m2_t, vint32m2_t, vint64m2_t, - vfloat32m2_t, vfloat64m2_t], + vfloat16m2_t, vfloat32m2_t, vfloat64m2_t], (add V26M2, V28M2, V30M2, V8M2, V10M2, V12M2, V14M2, V16M2, V18M2, V20M2, V22M2, V24M2, V2M2, V4M2, V6M2), 2>; def VRM4 : VReg<[vint8m4_t, vint16m4_t, vint32m4_t, vint64m4_t, - vfloat32m4_t, vfloat64m4_t], + vfloat16m4_t, vfloat32m4_t, vfloat64m4_t], (add V28M4, V8M4, V12M4, V16M4, V20M4, V24M4, V0M4, V4M4), 4>; def VRM4NoV0 : VReg<[vint8m4_t, vint16m4_t, vint32m4_t, vint64m4_t, - vfloat32m4_t, vfloat64m4_t], + vfloat16m4_t, vfloat32m4_t, vfloat64m4_t], (add V28M4, V8M4, V12M4, V16M4, V20M4, V24M4, V4M4), 4>; def VRM8 : VReg<[vint8m8_t, vint16m8_t, vint32m8_t, vint64m8_t, - vfloat32m8_t, vfloat64m8_t], + vfloat16m8_t, vfloat32m8_t, vfloat64m8_t], (add V8M8, V16M8, V24M8, V0M8), 8>; def VRM8NoV0 : VReg<[vint8m8_t, vint16m8_t, vint32m8_t, vint64m8_t, - vfloat32m8_t, vfloat64m8_t], + vfloat16m8_t, vfloat32m8_t, vfloat64m8_t], (add V8M8, V16M8, V24M8), 8>; defvar VMaskVTs = [vbool64_t, vbool32_t, vbool16_t, vbool8_t,