In [None]:
import numpy as np
from copy import deepcopy

import torch
import torch.backends.cudnn as cudnn
from models.mobilenetv2 import *
from net_transform import proj_wider_cout_expansion_wider_cin


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = MobileNetV2()
net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True


In [None]:
checkpoint = torch.load('./checkpoint/ckpt.pth')
net.load_state_dict(checkpoint['net'])
best_acc = checkpoint['acc']
start_epoch = checkpoint['epoch']

In [None]:
from models.mobilenetv2_wider import MobileNetV2_wider

In [None]:
net_wider = MobileNetV2_wider()

In [None]:
bn1 = net.module.layers[12].bn1.cpu()

In [None]:
bn1.state_dict()['weight'].shape

In [None]:
bn = torch.nn.BatchNorm2d(10)

In [None]:
for k, v in bn.state_dict().items():
    print(k, v.shape)

In [None]:
conv_ = torch.nn.Conv2d(5, 10,3)
conv_.state_dict().keys()

In [None]:
list(net.module.layers[12].conv1.state_dict().values())[0].shape

In [None]:
list(net.module.layers[12].conv3.state_dict().values())[0].shape

In [None]:
list(net_wider.layers[13].conv3.state_dict().values())[0].shape

In [None]:
# weight 를 모아서 한번에 이식해야 함. c_in 만 늘리는 경우가 있기 때문에 그 때그때 넣어줄 수 가 없음

In [None]:
net_wider.layers[12].conv3.state_dict()['weight'].shape

In [None]:

for dict_ in net.module.layers[block_num].bn3.state_dict().items():
    print(dict_)


In [None]:
from net_transform import proj_wider_cout_expansion_wider_cin
parent_net = net.to("cpu")
child_net = net_wider
block_num = 12

child_l3_weights, child_bn3, child_l4_weights = [child_net.layers[block_num].conv3.state_dict(), child_net.layers[block_num].bn3.state_dict(), child_net.layers[block_num+1].conv1.state_dict()]
l3_weights, bn3, l4_weights = [parent_net.module.layers[block_num].conv3.state_dict(), parent_net.module.layers[block_num].bn3.state_dict(), parent_net.module.layers[block_num+1].conv1.state_dict()]
parent_weights = [child_l3_weights, child_bn3, child_l4_weights]
child_weights = [l3_weights, bn3, l4_weights]

new_weight_list = proj_wider_cout_expansion_wider_cin(child_weights, parent_weights)

In [None]:
child_net.layers[block_num].conv3.data = (new_weight_list[0])
child_net.layers[block_num].bn3.data = (new_weight_list[1])
child_net.layers[block_num+1].conv1.data = (new_weight_list[2])


In [None]:
parent_net.module.layers[block_num].conv3.state_dict()['weight'].numpy().shape


In [None]:
x = torch.rand(1,576,32,32)
parent_result = parent_net.module.layers[block_num].conv3.forward(x)
parent_result = parent_net.module.layers[block_num].bn3.forward(parent_result)
parent_result = parent_net.module.layers[block_num+1].conv1.forward(parent_result)
parent_result.shape

In [None]:
child_result = child_net.layers[block_num].conv3.forward(x)
child_result = child_net.layers[block_num].bn3.forward(child_result)
child_result = child_net.layers[block_num+1].conv1.forward(child_result)
child_result.shape

In [None]:
parent_net.module.layers[block_num].conv3.forward

In [None]:
child_net.layers[block_num].bn3.state_dict()['weight'].numpy().shape

In [None]:
net.module.layers[block_num].bn3.state_dict()['weight'].numpy().shape

In [None]:
# 10 ~ 12번쨰 블록을 늘림. layers[5]
stage_block_weights = []
block_list = range(10,13)
child_net = net_wider
parent_net = net

for block_num in block_list:
    stage_block_weights = []

    if block != 0 : # 첫번째가 아니면
        parent_block_weights = parent_net[]
        child_block_weights = 
        parent_next_block_weight = 
        child_next_block_weight = 
        new_block_weight = 

    # proj c_out, bn3 ~ next exp c_in늘림
    child_proj_weights = child_net.layers[block_num].conv3.state_dict()['weight'].numpy()
    child_bn3_weights = child_net.layers[block_num].conv3.state_dict()['weight'].numpy()
                                 
    parent_proj_weights = parent_net.module.layers[block_num].conv3.state_dict()['weight'].numpy()
    parent_bn3_weights = parent_net.module.layers[block_num].conv3.state_dict()['weight'].numpy()

    if flag:  # 위의 로직이 실행되어 proj c_in 이 늘어난 경우를 처리
        parent_proj_and_bn3_weights[0] = new_block_weights[-2]
        
    child_next_expansion_weight = child_net.layers[block_num].conv3.state_dict()['weight'].numpy()
    parent_next_expansion_weight = parent_net.module.layers[block_num].conv3.state_dict()['weight'].numpy()

    new_weights = proj_wider_cout_expansion_wider_cin(child_proj_and_bn3_weights,
                                                        parent_proj_and_bn3_weights,
                                                        child_next_expansion_weight,
                                                        parent_next_expansion_weight)
    if new_block_weights:
        new_block_weights[-2:] = new_weights[:2]
    else:
        new_block_weights = new_weights[:2]

    stage_block_weights.extend(new_block_weights)

set_block_weights(child_cc, stage_block_weights, stage=True, parent=False, filtered_keys=filtered_keys)

In [None]:
def wider_MBconv_block(block_weights, child_block_weights, next_block_weight=None, child_next_block_weights=None,
                       use_SE=True, next_stage_first_block=False):
    """
    x -> older_conv_block -> x'
    x -> wider_conv_block -> x' + 0.1x'
    returns enlarged weight of blocks and next_block's expansion layer
    : [expansion, bn1, depthwise, b2, [se1, se_bias1, se2, se_bias2], proj, next_block_expansion], [b1, b2, 'se', b3]
    """

    # ex) l1_weights = [{k:List[np.array()]} * 4]
    keys, [l1_weights, bn1, l2_weights, bn2, se1_weights, se1_bias, se2_weights, se2_bias, l3_weights,
           bn3] = decompose_key_val(block_weights)

    child_keys, [child_l1_weights, child_bn1, child_l2_weights, child_bn2, child_se1_weights, child_se1_bias,
                 child_se2_weights, child_se2_bias, child_l3_weights,
                 child_bn3] = decompose_key_val(child_block_weights)

    new_weights = []

    # expansion
    # new_width = int(width_coeff * l1_weights[0].shape[3])
    new_width = child_l1_weights[0].shape[-1]
    rand = np.random.randint(l1_weights[0].shape[-1], size=(new_width - l1_weights[0].shape[-1]))
    replication_factor = np.bincount(rand)
    factor = replication_factor[rand] + 1

    student_w1 = np.array(deepcopy(l1_weights))
    student_bn1 = np.array(deepcopy(bn1))
    student_w2 = np.array(deepcopy(l2_weights))
    student_bn2 = np.array(deepcopy(bn2))

    if use_SE:
        student_se1 = np.array(deepcopy(se1_weights))
        student_se_bias1 = np.array(deepcopy(se1_bias))

        student_se2 = np.array(deepcopy(se2_weights))
        student_se_bias2 = np.array(deepcopy(se2_bias))

    student_w3 = np.array(deepcopy(l3_weights))
    student_bn3 = np.array(deepcopy(bn3))

    # Expansion layer c_out update
    # ex) [(1, 1, 24, 72)*B] + (4, 1, 1, 24, 7)
    new_weight = np.array(student_w1)[:, :, :, :, rand]  #
    student_w1 = np.concatenate((student_w1, new_weight), axis=-1)

    # BN1 update
    # ex) [(72,)*B] + (12,7)
    new_weight = np.array(student_bn1)[:, rand]
    student_bn1 = np.concatenate((student_bn1, new_weight), axis=-1)

    # se2 c_out update
    if use_SE:
        new_weight = np.array(student_se2)[:, :, :, :, rand]
        student_se2 = np.concatenate((student_se2, new_weight), axis=-1)
        student_se_bias2 = np.concatenate((student_se_bias2, np.array(se2_bias)[:, rand]), axis=-1)

    # c_in update : depthwise & proj, se1 if used SE
    new_weight = np.array(student_w2)[:, :, :, rand, :]  # depthwise 는 normalize 필요없
    student_w2 = np.concatenate((student_w2, new_weight), axis=-2)

    # bn2
    new_weight = np.array(student_bn2)[:, rand]
    student_bn2 = np.concatenate((student_bn2, new_weight), axis=-1)

    # proj c_in update
    new_weight = np.array(student_w3)[:, :, :, rand, :] / factor.reshape(-1, 1)
    student_w3 = np.concatenate((student_w3, new_weight), axis=-2)
    student_w3[:, :, :, rand, :] = new_weight

    # se1 c_in update
    if use_SE:
        new_weight = np.array(student_se1)[:, :, :, rand, :] / factor.reshape(-1, 1)
        student_se1 = np.concatenate((student_se1, new_weight), axis=-2)
        student_se1[:, :, :, rand, :] = new_weight

    # se1 c_out, se2 c_in update
    if use_SE:
        # 출력부분을 변경하면 bias 도 바꾼다
        new_width = child_se1_weights[0].shape[-1]
        rand = np.random.randint(se1_weights[0].shape[-1], size=(new_width - se1_weights[0].shape[-1]))
        replication_factor = np.bincount(rand)

        # se1 c_out update
        new_weight = np.array(student_se1)[:, :, :, :, rand]
        student_se1 = np.concatenate((student_se1, new_weight), axis=-1)
        student_se_bias1 = np.concatenate((student_se_bias1, np.array(se1_bias)[:, rand]), axis=-1)

        # se2 c_in update
        factor = replication_factor[rand] + 1

        new_weight = np.array(student_se2)[:, :, :, rand, :] / factor.reshape(-1, 1)
        student_se2 = np.concatenate((student_se2, new_weight), axis=-2)
        student_se2[:, :, :, rand, :] = new_weight
        # student_se_bias2 = np.concatenate((student_se_bias2, np.array(se2_bias)[:, rand]), axis=-1) !! c_in update 할 때는 bias 를 업데이트할 필요가 없다!

    # add changed weight to result_dict : first_conv
    new_weights.append(student_w1.astype('f'))
    new_weights.append(student_bn1.astype('f'))

    # add changed weight to result_dict : depthwise
    new_weights.append(student_w2.astype('f'))
    new_weights.append(student_bn2.astype('f'))

    # SE
    if use_SE:
        new_weights.append(student_se1.astype('f'))
        new_weights.append(student_se_bias1.astype('f'))
        new_weights.append(student_se2.astype('f'))
        new_weights.append(student_se_bias2.astype('f'))

    if next_stage_first_block:
        # proj c_out update
        new_width = child_l3_weights[0].shape[-1]
        rand_proj = np.random.randint(l3_weights[0].shape[-1], size=(new_width - l3_weights[0].shape[-1]))

        new_weight = student_w3[:, :, :, :, rand_proj]
        new_weight *= 1.005
        # new_weight = np.random.randn(*new_weight.shape)

        student_w3 = np.concatenate((student_w3, new_weight), axis=-1)

        bn3_new_weight = np.array(bn3)[:, rand_proj]

        # bn3_new_weight[0:4, :] = np.zeros_like(bn3_new_weight[0:4, :])  # beta
        # bn3_new_weight[4:8, :] = np.ones_like(bn3_new_weight[4:8, :])  # gamma
        # bn3_new_weight[(8, 9), :] = new_weight.mean(axis=1, keepdims=True)  # moving mean, exp_mean
        # bn3_new_weight[(10, 11), :] = new_weight.var(axis=1, keepdims=True)  # moving var, exp_var

        student_bn3 = np.concatenate((student_bn3, bn3_new_weight), axis=-1)

    new_weights.append(student_w3.astype('f'))
    new_weights.append(student_bn3.astype('f'))

    return process_key_weight_to_result(keys, new_weights)

In [None]:
def wider_MBConvBlock(parent_block_weights, child_block_weights):
       new_width = child_block_weights.conv1.out_channels
       rand = np.random.randint(parent_block_weights.conv1.out_channels, size=(new_width - parent_block_weights.conv1.out_channels))
       replication_factor = np.bincount(rand)
       factor = replication_factor[rand] + 1

       # Expansion layer 늘리기
       student_w1 = np.array(deepcopy(l1_weights))
       student_bn1 = np.array(deepcopy(bn1))
       student_w2 = np.array(deepcopy(l2_weights))
       student_bn2 = np.array(deepcopy(bn2))
       # Depthwise layer Normalize

       return process_key_weight_to_result(keys, new_weights)