diff --git a/nn_meter/predictor/prediction/extract_feature.py b/nn_meter/predictor/prediction/extract_feature.py index 217021ac..68c47401 100644 --- a/nn_meter/predictor/prediction/extract_feature.py +++ b/nn_meter/predictor/prediction/extract_feature.py @@ -56,16 +56,17 @@ def get_predict_features(config): elif "concat" in op: # maximum 4 branches itensors = item["input_tensors"] inputh = itensors[0][1] - features = [inputh, len(itensors)] + #features = [inputh, len(itensors)] + features = [inputh] for it in itensors: - co = it[-1] + co = it[-2] features.append(co) - if len(features) < 6: - features = features + [0] * (6 - len(features)) - elif len(features) > 6: - nf = features[0:6] + if len(features) < 5: + features = features + [0] * (5 - len(features)) + elif len(features) > 5: + nf = features[0:5] features = nf - features[1] = 6 + #features[1] = 5 elif op in ["hswish"]: if "inputh" in item: inputh = item["inputh"]