In [1]:
import torch
from transformer.Models import Transformer, LowRankTransformer

In [2]:
opt_src_vocab_size=9521
opt_trg_vocab_size=9521
opt_src_pad_idx=1
opt_trg_pad_idx=1
opt_proj_share_weight=True
opt_embs_share_weight=True
opt_d_k=64
opt_d_v=64
opt_d_model=512
opt_d_word_vec=512
opt_d_inner_hid=2048
opt_n_layers=6
opt_n_head=8
opt_dropout=0.1

In [3]:
lr_transformer = LowRankTransformer(
    opt_src_vocab_size,
    opt_trg_vocab_size,
    src_pad_idx=opt_src_pad_idx,
    trg_pad_idx=opt_trg_pad_idx,
    trg_emb_prj_weight_sharing=opt_proj_share_weight,
    emb_src_trg_weight_sharing=opt_embs_share_weight,
    d_k=opt_d_k,
    d_v=opt_d_v,
    d_model=opt_d_model,
    d_word_vec=opt_d_word_vec,
    d_inner=opt_d_inner_hid,
    n_layers=opt_n_layers,
    n_head=opt_n_head,
    dropout=opt_dropout)

In [4]:
for param_index, (param_name, param) in enumerate(lr_transformer.state_dict().items()):
    print(param_index, param_name, param.size())

0 encoder.src_word_emb.weight torch.Size([9521, 512])
1 encoder.position_enc.pos_table torch.Size([1, 200, 512])
2 encoder.layer_stack.0.slf_attn.w_qs.weight torch.Size([512, 512])
3 encoder.layer_stack.0.slf_attn.w_ks.weight torch.Size([512, 512])
4 encoder.layer_stack.0.slf_attn.w_vs.weight torch.Size([512, 512])
5 encoder.layer_stack.0.slf_attn.fc.weight torch.Size([512, 512])
6 encoder.layer_stack.0.slf_attn.layer_norm.weight torch.Size([512])
7 encoder.layer_stack.0.slf_attn.layer_norm.bias torch.Size([512])
8 encoder.layer_stack.0.pos_ffn.w_1.weight torch.Size([2048, 512])
9 encoder.layer_stack.0.pos_ffn.w_1.bias torch.Size([2048])
10 encoder.layer_stack.0.pos_ffn.w_2.weight torch.Size([512, 2048])
11 encoder.layer_stack.0.pos_ffn.w_2.bias torch.Size([512])
12 encoder.layer_stack.0.pos_ffn.layer_norm.weight torch.Size([512])
13 encoder.layer_stack.0.pos_ffn.layer_norm.bias torch.Size([512])
14 encoder.layer_stack.1.slf_attn.w_qs_u.weight torch.Size([128, 512])
15 encoder.layer_st

In [5]:
transformer = Transformer(
    opt_src_vocab_size,
    opt_trg_vocab_size,
    src_pad_idx=opt_src_pad_idx,
    trg_pad_idx=opt_trg_pad_idx,
    trg_emb_prj_weight_sharing=opt_proj_share_weight,
    emb_src_trg_weight_sharing=opt_embs_share_weight,
    d_k=opt_d_k,
    d_v=opt_d_v,
    d_model=opt_d_model,
    d_word_vec=opt_d_word_vec,
    d_inner=opt_d_inner_hid,
    n_layers=opt_n_layers,
    n_head=opt_n_head,
    dropout=opt_dropout)

In [6]:
for param_index, (param_name, param) in enumerate(transformer.state_dict().items()):
    print(param_index, param_name, param.size())

0 encoder.src_word_emb.weight torch.Size([9521, 512])
1 encoder.position_enc.pos_table torch.Size([1, 200, 512])
2 encoder.layer_stack.0.slf_attn.w_qs.weight torch.Size([512, 512])
3 encoder.layer_stack.0.slf_attn.w_ks.weight torch.Size([512, 512])
4 encoder.layer_stack.0.slf_attn.w_vs.weight torch.Size([512, 512])
5 encoder.layer_stack.0.slf_attn.fc.weight torch.Size([512, 512])
6 encoder.layer_stack.0.slf_attn.layer_norm.weight torch.Size([512])
7 encoder.layer_stack.0.slf_attn.layer_norm.bias torch.Size([512])
8 encoder.layer_stack.0.pos_ffn.w_1.weight torch.Size([2048, 512])
9 encoder.layer_stack.0.pos_ffn.w_1.bias torch.Size([2048])
10 encoder.layer_stack.0.pos_ffn.w_2.weight torch.Size([512, 2048])
11 encoder.layer_stack.0.pos_ffn.w_2.bias torch.Size([512])
12 encoder.layer_stack.0.pos_ffn.layer_norm.weight torch.Size([512])
13 encoder.layer_stack.0.pos_ffn.layer_norm.bias torch.Size([512])
14 encoder.layer_stack.1.slf_attn.w_qs.weight torch.Size([512, 512])
15 encoder.layer_stac

In [18]:
def decompose_vanilla_model(vanilla_model, low_rank_model, rank_ratio=0.25):
    collected_weights = []
    for p_index, (name, param) in enumerate(vanilla_model.state_dict().items()):
        if len(param.size()) == 2 and p_index not in range(0, 14) and p_index not in range(76, 96) and p_index != 188:
            rank = min(param.size()[0], param.size()[1])
            sliced_rank = int(rank * rank_ratio)
            u, s, v = torch.svd(param)
            u_weight = u * torch.sqrt(s)
            v_weight = torch.sqrt(s) * v
            u_weight_sliced, v_weight_sliced = u_weight[:, 0:sliced_rank], v_weight[:, 0:sliced_rank]
            #collected_weights.append(u_weight_sliced)
            collected_weights.append(v_weight_sliced.t())
            collected_weights.append(u_weight_sliced)
        else:
            collected_weights.append(param)
            
    #for cw_index, cw in enumerate(collected_weights):
    #     print("cw_index: {}, cw: {}".format(cw_index, cw.size()))
         
    reconstructed_state_dict = {}
    model_counter = 0
    for p_index, (name, param) in enumerate(low_rank_model.state_dict().items()):
        #print("p_index: {}, name: {}, param size: {}, collected weight size: {}".format(p_index,
        #                                                                                name,
        #                                                                                param.size(), collected_weights[model_counter].size()))
        assert param.size() == collected_weights[model_counter].size()
        reconstructed_state_dict[name] = collected_weights[model_counter]
        model_counter += 1
    low_rank_model.load_state_dict(reconstructed_state_dict)
    return low_rank_model

In [19]:
a = torch.randn(2048, 512)
print(a.size())
rank = min(a.size()[0], a.size()[1])
sliced_rank = int(rank * 0.25)
u, s, v = torch.svd(a)
print("dist: {}".format(torch.dist(a, torch.mm(torch.mm(u, torch.diag(s)), v.t()))))
print("u size: {}, s size: {}, v size: {}".format(u.size(), s.size(), v.size()))
u_weight = u * torch.sqrt(s)
v_weight = torch.sqrt(s) * v
u_weight_sliced, v_weight_sliced = u_weight[:, 0:sliced_rank], v_weight[:, 0:sliced_rank]
print("u weight slided: {}, v weight sliced: {}".format(u_weight_sliced.size(), v_weight_sliced.size()))
print("dist approx: {}".format(torch.dist(a, torch.mm(u_weight_sliced, v_weight_sliced.t()))))

torch.Size([2048, 512])
dist: 0.0017176233232021332
u size: torch.Size([2048, 512]), s size: torch.Size([512]), v size: torch.Size([512, 512])
u weight slided: torch.Size([2048, 128]), v weight sliced: torch.Size([512, 128])
dist approx: 775.8424682617188


In [20]:
loaded_lr_model = decompose_vanilla_model(vanilla_model=transformer, low_rank_model=lr_transformer, rank_ratio=0.25)

cw_index: 0, cw: torch.Size([9521, 512])
cw_index: 1, cw: torch.Size([1, 200, 512])
cw_index: 2, cw: torch.Size([512, 512])
cw_index: 3, cw: torch.Size([512, 512])
cw_index: 4, cw: torch.Size([512, 512])
cw_index: 5, cw: torch.Size([512, 512])
cw_index: 6, cw: torch.Size([512])
cw_index: 7, cw: torch.Size([512])
cw_index: 8, cw: torch.Size([2048, 512])
cw_index: 9, cw: torch.Size([2048])
cw_index: 10, cw: torch.Size([512, 2048])
cw_index: 11, cw: torch.Size([512])
cw_index: 12, cw: torch.Size([512])
cw_index: 13, cw: torch.Size([512])
cw_index: 14, cw: torch.Size([128, 512])
cw_index: 15, cw: torch.Size([512, 128])
cw_index: 16, cw: torch.Size([128, 512])
cw_index: 17, cw: torch.Size([512, 128])
cw_index: 18, cw: torch.Size([128, 512])
cw_index: 19, cw: torch.Size([512, 128])
cw_index: 20, cw: torch.Size([128, 512])
cw_index: 21, cw: torch.Size([512, 128])
cw_index: 22, cw: torch.Size([512])
cw_index: 23, cw: torch.Size([512])
cw_index: 24, cw: torch.Size([128, 512])
cw_index: 25, cw: 