#include "wt_data.h" #include #include #include #include "backend_utils.h" #include "cmd_node.h" #include "common/basic.h" #include "common/hack.h" #include "common/logging.h" #include "config.h" #include "hw_utils_basic.h" #define DEF_LOG_CATEG LogCateg::Weight #define WT_PRELU_BITW 16 #define WT_BIAS_BITW 16 #define WT_4B_MAX 7 #define WT_4B_MIN -8 #define DEF_PCONV_BIAS 0 #define GET_PCONV_VAL(vec, out_ch) (vec.empty() ? 0 : vec.at(out_ch)) using namespace std; template <> const std::map &EnumStrMap() { static const map str_map = { {BlockType::UNKNOWN, "unknown"}, {BlockType::CONV, "conv"}, {BlockType::PCONV, "pconv"}, {BlockType::PFUNC, "pfunc"}, }; return str_map; } template <> const std::map &EnumStrMap() { static const map str_map = { {WtSeg::NONE, "none"}, {WtSeg::WHOLE_4b, "4b"}, {WtSeg::WHOLE_8b, "8b"}, {WtSeg::WHOLE_16b, "16b"}, {WtSeg::LL_16b, "16bLL"}, {WtSeg::LH_16b, "16bLH"}, {WtSeg::HL_16b, "16bHL"}, {WtSeg::HH_16b, "16bHH"}, }; return str_map; } using KrPosMap = vector>; /* 6 3 0 * 7 4 1 * 8 5 2 */ static const KrPosMap WT_3X3_ORDER = { {0, 2}, {1, 2}, {2, 2}, {0, 1}, {1, 1}, {2, 1}, {0, 0}, {1, 0}, {2, 0}, }; /* 2 1 0 */ static const KrPosMap WT_1X3_ORDER = {{0, 2}, {0, 1}, {0, 0}}; /* ########################### * ## Local Functions ## * ########################### */ namespace { inline bool ValidInRow(int row, const CommandNode &cnode, const Shape3D &shape) { int in_str_row = cnode.GetInRowStr(); return (in_str_row <= row && row < in_str_row + (int)shape.row); } inline bool ValidInCh(int ch, const CommandNode &cnode, const Shape3D &shape) { int ich_str = SkipWtInChnlGrpExp(cnode.conv_mode) ? 0 : cnode.GetInChStr(); return (ich_str <= ch && ch < ich_str + (int)shape.channel); } template inline bool ValidKrWtPos(T &pos, int ksize_h, int ksize_w) { return (pos.first < ksize_h) && (pos.second < ksize_w); } inline size_t GetFcRawWtIdx(int r, int c, int ch, int f, int row, int col, int chnl, int feat) { assert((r < row) && (c < col) && (ch < chnl)); return (f * row * col * chnl) + (ch * row * col) + (r * col) + c; } inline fp_t GetFcRawWt(const vector &raw_wt, int r, int c, int ch, int f, int row, int col, int chnl, int feat) { size_t i = GetFcRawWtIdx(r, c, ch, f, row, col, chnl, feat); return raw_wt.at(i); } void ConvertBn(vector *gamma, vector *beta, vector *mean, vector *var, vector &update_gamma_w, vector &update_beta_w, float eplison) { float converted_scale = 0, converted_bias = 0; for (size_t i = 0; i < (*gamma).size(); ++i) { converted_scale = BatchNormalizationNode::GetConvertedScale(eplison, (*gamma)[i], (*var)[i]); update_gamma_w.push_back(converted_scale); converted_bias = BatchNormalizationNode::GetConvertedBias((*beta)[i], (*mean)[i], converted_scale); update_beta_w.push_back(converted_bias); } } inline pair GetWtKernelCnt(const CommandNode &cnode) { auto type = cnode.conv_type; return (type != ConvolutionType::CONV2D) ? pair({1, 1}) : GetKernelCnt(cnode); } inline int GetWtRadix(const vector &radix, int out_ch, int wt_bitw) { int frac_bit_shift = (wt_bitw == 4) ? 4 : 0; return out_ch < (int)radix.size() ? (radix.at(out_ch) - frac_bit_shift) : 0; } inline bool NeedTransFcWeight(const CommandNode &cnode) { return !cnode.trans_b; } /** * @brief Re-order raw FC weight based on re-interpreted input FMAP */ vector *TransFcWeight(const CommandNode &cnode, const vector &wt) { auto &s = cnode.full_in_shape; size_t feat = cnode.full_out_shape.channel; vector *new_wt = new vector(wt.size()); size_t cnt = s.row * s.column * s.channel; size_t idx = 0; for (size_t f = 0; f < feat; ++f) { for (size_t i = 0; i < cnt; ++i) { (*new_wt)[idx++] = wt[i * feat + f]; } } return new_wt; } pair GetConvWtIchRange(const WtBlock &block) { auto &info = block.info; if (SkipWtInChnlGrpExp(info.conv_mode)) { return {0, info.ich_per_grp}; } else { return {block.ich_str, block.ich_end}; } } inline KrPosMap GetKrUnitPosMap(CONV_MODE mode) { map pos_map = { {CONV_MODE::CM_3x3_DW, WT_3X3_ORDER}, {CONV_MODE::CM_3x3_1CH, WT_3X3_ORDER}, {CONV_MODE::CM_1x3_4CH, WT_1X3_ORDER}, }; return pos_map.at(mode); } inline fp_t GetValFromVec(vector vals, int ch, int def_val = 0) { if (ch < (int)vals.size()) { return vals[ch]; } else if (vals.size() > 0) { return vals[0]; } else { return def_val; } } }; /* anonymous namespace */ /* ###################### * ## WtPackInfo ## * ###################### */ WtPackInfo::WtPackInfo(const CommandNode &cnode) { conv_mode = cnode.conv_mode; pfunc_mode = cnode.pfunc_mode; pool_mode = cnode.pool_mode; pool_type = cnode.pool_type; relu_type = cnode.relu_type; std::tie(ich_grp, ich_per_grp, och_grp, och_per_grp) = WtPackInfo::GetInOutChnlGrp(cnode, cnode.in_shape.channel, cnode.out_shape.channel); kcnt = GetWtKernelCnt(cnode); } string WtPackInfo::Dump() const { return fmt::format("och_per_grp: %d (x%d), ich_per_grp: %d (x%d)", och_per_grp, och_grp, ich_per_grp, ich_grp); } std::tuple WtPackInfo::GetInOutChnlGrp(const CommandNode &cnode, int in_chnl, int out_chnl) { auto conv_mode = cnode.conv_mode; auto pfunc_mode = cnode.pfunc_mode; ConvModeParam p = GetConvModeParam(conv_mode); int ich_grp = 0; int och_grp = 0; int ich_per_grp = 0; int och_per_grp = 0; if (cnode.IsPFuncNode()) { ich_per_grp = 1; ich_grp = 1; och_per_grp = GetPFuncOutChUnit(pfunc_mode); och_grp = ceil(cnode.out_shape.channel / (double)och_per_grp); } else { switch (conv_mode) { case CONV_MODE::CM_BYPASS_FM: case CONV_MODE::CM_3x3_DW: case CONV_MODE::CM_BATCH_NORM: case CONV_MODE::CM_PRODUCT: case CONV_MODE::CM_SQUARE: case CONV_MODE::CM_UP_SAMPLE_H: case CONV_MODE::CM_ADD: ich_per_grp = 1; ich_grp = 1; break; default: ich_per_grp = p.in_ch; ich_grp = ceil(in_chnl / (double)ich_per_grp); break; } assert(cnode.GetOutChStr() % p.out_ch == 0); och_per_grp = p.out_ch; och_grp = ceil(out_chnl / (double)och_per_grp); } return {ich_grp, ich_per_grp, och_grp, och_per_grp}; } /* ################# * ## WtRaw ## * ################# */ void WtRaw::Update(OperationType op_type, uint32_t layer_index, const CommandNode &cnode) { float eplison = 1e-3; if (op_type == OperationType::LeakyRelu) { /* to make later code aligns to prelu case, add a vector of same values here */ Shape3D out_shape = cnode.full_out_shape; size_t size = out_shape.row * out_shape.column * out_shape.channel; float val = cnode.relu_coeff; this->leaky_relu_coef.resize(size, val); this->relu_w = &leaky_relu_coef; } else if (op_type == OperationType::BatchNormalization) { eplison = cnode.bn_eplison; } vector *gamma = NULL; vector *beta = NULL; vector *mean = NULL; vector *var = NULL; vector wnodes = cnode.wt_list[layer_index]; for (size_t i = 0; i < wnodes.size(); i++) { WeightNode *wnode = wnodes[i]; assert(op_type != OperationType::ConvTranspose); // TODO: check because there should not // have deconv after hw modification if (op_type == OperationType::Conv || op_type == OperationType::Gemm) { if (i == 0) { this->wt = &(wnode->value); } else if (i == 1) { this->bias = &(wnode->value); } } else if (op_type == OperationType::BatchNormalization) { if (i == 0) { gamma = &(wnode->value); } else if (i == 1) { beta = &(wnode->value); } else if (i == 2) { mean = &(wnode->value); } else if (i == 3) { var = &(wnode->value); ConvertBn(gamma, beta, mean, var, this->bn_gamma, this->bn_beta, eplison); this->wt = &this->bn_gamma; this->bias = &this->bn_beta; } } else if (op_type == OperationType::PRelu) { this->relu_w = &(wnode->value); } } } // kkk #define ConvertFpN(...) // kkk /** * @brief convert weight from float to fp_t */ void WtRaw::Convert(const CommandNode &cnode) { auto &fp_w = cnode.fp_w; CONV_MODE mode = cnode.conv_mode; int wt_bitw = cnode.GetWtBitw(); int out_ch = cnode.full_out_shape.channel; const vector *wt = this->wt; vector *tx_wt = nullptr; if (wt) { if (mode == CONV_MODE::CM_DENSE || mode == CONV_MODE::CM_DENSE_4B) { if (NeedTransFcWeight(cnode)) { wt = tx_wt = TransFcWeight(cnode, *wt); } } } for (int ch = 0; ch < out_ch; ++ch) { /* relu */ if (relu_w) { int idx = ch * relu_w->size() / out_ch; int bitw = WT_PRELU_BITW; int radix = GetWtRadix(fp_w.radix.relu, ch, bitw); fp_t val = ConvertFpN("relu", relu_w->at(idx), bitw, radix); relu_fp.push_back(val); } /* bias */ if (bias) { float val_f = bias->at(ch); int bitw = WT_BIAS_BITW; int radix = GetWtRadix(fp_w.radix.bias, ch, bitw); fp_t val = ConvertFpN("bias", val_f, bitw, radix); bias_fp.push_back(val); } /* wt */ if (wt) { size_t size = wt->size() / out_ch; int idx = ch * size; int bitw = wt_bitw; int radix = GetWtRadix(fp_w.radix.conv, ch, bitw); while (size--) { fp_t val = ConvertFpN("conv", wt->at(idx++), bitw, radix); wt_fp.push_back(val); } } } SAFE_DELETE(tx_wt); } /* ################## * ## WtPack ## * ################## */ WtPack::WtPack(const CommandNode &cnode) : list(), info(cnode) {} /* ##################### * ## WtPackOch ## * ##################### */ void WtPackOch::Merge(WtPackOch &other) { splice(end(), other); } /* ################### * ## WtBlock ## * ################### */ WtBlock::WtBlock(BlockType type, const WtPackInfo &info, const CommandNode &cnode, int och_str, int och_end, int ich_str, int ich_end) : vector(), type(type), och_str(och_str), och_end(och_end), ich_str(ich_str), ich_end(ich_end), info(info), fp_w(&cnode.fp_w) { /* wt_seg */ int wt_bitw = cnode.GetWtBitw(); int in_bitw = cnode.GetInBitw(); if (wt_bitw == 4) { auto mode = cnode.conv_mode; assert(mode == CONV_MODE::CM_1x1_4B || mode == CONV_MODE::CM_DENSE_4B); wt_seg = WtSeg::WHOLE_4b; } else if (in_bitw == 8) { wt_seg = WtSeg::WHOLE_8b; } else { assert(in_bitw == 15 || in_bitw == 16); wt_seg = SupportDirect16b(info.conv_mode) ? WtSeg::WHOLE_16b : WtSeg::LL_16b; } if (type == BlockType::LUT) { wt_seg = (in_bitw == 8) ? WtSeg::WHOLE_8b : WtSeg::WHOLE_16b; } } WtBlock::WtBlock(BlockType type, int ncore_switch) : type(type), ncore_switch(ncore_switch) { assert(type == BlockType::CONV_NCORE_SWITCH || type == BlockType::PCONV_NCORE_SWITCH); } void WtBlock::Merge(WtBlock &other) { ich_str = min(ich_str, other.ich_str); ich_end = max(ich_end, other.ich_end); std::move(other.begin(), other.end(), back_inserter(*this)); } inline bool IsBigKr(int ksize_h, int ksize_w, int kunit_h, int kunit_w) { return (kunit_h != ksize_h) || (ksize_w != kunit_w); } inline int GetWtBlkInChnlStep(CONV_MODE mode, int in_ch_before_cut) { return (mode == CONV_MODE::CM_3x3_DW) ? 1 : in_ch_before_cut; } void WtBlock::GenConv3x3Wt(const WtRaw &raw, const CommandNode &cnode, int k_r, int k_c) { int ksize_h = cnode.conv_kernel_size_h; int ksize_w = cnode.conv_kernel_size_w; auto kr_pos_map = GetKrUnitPosMap(cnode.conv_mode); auto [kr_unit_h, kr_unit_w] = GetKernelUnit(cnode.conv_mode); size_t wt_per_kr = kr_unit_h * kr_unit_w; bool bigk = IsBigKr(ksize_h, ksize_w, kr_unit_h, kr_unit_w); int ich_step = GetWtBlkInChnlStep(info.conv_mode, cnode.full_in_shape.channel); auto ich_range = GetConvWtIchRange(*this); int ich_str = ich_range.first; int ich_end = ich_range.second; auto &raw_wt = raw.wt_fp; int idx = 0; for (int out_ch = och_str; out_ch < och_end; ++out_ch) { if (!ValidOutCh(out_ch, cnode)) { break; } for (int in_ch = ich_str; in_ch < ich_end; ++in_ch) { if (!ValidInCh(in_ch, cnode, cnode.in_shape)) { idx += wt_per_kr; continue; } int raw_idx_base = (out_ch * ich_step + in_ch) * ksize_h * ksize_w; for (auto pos : kr_pos_map) { if (bigk) { pos.first += k_r * kr_unit_h; pos.second += k_c * kr_unit_w; if (!ValidKrWtPos(pos, ksize_h, ksize_w)) { idx++; continue; } } int raw_idx = raw_idx_base + pos.first * ksize_w + pos.second; this->at(idx++) = raw_wt.at(raw_idx); } } } } void WtBlock::GenConv1x1Wt(const WtRaw &raw, const CommandNode &cnode, int k_r, int k_c) { int ksize_h = cnode.conv_kernel_size_h; int ksize_w = cnode.conv_kernel_size_w; bool wt4b = (cnode.GetWtBitw() == 4); int ich_step = GetWtBlkInChnlStep(info.conv_mode, cnode.full_in_shape.channel); auto &raw_wt = raw.wt_fp; int idx = 0; for (int out_ch = och_str; out_ch < och_end; ++out_ch) { if (!ValidOutCh(out_ch, cnode)) { break; } for (int in_ch = ich_str; in_ch < ich_end; ++in_ch) { int ich_idx = GetEffectiveInCh(wt4b, in_ch); if (!ValidInCh(ich_idx, cnode, cnode.in_shape)) { idx++; continue; } int raw_idx = ((out_ch * ich_step + ich_idx) * ksize_h * ksize_w) + (k_r * ksize_w) + k_c; this->at(idx++) = raw_wt.at(raw_idx); } } } void WtBlock::GenBatchNormWt(const WtRaw &raw, const CommandNode &cnode, int k_r, int k_c) { auto &raw_wt = raw.wt_fp; int ksize_h = cnode.conv_kernel_size_h; int ksize_w = cnode.conv_kernel_size_w; int idx = 0; for (int out_ch = och_str; out_ch < och_end; ++out_ch) { if (!ValidOutCh(out_ch, cnode)) { break; } size_t raw_idx = (out_ch * ksize_h * ksize_w) + (k_r * ksize_w) + k_c; this->operator[](idx++) = raw_wt.at(raw_idx); } } void WtBlock::GenAddWt() { int idx = 0; int size = fp_w->add_bn1.size(); for (int ch = ich_str; ch < ich_end; ++ch) { this->at(idx++) = GetValFromVec(fp_w->add_bn1, ch, 0); } for (int ch = ich_str; ch < ich_end; ++ch) { this->at(idx++) = GetValFromVec(fp_w->add_bn2, ch, 0); } } void WtBlock::GenDenseWt(const WtRaw &raw, const CommandNode &cnode, int in_row_str, int in_row_end, int in_col_str, int in_col_end) { auto &raw_wt = raw.wt_fp; auto &shape = cnode.in_shape; int orig_och = cnode.full_out_shape.channel; bool wt4b = (cnode.GetWtBitw() == 4); int ich_step = ich_end - ich_str; int idx = 0; for (int r = in_row_str; r < in_row_end; ++r) { for (int c = in_col_str; c < in_col_end; ++c) { for (int och = och_str; och < och_end; ++och) { if (!ValidOutCh(och, cnode)) { idx += ich_step; continue; } for (int ich = ich_str; ich < ich_end; ++ich) { int ich_idx = GetEffectiveInCh(wt4b, ich); if (!ValidInCh(ich_idx, cnode, shape)) { idx++; continue; } this->at(idx++) = GetFcRawWt(raw_wt, r, c, ich_idx, och, shape.row, shape.column, shape.channel, orig_och); } } } } } void WtBlock::GenPConvWt(const WtRaw &raw, const CommandNode &cnode) { for (int out_ch = och_str; out_ch < och_end; ++out_ch) { if (ValidOutCh(out_ch, cnode)) { bias.push_back(GET_PCONV_VAL(raw.bias_fp, out_ch)); relu_coeff.push_back(GET_PCONV_VAL(raw.relu_fp, out_ch)); pconv_bias.push_back(GetValFromVec(fp_w->pconv_bias, out_ch, DEF_PCONV_BIAS)); } else { bias.push_back(0); relu_coeff.push_back(0); pconv_bias.push_back(0); } } } string WtBlock::Dump(int k_r, int k_c, int in_row_str, int in_row_end, int in_col_str, int in_col_end) const { string str = fmt::format("{}{} in_ch [{}-{}] out_ch [{}-{}] wt_num [{}]", Enum2Str(type), Enum2Str(wt_seg), ich_str, ich_end, och_str, och_end, size()); if (type == BlockType::PCONV) { assert(relu_coeff.size() == bias.size() && bias.size() == pconv_bias.size()); str += fmt::format(", pconv_p_size: {}", relu_coeff.size()); } else { assert(relu_coeff.empty() && bias.empty() && pconv_bias.empty()); } bool bigk = (info.kcnt.first > 1 || info.kcnt.second > 1); if (bigk) { str += fmt::format(", bigk_pos [{},{}]", k_r, k_c); } else if (info.conv_mode == CONV_MODE::CM_DENSE || info.conv_mode == CONV_MODE::CM_DENSE_4B) { str += fmt::format(", dense [in_row: {}-{}, in_col: {}-{}]", in_row_str, in_row_end, in_col_str, in_col_end); } if (compact) { str += ", compact"; } /* str += fmt::format(", data [{}]", Vals2Str(block)); */ return str; } /* ########################### * ## public function ## * ########################### */ bool ValidOutCh(int ch, const CommandNode &cnode) { int out_ch = cnode.out_shape.channel; int och_str = cnode.GetOutChStr(); return (och_str <= ch && ch < och_str + out_ch); } bool SkipWtInChnlGrpExp(CONV_MODE mode) { switch (mode) { case CONV_MODE::CM_BYPASS_FM: case CONV_MODE::CM_3x3_DW: case CONV_MODE::CM_BATCH_NORM: case CONV_MODE::CM_PRODUCT: case CONV_MODE::CM_SQUARE: case CONV_MODE::CM_ADD: return true; default: return false; } } int GetConvWtBlkSize(const CommandNode &cnode, bool compact, int in_row, int in_col, int out_ch, int ich_per_grp, int och_per_grp) { int wt_num; // per in_ch wt_num // only skip padded och weight for direct mode (see #10396) if (compact) { out_ch = och_per_grp; } switch (cnode.conv_mode) { case CONV_MODE::CM_3x3_DW: case CONV_MODE::CM_3x3_1CH: wt_num = (3 * 3) * ich_per_grp * out_ch; break; case CONV_MODE::CM_1x3_4CH: wt_num = (1 * 3) * ich_per_grp * out_ch; break; case CONV_MODE::CM_1x1: case CONV_MODE::CM_1x1_4B: case CONV_MODE::CM_BATCH_NORM: wt_num = 1 * ich_per_grp * out_ch; break; case CONV_MODE::CM_DENSE: case CONV_MODE::CM_DENSE_4B: { wt_num = in_row * in_col * ich_per_grp * och_per_grp; break; } case CONV_MODE::CM_ADD: /* 2 set BN sign 8bits (s2.6), one for each input */ wt_num = 2 * och_per_grp; break; default: wt_num = 0; break; } bool wt4b = (cnode.GetWtBitw() == 4); int wt_num_per_unit = wt4b ? (WT_UNIT_SIZE_BYTE * 2) : (WT_UNIT_SIZE_BYTE * 1); return PadToN(wt_num, wt_num_per_unit); } int GetEffectiveInCh(bool wt4b, int in_ch) { // put ich n & ich n+16 together in 4bits mode return (wt4b) ? (in_ch % 2) * 16 + (in_ch / 2) : in_ch; }