In [1]:
import torch
import numpy
from ptflops import get_model_complexity_info
import models



In [2]:
low_rank_model = models.__dict__["hybrid_resnet50_extra_bns"]()

In [3]:
vanilla_model = models.__dict__["resnet50"]()

In [4]:
acs, params = get_model_complexity_info(vanilla_model, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=False, verbose=False)
print("acs: {}, params: {}".format(acs, params))

acs: 4.12 GMac, params: 25.56 M


In [5]:
acs, params = get_model_complexity_info(low_rank_model, (3, 224, 224), as_strings=True,
                                           print_per_layer_stat=False, verbose=False)
print("acs: {}, params: {}".format(acs, params))

acs: 3.61 GMac, params: 15.2 M


In [6]:
def decompose_weights(model, low_rank_model, rank_factor):
    # SVD version
    reconstructed_aggregator = []

    for item_index, (param_name, param) in enumerate(model.state_dict().items()):
        #if len(param.size()) == 4 and item_index not in range(0, 258) and "downsample" not in param_name and "conv3" not in param_name:
        if len(param.size()) == 4 and item_index not in range(0, 258):
            # resize --> svd --> two layer
            param_reshaped = param.view(param.size()[0], -1)
            rank = min(param_reshaped.size()[0], param_reshaped.size()[1])
            u, s, v = torch.svd(param_reshaped)

            sliced_rank = int(rank/rank_factor)
            u_weight = u * torch.sqrt(s)
            v_weight = torch.sqrt(s) * v
            u_weight_sliced, v_weight_sliced = u_weight[:, 0:sliced_rank], v_weight[:, 0:sliced_rank]

            u_weight_sliced_shape, v_weight_sliced_shape = u_weight_sliced.size(), v_weight_sliced.size()

            #model_weight_v = u_weight.view(u_weight_sliced_shape[0],
            model_weight_v = u_weight_sliced.view(u_weight_sliced_shape[0],
                                                  u_weight_sliced_shape[1], 1, 1)
            
            #model_weight_u = v_weight.t().view(v_weight_sliced_shape[1], 
            model_weight_u = v_weight_sliced.t().view(v_weight_sliced_shape[1], 
                                                      param.size()[1], 
                                                      param.size()[2], 
                                                      param.size()[3])

            #if "downsample" in param_name:
            #    print("@@@@ U size: {}, V size: {}".format(model_weight_u.size(), model_weight_v.size()))
            reconstructed_aggregator.append(model_weight_u)
            reconstructed_aggregator.append(model_weight_v)
        else:
            reconstructed_aggregator.append(param)
            
    
    #for ra_index, ra in enumerate(reconstructed_aggregator):
    #    print("ra index: {}, ra size: {}".format(ra_index, ra.size()))
            
    model_counter = 0
    reload_state_dict = {}
    for item_index, (param_name, param) in enumerate(low_rank_model.state_dict().items()):
        print("#### {}, {}, recons agg: {}， param: {}".format(item_index, param_name, 
                                                                                reconstructed_aggregator[model_counter].size(),
                                                                               param.size()))

        if "bn1_u" in param_name or "bn2_u" in param_name or "bn3_u" in param_name:
            reload_state_dict[param_name] = param
        else:
            assert (reconstructed_aggregator[model_counter].size() == param.size())
            reload_state_dict[param_name] = reconstructed_aggregator[model_counter]
            model_counter += 1            
    low_rank_model.load_state_dict(reload_state_dict)
    return low_rank_model

In [7]:
decompose_weights(model=vanilla_model, low_rank_model=low_rank_model, rank_factor=4)

ra index: 0, ra size: torch.Size([64, 3, 7, 7])
ra index: 1, ra size: torch.Size([64])
ra index: 2, ra size: torch.Size([64])
ra index: 3, ra size: torch.Size([64])
ra index: 4, ra size: torch.Size([64])
ra index: 5, ra size: torch.Size([])
ra index: 6, ra size: torch.Size([64, 64, 1, 1])
ra index: 7, ra size: torch.Size([64])
ra index: 8, ra size: torch.Size([64])
ra index: 9, ra size: torch.Size([64])
ra index: 10, ra size: torch.Size([64])
ra index: 11, ra size: torch.Size([])
ra index: 12, ra size: torch.Size([64, 64, 3, 3])
ra index: 13, ra size: torch.Size([64])
ra index: 14, ra size: torch.Size([64])
ra index: 15, ra size: torch.Size([64])
ra index: 16, ra size: torch.Size([64])
ra index: 17, ra size: torch.Size([])
ra index: 18, ra size: torch.Size([256, 64, 1, 1])
ra index: 19, ra size: torch.Size([256])
ra index: 20, ra size: torch.Size([256])
ra index: 21, ra size: torch.Size([256])
ra index: 22, ra size: torch.Size([256])
ra index: 23, ra size: torch.Size([])
ra index: 24, 

HybridResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), strid