Skip to content

Commit

Permalink
amd/compiler: split vectors only locally if it is not safe to reuse t…
Browse files Browse the repository at this point in the history
…he result
  • Loading branch information
daniel-schuermann committed Jul 12, 2019
1 parent e711773 commit 9d29e77
Showing 1 changed file with 49 additions and 64 deletions.
113 changes: 49 additions & 64 deletions src/amd/compiler/aco_instruction_selection.cpp
Expand Up @@ -575,18 +575,13 @@ void emit_bcsel(isel_context *ctx, nir_alu_instr *instr, Temp dst)

bld.vop2(aco_opcode::v_cndmask_b32, Definition(dst), els, then, cond);
} else if (dst.size() == 2) {
emit_split_vector(ctx, then, 2);
emit_split_vector(ctx, els, 2);
Temp then_lo = bld.tmp(v1), then_hi = bld.tmp(v1);
bld.pseudo(aco_opcode::p_split_vector, Definition(then_lo), Definition(then_hi), then);
Temp else_lo = bld.tmp(v1), else_hi = bld.tmp(v1);
bld.pseudo(aco_opcode::p_split_vector, Definition(else_lo), Definition(else_hi), els);

Temp dst0 = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1),
emit_extract_vector(ctx, els, 0, v1),
emit_extract_vector(ctx, then, 0, v1),
cond);

Temp dst1 = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1),
emit_extract_vector(ctx, els, 1, v1),
emit_extract_vector(ctx, then, 1, v1),
cond);
Temp dst0 = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), else_lo, then_lo, cond);
Temp dst1 = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), else_hi, then_hi, cond);

bld.pseudo(aco_opcode::p_create_vector, Definition(dst), dst0, dst1);
} else {
Expand Down Expand Up @@ -993,12 +988,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
}

assert(src0.size() == 2 && src1.size() == 2);
emit_split_vector(ctx, src0, 2);
emit_split_vector(ctx, src1, 2);
Temp src00 = emit_extract_vector(ctx, src0, 0, RegClass(src0.type(), 1));
Temp src10 = emit_extract_vector(ctx, src1, 0, RegClass(src1.type(), 1));
Temp src01 = emit_extract_vector(ctx, src0, 1, RegClass(dst.type(), 1));
Temp src11 = emit_extract_vector(ctx, src1, 1, RegClass(dst.type(), 1));
Temp src00 = bld.tmp(src0.type(), 1);
Temp src01 = bld.tmp(dst.type(), 1);
bld.pseudo(aco_opcode::p_split_vector, Definition(src00), Definition(src01), src0);
Temp src10 = bld.tmp(src1.type(), 1);
Temp src11 = bld.tmp(dst.type(), 1);
bld.pseudo(aco_opcode::p_split_vector, Definition(src10), Definition(src11), src1);

if (dst.regClass() == s2) {
Temp carry = bld.tmp(s1);
Expand Down Expand Up @@ -1061,12 +1056,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}

emit_split_vector(ctx, src0, 2);
emit_split_vector(ctx, src1, 2);
Temp src00 = emit_extract_vector(ctx, src0, 0, RegClass(src0.type(), 1));
Temp src10 = emit_extract_vector(ctx, src1, 0, RegClass(src1.type(), 1));
Temp src01 = emit_extract_vector(ctx, src0, 1, RegClass(dst.type(), 1));
Temp src11 = emit_extract_vector(ctx, src1, 1, RegClass(dst.type(), 1));
Temp src00 = bld.tmp(src0.type(), 1);
Temp src01 = bld.tmp(dst.type(), 1);
bld.pseudo(aco_opcode::p_split_vector, Definition(src00), Definition(src01), src0);
Temp src10 = bld.tmp(src1.type(), 1);
Temp src11 = bld.tmp(dst.type(), 1);
bld.pseudo(aco_opcode::p_split_vector, Definition(src10), Definition(src11), src1);
if (dst.regClass() == s2) {
Temp carry = bld.tmp(s1);
bld.sop2(aco_opcode::s_add_u32, bld.def(s1), bld.scc(Definition(carry)), src00, src10);
Expand Down Expand Up @@ -1097,12 +1092,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}

emit_split_vector(ctx, src0, 2);
emit_split_vector(ctx, src1, 2);
Temp src00 = emit_extract_vector(ctx, src0, 0, RegClass(src0.type(), 1));
Temp src10 = emit_extract_vector(ctx, src1, 0, RegClass(src1.type(), 1));
Temp src01 = emit_extract_vector(ctx, src0, 1, RegClass(dst.type(), 1));
Temp src11 = emit_extract_vector(ctx, src1, 1, RegClass(dst.type(), 1));
Temp src00 = bld.tmp(src0.type(), 1);
Temp src01 = bld.tmp(dst.type(), 1);
bld.pseudo(aco_opcode::p_split_vector, Definition(src00), Definition(src01), src0);
Temp src10 = bld.tmp(src1.type(), 1);
Temp src11 = bld.tmp(dst.type(), 1);
bld.pseudo(aco_opcode::p_split_vector, Definition(src10), Definition(src11), src1);
if (dst.regClass() == s2) {
Temp carry = bld.tmp(s1);
Temp dst0 = bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.scc(Definition(carry)), src00, src10);
Expand Down Expand Up @@ -1133,12 +1128,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}

emit_split_vector(ctx, src0, 2);
emit_split_vector(ctx, src1, 2);
Temp src00 = emit_extract_vector(ctx, src0, 0, RegClass(src0.type(), 1));
Temp src10 = emit_extract_vector(ctx, src1, 0, RegClass(src1.type(), 1));
Temp src01 = emit_extract_vector(ctx, src0, 1, RegClass(dst.type(), 1));
Temp src11 = emit_extract_vector(ctx, src1, 1, RegClass(dst.type(), 1));
Temp src00 = bld.tmp(src0.type(), 1);
Temp src01 = bld.tmp(dst.type(), 1);
bld.pseudo(aco_opcode::p_split_vector, Definition(src00), Definition(src01), src0);
Temp src10 = bld.tmp(src1.type(), 1);
Temp src11 = bld.tmp(dst.type(), 1);
bld.pseudo(aco_opcode::p_split_vector, Definition(src10), Definition(src11), src1);
if (dst.regClass() == s2) {
Temp borrow = bld.tmp(s1);
bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.scc(Definition(borrow)), src00, src10);
Expand Down Expand Up @@ -1668,9 +1663,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
Temp borrow = emit_v_sub32(ctx, exponent, Operand(63u), Operand(exponent), true);
mantissa = bld.vop3(aco_opcode::v_lshrrev_b64, bld.def(v2), exponent, mantissa);
Temp saturate = bld.vop1(aco_opcode::v_bfrev_b32, bld.def(v1), Operand(0xfffffffeu));
emit_split_vector(ctx, mantissa, 2);
Temp lower = emit_extract_vector(ctx, mantissa, 0, v1);
Temp upper = emit_extract_vector(ctx, mantissa, 1, v1);
Temp lower = bld.tmp(v1), upper = bld.tmp(v1);
bld.pseudo(aco_opcode::p_split_vector, Definition(lower), Definition(upper), mantissa);
lower = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), lower, Operand(0xffffffffu), borrow);
upper = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), upper, saturate, borrow);
lower = bld.vop2(aco_opcode::v_xor_b32, bld.def(v1), sign, lower);
Expand All @@ -1696,9 +1690,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
Temp cond = bld.sopc(aco_opcode::s_cmp_eq_u32, bld.def(s1, scc), exponent, Operand(0xffffffffu)); // exp >= 64
Temp saturate = bld.sop1(aco_opcode::s_brev_b64, bld.def(s2), Operand(0xfffffffeu));
mantissa = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), saturate, mantissa, cond);
emit_split_vector(ctx, mantissa, 2);
Temp lower = emit_extract_vector(ctx, mantissa, 0, s1);
Temp upper = emit_extract_vector(ctx, mantissa, 1, s1);
Temp lower = bld.tmp(s1), upper = bld.tmp(s1);
bld.pseudo(aco_opcode::p_split_vector, Definition(lower), Definition(upper), mantissa);
lower = bld.sop2(aco_opcode::s_xor_b32, bld.def(s1), bld.def(s1, scc), sign, lower);
upper = bld.sop2(aco_opcode::s_xor_b32, bld.def(s1), bld.def(s1, scc), sign, upper);
Temp borrow = bld.tmp(s1);
Expand Down Expand Up @@ -1726,9 +1719,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
mantissa = bld.pseudo(aco_opcode::p_create_vector, bld.def(v2), Operand(0u), mantissa);
Temp cond_small = emit_v_sub32(ctx, exponent, Operand(exponent), Operand(24u), true);
mantissa = bld.vop3(aco_opcode::v_lshlrev_b64, bld.def(v2), exponent, mantissa);
emit_split_vector(ctx, mantissa, 2);
Temp lower = emit_extract_vector(ctx, mantissa, 0, v1);
Temp upper = emit_extract_vector(ctx, mantissa, 1, v1);
Temp lower = bld.tmp(v1), upper = bld.tmp(v1);
bld.pseudo(aco_opcode::p_split_vector, Definition(lower), Definition(upper), mantissa);
lower = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), lower, small, cond_small);
upper = bld.vop2_e64(aco_opcode::v_cndmask_b32, bld.def(v1), upper, Operand(0u), cond_small);
lower = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v1), Operand(0xffffffffu), lower, exponent_in_range);
Expand All @@ -1750,9 +1742,8 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
mantissa = bld.sop2(aco_opcode::s_lshl_b64, bld.def(s2), bld.def(s1, scc), mantissa, exponent_large);
Temp cond = bld.sopc(aco_opcode::s_cmp_ge_i32, bld.def(s1, scc), Operand(64u), exponent);
mantissa = bld.sop2(aco_opcode::s_cselect_b64, bld.def(s2), mantissa, Operand(0xffffffffu), cond);
emit_split_vector(ctx, mantissa, 2);
Temp lower = emit_extract_vector(ctx, mantissa, 0, s1);
Temp upper = emit_extract_vector(ctx, mantissa, 1, s1);
Temp lower = bld.tmp(s1), upper = bld.tmp(s1);
bld.pseudo(aco_opcode::p_split_vector, Definition(lower), Definition(upper), mantissa);
Temp cond_small = bld.sopc(aco_opcode::s_cmp_le_i32, bld.def(s1, scc), exponent, Operand(24u));
lower = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), small, lower, cond_small);
upper = bld.sop2(aco_opcode::s_cselect_b32, bld.def(s1), Operand(0u), upper, cond_small);
Expand Down Expand Up @@ -1869,19 +1860,18 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr)
break;
}
case nir_op_unpack_64_2x32_split_x:
case nir_op_unpack_64_2x32_split_y: {
Temp src = get_alu_src(ctx, instr->src[0]);
emit_split_vector(ctx, src, 2);
emit_extract_vector(ctx, src, instr->op == nir_op_unpack_64_2x32_split_x ? 0 : 1, dst);
bld.pseudo(aco_opcode::p_split_vector, Definition(dst), bld.def(dst.regClass()), get_alu_src(ctx, instr->src[0]));
break;
case nir_op_unpack_64_2x32_split_y:
bld.pseudo(aco_opcode::p_split_vector, bld.def(dst.regClass()), Definition(dst), get_alu_src(ctx, instr->src[0]));
break;
}
case nir_op_pack_half_2x16: {
Temp src = get_alu_src(ctx, instr->src[0], 2);
emit_split_vector(ctx, src, 2);

if (dst.regClass() == v1) {
Temp src0 = emit_extract_vector(ctx, src, 0, v1);
Temp src1 = emit_extract_vector(ctx, src, 1, v1);
Temp src0 = bld.tmp(v1);
Temp src1 = bld.tmp(v1);
bld.pseudo(aco_opcode::p_split_vector, Definition(src0), Definition(src1), src);
bld.vop3(aco_opcode::v_cvt_pkrtz_f16_f32, Definition(dst), src0, src1);

} else {
Expand Down Expand Up @@ -4820,10 +4810,9 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
}
case nir_intrinsic_load_barycentric_at_offset: {
Temp offset = get_ssa_temp(ctx, instr->src[0].ssa);
emit_split_vector(ctx, offset, 2);
RegClass rc = RegClass(offset.type(), 1);
Temp pos1 = emit_extract_vector(ctx, offset, 0, rc);
Temp pos2 = emit_extract_vector(ctx, offset, 1, rc);
Temp pos1 = bld.tmp(rc), pos2 = bld.tmp(rc);
bld.pseudo(aco_opcode::p_split_vector, Definition(pos1), Definition(pos2), offset);
emit_interp_center(ctx, get_ssa_temp(ctx, &instr->dest.ssa), pos1, pos2);
break;
}
Expand Down Expand Up @@ -5308,11 +5297,10 @@ void visit_intrinsic(isel_context *ctx, nir_intrinsic_instr *instr)
}
case nir_intrinsic_mbcnt_amd: {
Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
emit_split_vector(ctx, src, 2);
RegClass rc = RegClass(src.type(), 1);
Temp mask_lo = bld.as_uniform(emit_extract_vector(ctx, src, 0, rc));
Temp mask_lo = bld.tmp(rc), mask_hi = bld.tmp(rc);
bld.pseudo(aco_opcode::p_split_vector, Definition(mask_lo), Definition(mask_hi), src);
Temp tmp = bld.vop3(aco_opcode::v_mbcnt_lo_u32_b32, bld.def(v1), mask_lo, Operand(0u));
Temp mask_hi = bld.as_uniform(emit_extract_vector(ctx, src, 1, rc));
Temp dst = get_ssa_temp(ctx, &instr->dest.ssa);
Temp wqm_tmp = bld.vop3(aco_opcode::v_mbcnt_hi_u32_b32, bld.def(v1), mask_hi, tmp);
emit_wqm(ctx, wqm_tmp, dst);
Expand Down Expand Up @@ -5445,8 +5433,6 @@ void prepare_cube_coords(isel_context *ctx, Temp* coords, Temp* ddx, Temp* ddy,
{
Builder bld(ctx->program, ctx->block);
Temp coord_args[4], ma, tc, sc, id;
aco_ptr<Instruction> tmp;
emit_split_vector(ctx, *coords, is_array ? 4 : 3);
for (unsigned i = 0; i < (is_array ? 4 : 3); i++)
coord_args[i] = emit_extract_vector(ctx, *coords, i, v1);

Expand Down Expand Up @@ -5512,7 +5498,6 @@ void prepare_cube_coords(isel_context *ctx, Temp* coords, Temp* ddx, Temp* ddy,
Temp apply_round_slice(isel_context *ctx, Temp coords, unsigned idx)
{
Temp coord_vec[3];
emit_split_vector(ctx, coords, coords.size());
for (unsigned i = 0; i < coords.size(); i++)
coord_vec[i] = emit_extract_vector(ctx, coords, i, v1);

Expand Down

0 comments on commit 9d29e77

Please sign in to comment.