In [None]:
import mindspore
from mindspore import ops
from mindnlp.transformers import AutoModelForSequenceClassification
from mindnlp.transformers.models.bert.modeling_bert import BertDualForSequenceClassification
from collections import OrderedDict
from mindnlp._legacy.hypercomplex.tensor_decomposition.hypercomplex_td import LinearTDLayer, decompose_linear_parameters, set_new_dict_names, calculate_parameters
import re

In [None]:
real_model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=2)
#create dual model
model = BertDualForSequenceClassification(real_model.config)
print('Parameters of model:', calculate_parameters(model))

In [None]:
params_dict = OrderedDict(model.parameters_and_names())
new_dict = OrderedDict()
model_name = "model"
layers_list = ["query", "key", "value", "dense"]
threshold = 0.2

In [None]:
for name, p in params_dict.items():
    needed_layer = False
    for ind in map(str, range(real_model.config.num_hidden_layers)):
        if ("." + str(ind) + ".") in name:
            needed_layer = True
            break
    needed_qkv = False
    for ind in layers_list:
        if ind in name:
            needed_qkv = True
            break
    if not needed_layer or not needed_qkv:
        new_dict[name] = p
        continue
    print (name, p.shape)

    if name.endswith("weight_x"):
        print('Considered layer:',name)
        print("\tCompressing of", name.replace(".weight_x",""),"...")
        wy_name = name.replace("_x","_y")
        w_y = params_dict[wy_name]
        param = ops.cat([ops.unsqueeze(p.T, -1), ops.unsqueeze(w_y.T, -1)], -1).asnumpy()

        p1, p2 = decompose_linear_parameters(param, threshold)
        rk = p1._width

        bx_name = name.replace("weight_x","bias_x")
        by_name = name.replace("weight_x","bias_y")
        b_x = None
        b_y = None
        bias_flag = False
        
        if bx_name in params_dict or by_name in params_dict:
            b_x = params_dict[bx_name]
            b_y = params_dict[by_name]
            
        set_new_dict_names(p1, p2, name, new_dict, b_x, b_y)
        mod_name = name.replace(".weight_x", "")
        mod_name = re.sub('\.([0-9]+)(\.)?', '[\\1]\\2', mod_name)
        
        if bx_name in params_dict or by_name in params_dict:
            bias_flag = (params_dict[bx_name] is not None) and (params_dict[by_name] is not None)

        op = LinearTDLayer(p.shape[1], p.shape[0], bias_flag, rk)
        op_name = (model_name + "." + mod_name).split(".", -1) [-1]
        op_prefix = (model_name + "." + mod_name).replace("." + op_name, "")
        setattr(eval(op_prefix), op_name, op)
    else:
        continue

In [None]:
#Saving and loading the obtained checkpoint
mindspore.save_checkpoint(new_dict, './dual_bert.ckpt')
param_dict = mindspore.load_checkpoint('./dual_bert.ckpt')
mindspore.load_param_into_net(model, param_dict)
print('Parameters of obtained model:', calculate_parameters(model))